diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index e36eed7bd..b0cfa1c3d 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -10,7 +10,7 @@ jobs: - name: install Go uses: actions/setup-go@v2 with: - go-version: 1.18.x + go-version: 1.19.x - name: checkout code uses: actions/checkout@v2 with: @@ -43,7 +43,7 @@ jobs: test: strategy: matrix: - go-version: [1.17.x, 1.18.x] + go-version: [1.18.x, 1.19.x] os: [ubuntu-latest] runs-on: ${{ matrix.os }} needs: @@ -68,7 +68,9 @@ jobs: - name: install deps run: go install golang.org/x/tools/cmd/goimports@latest && go install github.com/klauspost/asmfmt/cmd/asmfmt@latest - name: Test - run: go test -p=1 -v -timeout=30m ./... + run: | + go test -p=1 -v -timeout=30m ./... + go test -p=1 -tags=purego -v -timeout=30m ./... - name: Test (32 bits & race) if: (matrix.os == 'ubuntu-latest') && (matrix.go-version == '1.18.x') run: | diff --git a/.golangci.yml b/.golangci.yml index e4dd73da6..091d52e10 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -11,6 +11,7 @@ linters: - gosimple - govet - ineffassign + # - errcheck run: issues-exit-code: 1 \ No newline at end of file diff --git a/README.md b/README.md index c8c368a78..242a7caa1 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ * [`bls12-377`] / [`bw6-761`] * [`bls24-315`] / [`bw6-633`] * [`bls12-378`] / [`bw6-756`] - * Each of these curve has a [`twistededwards`] sub-package with its companion curve which allow efficient elliptic curve cryptography inside zkSNARK circuits. + * Each of these curves has a [`twistededwards`] sub-package with its companion curve which allow efficient elliptic curve cryptography inside zkSNARK circuits. * [`field/goff`] - Finite field arithmetic code generator (blazingly fast big.Int) * [`fft`] - Fast Fourier Transform * [`fri`] - FRI (multiplicative) commitment scheme @@ -46,7 +46,7 @@ go get github.com/consensys/gnark-crypto ``` -Note if that if you use go modules, in `go.mod` the module path is case sensitive (use `consensys` and not `ConsenSys`). +Note that if you use go modules, in `go.mod` the module path is case sensitive (use `consensys` and not `ConsenSys`). ### Development @@ -54,7 +54,7 @@ Most (but not all) of the code is generated from the templates in `internal/gene The generated code contains little to no interfaces and is strongly typed with a field (generated by the `gnark-crypto/field` package). The two main factors driving this design choice are: -1. Performance: `gnark-crypto` algorithms manipulates millions (if not billions) of field elements. Interface indirection at this level, plus garbage collection indexing takes a heavy toll on perf. +1. Performance: `gnark-crypto` algorithms manipulate millions (if not billions) of field elements. Interface indirection at this level, plus garbage collection indexing takes a heavy toll on perf. 2. Need to derive (mostly) identical code for various moduli and curves, with consistent APIs. Generics introduce significant performance overhead and are not yet suited for high performance computing. To regenerate the files, see `internal/generator/main.go`. Run: @@ -117,4 +117,4 @@ This project is licensed under the Apache 2 License - see the [LICENSE](LICENSE) [`kzg`]: https://pkg.go.dev/github.com/consensys/gnark-crypto/ecc/bn254/fr/kzg [`plookup`]: https://pkg.go.dev/github.com/consensys/gnark-crypto/ecc/bn254/fr/plookup [`permutation`]: https://pkg.go.dev/github.com/consensys/gnark-crypto/ecc/bn254/fr/permutation -[`fiatshamir`]: https://pkg.go.dev/github.com/consensys/gnark-crypto/fiat-shamir \ No newline at end of file +[`fiatshamir`]: https://pkg.go.dev/github.com/consensys/gnark-crypto/fiat-shamir diff --git a/accumulator/merkletree/tree.go b/accumulator/merkletree/tree.go index 86e693a69..bb5ec8c72 100644 --- a/accumulator/merkletree/tree.go +++ b/accumulator/merkletree/tree.go @@ -84,7 +84,8 @@ func sum(h hash.Hash, data ...[]byte) []byte { // leafSum returns the hash created from data inserted to form a leaf. Leaf // sums are calculated using: -// Hash(0x00 || data) +// +// Hash(0x00 || data) func leafSum(h hash.Hash, data []byte) []byte { //return sum(h, leafHashPrefix, data) @@ -93,7 +94,8 @@ func leafSum(h hash.Hash, data []byte) []byte { // nodeSum returns the hash created from two sibling nodes being combined into // a parent node. Node sums are calculated using: -// Hash(0x01 || left sibling sum || right sibling sum) +// +// Hash(0x01 || left sibling sum || right sibling sum) func nodeSum(h hash.Hash, a, b []byte) []byte { //return sum(h, nodeHashPrefix, a, b) return sum(h, a, b) diff --git a/ecc/bls12-377/bls12-377.go b/ecc/bls12-377/bls12-377.go index f42bca432..8abee780d 100644 --- a/ecc/bls12-377/bls12-377.go +++ b/ecc/bls12-377/bls12-377.go @@ -1,23 +1,29 @@ // Package bls12377 efficient elliptic curve, pairing and hash to curve implementation for bls12-377. // // bls12-377: A Barreto--Lynn--Scott curve with -// embedding degree k=12 -// seed x₀=9586122913090633729 -// 𝔽r: r=8444461749428370424248824938781546531375899335154063827935233455917409239041 (x₀⁴-x₀²+1) -// 𝔽p: p=258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 ((x₀-1)² ⋅ r(x₀)/3+x₀) -// (E/𝔽p): Y²=X³+1 -// (Eₜ/𝔽p²): Y² = X³+1/u (D-type twist) -// r ∣ #E(Fp) and r ∣ #Eₜ(𝔽p²) +// +// embedding degree k=12 +// seed x₀=9586122913090633729 +// 𝔽r: r=8444461749428370424248824938781546531375899335154063827935233455917409239041 (x₀⁴-x₀²+1) +// 𝔽p: p=258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 ((x₀-1)² ⋅ r(x₀)/3+x₀) +// (E/𝔽p): Y²=X³+1 +// (Eₜ/𝔽p²): Y² = X³+1/u (D-type twist) +// r ∣ #E(Fp) and r ∣ #Eₜ(𝔽p²) +// // Extension fields tower: -// 𝔽p²[u] = 𝔽p/u²+5 -// 𝔽p⁶[v] = 𝔽p²/v³-u -// 𝔽p¹²[w] = 𝔽p⁶/w²-v +// +// 𝔽p²[u] = 𝔽p/u²+5 +// 𝔽p⁶[v] = 𝔽p²/v³-u +// 𝔽p¹²[w] = 𝔽p⁶/w²-v +// // optimal Ate loop size: -// x₀ +// +// x₀ +// // Security: estimated 126-bit level following [https://eprint.iacr.org/2019/885.pdf] // (r is 253 bits and p¹² is 4521 bits) // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package bls12377 diff --git a/ecc/bls12-377/fp/doc.go b/ecc/bls12-377/fp/doc.go index 01751eabd..c4d87f7b0 100644 --- a/ecc/bls12-377/fp/doc.go +++ b/ecc/bls12-377/fp/doc.go @@ -21,30 +21,33 @@ // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [6]uint64 // -// Usage +// type Element [6]uint64 +// +// # Usage // // Example API signature: -// // Mul z = x * y (mod q) -// func (z *Element) Mul(x, y *Element) *Element +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) // // Modulus q = // -// q[base10] = 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 -// q[base16] = 0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001 +// q[base10] = 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 +// q[base16] = 0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package fp diff --git a/ecc/bls12-377/fp/element.go b/ecc/bls12-377/fp/element.go index 9873a533a..b8b7cb80b 100644 --- a/ecc/bls12-377/fp/element.go +++ b/ecc/bls12-377/fp/element.go @@ -20,13 +20,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "github.com/consensys/gnark-crypto/field" "io" "math/big" "math/bits" "reflect" "strconv" "strings" - "sync" ) // Element represents a field element stored on 6 words (uint64) @@ -35,18 +35,18 @@ import ( // // Modulus q = // -// q[base10] = 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 -// q[base16] = 0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001 +// q[base10] = 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 +// q[base16] = 0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [6]uint64 const ( - Limbs = 6 // number of 64 bits words needed to represent a Element - Bits = 377 // number of bits needed to represent a Element - Bytes = Limbs * 8 // number of bytes needed to represent a Element + Limbs = 6 // number of 64 bits words needed to represent a Element + Bits = 377 // number of bits needed to represent a Element + Bytes = 48 // number of bytes needed to represent a Element ) // Field modulus q @@ -72,8 +72,8 @@ var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int // -// q[base10] = 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 -// q[base16] = 0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001 +// q[base10] = 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 +// q[base16] = 0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001 func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } @@ -82,12 +82,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 9586122913090633727 -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001", 16) } @@ -95,8 +89,9 @@ func init() { // NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v Element -// v.SetUint64(...) +// +// var v Element +// v.SetUint64(...) func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) @@ -107,7 +102,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -139,14 +134,15 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported // supported types: -// Element -// *Element -// uint64 -// int -// string (see SetString for valid formats) -// *big.Int -// big.Int -// []byte +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { if i1 == nil { return nil, errors.New("can't set fp.Element with ") @@ -260,15 +256,13 @@ func (z *Element) IsOne() bool { // IsUint64 reports whether z can be represented as an uint64. func (z *Element) IsUint64() bool { zz := *z - zz.FromMont() + zz.fromMont() return zz.FitsOnOneWord() } // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -280,15 +274,12 @@ func (z *Element) FitsOnOneWord() bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[5] > _x[5] { return 1 } else if _z[5] < _x[5] { @@ -329,8 +320,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 4793061456545316865, 0) @@ -429,67 +419,9 @@ func (z *Element) Halve() { } -// Mul z = x * y (mod q) -// -// x and y must be strictly inferior to q -func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be strictly inferior to q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } @@ -505,7 +437,7 @@ func (z *Element) Add(x, y *Element) *Element { z[4], carry = bits.Add64(x[4], y[4], carry) z[5], _ = bits.Add64(x[5], y[5], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -529,7 +461,7 @@ func (z *Element) Double(x *Element) *Element { z[4], carry = bits.Add64(x[4], x[4], carry) z[5], _ = bits.Add64(x[5], x[5], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -592,115 +524,219 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // see Mul for algorithm documentation - var t [6]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd1(v, y[4], c[1]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd1(v, y[5], c[1]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 4 - v := x[4] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 5 - v := x[5] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], z[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], z[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - z[5], z[4] = madd3(m, q5, c[0], c[2], c[1]) - } - - // if z >= q → z -= q + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [7]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + C, t[4] = madd1(y[0], x[4], C) + C, t[5] = madd1(y[0], x[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + C, t[4] = madd2(y[1], x[4], t[4], C) + C, t[5] = madd2(y[1], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + C, t[4] = madd2(y[2], x[4], t[4], C) + C, t[5] = madd2(y[2], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + C, t[4] = madd2(y[3], x[4], t[4], C) + C, t[5] = madd2(y[3], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[4], x[0], t[0]) + C, t[1] = madd2(y[4], x[1], t[1], C) + C, t[2] = madd2(y[4], x[2], t[2], C) + C, t[3] = madd2(y[4], x[3], t[3], C) + C, t[4] = madd2(y[4], x[4], t[4], C) + C, t[5] = madd2(y[4], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[5], x[0], t[0]) + C, t[1] = madd2(y[5], x[1], t[1], C) + C, t[2] = madd2(y[5], x[2], t[2], C) + C, t[3] = madd2(y[5], x[3], t[3], C) + C, t[4] = madd2(y[5], x[4], t[4], C) + C, t[5] = madd2(y[5], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + + if t[6] != 0 { + // we need to reduce, we have a result on 7 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], b = bits.Sub64(t[3], q3, b) + z[4], b = bits.Sub64(t[4], q4, b) + z[5], _ = bits.Sub64(t[5], q5, b) + return + } + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + z[4] = t[4] + z[5] = t[5] + + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -710,7 +746,6 @@ func _mulGeneric(z, x, y *Element) { z[4], b = bits.Sub64(z[4], q4, b) z[5], _ = bits.Sub64(z[5], q5, b) } - } func _fromMontGeneric(z *Element) { @@ -784,7 +819,7 @@ func _fromMontGeneric(z *Element) { z[5] = C } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -798,7 +833,7 @@ func _fromMontGeneric(z *Element) { func _reduceGeneric(z *Element) { - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -870,6 +905,35 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + // Exp z = xᵏ (mod q) func (z *Element) Exp(x Element, k *big.Int) *Element { if k.IsUint64() && k.Uint64() == 0 { @@ -884,8 +948,8 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } @@ -913,23 +977,31 @@ var rSquare = Element{ 30958721782860680, } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *Element) ToMont() *Element { +func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *Element) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[40:48], z[0]) + binary.BigEndian.PutUint64(b[32:40], z[1]) + binary.BigEndian.PutUint64(b[24:32], z[2]) + binary.BigEndian.PutUint64(b[16:24], z[3]) + binary.BigEndian.PutUint64(b[8:16], z[4]) + binary.BigEndian.PutUint64(b[0:8], z[5]) + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -948,51 +1020,49 @@ func (z *Element) Text(base int) string { if base == 10 { var zzNeg Element zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } zz := *z - zz.FromMont() + zz.fromMont() if zz.FitsOnOneWord() { return strconv.FormatUint(zz[0], base) } - vv := bigIntPool.Get().(*big.Int) - r := zz.ToBigInt(vv).Text(base) - bigIntPool.Put(vv) + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) return r } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[40:48], z[0]) - binary.BigEndian.PutUint64(b[32:40], z[1]) - binary.BigEndian.PutUint64(b[24:32], z[2]) - binary.BigEndian.PutUint64(b[16:24], z[3]) - binary.BigEndian.PutUint64(b[8:16], z[4]) - binary.BigEndian.PutUint64(b[0:8], z[5]) - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[40:48], _z[0]) - binary.BigEndian.PutUint64(res[32:40], _z[1]) - binary.BigEndian.PutUint64(res[24:32], _z[2]) - binary.BigEndian.PutUint64(res[16:24], _z[3]) - binary.BigEndian.PutUint64(res[8:16], _z[4]) - binary.BigEndian.PutUint64(res[0:8], _z[5]) +// Bits provides access to z by returning its value as a little-endian [6]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [6]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -1005,19 +1075,44 @@ func (z *Element) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } +// SetBytesCanonical interprets e as the bytes of a big-endian 48-byte integer. +// If e is not a 48-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fp.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -1035,17 +1130,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -1067,20 +1161,20 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z // // The number prefix determines the actual base: A prefix of -// ''0b'' or ''0B'' selects base 2, ''0'', ''0o'' or ''0O'' selects base 8, -// and ''0x'' or ''0X'' selects base 16. Otherwise, the selected base is 10 +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For base 16, lower and upper case letters are considered the same: // The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. // -// An underscore character ''_'' may appear between a base +// An underscore character ”_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as a panic if there @@ -1089,7 +1183,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -1098,7 +1192,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -1138,7 +1232,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -1147,10 +1241,87 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 48-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[40:48]) + z[1] = binary.BigEndian.Uint64((*b)[32:40]) + z[2] = binary.BigEndian.Uint64((*b)[24:32]) + z[3] = binary.BigEndian.Uint64((*b)[16:24]) + z[4] = binary.BigEndian.Uint64((*b)[8:16]) + z[5] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[40:48], e[0]) + binary.BigEndian.PutUint64((*b)[32:40], e[1]) + binary.BigEndian.PutUint64((*b)[24:32], e[2]) + binary.BigEndian.PutUint64((*b)[16:24], e[3]) + binary.BigEndian.PutUint64((*b)[8:16], e[4]) + binary.BigEndian.PutUint64((*b)[0:8], e[5]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + z[4] = binary.LittleEndian.Uint64((*b)[32:40]) + z[5] = binary.LittleEndian.Uint64((*b)[40:48]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) + binary.LittleEndian.PutUint64((*b)[32:40], e[4]) + binary.LittleEndian.PutUint64((*b)[40:48], e[5]) +} + +func (littleEndian) String() string { return "LittleEndian" } + // Legendre returns the Legendre symbol of z (either +1, -1, or 0.) func (z *Element) Legendre() int { var l Element @@ -1183,7 +1354,7 @@ func (z *Element) Sqrt(x *Element) *Element { // y = x^((s+1)/2)) = w * x y.Mul(x, &w) - // b = x^s = w * w * x = y * x + // b = xˢ = w * w * x = y * x b.Mul(&w, &y) // g = nonResidue ^ s @@ -1198,7 +1369,7 @@ func (z *Element) Sqrt(x *Element) *Element { r := uint64(46) // compute legendre symbol - // t = x^((q-1)/2) = r-1 squaring of x^s + // t = x^((q-1)/2) = r-1 squaring of xˢ t = b for i := uint64(0); i < r-1; i++ { t.Square(&t) @@ -1396,7 +1567,7 @@ func (z *Element) Inverse(x *Element) *Element { // we would multiply by pSq up to 13times; // on x86, the assembly routine outperforms generic code for mul by word // on arm64, we may loose up to ~5% for 6 limbs - mul(&v, &v, &a) + v.Mul(&v, &a) } u.Set(x) // for correctness check @@ -1413,17 +1584,28 @@ func (z *Element) Inverse(x *Element) *Element { // correctness check v.Mul(&u, z) if !v.IsOne() && !u.IsZero() { - return z.inverseExp(&u) + return z.inverseExp(u) } return z } // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) -func (z *Element) inverseExp(x *Element) *Element { - qMinusTwo := Modulus() - qMinusTwo.Sub(qMinusTwo, big.NewInt(2)) - return z.Exp(*x, qMinusTwo) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z } // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits @@ -1556,7 +1738,7 @@ func (z *Element) montReduceSigned(x *Element, xHi uint64) { z[5], z[4] = madd2(m, q5, t[i+5], C) } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) diff --git a/ecc/bls12-377/fp/element_mul_adx_amd64.s b/ecc/bls12-377/fp/element_mul_adx_amd64.s deleted file mode 100644 index e2afd074d..000000000 --- a/ecc/bls12-377/fp/element_mul_adx_amd64.s +++ /dev/null @@ -1,835 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x8508c00000000001 -DATA q<>+8(SB)/8, $0x170b5d4430000000 -DATA q<>+16(SB)/8, $0x1ef3622fba094800 -DATA q<>+24(SB)/8, $0x1a22d9f300f5138f -DATA q<>+32(SB)/8, $0xc63b05c06ca1493b -DATA q<>+40(SB)/8, $0x01ae3a4617c510ea -GLOBL q<>(SB), (RODATA+NOPTR), $48 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x8508bfffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), NOSPLIT, $0-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - MOVQ x+8(FP), R8 - - // x[0] -> R10 - // x[1] -> R11 - // x[2] -> R12 - MOVQ 0(R8), R10 - MOVQ 8(R8), R11 - MOVQ 16(R8), R12 - MOVQ y+16(FP), R13 - - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // t[5] -> DI - // clear the flags - XORQ AX, AX - MOVQ 0(R13), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R10, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R11, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R12, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(R8), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(R8), AX, DI - ADOXQ AX, SI - - // (A,t[5]) := x[5]*y[0] + A - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 8(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 16(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 24(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 32(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 40(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R9,R8,R13,R10,R11,R12) - REDUCE(R14,R15,CX,BX,SI,DI,R9,R8,R13,R10,R11,R12) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET - -TEXT ·fromMont(SB), NOSPLIT, $0-8 - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - MOVQ 40(DX), DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12,R13) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET diff --git a/ecc/bls12-377/fp/element_mul_amd64.s b/ecc/bls12-377/fp/element_mul_amd64.s index b32bb9e20..3e7650e5a 100644 --- a/ecc/bls12-377/fp/element_mul_amd64.s +++ b/ecc/bls12-377/fp/element_mul_amd64.s @@ -1,4 +1,4 @@ -// +build !amd64_adx +// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-377/fp/element_ops_amd64.go b/ecc/bls12-377/fp/element_ops_amd64.go index a3c830471..83bba45ae 100644 --- a/ecc/bls12-377/fp/element_ops_amd64.go +++ b/ecc/bls12-377/fp/element_ops_amd64.go @@ -1,3 +1,6 @@ +//go:build !purego +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,7 +38,70 @@ func fromMont(res *Element) func reduce(res *Element) // Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) +// +// a = a + b (mod q) +// b = a - b (mod q) +// //go:noescape func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bls12-377/fp/element_ops_amd64.s b/ecc/bls12-377/fp/element_ops_amd64.s index 5c31cbc7a..7242622a4 100644 --- a/ecc/bls12-377/fp/element_ops_amd64.s +++ b/ecc/bls12-377/fp/element_ops_amd64.s @@ -1,3 +1,5 @@ +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bls12-377/fp/element_ops_noasm.go b/ecc/bls12-377/fp/element_ops_noasm.go deleted file mode 100644 index 78c2dec58..000000000 --- a/ecc/bls12-377/fp/element_ops_noasm.go +++ /dev/null @@ -1,63 +0,0 @@ -//go:build !amd64 -// +build !amd64 - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fp - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 1176283927673829444, - 14130787773971430395, - 11354866436980285261, - 15740727779991009548, - 14951814113394531041, - 33013799364667434, - } - x.Mul(x, &y) -} - -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} diff --git a/ecc/bls12-377/fp/element_ops_purego.go b/ecc/bls12-377/fp/element_ops_purego.go new file mode 100644 index 000000000..a4c3796b9 --- /dev/null +++ b/ecc/bls12-377/fp/element_ops_purego.go @@ -0,0 +1,745 @@ +//go:build !amd64 || purego +// +build !amd64 purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 1176283927673829444, + 14130787773971430395, + 11354866436980285261, + 15740727779991009548, + 14951814113394531041, + 33013799364667434, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + var t0, t1, t2, t3, t4, t5 uint64 + var u0, u1, u2, u3, u4, u5 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + u4, t4 = bits.Mul64(v, y[4]) + u5, t5 = bits.Mul64(v, y[5]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3, t4, t5 uint64 + var u0, u1, u2, u3, u4, u5 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + u4, t4 = bits.Mul64(v, x[4]) + u5, t5 = bits.Mul64(v, x[5]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} diff --git a/ecc/bls12-377/fp/element_test.go b/ecc/bls12-377/fp/element_test.go index 99478e367..d86eb4141 100644 --- a/ecc/bls12-377/fp/element_test.go +++ b/ecc/bls12-377/fp/element_test.go @@ -23,7 +23,7 @@ import ( "math/big" "math/bits" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field" mrand "math/rand" "testing" @@ -182,17 +182,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -642,7 +635,7 @@ func TestElementBitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPairElement) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -757,7 +750,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -770,13 +763,13 @@ func TestElementAdd(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -805,17 +798,17 @@ func TestElementAdd(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -866,7 +859,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -879,13 +872,13 @@ func TestElementSub(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -914,17 +907,17 @@ func TestElementSub(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -975,7 +968,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -988,7 +981,7 @@ func TestElementMul(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -1002,7 +995,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1042,11 +1035,11 @@ func TestElementMul(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -1059,7 +1052,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1111,7 +1104,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1124,14 +1117,14 @@ func TestElementDiv(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1160,18 +1153,18 @@ func TestElementDiv(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1222,7 +1215,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1235,13 +1228,13 @@ func TestElementExp(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1270,17 +1263,17 @@ func TestElementExp(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1325,7 +1318,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1346,14 +1339,14 @@ func TestElementSquare(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1397,7 +1390,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1418,14 +1411,14 @@ func TestElementInverse(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1469,7 +1462,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1490,14 +1483,14 @@ func TestElementSqrt(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1541,7 +1534,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1562,14 +1555,14 @@ func TestElementDouble(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1613,7 +1606,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1634,14 +1627,14 @@ func TestElementNeg(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -2029,7 +2022,7 @@ func TestElementNegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -2162,17 +2155,17 @@ func TestElementFromMont(t *testing.T) { func(a testPairElement) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPairElement) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -2255,7 +2248,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } @@ -2323,7 +2316,7 @@ func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { return field.BigIntMatchUint64Slice(&aIntMod, slice) } -//TODO: Phase out in favor of property based testing +// TODO: Phase out in favor of property based testing func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { if err := z.matchVeryBigInt(aHi, aInt); err != nil { @@ -2367,7 +2360,7 @@ func TestElementInversionCorrectionFactorFormula(t *testing.T) { inversionCorrectionFactorWord4, inversionCorrectionFactorWord5, } - inversionCorrectionFactor.ToBigInt(&refFactorInt) + inversionCorrectionFactor.toBigInt(&refFactorInt) if refFactorInt.Cmp(factorInt) != 0 { t.Error("mismatch") @@ -2409,7 +2402,7 @@ func TestElementInversionCorrectionFactor(t *testing.T) { if !oneInv.Equal(&one) { var i big.Int - oneInv.ToBigIntRegular(&i) // no montgomery + oneInv.BigInt(&i) // no montgomery i.ModInverse(&i, Modulus()) var fac Element fac.setBigInt(&i) // back to montgomery @@ -2558,7 +2551,7 @@ func TestElement0Inverse(t *testing.T) { } } -//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen func TestUpdateFactorSubtraction(t *testing.T) { for i := 0; i < 1000; i++ { @@ -2771,11 +2764,11 @@ func randomizeUpdateFactors() (int64, int64) { func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { var p1 big.Int - x.ToBigInt(&p1) + x.toBigInt(&p1) p1.Mul(&p1, big.NewInt(xC)) var p2 big.Int - y.ToBigInt(&p2) + y.toBigInt(&p2) p2.Mul(&p2, big.NewInt(yC)) p1.Add(&p1, &p2) @@ -2812,7 +2805,7 @@ func montReduce(res *big.Int, x *big.Int) { } func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { - z.ToBigInt(i) + z.toBigInt(i) var upperWord big.Int upperWord.SetUint64(xHi) upperWord.Lsh(&upperWord, Limbs*64) @@ -2830,7 +2823,7 @@ func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { var xInt big.Int - x.ToBigInt(&xInt) + x.toBigInt(&xInt) xInt.Mul(&xInt, big.NewInt(c)) @@ -2841,7 +2834,7 @@ func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, result func approximateRef(x *Element) uint64 { var asInt big.Int - x.ToBigInt(&asInt) + x.toBigInt(&asInt) n := x.BitLen() if n <= 64 { diff --git a/ecc/bls12-377/fr/doc.go b/ecc/bls12-377/fr/doc.go index a03cc4d41..08f1a0ba2 100644 --- a/ecc/bls12-377/fr/doc.go +++ b/ecc/bls12-377/fr/doc.go @@ -21,30 +21,33 @@ // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [4]uint64 // -// Usage +// type Element [4]uint64 +// +// # Usage // // Example API signature: -// // Mul z = x * y (mod q) -// func (z *Element) Mul(x, y *Element) *Element +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) // // Modulus q = // -// q[base10] = 8444461749428370424248824938781546531375899335154063827935233455917409239041 -// q[base16] = 0x12ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001 +// q[base10] = 8444461749428370424248824938781546531375899335154063827935233455917409239041 +// q[base16] = 0x12ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package fr diff --git a/ecc/bls12-377/fr/element.go b/ecc/bls12-377/fr/element.go index 1b60b54a3..3a8fa5495 100644 --- a/ecc/bls12-377/fr/element.go +++ b/ecc/bls12-377/fr/element.go @@ -20,13 +20,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "github.com/consensys/gnark-crypto/field" "io" "math/big" "math/bits" "reflect" "strconv" "strings" - "sync" ) // Element represents a field element stored on 4 words (uint64) @@ -35,18 +35,18 @@ import ( // // Modulus q = // -// q[base10] = 8444461749428370424248824938781546531375899335154063827935233455917409239041 -// q[base16] = 0x12ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001 +// q[base10] = 8444461749428370424248824938781546531375899335154063827935233455917409239041 +// q[base16] = 0x12ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [4]uint64 const ( - Limbs = 4 // number of 64 bits words needed to represent a Element - Bits = 253 // number of bits needed to represent a Element - Bytes = Limbs * 8 // number of bytes needed to represent a Element + Limbs = 4 // number of 64 bits words needed to represent a Element + Bits = 253 // number of bits needed to represent a Element + Bytes = 32 // number of bytes needed to represent a Element ) // Field modulus q @@ -68,8 +68,8 @@ var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int // -// q[base10] = 8444461749428370424248824938781546531375899335154063827935233455917409239041 -// q[base16] = 0x12ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001 +// q[base10] = 8444461749428370424248824938781546531375899335154063827935233455917409239041 +// q[base16] = 0x12ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001 func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } @@ -78,12 +78,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 725501752471715839 -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("12ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001", 16) } @@ -91,8 +85,9 @@ func init() { // NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v Element -// v.SetUint64(...) +// +// var v Element +// v.SetUint64(...) func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) @@ -103,7 +98,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -133,14 +128,15 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported // supported types: -// Element -// *Element -// uint64 -// int -// string (see SetString for valid formats) -// *big.Int -// big.Int -// []byte +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { if i1 == nil { return nil, errors.New("can't set fr.Element with ") @@ -250,15 +246,13 @@ func (z *Element) IsOne() bool { // IsUint64 reports whether z can be represented as an uint64. func (z *Element) IsUint64() bool { zz := *z - zz.FromMont() + zz.fromMont() return zz.FitsOnOneWord() } // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -270,15 +264,12 @@ func (z *Element) FitsOnOneWord() bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[3] > _x[3] { return 1 } else if _z[3] < _x[3] { @@ -309,8 +300,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 9586122913090633729, 0) @@ -401,67 +391,9 @@ func (z *Element) Halve() { } -// Mul z = x * y (mod q) -// -// x and y must be strictly inferior to q -func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be strictly inferior to q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } @@ -475,7 +407,7 @@ func (z *Element) Add(x, y *Element) *Element { z[2], carry = bits.Add64(x[2], y[2], carry) z[3], _ = bits.Add64(x[3], y[3], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -495,7 +427,7 @@ func (z *Element) Double(x *Element) *Element { z[2], carry = bits.Add64(x[2], x[2], carry) z[3], _ = bits.Add64(x[3], x[3], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -548,65 +480,147 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // see Mul for algorithm documentation - var t [4]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [5]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + + if t[4] != 0 { + // we need to reduce, we have a result on 5 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], _ = bits.Sub64(t[3], q3, b) + return } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - z[3], z[2] = madd3(m, q3, c[0], c[2], c[1]) - } - - // if z >= q → z -= q + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -614,7 +628,6 @@ func _mulGeneric(z, x, y *Element) { z[2], b = bits.Sub64(z[2], q2, b) z[3], _ = bits.Sub64(z[3], q3, b) } - } func _fromMontGeneric(z *Element) { @@ -658,7 +671,7 @@ func _fromMontGeneric(z *Element) { z[3] = C } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -670,7 +683,7 @@ func _fromMontGeneric(z *Element) { func _reduceGeneric(z *Element) { - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -734,6 +747,35 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + // Exp z = xᵏ (mod q) func (z *Element) Exp(x Element, k *big.Int) *Element { if k.IsUint64() && k.Uint64() == 0 { @@ -748,8 +790,8 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } @@ -775,23 +817,29 @@ var rSquare = Element{ 81024008013859129, } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *Element) ToMont() *Element { +func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *Element) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[24:32], z[0]) + binary.BigEndian.PutUint64(b[16:24], z[1]) + binary.BigEndian.PutUint64(b[8:16], z[2]) + binary.BigEndian.PutUint64(b[0:8], z[3]) + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -810,47 +858,49 @@ func (z *Element) Text(base int) string { if base == 10 { var zzNeg Element zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } zz := *z - zz.FromMont() + zz.fromMont() if zz.FitsOnOneWord() { return strconv.FormatUint(zz[0], base) } - vv := bigIntPool.Get().(*big.Int) - r := zz.ToBigInt(vv).Text(base) - bigIntPool.Put(vv) + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) return r } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[24:32], z[0]) - binary.BigEndian.PutUint64(b[16:24], z[1]) - binary.BigEndian.PutUint64(b[8:16], z[2]) - binary.BigEndian.PutUint64(b[0:8], z[3]) - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[24:32], _z[0]) - binary.BigEndian.PutUint64(res[16:24], _z[1]) - binary.BigEndian.PutUint64(res[8:16], _z[2]) - binary.BigEndian.PutUint64(res[0:8], _z[3]) +// Bits provides access to z by returning its value as a little-endian [4]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [4]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -863,19 +913,44 @@ func (z *Element) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } +// SetBytesCanonical interprets e as the bytes of a big-endian 32-byte integer. +// If e is not a 32-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fr.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -893,17 +968,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -925,20 +999,20 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z // // The number prefix determines the actual base: A prefix of -// ''0b'' or ''0B'' selects base 2, ''0'', ''0o'' or ''0O'' selects base 8, -// and ''0x'' or ''0X'' selects base 16. Otherwise, the selected base is 10 +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For base 16, lower and upper case letters are considered the same: // The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. // -// An underscore character ''_'' may appear between a base +// An underscore character ”_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as a panic if there @@ -947,7 +1021,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -956,7 +1030,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -996,7 +1070,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -1005,10 +1079,79 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 32-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[24:32]) + z[1] = binary.BigEndian.Uint64((*b)[16:24]) + z[2] = binary.BigEndian.Uint64((*b)[8:16]) + z[3] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[24:32], e[0]) + binary.BigEndian.PutUint64((*b)[16:24], e[1]) + binary.BigEndian.PutUint64((*b)[8:16], e[2]) + binary.BigEndian.PutUint64((*b)[0:8], e[3]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) +} + +func (littleEndian) String() string { return "LittleEndian" } + // Legendre returns the Legendre symbol of z (either +1, -1, or 0.) func (z *Element) Legendre() int { var l Element @@ -1041,7 +1184,7 @@ func (z *Element) Sqrt(x *Element) *Element { // y = x^((s+1)/2)) = w * x y.Mul(x, &w) - // b = x^s = w * w * x = y * x + // b = xˢ = w * w * x = y * x b.Mul(&w, &y) // g = nonResidue ^ s @@ -1054,7 +1197,7 @@ func (z *Element) Sqrt(x *Element) *Element { r := uint64(47) // compute legendre symbol - // t = x^((q-1)/2) = r-1 squaring of x^s + // t = x^((q-1)/2) = r-1 squaring of xˢ t = b for i := uint64(0); i < r-1; i++ { t.Square(&t) @@ -1244,7 +1387,7 @@ func (z *Element) Inverse(x *Element) *Element { // we would multiply by pSq up to 13times; // on x86, the assembly routine outperforms generic code for mul by word // on arm64, we may loose up to ~5% for 6 limbs - mul(&v, &v, &a) + v.Mul(&v, &a) } u.Set(x) // for correctness check @@ -1259,17 +1402,28 @@ func (z *Element) Inverse(x *Element) *Element { // correctness check v.Mul(&u, z) if !v.IsOne() && !u.IsZero() { - return z.inverseExp(&u) + return z.inverseExp(u) } return z } // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) -func (z *Element) inverseExp(x *Element) *Element { - qMinusTwo := Modulus() - qMinusTwo.Sub(qMinusTwo, big.NewInt(2)) - return z.Exp(*x, qMinusTwo) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z } // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits @@ -1368,7 +1522,7 @@ func (z *Element) montReduceSigned(x *Element, xHi uint64) { z[3], z[2] = madd2(m, q3, t[i+3], C) } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) diff --git a/ecc/bls12-377/fr/element_mul_adx_amd64.s b/ecc/bls12-377/fr/element_mul_adx_amd64.s deleted file mode 100644 index 451ed7c65..000000000 --- a/ecc/bls12-377/fr/element_mul_adx_amd64.s +++ /dev/null @@ -1,465 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x0a11800000000001 -DATA q<>+8(SB)/8, $0x59aa76fed0000001 -DATA q<>+16(SB)/8, $0x60b44d1e5c37b001 -DATA q<>+24(SB)/8, $0x12ab655e9a2ca556 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x0a117fffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), NOSPLIT, $0-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -TEXT ·fromMont(SB), NOSPLIT, $0-8 - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET diff --git a/ecc/bls12-377/fr/element_mul_amd64.s b/ecc/bls12-377/fr/element_mul_amd64.s index 1293431f4..dc601e91e 100644 --- a/ecc/bls12-377/fr/element_mul_amd64.s +++ b/ecc/bls12-377/fr/element_mul_amd64.s @@ -1,4 +1,4 @@ -// +build !amd64_adx +// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-377/fr/element_ops_amd64.go b/ecc/bls12-377/fr/element_ops_amd64.go index b37a251b6..e40a9caed 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.go +++ b/ecc/bls12-377/fr/element_ops_amd64.go @@ -1,3 +1,6 @@ +//go:build !purego +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,7 +38,70 @@ func fromMont(res *Element) func reduce(res *Element) // Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) +// +// a = a + b (mod q) +// b = a - b (mod q) +// //go:noescape func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bls12-377/fr/element_ops_amd64.s b/ecc/bls12-377/fr/element_ops_amd64.s index e80d6694d..afe75ff25 100644 --- a/ecc/bls12-377/fr/element_ops_amd64.s +++ b/ecc/bls12-377/fr/element_ops_amd64.s @@ -1,3 +1,5 @@ +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bls12-377/fr/element_ops_noasm.go b/ecc/bls12-377/fr/element_ops_noasm.go deleted file mode 100644 index c3680f932..000000000 --- a/ecc/bls12-377/fr/element_ops_noasm.go +++ /dev/null @@ -1,61 +0,0 @@ -//go:build !amd64 -// +build !amd64 - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 18434640649710993230, - 12067750152132099910, - 14024878721438555919, - 347766975729306096, - } - x.Mul(x, &y) -} - -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} diff --git a/ecc/bls12-377/fr/element_ops_purego.go b/ecc/bls12-377/fr/element_ops_purego.go new file mode 100644 index 000000000..fe434ed61 --- /dev/null +++ b/ecc/bls12-377/fr/element_ops_purego.go @@ -0,0 +1,443 @@ +//go:build !amd64 || purego +// +build !amd64 purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 18434640649710993230, + 12067750152132099910, + 14024878721438555919, + 347766975729306096, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} diff --git a/ecc/bls12-377/fr/element_test.go b/ecc/bls12-377/fr/element_test.go index d6bc801d7..cb00a5c62 100644 --- a/ecc/bls12-377/fr/element_test.go +++ b/ecc/bls12-377/fr/element_test.go @@ -23,7 +23,7 @@ import ( "math/big" "math/bits" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field" mrand "math/rand" "testing" @@ -182,17 +182,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -638,7 +631,7 @@ func TestElementBitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPairElement) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -753,7 +746,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -766,13 +759,13 @@ func TestElementAdd(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -801,17 +794,17 @@ func TestElementAdd(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -862,7 +855,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -875,13 +868,13 @@ func TestElementSub(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -910,17 +903,17 @@ func TestElementSub(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -971,7 +964,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -984,7 +977,7 @@ func TestElementMul(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -998,7 +991,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1038,11 +1031,11 @@ func TestElementMul(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -1055,7 +1048,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1107,7 +1100,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1120,14 +1113,14 @@ func TestElementDiv(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1156,18 +1149,18 @@ func TestElementDiv(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1218,7 +1211,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1231,13 +1224,13 @@ func TestElementExp(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1266,17 +1259,17 @@ func TestElementExp(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1321,7 +1314,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1342,14 +1335,14 @@ func TestElementSquare(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1393,7 +1386,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1414,14 +1407,14 @@ func TestElementInverse(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1465,7 +1458,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1486,14 +1479,14 @@ func TestElementSqrt(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1537,7 +1530,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1558,14 +1551,14 @@ func TestElementDouble(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1609,7 +1602,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1630,14 +1623,14 @@ func TestElementNeg(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -2025,7 +2018,7 @@ func TestElementNegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -2158,17 +2151,17 @@ func TestElementFromMont(t *testing.T) { func(a testPairElement) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPairElement) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -2247,7 +2240,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } @@ -2309,7 +2302,7 @@ func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { return field.BigIntMatchUint64Slice(&aIntMod, slice) } -//TODO: Phase out in favor of property based testing +// TODO: Phase out in favor of property based testing func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { if err := z.matchVeryBigInt(aHi, aInt); err != nil { @@ -2351,7 +2344,7 @@ func TestElementInversionCorrectionFactorFormula(t *testing.T) { inversionCorrectionFactorWord2, inversionCorrectionFactorWord3, } - inversionCorrectionFactor.ToBigInt(&refFactorInt) + inversionCorrectionFactor.toBigInt(&refFactorInt) if refFactorInt.Cmp(factorInt) != 0 { t.Error("mismatch") @@ -2393,7 +2386,7 @@ func TestElementInversionCorrectionFactor(t *testing.T) { if !oneInv.Equal(&one) { var i big.Int - oneInv.ToBigIntRegular(&i) // no montgomery + oneInv.BigInt(&i) // no montgomery i.ModInverse(&i, Modulus()) var fac Element fac.setBigInt(&i) // back to montgomery @@ -2538,7 +2531,7 @@ func TestElement0Inverse(t *testing.T) { } } -//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen func TestUpdateFactorSubtraction(t *testing.T) { for i := 0; i < 1000; i++ { @@ -2751,11 +2744,11 @@ func randomizeUpdateFactors() (int64, int64) { func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { var p1 big.Int - x.ToBigInt(&p1) + x.toBigInt(&p1) p1.Mul(&p1, big.NewInt(xC)) var p2 big.Int - y.ToBigInt(&p2) + y.toBigInt(&p2) p2.Mul(&p2, big.NewInt(yC)) p1.Add(&p1, &p2) @@ -2792,7 +2785,7 @@ func montReduce(res *big.Int, x *big.Int) { } func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { - z.ToBigInt(i) + z.toBigInt(i) var upperWord big.Int upperWord.SetUint64(xHi) upperWord.Lsh(&upperWord, Limbs*64) @@ -2810,7 +2803,7 @@ func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { var xInt big.Int - x.ToBigInt(&xInt) + x.toBigInt(&xInt) xInt.Mul(&xInt, big.NewInt(c)) @@ -2821,7 +2814,7 @@ func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, result func approximateRef(x *Element) uint64 { var asInt big.Int - x.ToBigInt(&asInt) + x.toBigInt(&asInt) n := x.BitLen() if n <= 64 { diff --git a/ecc/bls12-377/fr/fri/fri.go b/ecc/bls12-377/fr/fri/fri.go index 02dc3c15a..491aa51aa 100644 --- a/ecc/bls12-377/fr/fri/fri.go +++ b/ecc/bls12-377/fr/fri/fri.go @@ -121,7 +121,6 @@ type Round struct { // a function is d-close to a low degree polynomial. // // It is composed of a series of Interactions, emulated with Fiat Shamir, -// type ProofOfProximity struct { // ID unique ID attached to the proof of proximity. It's needed for diff --git a/ecc/bls12-377/fr/gkr/gkr.go b/ecc/bls12-377/fr/gkr/gkr.go new file mode 100644 index 000000000..324b80b21 --- /dev/null +++ b/ecc/bls12-377/fr/gkr/gkr.go @@ -0,0 +1,774 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/sumcheck" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// Gate must be a low-degree polynomial +type Gate interface { + Evaluate(...fr.Element) fr.Element + Degree() int +} + +type Wire struct { + Gate Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]fr.Element + claimedEvaluations []fr.Element + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]fr.Element) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]fr.Element // x in the paper + claimedEvaluations []fr.Element // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + newEq.Eq(c.evaluationPoints[k]) + + eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// computeValAndStep returns val : i ↦ m(1, i...) and step : i ↦ m(1, i...) - m(0, i...) +func computeValAndStep(m polynomial.MultiLin, p *polynomial.Pool) (val polynomial.MultiLin, step polynomial.MultiLin) { + val = p.Clone(m[len(m)/2:]) + step = p.Clone(m[:len(m)/2]) + + valAsPoly, stepAsPoly := polynomial.Polynomial(val), polynomial.Polynomial(step) + + stepAsPoly.Sub(valAsPoly, stepAsPoly) + return +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = g_{j-1}(r_{j-1}). By convention, g_0 is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() (gJ polynomial.Polynomial) { + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + EVal, EStep := computeValAndStep(c.eq, c.manager.memPool) + + puVal := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO: Make a two-dimensional array struct, and index it i-first rather than inputI first: would result in scanning memory access in the "d" loop and obviate the gateInput variable + puStep := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO, ctd: the greater degGJ, the more this would matter + + for i, puI := range c.inputPreprocessors { + puVal[i], puStep[i] = computeValAndStep(puI, c.manager.memPool) + } + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + gJ = make([]fr.Element, degGJ) + + parallel := len(EVal) >= 1024 //TODO: Experiment with threshold + + var gateInput [][]fr.Element + + if parallel { + gateInput = [][]fr.Element{c.manager.memPool.Make(len(c.inputPreprocessors)), + c.manager.memPool.Make(len(c.inputPreprocessors))} + } else { + gateInput = [][]fr.Element{c.manager.memPool.Make(len(c.inputPreprocessors))} + } + + var wg sync.WaitGroup + + for d := 0; d < degGJ; d++ { + + notLastIteration := d+1 < degGJ + + sumOverI := func(res *fr.Element, gateInput []fr.Element, start, end int) { + for i := start; i < end; i++ { + + for inputI := range puVal { + gateInput[inputI].Set(&puVal[inputI][i]) + if notLastIteration { + puVal[inputI][i].Add(&puVal[inputI][i], &puStep[inputI][i]) + } + } + + // gJAtDI = gJ(d, i...) + gJAtDI := c.wire.Gate.Evaluate(gateInput...) + gJAtDI.Mul(&gJAtDI, &EVal[i]) + + res.Add(res, &gJAtDI) + + if notLastIteration { + EVal[i].Add(&EVal[i], &EStep[i]) + } + } + wg.Done() + } + + if parallel { + var firstHalf, secondHalf fr.Element + wg.Add(2) + go sumOverI(&secondHalf, gateInput[1], len(EVal)/2, len(EVal)) + go sumOverI(&firstHalf, gateInput[0], 0, len(EVal)/2) + wg.Wait() + gJ[d].Add(&firstHalf, &secondHalf) + } else { + wg.Add(1) // formalities + sumOverI(&gJ[d], gateInput[0], 0, len(EVal)) + } + } + + c.manager.memPool.Dump(gateInput...) + c.manager.memPool.Dump(EVal, EStep) + + for inputI := range puVal { + c.manager.memPool.Dump(puVal[inputI], puStep[inputI]) + } + + return +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { + c.eq.Fold(element) + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, pool *polynomial.Pool) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = pool + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< b { + return a + } + return b +} + +func ChallengeNames(sorted []*Wire, logNbInstances int, prefix string) []string { + + // Pre-compute the size TODO: Consider not doing this and just grow the list by appending + size := logNbInstances // first challenge + + for _, w := range sorted { + if w.noProof() { // no proof, no challenge + continue + } + if w.NbClaims() > 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + + claims := newClaimsManager(c, assignment, o.pool) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + + claims := newClaimsManager(c, assignment, o.pool) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]fr.Element) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return fmt.Errorf("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return fmt.Errorf("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +type IdentityGate struct{} + +func (IdentityGate) Evaluate(input ...fr.Element) fr.Element { + return input[0] +} + +func (IdentityGate) Degree() int { + return 1 +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = IdentityGate{} + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + + numEvaluations := 0 + + for _, w := range sortedWires { + if !w.IsInput() { + if numEvaluations == 0 { + numEvaluations = len(a[w.Inputs[0]]) + } + evals := make([]fr.Element, numEvaluations) + ins := make([]fr.Element, len(w.Inputs)) + for k := 0; k < numEvaluations; k++ { + for inI, in := range w.Inputs { + ins[inI] = a[in][k] + } + evals[k] = w.Gate.Evaluate(ins...) + } + a[w] = evals + } + } + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} diff --git a/ecc/bls12-377/fr/gkr/gkr_test.go b/ecc/bls12-377/fr/gkr/gkr_test.go new file mode 100644 index 000000000..f402f1cad --- /dev/null +++ b/ecc/bls12-377/fr/gkr/gkr_test.go @@ -0,0 +1,722 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/sumcheck" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + pool := polynomial.NewPool(256, 1<<11) + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, &pool) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandom(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/ConsenSys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandom(slice []fr.Element) { + for i := range slice { + slice[i].SetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, testCase.transcriptSetting()) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, testCase.transcriptSetting()) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(&test_vector_utils.MapHash{Map: testCase.Hash}, []byte{1})) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "../../../../internal/generator/gkr/test_vectors" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func BenchmarkGkrMimc(b *testing.B) { + const N = 1 << 19 + fmt.Println("creating circuit structure") + c := mimcCircuit(91) + + in0 := make([]fr.Element, N) + in1 := make([]fr.Element, N) + setRandom(in0) + setRandom(in1) + + fmt.Println("evaluating circuit") + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + + //b.ResetTimer() + fmt.Println("constructing proof") + Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate string `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = gates[c[i].Gate] + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +var gates map[string]Gate + +func init() { + gates = make(map[string]Gate) + gates["identity"] = IdentityGate{} + gates["mul"] = mulGate{} + gates["mimc"] = mimcCipherGate{} //TODO: Add ark + gates["select-input-3"] = _select(2) +} + +type mimcCipherGate struct { + ark fr.Element +} + +func (m mimcCipherGate) Evaluate(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]). + Add(&sum, &m.ark) + + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +func (m mimcCipherGate) Degree() int { + return 7 +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash *test_vector_utils.ElementMap + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash string `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash *test_vector_utils.ElementMap + if _hash, err = test_vector_utils.ElementMapFromFile(filepath.Join(dir, info.Hash)); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func (c *TestCase) transcriptSetting(initialChallenge ...[]byte) fiatshamir.Settings { + return fiatshamir.WithHash(&test_vector_utils.MapHash{Map: c.Hash}, initialChallenge...) +} + +type mulGate struct{} + +func (g mulGate) Evaluate(element ...fr.Element) (result fr.Element) { + result.Mul(&element[0], &element[1]) + return +} + +func (g mulGate) Degree() int { + return 2 +} + +type _select int + +func (g _select) Evaluate(in ...fr.Element) fr.Element { + return in[g] +} + +func (g _select) Degree() int { + return 1 +} diff --git a/ecc/bls12-377/fr/kzg/kzg.go b/ecc/bls12-377/fr/kzg/kzg.go index ea8ee346e..d1ad98ad7 100644 --- a/ecc/bls12-377/fr/kzg/kzg.go +++ b/ecc/bls12-377/fr/kzg/kzg.go @@ -84,9 +84,6 @@ func NewSRS(size uint64, bAlpha *big.Int) (*SRS, error) { for i := 1; i < len(alphas); i++ { alphas[i].Mul(&alphas[i-1], &alpha) } - for i := 0; i < len(alphas); i++ { - alphas[i].FromMont() - } g1s := bls12377.BatchScalarMultiplicationG1(&gen1Aff, alphas) copy(srs.G1[1:], g1s) @@ -125,7 +122,7 @@ func Commit(p []fr.Element, srs *SRS, nbTasks ...int) (Digest, error) { var res bls12377.G1Affine - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} if len(nbTasks) > 0 { config.NbTasks = nbTasks[0] } @@ -395,7 +392,7 @@ func BatchVerifyMultiPoints(digests []Digest, proofs []OpeningProof, points []fr for i := 0; i < len(randomNumbers); i++ { quotients[i].Set(&proofs[i].H) } - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} _, err := foldedQuotients.MultiExp(quotients, randomNumbers, config) if err != nil { return nil @@ -478,7 +475,7 @@ func fold(di []Digest, fai []fr.Element, ci []fr.Element) (Digest, fr.Element, e // fold the digests ∑ᵢ[cᵢ]([fᵢ(α)]G₁) var foldedDigests Digest - _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{ScalarsMont: true}) + _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{}) if err != nil { return foldedDigests, foldedEvaluations, err } diff --git a/ecc/bls12-377/fr/mimc/decompose.go b/ecc/bls12-377/fr/mimc/decompose.go new file mode 100644 index 000000000..a51138765 --- /dev/null +++ b/ecc/bls12-377/fr/mimc/decompose.go @@ -0,0 +1,46 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" +) + +// Decompose interpret rawBytes as a bigInt x in big endian, +// and returns the digits of x (from LSB to MSB) when x is written +// in basis modulo. +func Decompose(rawBytes []byte) []fr.Element { + + rawBigInt := big.NewInt(0).SetBytes(rawBytes) + modulo := fr.Modulus() + + // maximum number of chunks that a function + maxNbChunks := len(rawBytes) / fr.Bytes + + res := make([]fr.Element, 0, maxNbChunks) + var tmp fr.Element + t := new(big.Int) + for rawBigInt.Sign() != 0 { + rawBigInt.DivMod(rawBigInt, modulo, t) + tmp.SetBigInt(t) + res = append(res, tmp) + } + + return res +} diff --git a/ecc/bls12-377/fr/mimc/decompose_test.go b/ecc/bls12-377/fr/mimc/decompose_test.go new file mode 100644 index 000000000..937192ced --- /dev/null +++ b/ecc/bls12-377/fr/mimc/decompose_test.go @@ -0,0 +1,53 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" +) + +func TestDecompose(t *testing.T) { + + // create 10 random digits in basis r + nbDigits := 10 + a := make([]fr.Element, nbDigits) + for i := 0; i < nbDigits; i++ { + a[i].SetRandom() + } + + // create a big int whose digits in basis r are a + m := fr.Modulus() + var b, tmp big.Int + for i := nbDigits - 1; i >= 0; i-- { + b.Mul(&b, m) + a[i].ToBigIntRegular(&tmp) + b.Add(&b, &tmp) + } + + // query the decomposition and compare to a + bb := b.Bytes() + d := Decompose(bb) + for i := 0; i < nbDigits; i++ { + if !d[i].Equal(&a[i]) { + t.Fatal("error decomposition") + } + } + +} diff --git a/ecc/bls12-377/fr/mimc/mimc.go b/ecc/bls12-377/fr/mimc/mimc.go index 2fd03ea0c..0ce9f4ec0 100644 --- a/ecc/bls12-377/fr/mimc/mimc.go +++ b/ecc/bls12-377/fr/mimc/mimc.go @@ -17,6 +17,7 @@ package mimc import ( + "errors" "hash" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" @@ -91,44 +92,44 @@ func (d *digest) BlockSize() int { } // Write (via the embedded io.Writer interface) adds more data to the running hash. -// It never returns an error. +// +// Each []byte block of size BlockSize represents a big endian fr.Element. +// +// If len(p) is not a multiple of BlockSize and any of the []byte in p represent an integer +// larger than fr.Modulus, this function returns an error. +// +// To hash arbitrary data ([]byte not representing canonical field elements) use Decompose +// function in this package. func (d *digest) Write(p []byte) (n int, err error) { n = len(p) + if n%BlockSize != 0 { + return 0, errors.New("invalid input length: must represent a list of field elements, expects a []byte of len m*BlockSize") + } + + // ensure each block represents a field element in canonical reduced form + for i := 0; i < n; i += BlockSize { + if _, err = fr.BigEndian.Element((*[BlockSize]byte)(p[i : i+BlockSize])); err != nil { + return 0, err + } + } + d.data = append(d.data, p...) return } -// Hash hash using Miyaguchi–Preneel: +// Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form func (d *digest) checksum() fr.Element { + // Write guarantees len(data) % BlockSize == 0 - var buffer [BlockSize]byte - var x fr.Element - - // if data size is not multiple of BlockSizes we padd: - // .. || 0xaf8 -> .. || 0x0000...0af8 - if len(d.data)%BlockSize != 0 { - q := len(d.data) / BlockSize - r := len(d.data) % BlockSize - sliceq := make([]byte, q*BlockSize) - copy(sliceq, d.data) - slicer := make([]byte, r) - copy(slicer, d.data[q*BlockSize:]) - sliceremainder := make([]byte, BlockSize-r) - d.data = append(sliceq, sliceremainder...) - d.data = append(d.data, slicer...) - } - + // TODO @ThomasPiellard shouldn't Sum() returns an error if there is no data? if len(d.data) == 0 { - d.data = make([]byte, 32) + d.data = make([]byte, BlockSize) } - nbChunks := len(d.data) / BlockSize - - for i := 0; i < nbChunks; i++ { - copy(buffer[:], d.data[i*BlockSize:(i+1)*BlockSize]) - x.SetBytes(buffer[:]) + for i := 0; i < len(d.data); i += BlockSize { + x, _ := fr.BigEndian.Element((*[BlockSize]byte)(d.data[i : i+BlockSize])) r := d.encrypt(x) d.h.Add(&r, &d.h).Add(&d.h, &x) } diff --git a/ecc/bls12-377/fr/pedersen/pedersen.go b/ecc/bls12-377/fr/pedersen/pedersen.go new file mode 100644 index 000000000..18c130742 --- /dev/null +++ b/ecc/bls12-377/fr/pedersen/pedersen.go @@ -0,0 +1,113 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package pedersen + +import ( + "crypto/rand" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "math/big" +) + +// Key for proof and verification +type Key struct { + g bls12377.G2Affine // TODO @tabaie: does this really have to be randomized? + gRootSigmaNeg bls12377.G2Affine //gRootSigmaNeg = g^{-1/σ} + basis []bls12377.G1Affine + basisExpSigma []bls12377.G1Affine +} + +func randomOnG2() (bls12377.G2Affine, error) { // TODO: Add to G2.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return bls12377.G2Affine{}, err + } + return bls12377.HashToG2(gBytes, []byte("random on g2")) +} + +func Setup(basis []bls12377.G1Affine) (Key, error) { + var ( + k Key + err error + ) + + if k.g, err = randomOnG2(); err != nil { + return k, err + } + + var modMinusOne big.Int + modMinusOne.Sub(fr.Modulus(), big.NewInt(1)) + var sigma *big.Int + if sigma, err = rand.Int(rand.Reader, &modMinusOne); err != nil { + return k, err + } + sigma.Add(sigma, big.NewInt(1)) + + var sigmaInvNeg big.Int + sigmaInvNeg.ModInverse(sigma, fr.Modulus()) + sigmaInvNeg.Sub(fr.Modulus(), &sigmaInvNeg) + k.gRootSigmaNeg.ScalarMultiplication(&k.g, &sigmaInvNeg) + + k.basisExpSigma = make([]bls12377.G1Affine, len(basis)) + for i := range basis { + k.basisExpSigma[i].ScalarMultiplication(&basis[i], sigma) + } + + k.basis = basis + return k, err +} + +func (k *Key) Commit(values []fr.Element) (commitment bls12377.G1Affine, knowledgeProof bls12377.G1Affine, err error) { + + if len(values) != len(k.basis) { + err = fmt.Errorf("unexpected number of values") + return + } + + // TODO @gbotrel this will spawn more than one task, see + // https://github.com/ConsenSys/gnark-crypto/issues/269 + config := ecc.MultiExpConfig{ + NbTasks: 1, // TODO Experiment + } + + if _, err = commitment.MultiExp(k.basis, values, config); err != nil { + return + } + + _, err = knowledgeProof.MultiExp(k.basisExpSigma, values, config) + + return +} + +// VerifyKnowledgeProof checks if the proof of knowledge is valid +func (k *Key) VerifyKnowledgeProof(commitment bls12377.G1Affine, knowledgeProof bls12377.G1Affine) error { + + if !commitment.IsInSubGroup() || !knowledgeProof.IsInSubGroup() { + return fmt.Errorf("subgroup check failed") + } + + product, err := bls12377.Pair([]bls12377.G1Affine{commitment, knowledgeProof}, []bls12377.G2Affine{k.g, k.gRootSigmaNeg}) + if err != nil { + return err + } + if product.IsOne() { + return nil + } + return fmt.Errorf("proof rejected") +} diff --git a/ecc/bls12-377/fr/pedersen/pedersen_test.go b/ecc/bls12-377/fr/pedersen/pedersen_test.go new file mode 100644 index 000000000..51a14eef4 --- /dev/null +++ b/ecc/bls12-377/fr/pedersen/pedersen_test.go @@ -0,0 +1,90 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package pedersen + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-377" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/stretchr/testify/assert" + "math/rand" + "testing" +) + +func interfaceSliceToFrSlice(t *testing.T, values ...interface{}) []fr.Element { + res := make([]fr.Element, len(values)) + for i, v := range values { + _, err := res[i].SetInterface(v) + assert.NoError(t, err) + } + return res +} + +func randomFrSlice(t *testing.T, size int) []interface{} { + res := make([]interface{}, size) + var err error + for i := range res { + var v fr.Element + res[i], err = v.SetRandom() + assert.NoError(t, err) + } + return res +} + +func randomOnG1() (bls12377.G1Affine, error) { // TODO: Add to G1.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return bls12377.G1Affine{}, err + } + return bls12377.HashToG1(gBytes, []byte("random on g2")) +} + +func testCommit(t *testing.T, values ...interface{}) { + + basis := make([]bls12377.G1Affine, len(values)) + for i := range basis { + var err error + basis[i], err = randomOnG1() + assert.NoError(t, err) + } + + var ( + key Key + err error + commitment, pok bls12377.G1Affine + ) + + key, err = Setup(basis) + assert.NoError(t, err) + commitment, pok, err = key.Commit(interfaceSliceToFrSlice(t, values...)) + assert.NoError(t, err) + assert.NoError(t, key.VerifyKnowledgeProof(commitment, pok)) + + pok.Neg(&pok) + assert.NotNil(t, key.VerifyKnowledgeProof(commitment, pok)) +} + +func TestCommitToOne(t *testing.T) { + testCommit(t, 1) +} + +func TestCommitSingle(t *testing.T) { + testCommit(t, randomFrSlice(t, 1)...) +} + +func TestCommitFiveElements(t *testing.T) { + testCommit(t, randomFrSlice(t, 5)...) +} diff --git a/ecc/bls12-377/fr/plookup/vector.go b/ecc/bls12-377/fr/plookup/vector.go index 27e11697f..b8a4bf724 100644 --- a/ecc/bls12-377/fr/plookup/vector.go +++ b/ecc/bls12-377/fr/plookup/vector.go @@ -125,7 +125,8 @@ func evaluateAccumulationPolynomial(lf, lt, lh1, lh2 []fr.Element, beta, gamma f // evaluateNumBitReversed computes the evaluation (shifted, bit reversed) of h where // h = (x-1)*z*(1+\beta)*(\gamma+f)*(\gamma(1+\beta) + t+ \beta*t(gX)) - -// (x-1)*z(gX)*(\gamma(1+\beta) + h_{1} + \beta*h_{1}(gX))*(\gamma(1+\beta) + h_{2} + \beta*h_{2}(gX) ) +// +// (x-1)*z(gX)*(\gamma(1+\beta) + h_{1} + \beta*h_{1}(gX))*(\gamma(1+\beta) + h_{2} + \beta*h_{2}(gX) ) // // * cz, ch1, ch2, ct, cf are the polynomials z, h1, h2, t, f in canonical basis // * _lz, _lh1, _lh2, _lt, _lf are the polynomials z, h1, h2, t, f in shifted Lagrange basis (domainBig) @@ -370,7 +371,6 @@ func computeQuotientCanonical(alpha fr.Element, lh, lh0, lhn, lh1h2 []fr.Element // before generating a lookup proof), the commitment needs to be done on the // table sorted. Otherwise the commitment in proof.t will not be the same as // the public commitment: it will contain the same values, but permuted. -// func ProveLookupVector(srs *kzg.SRS, f, t Table) (ProofLookupVector, error) { // res diff --git a/ecc/bls12-377/fr/polynomial/multilin.go b/ecc/bls12-377/fr/polynomial/multilin.go index f7ad79d4b..7fb2c2590 100644 --- a/ecc/bls12-377/fr/polynomial/multilin.go +++ b/ecc/bls12-377/fr/polynomial/multilin.go @@ -18,6 +18,7 @@ package polynomial import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "math/bits" ) // MultiLin tracks the values of a (dense i.e. not sparse) multilinear polynomial @@ -46,44 +47,62 @@ func (m *MultiLin) Fold(r fr.Element) { *m = (*m)[:mid] } +func (m MultiLin) Sum() fr.Element { + s := m[0] + for i := 1; i < len(m); i++ { + s.Add(&s, &m[i]) + } + return s +} + +func _clone(m MultiLin, p *Pool) MultiLin { + if p == nil { + return m.Clone() + } else { + return p.Clone(m) + } +} + +func _dump(m MultiLin, p *Pool) { + if p != nil { + p.Dump(m) + } +} + // Evaluate extrapolate the value of the multilinear polynomial corresponding to m // on the given coordinates -func (m MultiLin) Evaluate(coordinates []fr.Element) fr.Element { +func (m MultiLin) Evaluate(coordinates []fr.Element, p *Pool) fr.Element { // Folding is a mutating operation - bkCopy := m.Clone() + bkCopy := _clone(m, p) // Evaluate step by step through repeated folding (i.e. evaluation at the first remaining variable) for _, r := range coordinates { bkCopy.Fold(r) } - return bkCopy[0] + result := bkCopy[0] + + _dump(bkCopy, p) + return result } -// Clone creates a deep copy of a book-keeping table. +// Clone creates a deep copy of a bookkeeping table. // Both multilinear interpolation and sumcheck require folding an underlying // array, but folding changes the array. To do both one requires a deep copy -// of the book-keeping table. +// of the bookkeeping table. func (m MultiLin) Clone() MultiLin { - tableDeepCopy := Make(len(m)) - copy(tableDeepCopy, m) - return tableDeepCopy + res := make(MultiLin, len(m)) + copy(res, m) + return res } // Add two bookKeepingTables func (m *MultiLin) Add(left, right MultiLin) { size := len(left) // Check that left and right have the same size - if len(right) != size { - panic("Left and right do not have the right size") + if len(right) != size || len(*m) != size { + panic("left, right and destination must have the right size") } - // Reallocate the table if necessary - if cap(*m) < size { - *m = make([]fr.Element, size) - } - - // Resize the destination table - *m = (*m)[:size] // Add elementwise for i := 0; i < size; i++ { @@ -93,15 +112,17 @@ func (m *MultiLin) Add(left, right MultiLin) { // EvalEq computes Eq(q₁, ... , qₙ, h₁, ... , hₙ) = Π₁ⁿ Eq(qᵢ, hᵢ) // where Eq(x,y) = xy + (1-x)(1-y) = 1 - x - y + xy + xy interpolates -// _________________ -// | | | -// | 0 | 1 | -// |_______|_______| -// y | | | -// | 1 | 0 | -// |_______|_______| // -// x +// _________________ +// | | | +// | 0 | 1 | +// |_______|_______| +// y | | | +// | 1 | 0 | +// |_______|_______| +// +// x +// // In other words the polynomial evaluated here is the multilinear extrapolation of // one that evaluates to q' == h' for vectors q', h' of binary values func EvalEq(q, h []fr.Element) fr.Element { @@ -128,10 +149,7 @@ func (m *MultiLin) Eq(q []fr.Element) { n := len(q) if len(*m) != 1< 0 { - i.Sub(fr.Modulus(), &i) - i.Neg(&i) +func (p Polynomial) SetZero() { + for i := 0; i < len(p); i++ { + p[i].SetZero() } - return i } func (p Polynomial) Text(base int) string { @@ -171,12 +175,13 @@ func (p Polynomial) Text(base int) string { continue } - i := signedBigInt(&p[d]) + pD := p[d] + pDText := pD.Text(base) initialLen := builder.Len() - if i.Sign() < 1 { - i.Neg(&i) + if pDText[0] == '-' { + pDText = pDText[1:] if first { builder.WriteString("-") } else { @@ -188,13 +193,8 @@ func (p Polynomial) Text(base int) string { first = false - asInt64 := int64(0) - if i.IsInt64() { - asInt64 = i.Int64() - } - - if asInt64 != 1 || d == 0 { - builder.WriteString(i.Text(base)) + if !pD.IsOne() || d == 0 { + builder.WriteString(pDText) } if builder.Len()-initialLen > 10 { diff --git a/ecc/bls12-377/fr/polynomial/polynomial_test.go b/ecc/bls12-377/fr/polynomial/polynomial_test.go index f1d623be8..eab8d3797 100644 --- a/ecc/bls12-377/fr/polynomial/polynomial_test.go +++ b/ecc/bls12-377/fr/polynomial/polynomial_test.go @@ -17,10 +17,10 @@ package polynomial import ( + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/stretchr/testify/assert" "math/big" "testing" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" ) func TestPolynomialEval(t *testing.T) { @@ -206,3 +206,13 @@ func TestPolynomialAdd(t *testing.T) { t.Fatal("side effect, _f2 should not have been modified") } } + +func TestPolynomialText(t *testing.T) { + var one, negTwo fr.Element + one.SetOne() + negTwo.SetInt64(-2) + + p := Polynomial{one, negTwo, one} + + assert.Equal(t, "X² - 2X + 1", p.Text(10)) +} diff --git a/ecc/bls12-377/fr/polynomial/pool.go b/ecc/bls12-377/fr/polynomial/pool.go index 447059759..00180a8d8 100644 --- a/ecc/bls12-377/fr/polynomial/pool.go +++ b/ecc/bls12-377/fr/polynomial/pool.go @@ -17,114 +17,187 @@ package polynomial import ( + "encoding/json" "fmt" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "reflect" + "runtime" + "sort" "sync" "unsafe" ) // Memory management for polynomials -// Copied verbatim from gkr repo +// WARNING: This is not thread safe TODO: Make sure that is not a problem +// TODO: There is a lot of "unsafe" memory management here and needs to be vetted thoroughly -// Sets a maximum for the array size we keep in pool -const maxNForLargePool int = 1 << 24 -const maxNForSmallPool int = 256 +type sizedPool struct { + maxN int + pool sync.Pool + stats poolStats +} + +type inUseData struct { + allocatedFor []uintptr + pool *sizedPool +} -// Aliases because it is annoying to use arrays in all the places -type largeArr = [maxNForLargePool]fr.Element -type smallArr = [maxNForSmallPool]fr.Element +type Pool struct { + //lock sync.Mutex + inUse map[*fr.Element]inUseData + subPools []sizedPool +} -var rC = sync.Map{} +func (p *sizedPool) get(n int) *fr.Element { + p.stats.maake(n) + return p.pool.Get().(*fr.Element) +} -var ( - largePool = sync.Pool{ - New: func() interface{} { - var res largeArr - return &res - }, - } - smallPool = sync.Pool{ - New: func() interface{} { - var res smallArr - return &res - }, +func (p *sizedPool) put(ptr *fr.Element) { + p.stats.dump() + p.pool.Put(ptr) +} + +func NewPool(maxN ...int) (pool Pool) { + + sort.Ints(maxN) + pool = Pool{ + inUse: make(map[*fr.Element]inUseData), + subPools: make([]sizedPool, len(maxN)), } -) -// ClearPool Clears the pool completely, shields against memory leaks -// Eg: if we forgot to dump a polynomial at some point, this will ensure the value get dumped eventually -// Returns how many polynomials were cleared that way -func ClearPool() int { - res := 0 - rC.Range(func(k, _ interface{}) bool { - switch ptr := k.(type) { - case *largeArr: - largePool.Put(ptr) - case *smallArr: - smallPool.Put(ptr) - default: - panic(fmt.Sprintf("tried to clear %v", reflect.TypeOf(ptr))) + for i := range pool.subPools { + subPool := &pool.subPools[i] + subPool.maxN = maxN[i] + subPool.pool = sync.Pool{ + New: func() interface{} { + subPool.stats.Allocated++ + return getDataPointer(make([]fr.Element, 0, subPool.maxN)) + }, } - res++ - return true - }) - return res + } + return } -// CountPool Returns the number of elements in the pool without mutating it -func CountPool() int { - res := 0 - rC.Range(func(_, _ interface{}) bool { - res++ - return true - }) - return res +func (p *Pool) findCorrespondingPool(n int) *sizedPool { + poolI := 0 + for poolI < len(p.subPools) && n > p.subPools[poolI].maxN { + poolI++ + } + return &p.subPools[poolI] // out of bounds error here would mean that n is too large } -// Make tries to find a reusable polynomial or allocates a new one -func Make(n int) []fr.Element { - if n > maxNForLargePool { - panic(fmt.Sprintf("been provided with size of %v but the maximum is %v", n, maxNForLargePool)) +func (p *Pool) Make(n int) []fr.Element { + pool := p.findCorrespondingPool(n) + ptr := pool.get(n) + p.addInUse(ptr, pool) + return unsafe.Slice(ptr, n) +} + +// Dump dumps a set of polynomials into the pool +func (p *Pool) Dump(slices ...[]fr.Element) { + for _, slice := range slices { + ptr := getDataPointer(slice) + if metadata, ok := p.inUse[ptr]; ok { + delete(p.inUse, ptr) + metadata.pool.put(ptr) + } else { + panic("attempting to dump a slice not created by the pool") + } } +} + +func (p *Pool) addInUse(ptr *fr.Element, pool *sizedPool) { + pcs := make([]uintptr, 2) + n := runtime.Callers(3, pcs) - if n <= maxNForSmallPool { - ptr := smallPool.Get().(*smallArr) - rC.Store(ptr, struct{}{}) // registers the pointer being used - return (*ptr)[:n] + if prevPcs, ok := p.inUse[ptr]; ok { // TODO: remove if unnecessary for security + panic(fmt.Errorf("re-allocated non-dumped slice, previously allocated at %v", runtime.CallersFrames(prevPcs.allocatedFor))) } + p.inUse[ptr] = inUseData{ + allocatedFor: pcs[:n], + pool: pool, + } +} - ptr := largePool.Get().(*largeArr) - rC.Store(ptr, struct{}{}) // remember we allocated the pointer is being used - return (*ptr)[:n] +func printFrame(frame runtime.Frame) { + fmt.Printf("\t%s line %d, function %s\n", frame.File, frame.Line, frame.Function) } -// Dump dumps a set of polynomials into the pool -// Returns the number of deallocated polys -func Dump(arrs ...[]fr.Element) int { - cnt := 0 - for _, arr := range arrs { - ptr := ptr(arr) - pool := &smallPool - if len(arr) > maxNForSmallPool { - pool = &largePool - } - // If the rC did not register, then - // either the array was allocated somewhere else which can be ignored - // otherwise a double put which MUST be ignored - if _, ok := rC.Load(ptr); ok { - pool.Put(ptr) - // And deregisters the ptr - rC.Delete(ptr) - cnt++ +func (p *Pool) printInUse() { + fmt.Println("slices never dumped allocated at:") + for _, pcs := range p.inUse { + fmt.Println("-------------------------") + + var frame runtime.Frame + frames := runtime.CallersFrames(pcs.allocatedFor) + more := true + for more { + frame, more = frames.Next() + printFrame(frame) } } - return cnt } -func ptr(m []fr.Element) unsafe.Pointer { - if cap(m) != maxNForSmallPool && cap(m) != maxNForLargePool { - panic(fmt.Sprintf("can't cast to large or small array, the put array's is %v it should have capacity %v or %v", cap(m), maxNForLargePool, maxNForSmallPool)) +type poolStats struct { + Used int + Allocated int + ReuseRate float64 + InUse int + GreatestNUsed int + SmallestNUsed int +} + +type poolsStats struct { + SubPools []poolStats + InUse int +} + +func (s *poolStats) maake(n int) { + s.Used++ + s.InUse++ + if n > s.GreatestNUsed { + s.GreatestNUsed = n + } + if s.SmallestNUsed == 0 || s.SmallestNUsed > n { + s.SmallestNUsed = n + } +} + +func (s *poolStats) dump() { + s.InUse-- +} + +func (s *poolStats) finalize() { + s.ReuseRate = float64(s.Used) / float64(s.Allocated) +} + +func getDataPointer(slice []fr.Element) *fr.Element { + header := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) + return (*fr.Element)(unsafe.Pointer(header.Data)) +} + +func (p *Pool) PrintPoolStats() { + InUse := 0 + subStats := make([]poolStats, len(p.subPools)) + for i := range p.subPools { + subPool := &p.subPools[i] + subPool.stats.finalize() + subStats[i] = subPool.stats + InUse += subPool.stats.InUse } - return unsafe.Pointer(&m[0]) + + poolsStats := poolsStats{ + SubPools: subStats, + InUse: InUse, + } + serialized, _ := json.MarshalIndent(poolsStats, "", " ") + fmt.Println(string(serialized)) + p.printInUse() +} + +func (p *Pool) Clone(slice []fr.Element) []fr.Element { + res := p.Make(len(slice)) + copy(res, slice) + return res } diff --git a/ecc/bls12-377/fr/sumcheck/sumcheck.go b/ecc/bls12-377/fr/sumcheck/sumcheck.go new file mode 100644 index 000000000..d4306b78c --- /dev/null +++ b/ecc/bls12-377/fr/sumcheck/sumcheck.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = &transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return fmt.Errorf("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/ecc/bls12-377/fr/sumcheck/sumcheck_test.go b/ecc/bls12-377/fr/sumcheck/sumcheck_test.go new file mode 100644 index 000000000..dd741858e --- /dev/null +++ b/ecc/bls12-377/fr/sumcheck/sumcheck_test.go @@ -0,0 +1,161 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/ecc/bls12-377/fr/test_vector_utils/test_vector_utils.go b/ecc/bls12-377/fr/test_vector_utils/test_vector_utils.go new file mode 100644 index 000000000..aaa5f41e1 --- /dev/null +++ b/ecc/bls12-377/fr/test_vector_utils/test_vector_utils.go @@ -0,0 +1,429 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package test_vector_utils + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/polynomial" + "hash" + + "os" + "path/filepath" + "reflect" + "sort" + "strconv" + "strings" +) + +type ElementTriplet struct { + key1 fr.Element + key2 fr.Element + key2Present bool + value fr.Element + used bool +} + +func (t *ElementTriplet) CmpKey(o *ElementTriplet) int { + if cmp1 := t.key1.Cmp(&o.key1); cmp1 != 0 { + return cmp1 + } + + if t.key2Present { + if o.key2Present { + return t.key2.Cmp(&o.key2) + } + return 1 + } else { + if o.key2Present { + return -1 + } + return 0 + } +} + +var MapCache = make(map[string]*ElementMap) + +func ElementMapFromFile(path string) (*ElementMap, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if h, ok := MapCache[path]; ok { + return h, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var asMap map[string]interface{} + if err = json.Unmarshal(bytes, &asMap); err != nil { + return nil, err + } + + var h ElementMap + if h, err = CreateElementMap(asMap); err == nil { + MapCache[path] = &h + } + + return &h, err + + } else { + return nil, err + } +} + +func CreateElementMap(rawMap map[string]interface{}) (ElementMap, error) { + res := make(ElementMap, 0, len(rawMap)) + + for k, v := range rawMap { + var entry ElementTriplet + if _, err := SetElement(&entry.value, v); err != nil { + return nil, err + } + + key := strings.Split(k, ",") + switch len(key) { + case 1: + entry.key2Present = false + case 2: + entry.key2Present = true + if _, err := SetElement(&entry.key2, key[1]); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("cannot parse %T as one or two field elements", v) + } + if _, err := SetElement(&entry.key1, key[0]); err != nil { + return nil, err + } + + res = append(res, &entry) + } + + res.sort() + return res, nil +} + +type ElementMap []*ElementTriplet + +type MapHash struct { + Map *ElementMap + state fr.Element + stateValid bool +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func (m *MapHash) Write(p []byte) (n int, err error) { + var x fr.Element + for i := 0; i < len(p); i += fr.Bytes { + x.SetBytes(p[i:min(len(p), i+fr.Bytes)]) + if err = m.write(x); err != nil { + return + } + } + n = len(p) + return +} + +func (m *MapHash) Sum(b []byte) []byte { + mP := *m + if _, err := mP.Write(b); err != nil { + panic(err) + } + bytes := mP.state.Bytes() + return bytes[:] +} + +func (m *MapHash) Reset() { + m.stateValid = false +} + +func (m *MapHash) Size() int { + return fr.Bytes +} + +func (m *MapHash) BlockSize() int { + return fr.Bytes +} + +func (m *MapHash) write(x fr.Element) error { + X := &x + Y := &m.state + if !m.stateValid { + Y = nil + } + var err error + if m.state, err = m.Map.FindPair(X, Y); err == nil { + m.stateValid = true + } + return err +} + +func (t *ElementTriplet) writeKey(sb *strings.Builder) { + sb.WriteRune('"') + sb.WriteString(t.key1.String()) + if t.key2Present { + sb.WriteRune(',') + sb.WriteString(t.key2.String()) + } + sb.WriteRune('"') +} +func (m *ElementMap) UnusedEntries() []interface{} { + unused := make([]interface{}, 0) + for _, v := range *m { + if !v.used { + var vInterface interface{} + if v.key2Present { + vInterface = []interface{}{ElementToInterface(&v.key1), ElementToInterface(&v.key2)} + } else { + vInterface = ElementToInterface(&v.key1) + } + unused = append(unused, vInterface) + } + } + return unused +} + +func (m *ElementMap) sort() { + sort.Slice(*m, func(i, j int) bool { + return (*m)[i].CmpKey((*m)[j]) <= 0 + }) +} + +func (m *ElementMap) find(toFind *ElementTriplet) (fr.Element, error) { + i := sort.Search(len(*m), func(i int) bool { return (*m)[i].CmpKey(toFind) >= 0 }) + + if i < len(*m) && (*m)[i].CmpKey(toFind) == 0 { + (*m)[i].used = true + return (*m)[i].value, nil + } + var sb strings.Builder + sb.WriteString("no value available for input ") + toFind.writeKey(&sb) + return fr.Element{}, fmt.Errorf(sb.String()) +} + +func (m *ElementMap) FindPair(x *fr.Element, y *fr.Element) (fr.Element, error) { + + toFind := ElementTriplet{ + key1: *x, + key2Present: y != nil, + } + + if y != nil { + toFind.key2 = *y + } + + return m.find(&toFind) +} + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type MessageCounter struct { + startState uint64 + state uint64 + step uint64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += uint64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + uint64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: uint64(startState), state: uint64(startState), step: uint64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + text := x.Text(10) + if len(text) < 10 && !strings.Contains(text, "/") { + if i, err := strconv.Atoi(text); err != nil { + panic(err.Error()) + } else { + return i + } + } + return text +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/ecc/bls12-377/g1.go b/ecc/bls12-377/g1.go index 4a39ccf7c..3e44f69cf 100644 --- a/ecc/bls12-377/g1.go +++ b/ecc/bls12-377/g1.go @@ -17,13 +17,12 @@ package bls12377 import ( - "math/big" - "runtime" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" + "runtime" ) // G1Affine point in affine coordinates @@ -36,7 +35,7 @@ type G1Jac struct { X, Y, Z fp.Element } -// g1JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +// g1JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) type g1JacExtended struct { X, Y, ZZ, ZZZ fp.Element } @@ -50,6 +49,13 @@ func (p *G1Affine) Set(a *G1Affine) *G1Affine { return p } +// setInfinity sets p to O +func (p *G1Affine) setInfinity() *G1Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + // ScalarMultiplication computes and returns p = a ⋅ s func (p *G1Affine) ScalarMultiplication(a *G1Affine, s *big.Int) *G1Affine { var _p G1Jac @@ -381,6 +387,7 @@ func (p *G1Jac) IsOnCurve() bool { func (p *G1Jac) IsInSubGroup() bool { var res G1Jac + res.phi(p). ScalarMultiplication(&res, &xGen). ScalarMultiplication(&res, &xGen). @@ -472,8 +479,8 @@ func (p *G1Jac) mulGLV(a *G1Jac, s *big.Int) *G1Jac { // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // we don't target constant-timeness so we check first if we increase the bounds or not maxBit := k1.BitLen() @@ -583,15 +590,15 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { return p } - var A, B, X1ZZ2, X2ZZ1, Y1ZZZ2, Y2ZZZ1 fp.Element + var A, B, U1, U2, S1, S2 fp.Element // p2: q, p1: p - X2ZZ1.Mul(&q.X, &p.ZZ) - X1ZZ2.Mul(&p.X, &q.ZZ) - A.Sub(&X2ZZ1, &X1ZZ2) - Y2ZZZ1.Mul(&q.Y, &p.ZZZ) - Y1ZZZ2.Mul(&p.Y, &q.ZZZ) - B.Sub(&Y2ZZZ1, &Y1ZZZ2) + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) if A.IsZero() { if B.IsZero() { @@ -603,11 +610,7 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { return p } - var U1, U2, S1, S2, P, R, PP, PPP, Q, V fp.Element - U1.Mul(&p.X, &q.ZZ) - U2.Mul(&q.X, &p.ZZ) - S1.Mul(&p.Y, &q.ZZZ) - S2.Mul(&q.Y, &p.ZZZ) + var P, R, PP, PPP, Q, V fp.Element P.Sub(&U2, &U1) R.Sub(&S2, &S1) PP.Square(&P) @@ -632,6 +635,8 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { // double point in Jacobian extended coordinates // http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well func (p *g1JacExtended) double(q *g1JacExtended) *g1JacExtended { var U, V, W, S, XX, M fp.Element @@ -874,95 +879,72 @@ func BatchJacobianToAffineG1(points []G1Jac) []G1Affine { // and return resulting points in affine coordinates // uses a simple windowed-NAF like exponentiation algorithm func BatchScalarMultiplicationG1(base *G1Affine, scalars []fr.Element) []G1Affine { - // approximate cost in group ops is // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) nbPoints := uint64(len(scalars)) min := ^uint64(0) bestC := 0 - for c := 2; c < 18; c++ { - cost := uint64(1 << (c - 1)) - nbChunks := uint64(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - cost += nbPoints * ((fr.Limbs * 64) + nbChunks) + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c - 1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add if cost < min { min = cost bestC = c } } c := uint64(bestC) // window size - nbChunks := int(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c } - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // precompute all powers of base for our window // note here that if performance is critical, we can implement as in the msmX methods // this allocation to be on the stack - baseTable := make([]G1Jac, (1 << (c - 1))) - baseTable[0].Set(&g1Infinity) - baseTable[0].AddMixed(base) + baseTable := make([]G1Jac, (1 << (maxC - 1))) + baseTable[0].FromAffine(base) for i := 1; i < len(baseTable); i++ { baseTable[i] = baseTable[i-1] baseTable[i].AddMixed(base) } - - pScalars, _ := partitionScalars(scalars, c, false, runtime.NumCPU()) - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := 0; chunk < nbChunks; chunk++ { - jc := uint64(uint64(chunk) * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = (64%c) != 0 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } // convert our base exp table into affine to use AddMixed baseTableAff := BatchJacobianToAffineG1(baseTable) toReturn := make([]G1Jac, len(scalars)) + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + // for each digit, take value in the base table, double it c time, voilà. - parallel.Execute(len(pScalars), func(start, end int) { + parallel.Execute(len(scalars), func(start, end int) { var p G1Jac for i := start; i < end; i++ { p.Set(&g1Infinity) for chunk := nbChunks - 1; chunk >= 0; chunk-- { - s := selectors[chunk] if chunk != nbChunks-1 { for j := uint64(0); j < c; j++ { p.DoubleAssign() } } + offset := chunk * len(scalars) + digit := digits[i+offset] - bits := (pScalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (pScalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { + if digit == 0 { continue } // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { + if digit&1 == 0 { // add - p.AddMixed(&baseTableAff[bits-1]) + p.AddMixed(&baseTableAff[(digit>>1)-1]) } else { // sub - t := baseTableAff[bits & ^msbWindow] + t := baseTableAff[digit>>1] t.Neg(&t) p.AddMixed(&t) } @@ -976,3 +958,54 @@ func BatchScalarMultiplicationG1(base *G1Affine, scalars []fr.Element) []G1Affin toReturnAff := BatchJacobianToAffineG1(toReturn) return toReturnAff } + +// batch add affine coordinates +// using batch inversion +// special cases (doubling, infinity) must be filtered out before this call +func batchAddG1Affine[TP pG1Affine, TPP ppG1Affine, TC cG1Affine](R *TPP, P *TP, batchSize int) { + var lambda, lambdain TC + + // add part + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // invert denominator using montgomery batch invert technique + { + var accumulator fp.Element + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var d fp.Element + var rr G1Affine + + // add part + for j := 0; j < batchSize; j++ { + // computa lambda + d.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &d) + + // compute X, Y + rr.X.Square(&lambda[j]) + rr.X.Sub(&rr.X, &(*R)[j].X) + rr.X.Sub(&rr.X, &(*P)[j].X) + d.Sub(&(*R)[j].X, &rr.X) + rr.Y.Mul(&lambda[j], &d) + rr.Y.Sub(&rr.Y, &(*R)[j].Y) + (*R)[j].Set(&rr) + } +} diff --git a/ecc/bls12-377/g1_test.go b/ecc/bls12-377/g1_test.go index 3209de0cd..94f0c32de 100644 --- a/ecc/bls12-377/g1_test.go +++ b/ecc/bls12-377/g1_test.go @@ -19,6 +19,7 @@ package bls12377 import ( "fmt" "math/big" + "math/rand" "testing" "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" @@ -338,7 +339,7 @@ func TestG1AffineOps(t *testing.T) { r := fr.Modulus() var g G1Jac - g.mulGLV(&g1Gen, r) + g.ScalarMultiplication(&g1Gen, r) var scalar, blindedScalar, rminusone big.Int var op1, op2, op3, gneg G1Jac @@ -458,8 +459,7 @@ func TestG1AffineBatchScalarMultiplication(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } result := BatchScalarMultiplicationG1(&g1GenAff, sampleScalars[:]) @@ -472,7 +472,7 @@ func TestG1AffineBatchScalarMultiplication(t *testing.T) { var expectedJac G1Jac var expected G1Affine var b big.Int - expectedJac.mulGLV(&g1Gen, sampleScalars[i].ToBigInt(&b)) + expectedJac.ScalarMultiplication(&g1Gen, sampleScalars[i].ToBigIntRegular(&b)) expected.FromJacobian(&expectedJac) if !result[i].Equal(&expected) { return false @@ -499,6 +499,33 @@ func BenchmarkG1JacIsInSubGroup(b *testing.B) { } +func BenchmarkBatchAddG1Affine(b *testing.B) { + + var P, R pG1AffineC16 + var RR ppG1AffineC16 + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBasesG1(P[:]) + fillBenchBasesG1(R[:]) + + for i := 0; i < len(ridx); i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAddG1Affine[pG1AffineC16, ppG1AffineC16, cG1AffineC16](&RR, &P, len(P)) + } +} + func BenchmarkG1AffineBatchScalarMultiplication(b *testing.B) { // ensure every words of the scalars are filled var mixer fr.Element @@ -511,8 +538,7 @@ func BenchmarkG1AffineBatchScalarMultiplication(b *testing.B) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } for i := 5; i <= pow; i++ { diff --git a/ecc/bls12-377/g2.go b/ecc/bls12-377/g2.go index c80b5bdee..c2155c8dc 100644 --- a/ecc/bls12-377/g2.go +++ b/ecc/bls12-377/g2.go @@ -17,13 +17,12 @@ package bls12377 import ( - "math/big" - "runtime" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" "github.com/consensys/gnark-crypto/ecc/bls12-377/internal/fptower" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" + "runtime" ) // G2Affine point in affine coordinates @@ -36,7 +35,7 @@ type G2Jac struct { X, Y, Z fptower.E2 } -// g2JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +// g2JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) type g2JacExtended struct { X, Y, ZZ, ZZZ fptower.E2 } @@ -55,6 +54,13 @@ func (p *G2Affine) Set(a *G2Affine) *G2Affine { return p } +// setInfinity sets p to O +func (p *G2Affine) setInfinity() *G2Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + // ScalarMultiplication computes and returns p = a ⋅ s func (p *G2Affine) ScalarMultiplication(a *G2Affine, s *big.Int) *G2Affine { var _p G2Jac @@ -370,7 +376,8 @@ func (p *G2Jac) IsOnCurve() bool { } // https://eprint.iacr.org/2021/1130.pdf, sec.4 -// ψ(p) = x₀ P +// and https://eprint.iacr.org/2022/352.pdf, sec. 4.2 +// ψ(p) = [x₀]P func (p *G2Jac) IsInSubGroup() bool { var res, tmp G2Jac tmp.psi(p) @@ -471,8 +478,8 @@ func (p *G2Jac) mulGLV(a *G2Jac, s *big.Int) *G2Jac { // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // we don't target constant-timeness so we check first if we increase the bounds or not maxBit := k1.BitLen() @@ -600,15 +607,15 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { return p } - var A, B, X1ZZ2, X2ZZ1, Y1ZZZ2, Y2ZZZ1 fptower.E2 + var A, B, U1, U2, S1, S2 fptower.E2 // p2: q, p1: p - X2ZZ1.Mul(&q.X, &p.ZZ) - X1ZZ2.Mul(&p.X, &q.ZZ) - A.Sub(&X2ZZ1, &X1ZZ2) - Y2ZZZ1.Mul(&q.Y, &p.ZZZ) - Y1ZZZ2.Mul(&p.Y, &q.ZZZ) - B.Sub(&Y2ZZZ1, &Y1ZZZ2) + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) if A.IsZero() { if B.IsZero() { @@ -620,11 +627,7 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { return p } - var U1, U2, S1, S2, P, R, PP, PPP, Q, V fptower.E2 - U1.Mul(&p.X, &q.ZZ) - U2.Mul(&q.X, &p.ZZ) - S1.Mul(&p.Y, &q.ZZZ) - S2.Mul(&q.Y, &p.ZZZ) + var P, R, PP, PPP, Q, V fptower.E2 P.Sub(&U2, &U1) R.Sub(&S2, &S1) PP.Square(&P) @@ -649,6 +652,8 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { // double point in Jacobian extended coordinates // http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well func (p *g2JacExtended) double(q *g2JacExtended) *g2JacExtended { var U, V, W, S, XX, M fptower.E2 @@ -872,93 +877,70 @@ func (p *g2Proj) FromAffine(Q *G2Affine) *g2Proj { // and return resulting points in affine coordinates // uses a simple windowed-NAF like exponentiation algorithm func BatchScalarMultiplicationG2(base *G2Affine, scalars []fr.Element) []G2Affine { - // approximate cost in group ops is // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) nbPoints := uint64(len(scalars)) min := ^uint64(0) bestC := 0 - for c := 2; c < 18; c++ { - cost := uint64(1 << (c - 1)) - nbChunks := uint64(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - cost += nbPoints * ((fr.Limbs * 64) + nbChunks) + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c - 1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add if cost < min { min = cost bestC = c } } c := uint64(bestC) // window size - nbChunks := int(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c } - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // precompute all powers of base for our window // note here that if performance is critical, we can implement as in the msmX methods // this allocation to be on the stack - baseTable := make([]G2Jac, (1 << (c - 1))) - baseTable[0].Set(&g2Infinity) - baseTable[0].AddMixed(base) + baseTable := make([]G2Jac, (1 << (maxC - 1))) + baseTable[0].FromAffine(base) for i := 1; i < len(baseTable); i++ { baseTable[i] = baseTable[i-1] baseTable[i].AddMixed(base) } - - pScalars, _ := partitionScalars(scalars, c, false, runtime.NumCPU()) - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := 0; chunk < nbChunks; chunk++ { - jc := uint64(uint64(chunk) * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = (64%c) != 0 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } toReturn := make([]G2Affine, len(scalars)) + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + // for each digit, take value in the base table, double it c time, voilà. - parallel.Execute(len(pScalars), func(start, end int) { + parallel.Execute(len(scalars), func(start, end int) { var p G2Jac for i := start; i < end; i++ { p.Set(&g2Infinity) for chunk := nbChunks - 1; chunk >= 0; chunk-- { - s := selectors[chunk] if chunk != nbChunks-1 { for j := uint64(0); j < c; j++ { p.DoubleAssign() } } + offset := chunk * len(scalars) + digit := digits[i+offset] - bits := (pScalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (pScalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { + if digit == 0 { continue } // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { + if digit&1 == 0 { // add - p.AddAssign(&baseTable[bits-1]) + p.AddAssign(&baseTable[(digit>>1)-1]) } else { // sub - t := baseTable[bits & ^msbWindow] + t := baseTable[digit>>1] t.Neg(&t) p.AddAssign(&t) } @@ -971,3 +953,54 @@ func BatchScalarMultiplicationG2(base *G2Affine, scalars []fr.Element) []G2Affin }) return toReturn } + +// batch add affine coordinates +// using batch inversion +// special cases (doubling, infinity) must be filtered out before this call +func batchAddG2Affine[TP pG2Affine, TPP ppG2Affine, TC cG2Affine](R *TPP, P *TP, batchSize int) { + var lambda, lambdain TC + + // add part + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // invert denominator using montgomery batch invert technique + { + var accumulator fptower.E2 + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var d fptower.E2 + var rr G2Affine + + // add part + for j := 0; j < batchSize; j++ { + // computa lambda + d.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &d) + + // compute X, Y + rr.X.Square(&lambda[j]) + rr.X.Sub(&rr.X, &(*R)[j].X) + rr.X.Sub(&rr.X, &(*P)[j].X) + d.Sub(&(*R)[j].X, &rr.X) + rr.Y.Mul(&lambda[j], &d) + rr.Y.Sub(&rr.Y, &(*R)[j].Y) + (*R)[j].Set(&rr) + } +} diff --git a/ecc/bls12-377/g2_test.go b/ecc/bls12-377/g2_test.go index d3b0af12b..e0ed36073 100644 --- a/ecc/bls12-377/g2_test.go +++ b/ecc/bls12-377/g2_test.go @@ -19,6 +19,7 @@ package bls12377 import ( "fmt" "math/big" + "math/rand" "testing" "github.com/consensys/gnark-crypto/ecc/bls12-377/internal/fptower" @@ -339,7 +340,7 @@ func TestG2AffineOps(t *testing.T) { r := fr.Modulus() var g G2Jac - g.mulGLV(&g2Gen, r) + g.ScalarMultiplication(&g2Gen, r) var scalar, blindedScalar, rminusone big.Int var op1, op2, op3, gneg G2Jac @@ -464,8 +465,7 @@ func TestG2AffineBatchScalarMultiplication(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } result := BatchScalarMultiplicationG2(&g2GenAff, sampleScalars[:]) @@ -478,7 +478,7 @@ func TestG2AffineBatchScalarMultiplication(t *testing.T) { var expectedJac G2Jac var expected G2Affine var b big.Int - expectedJac.mulGLV(&g2Gen, sampleScalars[i].ToBigInt(&b)) + expectedJac.ScalarMultiplication(&g2Gen, sampleScalars[i].ToBigIntRegular(&b)) expected.FromJacobian(&expectedJac) if !result[i].Equal(&expected) { return false @@ -505,6 +505,33 @@ func BenchmarkG2JacIsInSubGroup(b *testing.B) { } +func BenchmarkBatchAddG2Affine(b *testing.B) { + + var P, R pG2AffineC16 + var RR ppG2AffineC16 + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBasesG2(P[:]) + fillBenchBasesG2(R[:]) + + for i := 0; i < len(ridx); i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAddG2Affine[pG2AffineC16, ppG2AffineC16, cG2AffineC16](&RR, &P, len(P)) + } +} + func BenchmarkG2AffineBatchScalarMultiplication(b *testing.B) { // ensure every words of the scalars are filled var mixer fr.Element @@ -517,8 +544,7 @@ func BenchmarkG2AffineBatchScalarMultiplication(b *testing.B) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } for i := 5; i <= pow; i++ { diff --git a/ecc/bls12-377/hash_to_g1.go b/ecc/bls12-377/hash_to_g1.go index d63191b36..6d5762e3a 100644 --- a/ecc/bls12-377/hash_to_g1.go +++ b/ecc/bls12-377/hash_to_g1.go @@ -17,7 +17,6 @@ package bls12377 import ( - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" "math/big" @@ -254,35 +253,14 @@ func g1EvalPolynomial(z *fp.Element, monic bool, coefficients []fp.Element, x *f z.Set(&dst) } -// hashToFp hashes msg to count prime field elements. -// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 -func hashToFp(msg, dst []byte, count int) ([]fp.Element, error) { - // 128 bits of security - // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 - const Bytes = 1 + (fp.Bits-1)/8 - const L = 16 + Bytes - - lenInBytes := count * L - pseudoRandomBytes, err := ecc.ExpandMsgXmd(msg, dst, lenInBytes) - if err != nil { - return nil, err - } - - res := make([]fp.Element, count) - for i := 0; i < count; i++ { - res[i].SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) - } - return res, nil -} - // g1Sgn0 is an algebraic substitute for the notion of sign in ordered fields // Namely, every non-zero quadratic residue in a finite field of characteristic =/= 2 has exactly two square roots, one of each sign // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#name-the-sgn0-function // The sign of an element is not obviously related to that of its Montgomery form func g1Sgn0(z *fp.Element) uint64 { - nonMont := *z - nonMont.FromMont() + nonMont := z.Bits() + // m == 1 return nonMont[0] % 2 @@ -300,11 +278,11 @@ func MapToG1(u fp.Element) G1Affine { // EncodeToG1 hashes a message to a point on the G1 curve using the SSWU map. // It is faster than HashToG1, but the result is not uniformly distributed. Unsuitable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func EncodeToG1(msg, dst []byte) (G1Affine, error) { var res G1Affine - u, err := hashToFp(msg, dst, 1) + u, err := fp.Hash(msg, dst, 1) if err != nil { return res, err } @@ -320,9 +298,9 @@ func EncodeToG1(msg, dst []byte) (G1Affine, error) { // HashToG1 hashes a message to a point on the G1 curve using the SSWU map. // Slower than EncodeToG1, but usable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func HashToG1(msg, dst []byte) (G1Affine, error) { - u, err := hashToFp(msg, dst, 2*1) + u, err := fp.Hash(msg, dst, 2*1) if err != nil { return G1Affine{}, err } diff --git a/ecc/bls12-377/hash_to_g1_test.go b/ecc/bls12-377/hash_to_g1_test.go index 385ac04b5..00ec21465 100644 --- a/ecc/bls12-377/hash_to_g1_test.go +++ b/ecc/bls12-377/hash_to_g1_test.go @@ -62,7 +62,7 @@ func TestG1SqrtRatio(t *testing.T) { func TestHashToFpG1(t *testing.T) { for _, c := range encodeToG1Vector.cases { - elems, err := hashToFp([]byte(c.msg), encodeToG1Vector.dst, 1) + elems, err := fp.Hash([]byte(c.msg), encodeToG1Vector.dst, 1) if err != nil { t.Error(err) } @@ -70,7 +70,7 @@ func TestHashToFpG1(t *testing.T) { } for _, c := range hashToG1Vector.cases { - elems, err := hashToFp([]byte(c.msg), hashToG1Vector.dst, 2*1) + elems, err := fp.Hash([]byte(c.msg), hashToG1Vector.dst, 2*1) if err != nil { t.Error(err) } @@ -220,7 +220,7 @@ func BenchmarkHashToG1(b *testing.B) { } } -//TODO: Crude. Do something clever in Jacobian +// TODO: Crude. Do something clever in Jacobian func isOnE1Prime(p G1Affine) bool { var A, B fp.Element @@ -247,7 +247,7 @@ func isOnE1Prime(p G1Affine) bool { return LHS.Equal(&RHS) } -//Only works on simple extensions (two-story towers) +// Only works on simple extensions (two-story towers) func g1CoordSetString(z *fp.Element, s string) { z.SetString(s) } diff --git a/ecc/bls12-377/hash_to_g2.go b/ecc/bls12-377/hash_to_g2.go index 46e6c62ae..03f649caa 100644 --- a/ecc/bls12-377/hash_to_g2.go +++ b/ecc/bls12-377/hash_to_g2.go @@ -715,8 +715,7 @@ func g2EvalPolynomial(z *fptower.E2, monic bool, coefficients []fptower.E2, x *f // The sign of an element is not obviously related to that of its Montgomery form func g2Sgn0(z *fptower.E2) uint64 { - nonMont := *z - nonMont.FromMont() + nonMont := z.Bits() sign := uint64(0) // 1. sign = 0 zero := uint64(1) // 2. zero = 1 @@ -750,11 +749,11 @@ func MapToG2(u fptower.E2) G2Affine { // EncodeToG2 hashes a message to a point on the G2 curve using the SSWU map. // It is faster than HashToG2, but the result is not uniformly distributed. Unsuitable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func EncodeToG2(msg, dst []byte) (G2Affine, error) { var res G2Affine - u, err := hashToFp(msg, dst, 2) + u, err := fp.Hash(msg, dst, 2) if err != nil { return res, err } @@ -773,9 +772,9 @@ func EncodeToG2(msg, dst []byte) (G2Affine, error) { // HashToG2 hashes a message to a point on the G2 curve using the SSWU map. // Slower than EncodeToG2, but usable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func HashToG2(msg, dst []byte) (G2Affine, error) { - u, err := hashToFp(msg, dst, 2*2) + u, err := fp.Hash(msg, dst, 2*2) if err != nil { return G2Affine{}, err } diff --git a/ecc/bls12-377/hash_to_g2_test.go b/ecc/bls12-377/hash_to_g2_test.go index dd56e76c5..f3b31b8a8 100644 --- a/ecc/bls12-377/hash_to_g2_test.go +++ b/ecc/bls12-377/hash_to_g2_test.go @@ -64,7 +64,7 @@ func TestG2SqrtRatio(t *testing.T) { func TestHashToFpG2(t *testing.T) { for _, c := range encodeToG2Vector.cases { - elems, err := hashToFp([]byte(c.msg), encodeToG2Vector.dst, 2) + elems, err := fp.Hash([]byte(c.msg), encodeToG2Vector.dst, 2) if err != nil { t.Error(err) } @@ -72,7 +72,7 @@ func TestHashToFpG2(t *testing.T) { } for _, c := range hashToG2Vector.cases { - elems, err := hashToFp([]byte(c.msg), hashToG2Vector.dst, 2*2) + elems, err := fp.Hash([]byte(c.msg), hashToG2Vector.dst, 2*2) if err != nil { t.Error(err) } @@ -222,7 +222,7 @@ func BenchmarkHashToG2(b *testing.B) { } } -//TODO: Crude. Do something clever in Jacobian +// TODO: Crude. Do something clever in Jacobian func isOnE2Prime(p G2Affine) bool { var A, B fptower.E2 @@ -251,7 +251,7 @@ func isOnE2Prime(p G2Affine) bool { return LHS.Equal(&RHS) } -//Only works on simple extensions (two-story towers) +// Only works on simple extensions (two-story towers) func g2CoordSetString(z *fptower.E2, s string) { ssplit := strings.Split(s, ",") if len(ssplit) != 2 { diff --git a/ecc/bls12-377/internal/fptower/e12.go b/ecc/bls12-377/internal/fptower/e12.go index b3edf3a17..de9c52aa9 100644 --- a/ecc/bls12-377/internal/fptower/e12.go +++ b/ecc/bls12-377/internal/fptower/e12.go @@ -17,7 +17,6 @@ package fptower import ( - "encoding/binary" "errors" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" @@ -68,20 +67,6 @@ func (z *E12) SetOne() *E12 { return z } -// ToMont converts to Mont form -func (z *E12) ToMont() *E12 { - z.C0.ToMont() - z.C1.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E12) FromMont() *E12 { - z.C0.FromMont() - z.C1.FromMont() - return z -} - // Add set z=x+y in E12 and return z func (z *E12) Add(x, y *E12) *E12 { z.C0.Add(&x.C0, &y.C0) @@ -119,6 +104,10 @@ func (z *E12) IsZero() bool { return z.C0.IsZero() && z.C1.IsZero() } +func (z *E12) IsOne() bool { + return z.C0.IsOne() && z.C1.IsZero() +} + // Mul set z=x*y in E12 and return z func (z *E12) Mul(x, y *E12) *E12 { var a, b, c E6 @@ -226,9 +215,12 @@ func (z *E12) CyclotomicSquareCompressed(x *E12) *E12 { // DecompressKarabina Karabina's cyclotomic square result // if g3 != 0 -// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// +// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// // if g3 == 0 -// g4 = 2g1g5/g2 +// +// g4 = 2g1g5/g2 // // if g3=g2=0 then g4=g5=g1=0 and g0=1 (x=1) // Theorem 3.1 is well-defined for all x in Gϕₙ\{1} @@ -289,9 +281,12 @@ func (z *E12) DecompressKarabina(x *E12) *E12 { // BatchDecompressKarabina multiple Karabina's cyclotomic square results // if g3 != 0 -// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// +// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// // if g3 == 0 -// g4 = 2g1g5/g2 +// +// g4 = 2g1g5/g2 // // if g3=g2=0 then g4=g5=g1=0 and g0=1 (x=1) // Theorem 3.1 is well-defined for all x in Gϕₙ\{1} @@ -602,8 +597,8 @@ func (z *E12) ExpGLV(x E12, k *big.Int) *E12 { table[14].Mul(&table[11], &table[2]) // bounds on the lattice base vectors guarantee that s1, s2 are len(r)/2 bits long max - s1.SetBigInt(&s[0]).FromMont() - s2.SetBigInt(&s[1]).FromMont() + s1 = s1.SetBigInt(&s[0]).Bits() + s2 = s2.SetBigInt(&s[1]).Bits() // loop starts from len(s1)/2 due to the bounds for i := len(s1) / 2; i >= 0; i-- { @@ -652,93 +647,20 @@ func (z *E12) Unmarshal(buf []byte) error { // Bytes returns the regular (non montgomery) value // of z as a big-endian byte array. -// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... +// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... func (z *E12) Bytes() (r [SizeOfGT]byte) { - _z := *z - _z.FromMont() - binary.BigEndian.PutUint64(r[568:576], _z.C0.B0.A0[0]) - binary.BigEndian.PutUint64(r[560:568], _z.C0.B0.A0[1]) - binary.BigEndian.PutUint64(r[552:560], _z.C0.B0.A0[2]) - binary.BigEndian.PutUint64(r[544:552], _z.C0.B0.A0[3]) - binary.BigEndian.PutUint64(r[536:544], _z.C0.B0.A0[4]) - binary.BigEndian.PutUint64(r[528:536], _z.C0.B0.A0[5]) - - binary.BigEndian.PutUint64(r[520:528], _z.C0.B0.A1[0]) - binary.BigEndian.PutUint64(r[512:520], _z.C0.B0.A1[1]) - binary.BigEndian.PutUint64(r[504:512], _z.C0.B0.A1[2]) - binary.BigEndian.PutUint64(r[496:504], _z.C0.B0.A1[3]) - binary.BigEndian.PutUint64(r[488:496], _z.C0.B0.A1[4]) - binary.BigEndian.PutUint64(r[480:488], _z.C0.B0.A1[5]) - - binary.BigEndian.PutUint64(r[472:480], _z.C0.B1.A0[0]) - binary.BigEndian.PutUint64(r[464:472], _z.C0.B1.A0[1]) - binary.BigEndian.PutUint64(r[456:464], _z.C0.B1.A0[2]) - binary.BigEndian.PutUint64(r[448:456], _z.C0.B1.A0[3]) - binary.BigEndian.PutUint64(r[440:448], _z.C0.B1.A0[4]) - binary.BigEndian.PutUint64(r[432:440], _z.C0.B1.A0[5]) - - binary.BigEndian.PutUint64(r[424:432], _z.C0.B1.A1[0]) - binary.BigEndian.PutUint64(r[416:424], _z.C0.B1.A1[1]) - binary.BigEndian.PutUint64(r[408:416], _z.C0.B1.A1[2]) - binary.BigEndian.PutUint64(r[400:408], _z.C0.B1.A1[3]) - binary.BigEndian.PutUint64(r[392:400], _z.C0.B1.A1[4]) - binary.BigEndian.PutUint64(r[384:392], _z.C0.B1.A1[5]) - - binary.BigEndian.PutUint64(r[376:384], _z.C0.B2.A0[0]) - binary.BigEndian.PutUint64(r[368:376], _z.C0.B2.A0[1]) - binary.BigEndian.PutUint64(r[360:368], _z.C0.B2.A0[2]) - binary.BigEndian.PutUint64(r[352:360], _z.C0.B2.A0[3]) - binary.BigEndian.PutUint64(r[344:352], _z.C0.B2.A0[4]) - binary.BigEndian.PutUint64(r[336:344], _z.C0.B2.A0[5]) - - binary.BigEndian.PutUint64(r[328:336], _z.C0.B2.A1[0]) - binary.BigEndian.PutUint64(r[320:328], _z.C0.B2.A1[1]) - binary.BigEndian.PutUint64(r[312:320], _z.C0.B2.A1[2]) - binary.BigEndian.PutUint64(r[304:312], _z.C0.B2.A1[3]) - binary.BigEndian.PutUint64(r[296:304], _z.C0.B2.A1[4]) - binary.BigEndian.PutUint64(r[288:296], _z.C0.B2.A1[5]) - - binary.BigEndian.PutUint64(r[280:288], _z.C1.B0.A0[0]) - binary.BigEndian.PutUint64(r[272:280], _z.C1.B0.A0[1]) - binary.BigEndian.PutUint64(r[264:272], _z.C1.B0.A0[2]) - binary.BigEndian.PutUint64(r[256:264], _z.C1.B0.A0[3]) - binary.BigEndian.PutUint64(r[248:256], _z.C1.B0.A0[4]) - binary.BigEndian.PutUint64(r[240:248], _z.C1.B0.A0[5]) - - binary.BigEndian.PutUint64(r[232:240], _z.C1.B0.A1[0]) - binary.BigEndian.PutUint64(r[224:232], _z.C1.B0.A1[1]) - binary.BigEndian.PutUint64(r[216:224], _z.C1.B0.A1[2]) - binary.BigEndian.PutUint64(r[208:216], _z.C1.B0.A1[3]) - binary.BigEndian.PutUint64(r[200:208], _z.C1.B0.A1[4]) - binary.BigEndian.PutUint64(r[192:200], _z.C1.B0.A1[5]) - - binary.BigEndian.PutUint64(r[184:192], _z.C1.B1.A0[0]) - binary.BigEndian.PutUint64(r[176:184], _z.C1.B1.A0[1]) - binary.BigEndian.PutUint64(r[168:176], _z.C1.B1.A0[2]) - binary.BigEndian.PutUint64(r[160:168], _z.C1.B1.A0[3]) - binary.BigEndian.PutUint64(r[152:160], _z.C1.B1.A0[4]) - binary.BigEndian.PutUint64(r[144:152], _z.C1.B1.A0[5]) - - binary.BigEndian.PutUint64(r[136:144], _z.C1.B1.A1[0]) - binary.BigEndian.PutUint64(r[128:136], _z.C1.B1.A1[1]) - binary.BigEndian.PutUint64(r[120:128], _z.C1.B1.A1[2]) - binary.BigEndian.PutUint64(r[112:120], _z.C1.B1.A1[3]) - binary.BigEndian.PutUint64(r[104:112], _z.C1.B1.A1[4]) - binary.BigEndian.PutUint64(r[96:104], _z.C1.B1.A1[5]) - - binary.BigEndian.PutUint64(r[88:96], _z.C1.B2.A0[0]) - binary.BigEndian.PutUint64(r[80:88], _z.C1.B2.A0[1]) - binary.BigEndian.PutUint64(r[72:80], _z.C1.B2.A0[2]) - binary.BigEndian.PutUint64(r[64:72], _z.C1.B2.A0[3]) - binary.BigEndian.PutUint64(r[56:64], _z.C1.B2.A0[4]) - binary.BigEndian.PutUint64(r[48:56], _z.C1.B2.A0[5]) - - binary.BigEndian.PutUint64(r[40:48], _z.C1.B2.A1[0]) - binary.BigEndian.PutUint64(r[32:40], _z.C1.B2.A1[1]) - binary.BigEndian.PutUint64(r[24:32], _z.C1.B2.A1[2]) - binary.BigEndian.PutUint64(r[16:24], _z.C1.B2.A1[3]) - binary.BigEndian.PutUint64(r[8:16], _z.C1.B2.A1[4]) - binary.BigEndian.PutUint64(r[0:8], _z.C1.B2.A1[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[528:528+fp.Bytes]), z.C0.B0.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[480:480+fp.Bytes]), z.C0.B0.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[432:432+fp.Bytes]), z.C0.B1.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[384:384+fp.Bytes]), z.C0.B1.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[336:336+fp.Bytes]), z.C0.B2.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[288:288+fp.Bytes]), z.C0.B2.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[240:240+fp.Bytes]), z.C1.B0.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[192:192+fp.Bytes]), z.C1.B0.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[144:144+fp.Bytes]), z.C1.B1.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[96:96+fp.Bytes]), z.C1.B1.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[48:48+fp.Bytes]), z.C1.B2.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[0:0+fp.Bytes]), z.C1.B2.A1) return } @@ -746,34 +668,47 @@ func (z *E12) Bytes() (r [SizeOfGT]byte) { // SetBytes interprets e as the bytes of a big-endian GT // sets z to that value (in Montgomery form), and returns z. // size(e) == 48 * 12 -// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... +// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... func (z *E12) SetBytes(e []byte) error { if len(e) != SizeOfGT { return errors.New("invalid buffer size") } - z.C0.B0.A0.SetBytes(e[528 : 528+fp.Bytes]) - - z.C0.B0.A1.SetBytes(e[480 : 480+fp.Bytes]) - - z.C0.B1.A0.SetBytes(e[432 : 432+fp.Bytes]) - - z.C0.B1.A1.SetBytes(e[384 : 384+fp.Bytes]) - - z.C0.B2.A0.SetBytes(e[336 : 336+fp.Bytes]) - - z.C0.B2.A1.SetBytes(e[288 : 288+fp.Bytes]) - - z.C1.B0.A0.SetBytes(e[240 : 240+fp.Bytes]) - - z.C1.B0.A1.SetBytes(e[192 : 192+fp.Bytes]) - - z.C1.B1.A0.SetBytes(e[144 : 144+fp.Bytes]) - - z.C1.B1.A1.SetBytes(e[96 : 96+fp.Bytes]) - - z.C1.B2.A0.SetBytes(e[48 : 48+fp.Bytes]) - - z.C1.B2.A1.SetBytes(e[0 : 0+fp.Bytes]) + if err := z.C0.B0.A0.SetBytesCanonical(e[528 : 528+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B0.A1.SetBytesCanonical(e[480 : 480+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B1.A0.SetBytesCanonical(e[432 : 432+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B1.A1.SetBytesCanonical(e[384 : 384+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B2.A0.SetBytesCanonical(e[336 : 336+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B2.A1.SetBytesCanonical(e[288 : 288+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B0.A0.SetBytesCanonical(e[240 : 240+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B0.A1.SetBytesCanonical(e[192 : 192+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B1.A0.SetBytesCanonical(e[144 : 144+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B1.A1.SetBytesCanonical(e[96 : 96+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B2.A0.SetBytesCanonical(e[48 : 48+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B2.A1.SetBytesCanonical(e[0 : 0+fp.Bytes]); err != nil { + return err + } return nil } diff --git a/ecc/bls12-377/internal/fptower/e2.go b/ecc/bls12-377/internal/fptower/e2.go index c617faf6e..640fd5209 100644 --- a/ecc/bls12-377/internal/fptower/e2.go +++ b/ecc/bls12-377/internal/fptower/e2.go @@ -31,12 +31,20 @@ func (z *E2) Equal(x *E2) bool { return z.A0.Equal(&x.A0) && z.A1.Equal(&x.A1) } +// Bits +// TODO @gbotrel fixme this shouldn't return a E2 +func (z *E2) Bits() E2 { + r := E2{} + r.A0 = z.A0.Bits() + r.A1 = z.A1.Bits() + return r +} + // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *E2) Cmp(x *E2) int { if a1 := z.A1.Cmp(&x.A1); a1 != 0 { return a1 @@ -98,6 +106,10 @@ func (z *E2) IsZero() bool { return z.A0.IsZero() && z.A1.IsZero() } +func (z *E2) IsOne() bool { + return z.A0.IsOne() && z.A1.IsZero() +} + // Add adds two elements of E2 func (z *E2) Add(x, y *E2) *E2 { addE2(z, x, y) @@ -127,20 +139,6 @@ func (z *E2) String() string { return z.A0.String() + "+" + z.A1.String() + "*u" } -// ToMont converts to mont form -func (z *E2) ToMont() *E2 { - z.A0.ToMont() - z.A1.ToMont() - return z -} - -// FromMont converts from mont form -func (z *E2) FromMont() *E2 { - z.A0.FromMont() - z.A1.FromMont() - return z -} - // MulByElement multiplies an element in E2 by an element in fp func (z *E2) MulByElement(x *E2, y *fp.Element) *E2 { var yCopy fp.Element diff --git a/ecc/bls12-377/internal/fptower/e6.go b/ecc/bls12-377/internal/fptower/e6.go index 4da093f5f..8ae7216ec 100644 --- a/ecc/bls12-377/internal/fptower/e6.go +++ b/ecc/bls12-377/internal/fptower/e6.go @@ -63,25 +63,13 @@ func (z *E6) SetRandom() (*E6, error) { return z, nil } -// IsZero returns true if the two elements are equal, fasle otherwise +// IsZero returns true if the two elements are equal, false otherwise func (z *E6) IsZero() bool { return z.B0.IsZero() && z.B1.IsZero() && z.B2.IsZero() } -// ToMont converts to Mont form -func (z *E6) ToMont() *E6 { - z.B0.ToMont() - z.B1.ToMont() - z.B2.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E6) FromMont() *E6 { - z.B0.FromMont() - z.B1.FromMont() - z.B2.FromMont() - return z +func (z *E6) IsOne() bool { + return z.B0.IsOne() && z.B1.IsZero() && z.B2.IsZero() } // Add adds two elements of E6 diff --git a/ecc/bls12-377/marshal.go b/ecc/bls12-377/marshal.go index b3fc67d79..b39f62428 100644 --- a/ecc/bls12-377/marshal.go +++ b/ecc/bls12-377/marshal.go @@ -100,7 +100,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fr.Bytes]) + err = t.SetBytesCanonical(buf[:fr.Bytes]) return case *fp.Element: read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) @@ -108,7 +108,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fp.Bytes]) + err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: var sliceLen uint32 @@ -126,7 +126,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fr.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { + return + } } return case *[]fp.Element: @@ -145,7 +147,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fp.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return + } } return case *G1Affine: @@ -221,7 +225,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -276,7 +284,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -643,9 +655,6 @@ func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - msbMask := mCompressedSmallest // compressed, we need to know if Y is lexicographically bigger than -Y // if p.Y ">" -p.Y @@ -654,14 +663,7 @@ func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { } // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[40:48], tmp[0]) - binary.BigEndian.PutUint64(res[32:40], tmp[1]) - binary.BigEndian.PutUint64(res[24:32], tmp[2]) - binary.BigEndian.PutUint64(res[16:24], tmp[3]) - binary.BigEndian.PutUint64(res[8:16], tmp[4]) - binary.BigEndian.PutUint64(res[0:8], tmp[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= msbMask @@ -680,29 +682,12 @@ func (p *G1Affine) RawBytes() (res [SizeOfG1AffineUncompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - // not compressed // we store the Y coordinate - tmp = p.Y - tmp.FromMont() - binary.BigEndian.PutUint64(res[88:96], tmp[0]) - binary.BigEndian.PutUint64(res[80:88], tmp[1]) - binary.BigEndian.PutUint64(res[72:80], tmp[2]) - binary.BigEndian.PutUint64(res[64:72], tmp[3]) - binary.BigEndian.PutUint64(res[56:64], tmp[4]) - binary.BigEndian.PutUint64(res[48:56], tmp[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[48:48+fp.Bytes]), p.Y) // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[40:48], tmp[0]) - binary.BigEndian.PutUint64(res[32:40], tmp[1]) - binary.BigEndian.PutUint64(res[24:32], tmp[2]) - binary.BigEndian.PutUint64(res[16:24], tmp[3]) - binary.BigEndian.PutUint64(res[8:16], tmp[4]) - binary.BigEndian.PutUint64(res[0:8], tmp[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= mUncompressed @@ -753,8 +738,12 @@ func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { // uncompressed point if mData == mUncompressed { // read X and Y coordinates - p.X.SetBytes(buf[:fp.Bytes]) - p.Y.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.Y.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } // subgroup check if subGroupCheck && !p.IsInSubGroup() { @@ -774,7 +763,9 @@ func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return 0, err + } var YSquared, Y fp.Element @@ -848,7 +839,7 @@ func (p *G1Affine) unsafeComputeY(subGroupCheck bool) error { // assumes buf[:8] mask is set to compressed // returns true if point is infinity and need no further processing // it sets X coordinate and uses Y for scratch space to store decompression metadata -func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { +func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { // read the most significant byte mData := buf[0] & mMask @@ -857,7 +848,7 @@ func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { p.X.SetZero() p.Y.SetZero() isInfinity = true - return + return isInfinity, nil } // we need to copy the input buffer (to keep this method thread safe) @@ -866,12 +857,14 @@ func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return false, err + } // store mData in p.Y[0] p.Y[0] = uint64(mData) // recomputing Y will be done asynchronously - return + return isInfinity, nil } // SizeOfG2AffineCompressed represents the size in bytes that a G2Affine need in binary form, compressed @@ -909,9 +902,6 @@ func (p *G2Affine) Bytes() (res [SizeOfG2AffineCompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - msbMask := mCompressedSmallest // compressed, we need to know if Y is lexicographically bigger than -Y // if p.Y ">" -p.Y @@ -921,23 +911,8 @@ func (p *G2Affine) Bytes() (res [SizeOfG2AffineCompressed]byte) { // we store X and mask the most significant word with our metadata mask // p.X.A1 | p.X.A0 - tmp = p.X.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[88:96], tmp[0]) - binary.BigEndian.PutUint64(res[80:88], tmp[1]) - binary.BigEndian.PutUint64(res[72:80], tmp[2]) - binary.BigEndian.PutUint64(res[64:72], tmp[3]) - binary.BigEndian.PutUint64(res[56:64], tmp[4]) - binary.BigEndian.PutUint64(res[48:56], tmp[5]) - - tmp = p.X.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[40:48], tmp[0]) - binary.BigEndian.PutUint64(res[32:40], tmp[1]) - binary.BigEndian.PutUint64(res[24:32], tmp[2]) - binary.BigEndian.PutUint64(res[16:24], tmp[3]) - binary.BigEndian.PutUint64(res[8:16], tmp[4]) - binary.BigEndian.PutUint64(res[0:8], tmp[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[48:48+fp.Bytes]), p.X.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X.A1) res[0] |= msbMask @@ -956,49 +931,16 @@ func (p *G2Affine) RawBytes() (res [SizeOfG2AffineUncompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - // not compressed // we store the Y coordinate // p.Y.A1 | p.Y.A0 - tmp = p.Y.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[184:192], tmp[0]) - binary.BigEndian.PutUint64(res[176:184], tmp[1]) - binary.BigEndian.PutUint64(res[168:176], tmp[2]) - binary.BigEndian.PutUint64(res[160:168], tmp[3]) - binary.BigEndian.PutUint64(res[152:160], tmp[4]) - binary.BigEndian.PutUint64(res[144:152], tmp[5]) - - tmp = p.Y.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[136:144], tmp[0]) - binary.BigEndian.PutUint64(res[128:136], tmp[1]) - binary.BigEndian.PutUint64(res[120:128], tmp[2]) - binary.BigEndian.PutUint64(res[112:120], tmp[3]) - binary.BigEndian.PutUint64(res[104:112], tmp[4]) - binary.BigEndian.PutUint64(res[96:104], tmp[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[144:144+fp.Bytes]), p.Y.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[96:96+fp.Bytes]), p.Y.A1) // we store X and mask the most significant word with our metadata mask // p.X.A1 | p.X.A0 - tmp = p.X.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[40:48], tmp[0]) - binary.BigEndian.PutUint64(res[32:40], tmp[1]) - binary.BigEndian.PutUint64(res[24:32], tmp[2]) - binary.BigEndian.PutUint64(res[16:24], tmp[3]) - binary.BigEndian.PutUint64(res[8:16], tmp[4]) - binary.BigEndian.PutUint64(res[0:8], tmp[5]) - - tmp = p.X.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[88:96], tmp[0]) - binary.BigEndian.PutUint64(res[80:88], tmp[1]) - binary.BigEndian.PutUint64(res[72:80], tmp[2]) - binary.BigEndian.PutUint64(res[64:72], tmp[3]) - binary.BigEndian.PutUint64(res[56:64], tmp[4]) - binary.BigEndian.PutUint64(res[48:56], tmp[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[48:48+fp.Bytes]), p.X.A0) res[0] |= mUncompressed @@ -1050,11 +992,19 @@ func (p *G2Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { if mData == mUncompressed { // read X and Y coordinates // p.X.A1 | p.X.A0 - p.X.A1.SetBytes(buf[:fp.Bytes]) - p.X.A0.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.A1.SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.X.A0.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } // p.Y.A1 | p.Y.A0 - p.Y.A1.SetBytes(buf[fp.Bytes*2 : fp.Bytes*3]) - p.Y.A0.SetBytes(buf[fp.Bytes*3 : fp.Bytes*4]) + if err := p.Y.A1.SetBytesCanonical(buf[fp.Bytes*2 : fp.Bytes*3]); err != nil { + return 0, err + } + if err := p.Y.A0.SetBytesCanonical(buf[fp.Bytes*3 : fp.Bytes*4]); err != nil { + return 0, err + } // subgroup check if subGroupCheck && !p.IsInSubGroup() { @@ -1075,8 +1025,12 @@ func (p *G2Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { // read X coordinate // p.X.A1 | p.X.A0 - p.X.A1.SetBytes(bufX[:fp.Bytes]) - p.X.A0.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.A1.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.X.A0.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } var YSquared, Y fptower.E2 @@ -1152,7 +1106,7 @@ func (p *G2Affine) unsafeComputeY(subGroupCheck bool) error { // assumes buf[:8] mask is set to compressed // returns true if point is infinity and need no further processing // it sets X coordinate and uses Y for scratch space to store decompression metadata -func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { +func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { // read the most significant byte mData := buf[0] & mMask @@ -1161,7 +1115,7 @@ func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { p.X.SetZero() p.Y.SetZero() isInfinity = true - return + return isInfinity, nil } // we need to copy the input buffer (to keep this method thread safe) @@ -1171,12 +1125,16 @@ func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { // read X coordinate // p.X.A1 | p.X.A0 - p.X.A1.SetBytes(bufX[:fp.Bytes]) - p.X.A0.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.A1.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return false, err + } + if err := p.X.A0.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return false, err + } // store mData in p.Y.A0[0] p.Y.A0[0] = uint64(mData) // recomputing Y will be done asynchronously - return + return isInfinity, nil } diff --git a/ecc/bls12-377/multiexp.go b/ecc/bls12-377/multiexp.go index 6d4f14f13..f75f78aa0 100644 --- a/ecc/bls12-377/multiexp.go +++ b/ecc/bls12-377/multiexp.go @@ -25,143 +25,6 @@ import ( "runtime" ) -// selector stores the index, mask and shifts needed to select bits from a scalar -// it is used during the multiExp algorithm or the batch scalar multiplication -type selector struct { - index uint64 // index in the multi-word scalar to select bits from - mask uint64 // mask (c-bit wide) - shift uint64 // shift needed to get our bits on low positions - - multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) - maskHigh uint64 // same than mask, for index+1 - shiftHigh uint64 // same than shift, for index+1 -} - -// partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits -// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract -// 2^{c} to the current digit, making it negative. -// negative digits can be processed in a later step as adding -G into the bucket instead of G -// (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) -// scalarsMont indicates wheter the provided scalars are in montgomery form -// returns smallValues, which represent the number of scalars which meets the following condition -// 0 < scalar < 2^c (in other words, scalars where only the c-least significant bits are non zero) -func partitionScalars(scalars []fr.Element, c uint64, scalarsMont bool, nbTasks int) ([]fr.Element, int) { - toReturn := make([]fr.Element, len(scalars)) - - // number of c-bit radixes in a scalar - nbChunks := fr.Limbs * 64 / c - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // msb of the c-bit window - max := int(1 << (c - 1)) // max value we want for our digits - cDivides64 := (64 % c) == 0 // if c doesn't divide 64, we may need to select over multiple words - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := uint64(0); chunk < nbChunks; chunk++ { - jc := uint64(chunk * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } - - // for each chunk, we could track the number of non-zeros points we will need to process - // this way, if a chunk has more work to do than others, we can spawn off more go routines - // (at the cost of more buckets allocated) - // a simplified approach is to track the small values where only the first word is set - // if this number represent a significant number of points, then we will split first chunk - // processing in the msm in 2, to ensure all go routines finish at ~same time - // /!\ nbTasks is enough as parallel.Execute is not going to spawn more than nbTasks go routine - // if it does, though, this will deadlocK. - chSmallValues := make(chan int, nbTasks) - - parallel.Execute(len(scalars), func(start, end int) { - smallValues := 0 - for i := start; i < end; i++ { - var carry int - - scalar := scalars[i] - if scalarsMont { - scalar.FromMont() - } - if scalar.FitsOnOneWord() { - // everything is 0, no need to process this scalar - if scalar[0] == 0 { - continue - } - // low c-bits are 1 in mask - if scalar[0]&mask == scalar[0] { - smallValues++ - } - } - - // for each chunk in the scalar, compute the current digit, and an eventual carry - for chunk := uint64(0); chunk < nbChunks; chunk++ { - s := selectors[chunk] - - // init with carry if any - digit := carry - carry = 0 - - // digit = value of the c-bit window - digit += int((scalar[s.index] & s.mask) >> s.shift) - - if s.multiWordSelect { - // we are selecting bits over 2 words - digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh - } - - // if digit is zero, no impact on result - if digit == 0 { - continue - } - - // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract - // 2^{c} to the current digit, making it negative. - if digit >= max { - digit -= (1 << c) - carry = 1 - } - - var bits uint64 - if digit >= 0 { - bits = uint64(digit) - } else { - bits = uint64(-digit-1) | msbWindow - } - - toReturn[i][s.index] |= (bits << s.shift) - if s.multiWordSelect { - toReturn[i][s.index+1] |= (bits >> s.shiftHigh) - } - - } - } - - chSmallValues <- smallValues - - }, nbTasks) - - // aggregate small values - close(chSmallValues) - smallValues := 0 - for o := range chSmallValues { - smallValues += o - } - return toReturn, smallValues -} - // MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf // // This call return an error if len(scalars) != len(points) or if provided config is invalid. @@ -221,7 +84,7 @@ func (p *G1Jac) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.Mul // we split recursively until nbChunks(c) >= nbTasks, bestC := func(nbPoints int) uint64 { // implemented msmC methods (the c we use must be in this slice) - implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21} + implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} var C uint64 // approximate cost (in group operations) // cost = bits/c * (nbPoints + 2^{c}) @@ -229,118 +92,177 @@ func (p *G1Jac) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.Mul // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results min := math.MaxFloat64 for _, c := range implementedCs { - cc := fr.Limbs * 64 * (nbPoints + (1 << (c))) + cc := (fr.Bits + 1) * (nbPoints + (1 << c)) cost := float64(cc) / float64(c) if cost < min { min = cost C = c } } - // empirical, needs to be tuned. - // if C > 16 && nbPoints < 1 << 23 { - // C = 16 - // } return C } - var C uint64 - nbSplits := 1 - nbChunks := 0 - for nbChunks < config.NbTasks { - C = bestC(nbPoints) - nbChunks = int(fr.Limbs * 64 / C) // number of c-bit radixes in a scalar - if (fr.Limbs*64)%C != 0 { - nbChunks++ - } - nbChunks *= nbSplits - if nbChunks < config.NbTasks { - nbSplits <<= 1 - nbPoints >>= 1 + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // if we don't utilise all the tasks (CPU in the default case) that we could, let's see if it's worth it to split + if config.NbTasks > 1 && nbChunks < config.NbTasks { + // before spliting, let's see if we endup with more tasks than thread; + cSplit := bestC(nbPoints / 2) + nbChunksPostSplit := int(computeNbChunks(cSplit)) + nbTasksPostSplit := nbChunksPostSplit * 2 + if (nbTasksPostSplit <= config.NbTasks/2) || (nbTasksPostSplit-config.NbTasks/2) <= (config.NbTasks-nbChunks) { + // if postSplit we still have less tasks than available CPU + // or if we have more tasks BUT the difference of CPU usage is in our favor, we split. + config.NbTasks /= 2 + var _p G1Jac + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil } } + _innerMsmG1(p, C, points, scalars, config) + + return p, nil +} + +func _innerMsmG1(p *G1Jac, c uint64, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { // partition the scalars - // note: we do that before the actual chunk processing, as for each c-bit window (starting from LSW) - // if it's larger than 2^{c-1}, we have a carry we need to propagate up to the higher window - var smallValues int - scalars, smallValues = partitionScalars(scalars, C, config.ScalarsMont, config.NbTasks) - - // if we have more than 10% of small values, we split the processing of the first chunk in 2 - // we may want to do that in msmInnerG1Jac , but that would incur a cost of looping through all scalars one more time - splitFirstChunk := (float64(smallValues) / float64(len(scalars))) >= 0.1 - - // we have nbSplits intermediate results that we must sum together. - _p := make([]G1Jac, nbSplits-1) - chDone := make(chan int, nbSplits-1) - for i := 0; i < nbSplits-1; i++ { - start := i * nbPoints - end := start + nbPoints - go func(start, end, i int) { - msmInnerG1Jac(&_p[i], int(C), points[start:end], scalars[start:end], splitFirstChunk) - chDone <- i - }(start, end, i) + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) } - msmInnerG1Jac(p, int(C), points[(nbSplits-1)*nbPoints:], scalars[(nbSplits-1)*nbPoints:], splitFirstChunk) - for i := 0; i < nbSplits-1; i++ { - done := <-chDone - p.AddAssign(&_p[done]) + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessorG1(c, chunkStats[j]) + if j == int(nbChunks-1) { + processChunk = getChunkProcessorG1(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan g1JacExtended, 2) + split := n / 2 + go processChunk(uint64(j), chSplit, c, points[:split], digits[j*n:(j*n)+split]) + go processChunk(uint64(j), chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n]) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n]) } - close(chDone) - return p, nil -} -func msmInnerG1Jac(p *G1Jac, c int, points []G1Affine, scalars []fr.Element, splitFirstChunk bool) { + return msmReduceChunkG1Affine(p, int(c), chChunks[:]) +} +// getChunkProcessorG1 decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessorG1(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- g1JacExtended, c uint64, points []G1Affine, digits []uint16) { switch c { + case 2: + return processChunkG1Jacobian[bucketg1JacExtendedC2] case 4: - p.msmC4(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC4] case 5: - p.msmC5(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC5] case 6: - p.msmC6(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC6] case 7: - p.msmC7(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC7] case 8: - p.msmC8(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC8] case 9: - p.msmC9(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC9] case 10: - p.msmC10(points, scalars, splitFirstChunk) - + const batchSize = 80 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC10] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC10, bucketG1AffineC10, bitSetC10, pG1AffineC10, ppG1AffineC10, qG1AffineC10, cG1AffineC10] case 11: - p.msmC11(points, scalars, splitFirstChunk) - + const batchSize = 150 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC11] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC11, bucketG1AffineC11, bitSetC11, pG1AffineC11, ppG1AffineC11, qG1AffineC11, cG1AffineC11] case 12: - p.msmC12(points, scalars, splitFirstChunk) - + const batchSize = 200 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC12] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC12, bucketG1AffineC12, bitSetC12, pG1AffineC12, ppG1AffineC12, qG1AffineC12, cG1AffineC12] case 13: - p.msmC13(points, scalars, splitFirstChunk) - + const batchSize = 350 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC13] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC13, bucketG1AffineC13, bitSetC13, pG1AffineC13, ppG1AffineC13, qG1AffineC13, cG1AffineC13] case 14: - p.msmC14(points, scalars, splitFirstChunk) - + const batchSize = 400 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC14] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC14, bucketG1AffineC14, bitSetC14, pG1AffineC14, ppG1AffineC14, qG1AffineC14, cG1AffineC14] case 15: - p.msmC15(points, scalars, splitFirstChunk) - + const batchSize = 500 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC15] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC15, bucketG1AffineC15, bitSetC15, pG1AffineC15, ppG1AffineC15, qG1AffineC15, cG1AffineC15] case 16: - p.msmC16(points, scalars, splitFirstChunk) - - case 20: - p.msmC20(points, scalars, splitFirstChunk) - - case 21: - p.msmC21(points, scalars, splitFirstChunk) - + const batchSize = 640 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC16] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC16, bucketG1AffineC16, bitSetC16, pG1AffineC16, ppG1AffineC16, qG1AffineC16, cG1AffineC16] default: - panic("not implemented") + // panic("will not happen c != previous values is not generated by templates") + return processChunkG1Jacobian[bucketg1JacExtendedC16] } } @@ -360,1846 +282,445 @@ func msmReduceChunkG1Affine(p *G1Jac, c int, chChunks []chan g1JacExtended) *G1J return p.unsafeFromJacExtended(&_p) } -func msmProcessChunkG1Affine(chunk uint64, - chRes chan<- g1JacExtended, - buckets []g1JacExtended, - c uint64, - points []G1Affine, - scalars []fr.Element) { +// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf +// +// This call return an error if len(scalars) != len(points) or if provided config is invalid. +func (p *G2Affine) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Affine, error) { + var _p G2Jac + if _, err := _p.MultiExp(points, scalars, config); err != nil { + return nil, err + } + p.FromJacobian(&_p) + return p, nil +} - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) +// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf +// +// This call return an error if len(scalars) != len(points) or if provided config is invalid. +func (p *G2Jac) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Jac, error) { + // note: + // each of the msmCX method is the same, except for the c constant it declares + // duplicating (through template generation) these methods allows to declare the buckets on the stack + // the choice of c needs to be improved: + // there is a theoritical value that gives optimal asymptotics + // but in practice, other factors come into play, including: + // * if c doesn't divide 64, the word size, then we're bound to select bits over 2 words of our scalars, instead of 1 + // * number of CPUs + // * cache friendliness (which depends on the host, G1 or G2... ) + // --> for example, on BN254, a G1 point fits into one cache line of 64bytes, but a G2 point don't. - for i := 0; i < len(buckets); i++ { - buckets[i].setInfinity() - } + // for each msmCX + // step 1 + // we compute, for each scalars over c-bit wide windows, nbChunk digits + // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract + // 2^{c} to the current digit, making it negative. + // negative digits will be processed in the next step as adding -G into the bucket instead of G + // (computing -G is cheap, and this saves us half of the buckets) + // step 2 + // buckets are declared on the stack + // notice that we have 2^{c-1} buckets instead of 2^{c} (see step1) + // we use jacobian extended formulas here as they are faster than mixed addition + // msmProcessChunk places points into buckets base on their selector and return the weighted bucket sum in given channel + // step 3 + // reduce the buckets weigthed sums into our result (msmReduceChunk) - jc := uint64(chunk * c) - s := selector{} - s.index = jc / 64 - s.shift = jc - (s.index * 64) - s.mask = mask << s.shift - s.multiWordSelect = (64%c) != 0 && s.shift > (64-c) && s.index < (fr.Limbs-1) - if s.multiWordSelect { - nbBitsHigh := s.shift - uint64(64-c) - s.maskHigh = (1 << nbBitsHigh) - 1 - s.shiftHigh = (c - nbBitsHigh) + // ensure len(points) == len(scalars) + nbPoints := len(points) + if nbPoints != len(scalars) { + return nil, errors.New("len(points) != len(scalars)") } - // for each scalars, get the digit corresponding to the chunk we're processing. - for i := 0; i < len(scalars); i++ { - bits := (scalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (scalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { - continue - } + // if nbTasks is not set, use all available CPUs + if config.NbTasks <= 0 { + config.NbTasks = runtime.NumCPU() + } else if config.NbTasks > 1024 { + return nil, errors.New("invalid config: config.NbTasks > 1024") + } - // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { - // add - buckets[bits-1].addMixed(&points[i]) - } else { - // sub - buckets[bits & ^msbWindow].subMixed(&points[i]) + // here, we compute the best C for nbPoints + // we split recursively until nbChunks(c) >= nbTasks, + bestC := func(nbPoints int) uint64 { + // implemented msmC methods (the c we use must be in this slice) + implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + var C uint64 + // approximate cost (in group operations) + // cost = bits/c * (nbPoints + 2^{c}) + // this needs to be verified empirically. + // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results + min := math.MaxFloat64 + for _, c := range implementedCs { + cc := (fr.Bits + 1) * (nbPoints + (1 << c)) + cost := float64(cc) / float64(c) + if cost < min { + min = cost + C = c + } } + return C } - // reduce buckets into total - // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] - - var runningSum, total g1JacExtended - runningSum.setInfinity() - total.setInfinity() - for k := len(buckets) - 1; k >= 0; k-- { - if !buckets[k].ZZ.IsZero() { - runningSum.add(&buckets[k]) + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // if we don't utilise all the tasks (CPU in the default case) that we could, let's see if it's worth it to split + if config.NbTasks > 1 && nbChunks < config.NbTasks { + // before spliting, let's see if we endup with more tasks than thread; + cSplit := bestC(nbPoints / 2) + nbChunksPostSplit := int(computeNbChunks(cSplit)) + nbTasksPostSplit := nbChunksPostSplit * 2 + if (nbTasksPostSplit <= config.NbTasks/2) || (nbTasksPostSplit-config.NbTasks/2) <= (config.NbTasks-nbChunks) { + // if postSplit we still have less tasks than available CPU + // or if we have more tasks BUT the difference of CPU usage is in our favor, we split. + config.NbTasks /= 2 + var _p G2Jac + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil } - total.add(&runningSum) } - chRes <- total + _innerMsmG2(p, C, points, scalars, config) + return p, nil } -func (p *G1Jac) msmC4(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 4 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) +func _innerMsmG2(p *G2Jac, c uint64, points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G2Jac { + // partition the scalars + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) // for each chunk, spawn one go routine that'll loop through all the scalars in the // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + // note that buckets is an array allocated on the stack and this is critical for performance // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended + chChunks := make([]chan g2JacExtended, nbChunks) for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) + chChunks[i] = make(chan g2JacExtended, 1) } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessorG2(c, chunkStats[j]) + if j == int(nbChunks-1) { + processChunk = getChunkProcessorG2(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan g2JacExtended, 2) + split := n / 2 + go processChunk(uint64(j), chSplit, c, points[:split], digits[j*n:(j*n)+split]) + go processChunk(uint64(j), chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n]) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n]) } - return msmReduceChunkG1Affine(p, c, chChunks[:]) + return msmReduceChunkG2Affine(p, int(c), chChunks[:]) } -func (p *G1Jac) msmC5(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 5 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } +// getChunkProcessorG2 decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessorG2(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- g2JacExtended, c uint64, points []G2Affine, digits []uint16) { + switch c { - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + case 2: + return processChunkG2Jacobian[bucketg2JacExtendedC2] + case 4: + return processChunkG2Jacobian[bucketg2JacExtendedC4] + case 5: + return processChunkG2Jacobian[bucketg2JacExtendedC5] + case 6: + return processChunkG2Jacobian[bucketg2JacExtendedC6] + case 7: + return processChunkG2Jacobian[bucketg2JacExtendedC7] + case 8: + return processChunkG2Jacobian[bucketg2JacExtendedC8] + case 9: + return processChunkG2Jacobian[bucketg2JacExtendedC9] + case 10: + const batchSize = 80 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC10] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC10, bucketG2AffineC10, bitSetC10, pG2AffineC10, ppG2AffineC10, qG2AffineC10, cG2AffineC10] + case 11: + const batchSize = 150 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC11] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC11, bucketG2AffineC11, bitSetC11, pG2AffineC11, ppG2AffineC11, qG2AffineC11, cG2AffineC11] + case 12: + const batchSize = 200 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC12] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC12, bucketG2AffineC12, bitSetC12, pG2AffineC12, ppG2AffineC12, qG2AffineC12, cG2AffineC12] + case 13: + const batchSize = 350 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC13] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC13, bucketG2AffineC13, bitSetC13, pG2AffineC13, ppG2AffineC13, qG2AffineC13, cG2AffineC13] + case 14: + const batchSize = 400 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC14] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC14, bucketG2AffineC14, bitSetC14, pG2AffineC14, ppG2AffineC14, qG2AffineC14, cG2AffineC14] + case 15: + const batchSize = 500 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC15] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC15, bucketG2AffineC15, bitSetC15, pG2AffineC15, ppG2AffineC15, qG2AffineC15, cG2AffineC15] + case 16: + const batchSize = 640 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC16] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC16, bucketG2AffineC16, bitSetC16, pG2AffineC16, ppG2AffineC16, qG2AffineC16, cG2AffineC16] + default: + // panic("will not happen c != previous values is not generated by templates") + return processChunkG2Jacobian[bucketg2JacExtendedC16] } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) } -func (p *G1Jac) msmC6(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 6 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC7(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 7 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC8(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 8 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC9(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 9 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC10(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 10 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC11(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 11 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC12(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 12 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC13(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 13 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC14(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 14 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC15(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 15 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC16(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 16 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC20(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 20 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC21(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 21 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf -// -// This call return an error if len(scalars) != len(points) or if provided config is invalid. -func (p *G2Affine) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Affine, error) { - var _p G2Jac - if _, err := _p.MultiExp(points, scalars, config); err != nil { - return nil, err - } - p.FromJacobian(&_p) - return p, nil -} - -// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf -// -// This call return an error if len(scalars) != len(points) or if provided config is invalid. -func (p *G2Jac) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Jac, error) { - // note: - // each of the msmCX method is the same, except for the c constant it declares - // duplicating (through template generation) these methods allows to declare the buckets on the stack - // the choice of c needs to be improved: - // there is a theoritical value that gives optimal asymptotics - // but in practice, other factors come into play, including: - // * if c doesn't divide 64, the word size, then we're bound to select bits over 2 words of our scalars, instead of 1 - // * number of CPUs - // * cache friendliness (which depends on the host, G1 or G2... ) - // --> for example, on BN254, a G1 point fits into one cache line of 64bytes, but a G2 point don't. - - // for each msmCX - // step 1 - // we compute, for each scalars over c-bit wide windows, nbChunk digits - // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract - // 2^{c} to the current digit, making it negative. - // negative digits will be processed in the next step as adding -G into the bucket instead of G - // (computing -G is cheap, and this saves us half of the buckets) - // step 2 - // buckets are declared on the stack - // notice that we have 2^{c-1} buckets instead of 2^{c} (see step1) - // we use jacobian extended formulas here as they are faster than mixed addition - // msmProcessChunk places points into buckets base on their selector and return the weighted bucket sum in given channel - // step 3 - // reduce the buckets weigthed sums into our result (msmReduceChunk) - - // ensure len(points) == len(scalars) - nbPoints := len(points) - if nbPoints != len(scalars) { - return nil, errors.New("len(points) != len(scalars)") - } - - // if nbTasks is not set, use all available CPUs - if config.NbTasks <= 0 { - config.NbTasks = runtime.NumCPU() - } else if config.NbTasks > 1024 { - return nil, errors.New("invalid config: config.NbTasks > 1024") - } - - // here, we compute the best C for nbPoints - // we split recursively until nbChunks(c) >= nbTasks, - bestC := func(nbPoints int) uint64 { - // implemented msmC methods (the c we use must be in this slice) - implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21} - var C uint64 - // approximate cost (in group operations) - // cost = bits/c * (nbPoints + 2^{c}) - // this needs to be verified empirically. - // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results - min := math.MaxFloat64 - for _, c := range implementedCs { - cc := fr.Limbs * 64 * (nbPoints + (1 << (c))) - cost := float64(cc) / float64(c) - if cost < min { - min = cost - C = c - } - } - // empirical, needs to be tuned. - // if C > 16 && nbPoints < 1 << 23 { - // C = 16 - // } - return C - } - - var C uint64 - nbSplits := 1 - nbChunks := 0 - for nbChunks < config.NbTasks { - C = bestC(nbPoints) - nbChunks = int(fr.Limbs * 64 / C) // number of c-bit radixes in a scalar - if (fr.Limbs*64)%C != 0 { - nbChunks++ - } - nbChunks *= nbSplits - if nbChunks < config.NbTasks { - nbSplits <<= 1 - nbPoints >>= 1 - } - } - - // partition the scalars - // note: we do that before the actual chunk processing, as for each c-bit window (starting from LSW) - // if it's larger than 2^{c-1}, we have a carry we need to propagate up to the higher window - var smallValues int - scalars, smallValues = partitionScalars(scalars, C, config.ScalarsMont, config.NbTasks) - - // if we have more than 10% of small values, we split the processing of the first chunk in 2 - // we may want to do that in msmInnerG2Jac , but that would incur a cost of looping through all scalars one more time - splitFirstChunk := (float64(smallValues) / float64(len(scalars))) >= 0.1 - - // we have nbSplits intermediate results that we must sum together. - _p := make([]G2Jac, nbSplits-1) - chDone := make(chan int, nbSplits-1) - for i := 0; i < nbSplits-1; i++ { - start := i * nbPoints - end := start + nbPoints - go func(start, end, i int) { - msmInnerG2Jac(&_p[i], int(C), points[start:end], scalars[start:end], splitFirstChunk) - chDone <- i - }(start, end, i) - } - - msmInnerG2Jac(p, int(C), points[(nbSplits-1)*nbPoints:], scalars[(nbSplits-1)*nbPoints:], splitFirstChunk) - for i := 0; i < nbSplits-1; i++ { - done := <-chDone - p.AddAssign(&_p[done]) - } - close(chDone) - return p, nil -} - -func msmInnerG2Jac(p *G2Jac, c int, points []G2Affine, scalars []fr.Element, splitFirstChunk bool) { - - switch c { - - case 4: - p.msmC4(points, scalars, splitFirstChunk) - - case 5: - p.msmC5(points, scalars, splitFirstChunk) - - case 6: - p.msmC6(points, scalars, splitFirstChunk) - - case 7: - p.msmC7(points, scalars, splitFirstChunk) - - case 8: - p.msmC8(points, scalars, splitFirstChunk) - - case 9: - p.msmC9(points, scalars, splitFirstChunk) - - case 10: - p.msmC10(points, scalars, splitFirstChunk) - - case 11: - p.msmC11(points, scalars, splitFirstChunk) - - case 12: - p.msmC12(points, scalars, splitFirstChunk) - - case 13: - p.msmC13(points, scalars, splitFirstChunk) - - case 14: - p.msmC14(points, scalars, splitFirstChunk) - - case 15: - p.msmC15(points, scalars, splitFirstChunk) - - case 16: - p.msmC16(points, scalars, splitFirstChunk) - - case 20: - p.msmC20(points, scalars, splitFirstChunk) - - case 21: - p.msmC21(points, scalars, splitFirstChunk) - - default: - panic("not implemented") - } -} - -// msmReduceChunkG2Affine reduces the weighted sum of the buckets into the result of the multiExp -func msmReduceChunkG2Affine(p *G2Jac, c int, chChunks []chan g2JacExtended) *G2Jac { - var _p g2JacExtended - totalj := <-chChunks[len(chChunks)-1] - _p.Set(&totalj) - for j := len(chChunks) - 2; j >= 0; j-- { - for l := 0; l < c; l++ { - _p.double(&_p) - } - totalj := <-chChunks[j] - _p.add(&totalj) - } - - return p.unsafeFromJacExtended(&_p) -} - -func msmProcessChunkG2Affine(chunk uint64, - chRes chan<- g2JacExtended, - buckets []g2JacExtended, - c uint64, - points []G2Affine, - scalars []fr.Element) { - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) - - for i := 0; i < len(buckets); i++ { - buckets[i].setInfinity() - } - - jc := uint64(chunk * c) - s := selector{} - s.index = jc / 64 - s.shift = jc - (s.index * 64) - s.mask = mask << s.shift - s.multiWordSelect = (64%c) != 0 && s.shift > (64-c) && s.index < (fr.Limbs-1) - if s.multiWordSelect { - nbBitsHigh := s.shift - uint64(64-c) - s.maskHigh = (1 << nbBitsHigh) - 1 - s.shiftHigh = (c - nbBitsHigh) - } - - // for each scalars, get the digit corresponding to the chunk we're processing. - for i := 0; i < len(scalars); i++ { - bits := (scalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (scalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { - continue - } - - // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { - // add - buckets[bits-1].addMixed(&points[i]) - } else { - // sub - buckets[bits & ^msbWindow].subMixed(&points[i]) - } - } - - // reduce buckets into total - // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] - - var runningSum, total g2JacExtended - runningSum.setInfinity() - total.setInfinity() - for k := len(buckets) - 1; k >= 0; k-- { - if !buckets[k].ZZ.IsZero() { - runningSum.add(&buckets[k]) - } - total.add(&runningSum) - } - - chRes <- total - -} - -func (p *G2Jac) msmC4(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 4 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC5(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 5 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC6(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 6 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC7(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 7 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC8(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 8 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC9(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 9 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC10(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 10 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC11(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 11 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() +// msmReduceChunkG2Affine reduces the weighted sum of the buckets into the result of the multiExp +func msmReduceChunkG2Affine(p *G2Jac, c int, chChunks []chan g2JacExtended) *G2Jac { + var _p g2JacExtended + totalj := <-chChunks[len(chChunks)-1] + _p.Set(&totalj) + for j := len(chChunks) - 2; j >= 0; j-- { + for l := 0; l < c; l++ { + _p.double(&_p) + } + totalj := <-chChunks[j] + _p.add(&totalj) } - return msmReduceChunkG2Affine(p, c, chChunks[:]) + return p.unsafeFromJacExtended(&_p) } -func (p *G2Jac) msmC12(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 12 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } +// selector stores the index, mask and shifts needed to select bits from a scalar +// it is used during the multiExp algorithm or the batch scalar multiplication +type selector struct { + index uint64 // index in the multi-word scalar to select bits from + mask uint64 // mask (c-bit wide) + shift uint64 // shift needed to get our bits on low positions - return msmReduceChunkG2Affine(p, c, chChunks[:]) + multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) + maskHigh uint64 // same than mask, for index+1 + shiftHigh uint64 // same than shift, for index+1 } -func (p *G2Jac) msmC13(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 13 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) +// return number of chunks for a given window size c +// the last chunk may be bigger to accomodate a potential carry from the NAF decomposition +func computeNbChunks(c uint64) uint64 { + return (fr.Bits + c - 1) / c } -func (p *G2Jac) msmC14(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 14 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) +// return the last window size for a scalar; +// this last window should accomodate a carry (from the NAF decomposition) +// it can be == c if we have 1 available bit +// it can be > c if we have 0 available bit +// it can be < c if we have 2+ available bits +func lastC(c uint64) uint64 { + nbAvailableBits := (computeNbChunks(c) * c) - fr.Bits + return c + 1 - nbAvailableBits } -func (p *G2Jac) msmC15(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 15 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } +type chunkStat struct { + // relative weight of work compared to other chunks. 100.0 -> nominal weight. + weight float32 - return msmReduceChunkG2Affine(p, c, chChunks[:]) + // percentage of bucket filled in the window; + ppBucketFilled float32 + nbBucketFilled int } -func (p *G2Jac) msmC16(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 16 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } +// partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits +// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract +// 2^{c} to the current digit, making it negative. +// negative digits can be processed in a later step as adding -G into the bucket instead of G +// (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) +func partitionScalars(scalars []fr.Element, c uint64, nbTasks int) ([]uint16, []chunkStat) { + // number of c-bit radixes in a scalar + nbChunks := computeNbChunks(c) - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + digits := make([]uint16, len(scalars)*int(nbChunks)) - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } + mask := uint64((1 << c) - 1) // low c bits are 1 + max := int(1<<(c-1)) - 1 // max value (inclusive) we want for our digits + cDivides64 := (64 % c) == 0 // if c doesn't divide 64, we may need to select over multiple words - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + // compute offset and word selector / shift to select the right bits of our windows + selectors := make([]selector, nbChunks) + for chunk := uint64(0); chunk < nbChunks; chunk++ { + jc := uint64(chunk * c) + d := selector{} + d.index = jc / 64 + d.shift = jc - (d.index * 64) + d.mask = mask << d.shift + d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs-1) + if d.multiWordSelect { + nbBitsHigh := d.shift - uint64(64-c) + d.maskHigh = (1 << nbBitsHigh) - 1 + d.shiftHigh = (c - nbBitsHigh) + } + selectors[chunk] = d } - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} + parallel.Execute(len(scalars), func(start, end int) { + for i := start; i < end; i++ { + if scalars[i].IsZero() { + // everything is 0, no need to process this scalar + continue + } + scalar := scalars[i].Bits() -func (p *G2Jac) msmC20(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 20 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) + var carry int - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + // for each chunk in the scalar, compute the current digit, and an eventual carry + for chunk := uint64(0); chunk < nbChunks-1; chunk++ { + s := selectors[chunk] - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } + // init with carry if any + digit := carry + carry = 0 - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } + // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract + // 2^{c} to the current digit, making it negative. + if digit > max { + digit -= (1 << c) + carry = 1 + } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } + // if digit is zero, no impact on result + if digit == 0 { + continue + } - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} + var bits uint16 + if digit > 0 { + bits = uint16(digit) << 1 + } else { + bits = (uint16(-digit-1) << 1) + 1 + } + digits[int(chunk)*len(scalars)+i] = bits + } -func (p *G2Jac) msmC21(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 21 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) + // for the last chunk, we don't want to borrow from a next window + // (but may have a larger max value) + chunk := nbChunks - 1 + s := selectors[chunk] + // init with carry if any + digit := carry + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } + digits[int(chunk)*len(scalars)+i] = uint16(digit) << 1 + } - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + }, nbTasks) - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) + // aggregate chunk stats + chunkStats := make([]chunkStat, nbChunks) + if c <= 9 { + // no need to compute stats for small window sizes + return digits, chunkStats } + parallel.Execute(len(chunkStats), func(start, end int) { + // for each chunk compute the statistics + for chunkID := start; chunkID < end; chunkID++ { + // indicates if a bucket is hit. + var b bitSetC16 - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) + // digits for the chunk + chunkDigits := digits[chunkID*len(scalars) : (chunkID+1)*len(scalars)] - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + totalOps := 0 + nz := 0 // non zero buckets count + for _, digit := range chunkDigits { + if digit == 0 { + continue + } + totalOps++ + bucketID := digit >> 1 + if digit&1 == 0 { + bucketID -= 1 + } + if !b[bucketID] { + nz++ + b[bucketID] = true + } + } + chunkStats[chunkID].weight = float32(totalOps) // count number of ops for now, we will compute the weight after + chunkStats[chunkID].ppBucketFilled = (float32(nz) * 100.0) / float32(int(1<<(c-1))) + chunkStats[chunkID].nbBucketFilled = nz + } + }, nbTasks) - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) + totalOps := float32(0.0) + for _, stat := range chunkStats { + totalOps += stat.weight } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + target := totalOps / float32(nbChunks) + if target != 0.0 { + // if target == 0, it means all the scalars are 0 everywhere, there is no work to be done. + for i := 0; i < len(chunkStats); i++ { + chunkStats[i].weight = (chunkStats[i].weight * 100.0) / target + } } - return msmReduceChunkG2Affine(p, c, chChunks[:]) + return digits, chunkStats } diff --git a/ecc/bls12-377/multiexp_affine.go b/ecc/bls12-377/multiexp_affine.go new file mode 100644 index 000000000..5aa3546b5 --- /dev/null +++ b/ecc/bls12-377/multiexp_affine.go @@ -0,0 +1,686 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package bls12377 + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-377/fp" + "github.com/consensys/gnark-crypto/ecc/bls12-377/internal/fptower" +) + +type batchOpG1Affine struct { + bucketID uint16 + point G1Affine +} + +// processChunkG1BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/ConsenSys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunkG1BatchAffine[BJE ibg1JacExtended, B ibG1Affine, BS bitSet, TP pG1Affine, TPP ppG1Affine, TQ qOpsG1Affine, TC cG1Affine]( + chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16) { + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // g1JacExtended coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediatly + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in G1Affine used with the batch affine additions + // 1 in g1JacExtended used in case the queue of conflicting points + var buckets B + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + bucketsJE[i].setInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize } + + executeAndReset := func() { + batchAddG1Affine[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOpG1Affine) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.setInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *G1Affine, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.setInfinity() + } + return + } + if isAdd { + BK.setInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func() { + for i := 0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func() { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit >> 1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID -= 1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue)-1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total g1JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].ZZ.IsZero() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketG1AffineC10 [512]G1Affine +type bucketG1AffineC11 [1024]G1Affine +type bucketG1AffineC12 [2048]G1Affine +type bucketG1AffineC13 [4096]G1Affine +type bucketG1AffineC14 [8192]G1Affine +type bucketG1AffineC15 [16384]G1Affine +type bucketG1AffineC16 [32768]G1Affine + +// buckets: array of G1Affine points of size 1 << (c-1) +type ibG1Affine interface { + bucketG1AffineC10 | + bucketG1AffineC11 | + bucketG1AffineC12 | + bucketG1AffineC13 | + bucketG1AffineC14 | + bucketG1AffineC15 | + bucketG1AffineC16 +} + +// array of coordinates fp.Element +type cG1Affine interface { + cG1AffineC10 | + cG1AffineC11 | + cG1AffineC12 | + cG1AffineC13 | + cG1AffineC14 | + cG1AffineC15 | + cG1AffineC16 +} + +// buckets: array of G1Affine points (for the batch addition) +type pG1Affine interface { + pG1AffineC10 | + pG1AffineC11 | + pG1AffineC12 | + pG1AffineC13 | + pG1AffineC14 | + pG1AffineC15 | + pG1AffineC16 +} + +// buckets: array of *G1Affine points (for the batch addition) +type ppG1Affine interface { + ppG1AffineC10 | + ppG1AffineC11 | + ppG1AffineC12 | + ppG1AffineC13 | + ppG1AffineC14 | + ppG1AffineC15 | + ppG1AffineC16 +} + +// buckets: array of G1Affine queue operations (for the batch addition) +type qOpsG1Affine interface { + qG1AffineC10 | + qG1AffineC11 | + qG1AffineC12 | + qG1AffineC13 | + qG1AffineC14 | + qG1AffineC15 | + qG1AffineC16 +} + +// batch size 80 when c = 10 +type cG1AffineC10 [80]fp.Element +type pG1AffineC10 [80]G1Affine +type ppG1AffineC10 [80]*G1Affine +type qG1AffineC10 [80]batchOpG1Affine + +// batch size 150 when c = 11 +type cG1AffineC11 [150]fp.Element +type pG1AffineC11 [150]G1Affine +type ppG1AffineC11 [150]*G1Affine +type qG1AffineC11 [150]batchOpG1Affine + +// batch size 200 when c = 12 +type cG1AffineC12 [200]fp.Element +type pG1AffineC12 [200]G1Affine +type ppG1AffineC12 [200]*G1Affine +type qG1AffineC12 [200]batchOpG1Affine + +// batch size 350 when c = 13 +type cG1AffineC13 [350]fp.Element +type pG1AffineC13 [350]G1Affine +type ppG1AffineC13 [350]*G1Affine +type qG1AffineC13 [350]batchOpG1Affine + +// batch size 400 when c = 14 +type cG1AffineC14 [400]fp.Element +type pG1AffineC14 [400]G1Affine +type ppG1AffineC14 [400]*G1Affine +type qG1AffineC14 [400]batchOpG1Affine + +// batch size 500 when c = 15 +type cG1AffineC15 [500]fp.Element +type pG1AffineC15 [500]G1Affine +type ppG1AffineC15 [500]*G1Affine +type qG1AffineC15 [500]batchOpG1Affine + +// batch size 640 when c = 16 +type cG1AffineC16 [640]fp.Element +type pG1AffineC16 [640]G1Affine +type ppG1AffineC16 [640]*G1Affine +type qG1AffineC16 [640]batchOpG1Affine + +type batchOpG2Affine struct { + bucketID uint16 + point G2Affine +} + +// processChunkG2BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/ConsenSys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunkG2BatchAffine[BJE ibg2JacExtended, B ibG2Affine, BS bitSet, TP pG2Affine, TPP ppG2Affine, TQ qOpsG2Affine, TC cG2Affine]( + chunk uint64, + chRes chan<- g2JacExtended, + c uint64, + points []G2Affine, + digits []uint16) { + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // g2JacExtended coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediatly + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in G2Affine used with the batch affine additions + // 1 in g2JacExtended used in case the queue of conflicting points + var buckets B + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + bucketsJE[i].setInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize } + + executeAndReset := func() { + batchAddG2Affine[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOpG2Affine) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.setInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *G2Affine, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.setInfinity() + } + return + } + if isAdd { + BK.setInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func() { + for i := 0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func() { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit >> 1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID -= 1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue)-1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total g2JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].ZZ.IsZero() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketG2AffineC10 [512]G2Affine +type bucketG2AffineC11 [1024]G2Affine +type bucketG2AffineC12 [2048]G2Affine +type bucketG2AffineC13 [4096]G2Affine +type bucketG2AffineC14 [8192]G2Affine +type bucketG2AffineC15 [16384]G2Affine +type bucketG2AffineC16 [32768]G2Affine + +// buckets: array of G2Affine points of size 1 << (c-1) +type ibG2Affine interface { + bucketG2AffineC10 | + bucketG2AffineC11 | + bucketG2AffineC12 | + bucketG2AffineC13 | + bucketG2AffineC14 | + bucketG2AffineC15 | + bucketG2AffineC16 +} + +// array of coordinates fptower.E2 +type cG2Affine interface { + cG2AffineC10 | + cG2AffineC11 | + cG2AffineC12 | + cG2AffineC13 | + cG2AffineC14 | + cG2AffineC15 | + cG2AffineC16 +} + +// buckets: array of G2Affine points (for the batch addition) +type pG2Affine interface { + pG2AffineC10 | + pG2AffineC11 | + pG2AffineC12 | + pG2AffineC13 | + pG2AffineC14 | + pG2AffineC15 | + pG2AffineC16 +} + +// buckets: array of *G2Affine points (for the batch addition) +type ppG2Affine interface { + ppG2AffineC10 | + ppG2AffineC11 | + ppG2AffineC12 | + ppG2AffineC13 | + ppG2AffineC14 | + ppG2AffineC15 | + ppG2AffineC16 +} + +// buckets: array of G2Affine queue operations (for the batch addition) +type qOpsG2Affine interface { + qG2AffineC10 | + qG2AffineC11 | + qG2AffineC12 | + qG2AffineC13 | + qG2AffineC14 | + qG2AffineC15 | + qG2AffineC16 +} + +// batch size 80 when c = 10 +type cG2AffineC10 [80]fptower.E2 +type pG2AffineC10 [80]G2Affine +type ppG2AffineC10 [80]*G2Affine +type qG2AffineC10 [80]batchOpG2Affine + +// batch size 150 when c = 11 +type cG2AffineC11 [150]fptower.E2 +type pG2AffineC11 [150]G2Affine +type ppG2AffineC11 [150]*G2Affine +type qG2AffineC11 [150]batchOpG2Affine + +// batch size 200 when c = 12 +type cG2AffineC12 [200]fptower.E2 +type pG2AffineC12 [200]G2Affine +type ppG2AffineC12 [200]*G2Affine +type qG2AffineC12 [200]batchOpG2Affine + +// batch size 350 when c = 13 +type cG2AffineC13 [350]fptower.E2 +type pG2AffineC13 [350]G2Affine +type ppG2AffineC13 [350]*G2Affine +type qG2AffineC13 [350]batchOpG2Affine + +// batch size 400 when c = 14 +type cG2AffineC14 [400]fptower.E2 +type pG2AffineC14 [400]G2Affine +type ppG2AffineC14 [400]*G2Affine +type qG2AffineC14 [400]batchOpG2Affine + +// batch size 500 when c = 15 +type cG2AffineC15 [500]fptower.E2 +type pG2AffineC15 [500]G2Affine +type ppG2AffineC15 [500]*G2Affine +type qG2AffineC15 [500]batchOpG2Affine + +// batch size 640 when c = 16 +type cG2AffineC16 [640]fptower.E2 +type pG2AffineC16 [640]G2Affine +type ppG2AffineC16 [640]*G2Affine +type qG2AffineC16 [640]batchOpG2Affine + +type bitSetC2 [2]bool +type bitSetC4 [8]bool +type bitSetC5 [16]bool +type bitSetC6 [32]bool +type bitSetC7 [64]bool +type bitSetC8 [128]bool +type bitSetC9 [256]bool +type bitSetC10 [512]bool +type bitSetC11 [1024]bool +type bitSetC12 [2048]bool +type bitSetC13 [4096]bool +type bitSetC14 [8192]bool +type bitSetC15 [16384]bool +type bitSetC16 [32768]bool + +type bitSet interface { + bitSetC2 | + bitSetC4 | + bitSetC5 | + bitSetC6 | + bitSetC7 | + bitSetC8 | + bitSetC9 | + bitSetC10 | + bitSetC11 | + bitSetC12 | + bitSetC13 | + bitSetC14 | + bitSetC15 | + bitSetC16 +} diff --git a/ecc/bls12-377/multiexp_jacobian.go b/ecc/bls12-377/multiexp_jacobian.go new file mode 100644 index 000000000..e01f5567f --- /dev/null +++ b/ecc/bls12-377/multiexp_jacobian.go @@ -0,0 +1,171 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package bls12377 + +func processChunkG1Jacobian[B ibg1JacExtended](chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16) { + + var buckets B + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to substract + if digit&1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit >> 1)].subMixed(&points[i]) + } + } + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total g1JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].ZZ.IsZero() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketg1JacExtendedC2 [2]g1JacExtended +type bucketg1JacExtendedC4 [8]g1JacExtended +type bucketg1JacExtendedC5 [16]g1JacExtended +type bucketg1JacExtendedC6 [32]g1JacExtended +type bucketg1JacExtendedC7 [64]g1JacExtended +type bucketg1JacExtendedC8 [128]g1JacExtended +type bucketg1JacExtendedC9 [256]g1JacExtended +type bucketg1JacExtendedC10 [512]g1JacExtended +type bucketg1JacExtendedC11 [1024]g1JacExtended +type bucketg1JacExtendedC12 [2048]g1JacExtended +type bucketg1JacExtendedC13 [4096]g1JacExtended +type bucketg1JacExtendedC14 [8192]g1JacExtended +type bucketg1JacExtendedC15 [16384]g1JacExtended +type bucketg1JacExtendedC16 [32768]g1JacExtended + +type ibg1JacExtended interface { + bucketg1JacExtendedC2 | + bucketg1JacExtendedC4 | + bucketg1JacExtendedC5 | + bucketg1JacExtendedC6 | + bucketg1JacExtendedC7 | + bucketg1JacExtendedC8 | + bucketg1JacExtendedC9 | + bucketg1JacExtendedC10 | + bucketg1JacExtendedC11 | + bucketg1JacExtendedC12 | + bucketg1JacExtendedC13 | + bucketg1JacExtendedC14 | + bucketg1JacExtendedC15 | + bucketg1JacExtendedC16 +} + +func processChunkG2Jacobian[B ibg2JacExtended](chunk uint64, + chRes chan<- g2JacExtended, + c uint64, + points []G2Affine, + digits []uint16) { + + var buckets B + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to substract + if digit&1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit >> 1)].subMixed(&points[i]) + } + } + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total g2JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].ZZ.IsZero() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketg2JacExtendedC2 [2]g2JacExtended +type bucketg2JacExtendedC4 [8]g2JacExtended +type bucketg2JacExtendedC5 [16]g2JacExtended +type bucketg2JacExtendedC6 [32]g2JacExtended +type bucketg2JacExtendedC7 [64]g2JacExtended +type bucketg2JacExtendedC8 [128]g2JacExtended +type bucketg2JacExtendedC9 [256]g2JacExtended +type bucketg2JacExtendedC10 [512]g2JacExtended +type bucketg2JacExtendedC11 [1024]g2JacExtended +type bucketg2JacExtendedC12 [2048]g2JacExtended +type bucketg2JacExtendedC13 [4096]g2JacExtended +type bucketg2JacExtendedC14 [8192]g2JacExtended +type bucketg2JacExtendedC15 [16384]g2JacExtended +type bucketg2JacExtendedC16 [32768]g2JacExtended + +type ibg2JacExtended interface { + bucketg2JacExtendedC2 | + bucketg2JacExtendedC4 | + bucketg2JacExtendedC5 | + bucketg2JacExtendedC6 | + bucketg2JacExtendedC7 | + bucketg2JacExtendedC8 | + bucketg2JacExtendedC9 | + bucketg2JacExtendedC10 | + bucketg2JacExtendedC11 | + bucketg2JacExtendedC12 | + bucketg2JacExtendedC13 | + bucketg2JacExtendedC14 | + bucketg2JacExtendedC15 | + bucketg2JacExtendedC16 +} diff --git a/ecc/bls12-377/multiexp_test.go b/ecc/bls12-377/multiexp_test.go index 7f87124d9..6d11510f9 100644 --- a/ecc/bls12-377/multiexp_test.go +++ b/ecc/bls12-377/multiexp_test.go @@ -20,9 +20,11 @@ import ( "fmt" "math/big" "math/bits" + "math/rand" "runtime" "sync" "testing" + "time" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" @@ -34,9 +36,9 @@ func TestMultiExpG1(t *testing.T) { parameters := gopter.DefaultTestParameters() if testing.Short() { - parameters.MinSuccessfulTests = 2 + parameters.MinSuccessfulTests = 3 } else { - parameters.MinSuccessfulTests = nbFuzzShort + parameters.MinSuccessfulTests = nbFuzzShort * 2 } properties := gopter.NewProperties(parameters) @@ -55,6 +57,13 @@ func TestMultiExpG1(t *testing.T) { g.AddAssign(&g1Gen) } + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + // final scalar to use in double and add method (without mixer factor) // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) var scalar big.Int @@ -78,13 +87,10 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - scalars16, _ := partitionScalars(sampleScalars[:], 16, false, runtime.NumCPU()) - r16.msmC16(samplePoints[:], scalars16, true) - + r16.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) return r16.Equal(&splitted1) && r16.Equal(&splitted2) @@ -93,7 +99,7 @@ func TestMultiExpG1(t *testing.T) { )) // cRange is generated from template and contains the available parameters for the multiexp window size - cRange := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21} + cRange := []uint64{2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} if testing.Short() { // test only "odd" and "even" (ie windows size divide word size vs not) cRange = []uint64{5, 16} @@ -114,21 +120,72 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - results := make([]G1Jac, len(cRange)+1) + results := make([]G1Jac, len(cRange)) for i, c := range cRange { - scalars, _ := partitionScalars(sampleScalars[:], c, false, runtime.NumCPU()) - msmInnerG1Jac(&results[i], int(c), samplePoints[:], scalars, false) - if c == 16 { - // split the first chunk - msmInnerG1Jac(&results[len(results)-1], 16, samplePoints[:], scalars, true) - } + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) } for i := 1; i < len(results); i++ { if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1], cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]G1Affine + + var expected G1Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) + expected.ScalarMultiplication(&g1Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].setInfinity() + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) return false } } @@ -151,8 +208,7 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= 30; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) samplePoints[i-1].FromJacobian(&g) g.AddAssign(&g1Gen) } @@ -175,6 +231,87 @@ func TestMultiExpG1(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestCrossMultiExpG1(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]G1Affine + var g G1Jac + g.Set(&g1Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g1Gen) + } + + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i := 10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + cRange := []uint64{2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + if testing.Short() { + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange = []uint64{5, 16} + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r G1Jac + _innerMsmG1Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got G1Affine + expected.FromJacobian(&r) + + for i := 0; i < len(results); i++ { + got.FromJacobian(&results[i]) + if !expected.Equal(&got) { + t.Fatalf("cross msm failed with c=%d", cRange[i]) + } + } + +} + +// _innerMsmG1Reference always do ext jacobian with c == 16 +func _innerMsmG1Reference(p *G1Jac, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { + // partition the scalars + digits, _ := partitionScalars(scalars, 16, config.NbTasks) + + nbChunks := computeNbChunks(16) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := processChunkG1Jacobian[bucketg1JacExtendedC16] + go processChunk(uint64(j), chChunks[j], 16, points, digits[j*n:(j+1)*n]) + } + + return msmReduceChunkG1Affine(p, int(16), chChunks[:]) +} + func BenchmarkMultiExpG1(b *testing.B) { const ( @@ -183,11 +320,33 @@ func BenchmarkMultiExpG1(b *testing.B) { ) var ( - samplePoints [nbSamples]G1Affine - sampleScalars [nbSamples]fr.Element + samplePoints [nbSamples]G1Affine + sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element ) fillBenchScalars(sampleScalars[:]) + copy(sampleScalarsSmallValues[:], sampleScalars[:]) + copy(sampleScalarsRedundant[:], sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i := 0; i < len(sampleScalarsSmallValues); i++ { + if i%5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i := 0; i < len(sampleScalarsRedundant); i += 100 { + for j := i + 1; j < i+100 && j < len(sampleScalarsRedundant); j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + fillBenchBasesG1(samplePoints[:]) var testPoint G1Affine @@ -201,6 +360,20 @@ func BenchmarkMultiExpG1(b *testing.B) { testPoint.MultiExp(samplePoints[:using], sampleScalars[:using], ecc.MultiExpConfig{}) } }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using], ecc.MultiExpConfig{}) + } + }) } } @@ -279,9 +452,9 @@ func TestMultiExpG2(t *testing.T) { parameters := gopter.DefaultTestParameters() if testing.Short() { - parameters.MinSuccessfulTests = 2 + parameters.MinSuccessfulTests = 3 } else { - parameters.MinSuccessfulTests = nbFuzzShort + parameters.MinSuccessfulTests = nbFuzzShort * 2 } properties := gopter.NewProperties(parameters) @@ -300,6 +473,13 @@ func TestMultiExpG2(t *testing.T) { g.AddAssign(&g2Gen) } + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + // final scalar to use in double and add method (without mixer factor) // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) var scalar big.Int @@ -323,13 +503,10 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - scalars16, _ := partitionScalars(sampleScalars[:], 16, false, runtime.NumCPU()) - r16.msmC16(samplePoints[:], scalars16, true) - + r16.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) return r16.Equal(&splitted1) && r16.Equal(&splitted2) @@ -357,21 +534,72 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - results := make([]G2Jac, len(cRange)+1) + results := make([]G2Jac, len(cRange)) for i, c := range cRange { - scalars, _ := partitionScalars(sampleScalars[:], c, false, runtime.NumCPU()) - msmInnerG2Jac(&results[i], int(c), samplePoints[:], scalars, false) - if c == 16 { - // split the first chunk - msmInnerG2Jac(&results[len(results)-1], 16, samplePoints[:], scalars, true) - } + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) } for i := 1; i < len(results); i++ { if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1], cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G2] Multi exponentation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]G2Affine + + var expected G2Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) + expected.ScalarMultiplication(&g2Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].setInfinity() + } + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G2] Multi exponentation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) return false } } @@ -394,8 +622,7 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= 30; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) samplePoints[i-1].FromJacobian(&g) g.AddAssign(&g2Gen) } @@ -418,6 +645,85 @@ func TestMultiExpG2(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestCrossMultiExpG2(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]G2Affine + var g G2Jac + g.Set(&g2Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g2Gen) + } + + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i := 10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + // for g2, CI suffers with large c size since it needs to allocate a lot of memory for the buckets. + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange := []uint64{5, 16} + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r G2Jac + _innerMsmG2Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got G2Affine + expected.FromJacobian(&r) + + for i := 0; i < len(results); i++ { + got.FromJacobian(&results[i]) + if !expected.Equal(&got) { + t.Fatalf("cross msm failed with c=%d", cRange[i]) + } + } + +} + +// _innerMsmG2Reference always do ext jacobian with c == 16 +func _innerMsmG2Reference(p *G2Jac, points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G2Jac { + // partition the scalars + digits, _ := partitionScalars(scalars, 16, config.NbTasks) + + nbChunks := computeNbChunks(16) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g2JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g2JacExtended, 1) + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := processChunkG2Jacobian[bucketg2JacExtendedC16] + go processChunk(uint64(j), chChunks[j], 16, points, digits[j*n:(j+1)*n]) + } + + return msmReduceChunkG2Affine(p, int(16), chChunks[:]) +} + func BenchmarkMultiExpG2(b *testing.B) { const ( @@ -426,11 +732,33 @@ func BenchmarkMultiExpG2(b *testing.B) { ) var ( - samplePoints [nbSamples]G2Affine - sampleScalars [nbSamples]fr.Element + samplePoints [nbSamples]G2Affine + sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element ) fillBenchScalars(sampleScalars[:]) + copy(sampleScalarsSmallValues[:], sampleScalars[:]) + copy(sampleScalarsRedundant[:], sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i := 0; i < len(sampleScalarsSmallValues); i++ { + if i%5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i := 0; i < len(sampleScalarsRedundant); i += 100 { + for j := i + 1; j < i+100 && j < len(sampleScalarsRedundant); j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + fillBenchBasesG2(samplePoints[:]) var testPoint G2Affine @@ -444,6 +772,20 @@ func BenchmarkMultiExpG2(b *testing.B) { testPoint.MultiExp(samplePoints[:using], sampleScalars[:using], ecc.MultiExpConfig{}) } }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using], ecc.MultiExpConfig{}) + } + }) } } @@ -520,11 +862,7 @@ func fillBenchBasesG2(samplePoints []G2Affine) { func fillBenchScalars(sampleScalars []fr.Element) { // ensure every words of the scalars are filled - var mixer fr.Element - mixer.SetString("7716837800905789770901243404444209691916730933998574719964609384059111546487") - for i := 1; i <= len(sampleScalars); i++ { - sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + for i := 0; i < len(sampleScalars); i++ { + sampleScalars[i].SetRandom() } } diff --git a/ecc/bls12-377/twistededwards/eddsa/doc.go b/ecc/bls12-377/twistededwards/eddsa/doc.go index 568432e87..af7da2990 100644 --- a/ecc/bls12-377/twistededwards/eddsa/doc.go +++ b/ecc/bls12-377/twistededwards/eddsa/doc.go @@ -16,7 +16,7 @@ // Package eddsa provides EdDSA signature scheme on bls12-377's twisted edwards curve. // -// See also +// # See also // // https://en.wikipedia.org/wiki/EdDSA package eddsa diff --git a/ecc/bls12-377/twistededwards/eddsa/eddsa_test.go b/ecc/bls12-377/twistededwards/eddsa/eddsa_test.go index aaf4b94cf..1175d7458 100644 --- a/ecc/bls12-377/twistededwards/eddsa/eddsa_test.go +++ b/ecc/bls12-377/twistededwards/eddsa/eddsa_test.go @@ -37,8 +37,10 @@ func Example() { privateKey, _ := GenerateKey(crand.Reader) publicKey := privateKey.PublicKey - // note that the message is on 4 bytes - msg := []byte{0xde, 0xad, 0xf0, 0x0d} + // generate a message (the size must be a multiple of the size of Fr) + var _msg fr.Element + _msg.SetRandom() + msg := _msg.Marshal() // sign the message signature, _ := privateKey.Sign(msg, hFunc) diff --git a/ecc/bls12-377/twistededwards/eddsa/marshal.go b/ecc/bls12-377/twistededwards/eddsa/marshal.go index c68129087..485e9b710 100644 --- a/ecc/bls12-377/twistededwards/eddsa/marshal.go +++ b/ecc/bls12-377/twistededwards/eddsa/marshal.go @@ -94,11 +94,11 @@ func (privKey *PrivateKey) SetBytes(buf []byte) (int, error) { // Bytes returns the binary representation of sig // as a byte array of size 3*sizeFr x||y||s where -// * x, y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x, y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) func (sig *Signature) Bytes() []byte { var res [sizeSignature]byte sigRBin := sig.R.Bytes() @@ -109,11 +109,12 @@ func (sig *Signature) Bytes() []byte { // SetBytes sets sig from a buffer in binary. // buf is read interpreted as x||y||s where -// * x,y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x,y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) +// // It returns the number of bytes read from buf. func (sig *Signature) SetBytes(buf []byte) (int, error) { n := 0 diff --git a/ecc/bls12-377/twistededwards/point.go b/ecc/bls12-377/twistededwards/point.go index 808cfbee5..9ec9e0741 100644 --- a/ecc/bls12-377/twistededwards/point.go +++ b/ecc/bls12-377/twistededwards/point.go @@ -49,7 +49,7 @@ const ( mUnmask = 0x7f // size in byte of a compressed point (point.Y --> fr.Element) - sizePointCompressed = fr.Limbs * 8 + sizePointCompressed = fr.Bytes ) // Bytes returns the compressed point as a byte array diff --git a/ecc/bls12-378/bls12-378.go b/ecc/bls12-378/bls12-378.go index 5b6cc91c3..e1f438b29 100644 --- a/ecc/bls12-378/bls12-378.go +++ b/ecc/bls12-378/bls12-378.go @@ -1,23 +1,29 @@ // Package bls12378 efficient elliptic curve, pairing and hash to curve implementation for bls12-378. // // bls12-378: A Barreto--Lynn--Scott curve -// embedding degree k=12 -// seed x₀=11045256207009841153 -// 𝔽r: r=14883435066912132899950318861128167269793560281114003360875131245101026639873 (x₀⁴-x₀²+1) -// 𝔽p: p=605248206075306171733248481581800960739847691770924913753520744034740935903401304776283802348837311170974282940417 ((x₀-1)² ⋅ r(x₀)/3+x₀) -// (E/𝔽p): Y²=X³+1 -// (Eₜ/𝔽p²): Y² = X³+u (M-type twist) -// r ∣ #E(Fp) and r ∣ #Eₜ(𝔽p²) +// +// embedding degree k=12 +// seed x₀=11045256207009841153 +// 𝔽r: r=14883435066912132899950318861128167269793560281114003360875131245101026639873 (x₀⁴-x₀²+1) +// 𝔽p: p=605248206075306171733248481581800960739847691770924913753520744034740935903401304776283802348837311170974282940417 ((x₀-1)² ⋅ r(x₀)/3+x₀) +// (E/𝔽p): Y²=X³+1 +// (Eₜ/𝔽p²): Y² = X³+u (M-type twist) +// r ∣ #E(Fp) and r ∣ #Eₜ(𝔽p²) +// // Extension fields tower: -// 𝔽p²[u] = 𝔽p/u²+5 -// 𝔽p⁶[v] = 𝔽p²/v³-u -// 𝔽p¹²[w] = 𝔽p⁶/w²-v +// +// 𝔽p²[u] = 𝔽p/u²+5 +// 𝔽p⁶[v] = 𝔽p²/v³-u +// 𝔽p¹²[w] = 𝔽p⁶/w²-v +// // optimal Ate loop size: -// x₀ +// +// x₀ +// // Security: estimated 126-bit level following [https://eprint.iacr.org/2019/885.pdf] // (r is 254 bits and p¹² is 4536 bits) // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package bls12378 diff --git a/ecc/bls12-378/fp/doc.go b/ecc/bls12-378/fp/doc.go index 70fbca511..3068596ed 100644 --- a/ecc/bls12-378/fp/doc.go +++ b/ecc/bls12-378/fp/doc.go @@ -21,30 +21,33 @@ // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [6]uint64 // -// Usage +// type Element [6]uint64 +// +// # Usage // // Example API signature: -// // Mul z = x * y (mod q) -// func (z *Element) Mul(x, y *Element) *Element +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) // // Modulus q = // -// q[base10] = 605248206075306171733248481581800960739847691770924913753520744034740935903401304776283802348837311170974282940417 -// q[base16] = 0x3eeb0416684d19053cb5d240ed107a284059eb647102326980dc360d0a49d7fce97f76a822c00009948a20000000001 +// q[base10] = 605248206075306171733248481581800960739847691770924913753520744034740935903401304776283802348837311170974282940417 +// q[base16] = 0x3eeb0416684d19053cb5d240ed107a284059eb647102326980dc360d0a49d7fce97f76a822c00009948a20000000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package fp diff --git a/ecc/bls12-378/fp/element.go b/ecc/bls12-378/fp/element.go index 4161fec82..c77191c88 100644 --- a/ecc/bls12-378/fp/element.go +++ b/ecc/bls12-378/fp/element.go @@ -20,13 +20,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "github.com/consensys/gnark-crypto/field" "io" "math/big" "math/bits" "reflect" "strconv" "strings" - "sync" ) // Element represents a field element stored on 6 words (uint64) @@ -35,18 +35,18 @@ import ( // // Modulus q = // -// q[base10] = 605248206075306171733248481581800960739847691770924913753520744034740935903401304776283802348837311170974282940417 -// q[base16] = 0x3eeb0416684d19053cb5d240ed107a284059eb647102326980dc360d0a49d7fce97f76a822c00009948a20000000001 +// q[base10] = 605248206075306171733248481581800960739847691770924913753520744034740935903401304776283802348837311170974282940417 +// q[base16] = 0x3eeb0416684d19053cb5d240ed107a284059eb647102326980dc360d0a49d7fce97f76a822c00009948a20000000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [6]uint64 const ( - Limbs = 6 // number of 64 bits words needed to represent a Element - Bits = 378 // number of bits needed to represent a Element - Bytes = Limbs * 8 // number of bytes needed to represent a Element + Limbs = 6 // number of 64 bits words needed to represent a Element + Bits = 378 // number of bits needed to represent a Element + Bytes = 48 // number of bytes needed to represent a Element ) // Field modulus q @@ -72,8 +72,8 @@ var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int // -// q[base10] = 605248206075306171733248481581800960739847691770924913753520744034740935903401304776283802348837311170974282940417 -// q[base16] = 0x3eeb0416684d19053cb5d240ed107a284059eb647102326980dc360d0a49d7fce97f76a822c00009948a20000000001 +// q[base10] = 605248206075306171733248481581800960739847691770924913753520744034740935903401304776283802348837311170974282940417 +// q[base16] = 0x3eeb0416684d19053cb5d240ed107a284059eb647102326980dc360d0a49d7fce97f76a822c00009948a20000000001 func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } @@ -82,12 +82,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 11045256207009841151 -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("3eeb0416684d19053cb5d240ed107a284059eb647102326980dc360d0a49d7fce97f76a822c00009948a20000000001", 16) } @@ -95,8 +89,9 @@ func init() { // NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v Element -// v.SetUint64(...) +// +// var v Element +// v.SetUint64(...) func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) @@ -107,7 +102,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -139,14 +134,15 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported // supported types: -// Element -// *Element -// uint64 -// int -// string (see SetString for valid formats) -// *big.Int -// big.Int -// []byte +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { if i1 == nil { return nil, errors.New("can't set fp.Element with ") @@ -260,15 +256,13 @@ func (z *Element) IsOne() bool { // IsUint64 reports whether z can be represented as an uint64. func (z *Element) IsUint64() bool { zz := *z - zz.FromMont() + zz.fromMont() return zz.FitsOnOneWord() } // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -280,15 +274,12 @@ func (z *Element) FitsOnOneWord() bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[5] > _x[5] { return 1 } else if _z[5] < _x[5] { @@ -329,8 +320,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 5522628103504920577, 0) @@ -429,67 +419,9 @@ func (z *Element) Halve() { } -// Mul z = x * y (mod q) -// -// x and y must be strictly inferior to q -func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be strictly inferior to q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } @@ -505,7 +437,7 @@ func (z *Element) Add(x, y *Element) *Element { z[4], carry = bits.Add64(x[4], y[4], carry) z[5], _ = bits.Add64(x[5], y[5], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -529,7 +461,7 @@ func (z *Element) Double(x *Element) *Element { z[4], carry = bits.Add64(x[4], x[4], carry) z[5], _ = bits.Add64(x[5], x[5], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -592,115 +524,219 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // see Mul for algorithm documentation - var t [6]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd1(v, y[4], c[1]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd1(v, y[5], c[1]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 4 - v := x[4] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 5 - v := x[5] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], z[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], z[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - z[5], z[4] = madd3(m, q5, c[0], c[2], c[1]) - } - - // if z >= q → z -= q + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [7]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + C, t[4] = madd1(y[0], x[4], C) + C, t[5] = madd1(y[0], x[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + C, t[4] = madd2(y[1], x[4], t[4], C) + C, t[5] = madd2(y[1], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + C, t[4] = madd2(y[2], x[4], t[4], C) + C, t[5] = madd2(y[2], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + C, t[4] = madd2(y[3], x[4], t[4], C) + C, t[5] = madd2(y[3], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[4], x[0], t[0]) + C, t[1] = madd2(y[4], x[1], t[1], C) + C, t[2] = madd2(y[4], x[2], t[2], C) + C, t[3] = madd2(y[4], x[3], t[3], C) + C, t[4] = madd2(y[4], x[4], t[4], C) + C, t[5] = madd2(y[4], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[5], x[0], t[0]) + C, t[1] = madd2(y[5], x[1], t[1], C) + C, t[2] = madd2(y[5], x[2], t[2], C) + C, t[3] = madd2(y[5], x[3], t[3], C) + C, t[4] = madd2(y[5], x[4], t[4], C) + C, t[5] = madd2(y[5], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + + if t[6] != 0 { + // we need to reduce, we have a result on 7 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], b = bits.Sub64(t[3], q3, b) + z[4], b = bits.Sub64(t[4], q4, b) + z[5], _ = bits.Sub64(t[5], q5, b) + return + } + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + z[4] = t[4] + z[5] = t[5] + + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -710,7 +746,6 @@ func _mulGeneric(z, x, y *Element) { z[4], b = bits.Sub64(z[4], q4, b) z[5], _ = bits.Sub64(z[5], q5, b) } - } func _fromMontGeneric(z *Element) { @@ -784,7 +819,7 @@ func _fromMontGeneric(z *Element) { z[5] = C } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -798,7 +833,7 @@ func _fromMontGeneric(z *Element) { func _reduceGeneric(z *Element) { - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -870,6 +905,35 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + // Exp z = xᵏ (mod q) func (z *Element) Exp(x Element, k *big.Int) *Element { if k.IsUint64() && k.Uint64() == 0 { @@ -884,8 +948,8 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } @@ -913,23 +977,31 @@ var rSquare = Element{ 51529254522778566, } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *Element) ToMont() *Element { +func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *Element) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[40:48], z[0]) + binary.BigEndian.PutUint64(b[32:40], z[1]) + binary.BigEndian.PutUint64(b[24:32], z[2]) + binary.BigEndian.PutUint64(b[16:24], z[3]) + binary.BigEndian.PutUint64(b[8:16], z[4]) + binary.BigEndian.PutUint64(b[0:8], z[5]) + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -948,51 +1020,49 @@ func (z *Element) Text(base int) string { if base == 10 { var zzNeg Element zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } zz := *z - zz.FromMont() + zz.fromMont() if zz.FitsOnOneWord() { return strconv.FormatUint(zz[0], base) } - vv := bigIntPool.Get().(*big.Int) - r := zz.ToBigInt(vv).Text(base) - bigIntPool.Put(vv) + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) return r } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[40:48], z[0]) - binary.BigEndian.PutUint64(b[32:40], z[1]) - binary.BigEndian.PutUint64(b[24:32], z[2]) - binary.BigEndian.PutUint64(b[16:24], z[3]) - binary.BigEndian.PutUint64(b[8:16], z[4]) - binary.BigEndian.PutUint64(b[0:8], z[5]) - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[40:48], _z[0]) - binary.BigEndian.PutUint64(res[32:40], _z[1]) - binary.BigEndian.PutUint64(res[24:32], _z[2]) - binary.BigEndian.PutUint64(res[16:24], _z[3]) - binary.BigEndian.PutUint64(res[8:16], _z[4]) - binary.BigEndian.PutUint64(res[0:8], _z[5]) +// Bits provides access to z by returning its value as a little-endian [6]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [6]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -1005,19 +1075,44 @@ func (z *Element) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } +// SetBytesCanonical interprets e as the bytes of a big-endian 48-byte integer. +// If e is not a 48-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fp.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -1035,17 +1130,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -1067,20 +1161,20 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z // // The number prefix determines the actual base: A prefix of -// ''0b'' or ''0B'' selects base 2, ''0'', ''0o'' or ''0O'' selects base 8, -// and ''0x'' or ''0X'' selects base 16. Otherwise, the selected base is 10 +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For base 16, lower and upper case letters are considered the same: // The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. // -// An underscore character ''_'' may appear between a base +// An underscore character ”_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as a panic if there @@ -1089,7 +1183,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -1098,7 +1192,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -1138,7 +1232,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -1147,10 +1241,87 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 48-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[40:48]) + z[1] = binary.BigEndian.Uint64((*b)[32:40]) + z[2] = binary.BigEndian.Uint64((*b)[24:32]) + z[3] = binary.BigEndian.Uint64((*b)[16:24]) + z[4] = binary.BigEndian.Uint64((*b)[8:16]) + z[5] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[40:48], e[0]) + binary.BigEndian.PutUint64((*b)[32:40], e[1]) + binary.BigEndian.PutUint64((*b)[24:32], e[2]) + binary.BigEndian.PutUint64((*b)[16:24], e[3]) + binary.BigEndian.PutUint64((*b)[8:16], e[4]) + binary.BigEndian.PutUint64((*b)[0:8], e[5]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + z[4] = binary.LittleEndian.Uint64((*b)[32:40]) + z[5] = binary.LittleEndian.Uint64((*b)[40:48]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) + binary.LittleEndian.PutUint64((*b)[32:40], e[4]) + binary.LittleEndian.PutUint64((*b)[40:48], e[5]) +} + +func (littleEndian) String() string { return "LittleEndian" } + // Legendre returns the Legendre symbol of z (either +1, -1, or 0.) func (z *Element) Legendre() int { var l Element @@ -1183,7 +1354,7 @@ func (z *Element) Sqrt(x *Element) *Element { // y = x^((s+1)/2)) = w * x y.Mul(x, &w) - // b = x^s = w * w * x = y * x + // b = xˢ = w * w * x = y * x b.Mul(&w, &y) // g = nonResidue ^ s @@ -1198,7 +1369,7 @@ func (z *Element) Sqrt(x *Element) *Element { r := uint64(41) // compute legendre symbol - // t = x^((q-1)/2) = r-1 squaring of x^s + // t = x^((q-1)/2) = r-1 squaring of xˢ t = b for i := uint64(0); i < r-1; i++ { t.Square(&t) @@ -1396,7 +1567,7 @@ func (z *Element) Inverse(x *Element) *Element { // we would multiply by pSq up to 13times; // on x86, the assembly routine outperforms generic code for mul by word // on arm64, we may loose up to ~5% for 6 limbs - mul(&v, &v, &a) + v.Mul(&v, &a) } u.Set(x) // for correctness check @@ -1413,17 +1584,28 @@ func (z *Element) Inverse(x *Element) *Element { // correctness check v.Mul(&u, z) if !v.IsOne() && !u.IsZero() { - return z.inverseExp(&u) + return z.inverseExp(u) } return z } // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) -func (z *Element) inverseExp(x *Element) *Element { - qMinusTwo := Modulus() - qMinusTwo.Sub(qMinusTwo, big.NewInt(2)) - return z.Exp(*x, qMinusTwo) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z } // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits @@ -1556,7 +1738,7 @@ func (z *Element) montReduceSigned(x *Element, xHi uint64) { z[5], z[4] = madd2(m, q5, t[i+5], C) } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) diff --git a/ecc/bls12-378/fp/element_mul_adx_amd64.s b/ecc/bls12-378/fp/element_mul_adx_amd64.s deleted file mode 100644 index 58909b6ec..000000000 --- a/ecc/bls12-378/fp/element_mul_adx_amd64.s +++ /dev/null @@ -1,835 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x9948a20000000001 -DATA q<>+8(SB)/8, $0xce97f76a822c0000 -DATA q<>+16(SB)/8, $0x980dc360d0a49d7f -DATA q<>+24(SB)/8, $0x84059eb647102326 -DATA q<>+32(SB)/8, $0x53cb5d240ed107a2 -DATA q<>+40(SB)/8, $0x03eeb0416684d190 -GLOBL q<>(SB), (RODATA+NOPTR), $48 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x9948a1ffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), NOSPLIT, $0-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - MOVQ x+8(FP), R8 - - // x[0] -> R10 - // x[1] -> R11 - // x[2] -> R12 - MOVQ 0(R8), R10 - MOVQ 8(R8), R11 - MOVQ 16(R8), R12 - MOVQ y+16(FP), R13 - - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // t[5] -> DI - // clear the flags - XORQ AX, AX - MOVQ 0(R13), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R10, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R11, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R12, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(R8), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(R8), AX, DI - ADOXQ AX, SI - - // (A,t[5]) := x[5]*y[0] + A - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 8(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 16(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 24(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 32(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 40(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R9,R8,R13,R10,R11,R12) - REDUCE(R14,R15,CX,BX,SI,DI,R9,R8,R13,R10,R11,R12) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET - -TEXT ·fromMont(SB), NOSPLIT, $0-8 - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - MOVQ 40(DX), DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12,R13) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET diff --git a/ecc/bls12-378/fp/element_mul_amd64.s b/ecc/bls12-378/fp/element_mul_amd64.s index 3afd58112..39ededda7 100644 --- a/ecc/bls12-378/fp/element_mul_amd64.s +++ b/ecc/bls12-378/fp/element_mul_amd64.s @@ -1,4 +1,4 @@ -// +build !amd64_adx +// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-378/fp/element_ops_amd64.go b/ecc/bls12-378/fp/element_ops_amd64.go index a3c830471..83bba45ae 100644 --- a/ecc/bls12-378/fp/element_ops_amd64.go +++ b/ecc/bls12-378/fp/element_ops_amd64.go @@ -1,3 +1,6 @@ +//go:build !purego +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,7 +38,70 @@ func fromMont(res *Element) func reduce(res *Element) // Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) +// +// a = a + b (mod q) +// b = a - b (mod q) +// //go:noescape func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bls12-378/fp/element_ops_amd64.s b/ecc/bls12-378/fp/element_ops_amd64.s index fa881ff9c..9440e0ccb 100644 --- a/ecc/bls12-378/fp/element_ops_amd64.s +++ b/ecc/bls12-378/fp/element_ops_amd64.s @@ -1,3 +1,5 @@ +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bls12-378/fp/element_ops_noasm.go b/ecc/bls12-378/fp/element_ops_noasm.go deleted file mode 100644 index 407f4f93e..000000000 --- a/ecc/bls12-378/fp/element_ops_noasm.go +++ /dev/null @@ -1,63 +0,0 @@ -//go:build !amd64 -// +build !amd64 - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fp - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 8212494240417053874, - 5029498262967025157, - 9404736542133420963, - 13073247822498485877, - 1581382318314538223, - 87125160541517067, - } - x.Mul(x, &y) -} - -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} diff --git a/ecc/bls12-378/fp/element_ops_purego.go b/ecc/bls12-378/fp/element_ops_purego.go new file mode 100644 index 000000000..ac0e48737 --- /dev/null +++ b/ecc/bls12-378/fp/element_ops_purego.go @@ -0,0 +1,745 @@ +//go:build !amd64 || purego +// +build !amd64 purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 8212494240417053874, + 5029498262967025157, + 9404736542133420963, + 13073247822498485877, + 1581382318314538223, + 87125160541517067, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + var t0, t1, t2, t3, t4, t5 uint64 + var u0, u1, u2, u3, u4, u5 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + u4, t4 = bits.Mul64(v, y[4]) + u5, t5 = bits.Mul64(v, y[5]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3, t4, t5 uint64 + var u0, u1, u2, u3, u4, u5 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + u4, t4 = bits.Mul64(v, x[4]) + u5, t5 = bits.Mul64(v, x[5]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} diff --git a/ecc/bls12-378/fp/element_test.go b/ecc/bls12-378/fp/element_test.go index 2b9ddf149..a3ef820b6 100644 --- a/ecc/bls12-378/fp/element_test.go +++ b/ecc/bls12-378/fp/element_test.go @@ -23,7 +23,7 @@ import ( "math/big" "math/bits" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field" mrand "math/rand" "testing" @@ -182,17 +182,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -642,7 +635,7 @@ func TestElementBitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPairElement) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -757,7 +750,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -770,13 +763,13 @@ func TestElementAdd(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -805,17 +798,17 @@ func TestElementAdd(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -866,7 +859,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -879,13 +872,13 @@ func TestElementSub(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -914,17 +907,17 @@ func TestElementSub(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -975,7 +968,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -988,7 +981,7 @@ func TestElementMul(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -1002,7 +995,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1042,11 +1035,11 @@ func TestElementMul(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -1059,7 +1052,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1111,7 +1104,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1124,14 +1117,14 @@ func TestElementDiv(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1160,18 +1153,18 @@ func TestElementDiv(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1222,7 +1215,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1235,13 +1228,13 @@ func TestElementExp(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1270,17 +1263,17 @@ func TestElementExp(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1325,7 +1318,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1346,14 +1339,14 @@ func TestElementSquare(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1397,7 +1390,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1418,14 +1411,14 @@ func TestElementInverse(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1469,7 +1462,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1490,14 +1483,14 @@ func TestElementSqrt(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1541,7 +1534,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1562,14 +1555,14 @@ func TestElementDouble(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1613,7 +1606,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1634,14 +1627,14 @@ func TestElementNeg(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -2029,7 +2022,7 @@ func TestElementNegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -2162,17 +2155,17 @@ func TestElementFromMont(t *testing.T) { func(a testPairElement) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPairElement) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -2255,7 +2248,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } @@ -2323,7 +2316,7 @@ func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { return field.BigIntMatchUint64Slice(&aIntMod, slice) } -//TODO: Phase out in favor of property based testing +// TODO: Phase out in favor of property based testing func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { if err := z.matchVeryBigInt(aHi, aInt); err != nil { @@ -2367,7 +2360,7 @@ func TestElementInversionCorrectionFactorFormula(t *testing.T) { inversionCorrectionFactorWord4, inversionCorrectionFactorWord5, } - inversionCorrectionFactor.ToBigInt(&refFactorInt) + inversionCorrectionFactor.toBigInt(&refFactorInt) if refFactorInt.Cmp(factorInt) != 0 { t.Error("mismatch") @@ -2409,7 +2402,7 @@ func TestElementInversionCorrectionFactor(t *testing.T) { if !oneInv.Equal(&one) { var i big.Int - oneInv.ToBigIntRegular(&i) // no montgomery + oneInv.BigInt(&i) // no montgomery i.ModInverse(&i, Modulus()) var fac Element fac.setBigInt(&i) // back to montgomery @@ -2558,7 +2551,7 @@ func TestElement0Inverse(t *testing.T) { } } -//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen func TestUpdateFactorSubtraction(t *testing.T) { for i := 0; i < 1000; i++ { @@ -2771,11 +2764,11 @@ func randomizeUpdateFactors() (int64, int64) { func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { var p1 big.Int - x.ToBigInt(&p1) + x.toBigInt(&p1) p1.Mul(&p1, big.NewInt(xC)) var p2 big.Int - y.ToBigInt(&p2) + y.toBigInt(&p2) p2.Mul(&p2, big.NewInt(yC)) p1.Add(&p1, &p2) @@ -2812,7 +2805,7 @@ func montReduce(res *big.Int, x *big.Int) { } func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { - z.ToBigInt(i) + z.toBigInt(i) var upperWord big.Int upperWord.SetUint64(xHi) upperWord.Lsh(&upperWord, Limbs*64) @@ -2830,7 +2823,7 @@ func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { var xInt big.Int - x.ToBigInt(&xInt) + x.toBigInt(&xInt) xInt.Mul(&xInt, big.NewInt(c)) @@ -2841,7 +2834,7 @@ func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, result func approximateRef(x *Element) uint64 { var asInt big.Int - x.ToBigInt(&asInt) + x.toBigInt(&asInt) n := x.BitLen() if n <= 64 { diff --git a/ecc/bls12-378/fr/doc.go b/ecc/bls12-378/fr/doc.go index d0ebdbfae..5c0faae5c 100644 --- a/ecc/bls12-378/fr/doc.go +++ b/ecc/bls12-378/fr/doc.go @@ -21,30 +21,33 @@ // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [4]uint64 // -// Usage +// type Element [4]uint64 +// +// # Usage // // Example API signature: -// // Mul z = x * y (mod q) -// func (z *Element) Mul(x, y *Element) *Element +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) // // Modulus q = // -// q[base10] = 14883435066912132899950318861128167269793560281114003360875131245101026639873 -// q[base16] = 0x20e7b9c8ef7b2eb187787fb4e3dbb0ffeae77f3da09400013291440000000001 +// q[base10] = 14883435066912132899950318861128167269793560281114003360875131245101026639873 +// q[base16] = 0x20e7b9c8ef7b2eb187787fb4e3dbb0ffeae77f3da09400013291440000000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package fr diff --git a/ecc/bls12-378/fr/element.go b/ecc/bls12-378/fr/element.go index 0bd32b5a7..d487cec4e 100644 --- a/ecc/bls12-378/fr/element.go +++ b/ecc/bls12-378/fr/element.go @@ -20,13 +20,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "github.com/consensys/gnark-crypto/field" "io" "math/big" "math/bits" "reflect" "strconv" "strings" - "sync" ) // Element represents a field element stored on 4 words (uint64) @@ -35,18 +35,18 @@ import ( // // Modulus q = // -// q[base10] = 14883435066912132899950318861128167269793560281114003360875131245101026639873 -// q[base16] = 0x20e7b9c8ef7b2eb187787fb4e3dbb0ffeae77f3da09400013291440000000001 +// q[base10] = 14883435066912132899950318861128167269793560281114003360875131245101026639873 +// q[base16] = 0x20e7b9c8ef7b2eb187787fb4e3dbb0ffeae77f3da09400013291440000000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [4]uint64 const ( - Limbs = 4 // number of 64 bits words needed to represent a Element - Bits = 254 // number of bits needed to represent a Element - Bytes = Limbs * 8 // number of bytes needed to represent a Element + Limbs = 4 // number of 64 bits words needed to represent a Element + Bits = 254 // number of bits needed to represent a Element + Bytes = 32 // number of bytes needed to represent a Element ) // Field modulus q @@ -68,8 +68,8 @@ var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int // -// q[base10] = 14883435066912132899950318861128167269793560281114003360875131245101026639873 -// q[base16] = 0x20e7b9c8ef7b2eb187787fb4e3dbb0ffeae77f3da09400013291440000000001 +// q[base10] = 14883435066912132899950318861128167269793560281114003360875131245101026639873 +// q[base16] = 0x20e7b9c8ef7b2eb187787fb4e3dbb0ffeae77f3da09400013291440000000001 func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } @@ -78,12 +78,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 3643768340310130687 -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("20e7b9c8ef7b2eb187787fb4e3dbb0ffeae77f3da09400013291440000000001", 16) } @@ -91,8 +85,9 @@ func init() { // NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v Element -// v.SetUint64(...) +// +// var v Element +// v.SetUint64(...) func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) @@ -103,7 +98,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -133,14 +128,15 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported // supported types: -// Element -// *Element -// uint64 -// int -// string (see SetString for valid formats) -// *big.Int -// big.Int -// []byte +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { if i1 == nil { return nil, errors.New("can't set fr.Element with ") @@ -250,15 +246,13 @@ func (z *Element) IsOne() bool { // IsUint64 reports whether z can be represented as an uint64. func (z *Element) IsUint64() bool { zz := *z - zz.FromMont() + zz.fromMont() return zz.FitsOnOneWord() } // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -270,15 +264,12 @@ func (z *Element) FitsOnOneWord() bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[3] > _x[3] { return 1 } else if _z[3] < _x[3] { @@ -309,8 +300,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 11045256207009841153, 0) @@ -401,67 +391,9 @@ func (z *Element) Halve() { } -// Mul z = x * y (mod q) -// -// x and y must be strictly inferior to q -func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be strictly inferior to q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } @@ -475,7 +407,7 @@ func (z *Element) Add(x, y *Element) *Element { z[2], carry = bits.Add64(x[2], y[2], carry) z[3], _ = bits.Add64(x[3], y[3], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -495,7 +427,7 @@ func (z *Element) Double(x *Element) *Element { z[2], carry = bits.Add64(x[2], x[2], carry) z[3], _ = bits.Add64(x[3], x[3], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -548,65 +480,147 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // see Mul for algorithm documentation - var t [4]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [5]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + + if t[4] != 0 { + // we need to reduce, we have a result on 5 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], _ = bits.Sub64(t[3], q3, b) + return } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - z[3], z[2] = madd3(m, q3, c[0], c[2], c[1]) - } - - // if z >= q → z -= q + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -614,7 +628,6 @@ func _mulGeneric(z, x, y *Element) { z[2], b = bits.Sub64(z[2], q2, b) z[3], _ = bits.Sub64(z[3], q3, b) } - } func _fromMontGeneric(z *Element) { @@ -658,7 +671,7 @@ func _fromMontGeneric(z *Element) { z[3] = C } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -670,7 +683,7 @@ func _fromMontGeneric(z *Element) { func _reduceGeneric(z *Element) { - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -734,6 +747,35 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + // Exp z = xᵏ (mod q) func (z *Element) Exp(x Element, k *big.Int) *Element { if k.IsUint64() && k.Uint64() == 0 { @@ -748,8 +790,8 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } @@ -775,23 +817,29 @@ var rSquare = Element{ 405261321576397495, } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *Element) ToMont() *Element { +func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *Element) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[24:32], z[0]) + binary.BigEndian.PutUint64(b[16:24], z[1]) + binary.BigEndian.PutUint64(b[8:16], z[2]) + binary.BigEndian.PutUint64(b[0:8], z[3]) + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -810,47 +858,49 @@ func (z *Element) Text(base int) string { if base == 10 { var zzNeg Element zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } zz := *z - zz.FromMont() + zz.fromMont() if zz.FitsOnOneWord() { return strconv.FormatUint(zz[0], base) } - vv := bigIntPool.Get().(*big.Int) - r := zz.ToBigInt(vv).Text(base) - bigIntPool.Put(vv) + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) return r } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[24:32], z[0]) - binary.BigEndian.PutUint64(b[16:24], z[1]) - binary.BigEndian.PutUint64(b[8:16], z[2]) - binary.BigEndian.PutUint64(b[0:8], z[3]) - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[24:32], _z[0]) - binary.BigEndian.PutUint64(res[16:24], _z[1]) - binary.BigEndian.PutUint64(res[8:16], _z[2]) - binary.BigEndian.PutUint64(res[0:8], _z[3]) +// Bits provides access to z by returning its value as a little-endian [4]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [4]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -863,19 +913,44 @@ func (z *Element) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } +// SetBytesCanonical interprets e as the bytes of a big-endian 32-byte integer. +// If e is not a 32-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fr.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -893,17 +968,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -925,20 +999,20 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z // // The number prefix determines the actual base: A prefix of -// ''0b'' or ''0B'' selects base 2, ''0'', ''0o'' or ''0O'' selects base 8, -// and ''0x'' or ''0X'' selects base 16. Otherwise, the selected base is 10 +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For base 16, lower and upper case letters are considered the same: // The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. // -// An underscore character ''_'' may appear between a base +// An underscore character ”_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as a panic if there @@ -947,7 +1021,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -956,7 +1030,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -996,7 +1070,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -1005,10 +1079,79 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 32-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[24:32]) + z[1] = binary.BigEndian.Uint64((*b)[16:24]) + z[2] = binary.BigEndian.Uint64((*b)[8:16]) + z[3] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[24:32], e[0]) + binary.BigEndian.PutUint64((*b)[16:24], e[1]) + binary.BigEndian.PutUint64((*b)[8:16], e[2]) + binary.BigEndian.PutUint64((*b)[0:8], e[3]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) +} + +func (littleEndian) String() string { return "LittleEndian" } + // Legendre returns the Legendre symbol of z (either +1, -1, or 0.) func (z *Element) Legendre() int { var l Element @@ -1041,7 +1184,7 @@ func (z *Element) Sqrt(x *Element) *Element { // y = x^((s+1)/2)) = w * x y.Mul(x, &w) - // b = x^s = w * w * x = y * x + // b = xˢ = w * w * x = y * x b.Mul(&w, &y) // g = nonResidue ^ s @@ -1054,7 +1197,7 @@ func (z *Element) Sqrt(x *Element) *Element { r := uint64(42) // compute legendre symbol - // t = x^((q-1)/2) = r-1 squaring of x^s + // t = x^((q-1)/2) = r-1 squaring of xˢ t = b for i := uint64(0); i < r-1; i++ { t.Square(&t) @@ -1244,7 +1387,7 @@ func (z *Element) Inverse(x *Element) *Element { // we would multiply by pSq up to 13times; // on x86, the assembly routine outperforms generic code for mul by word // on arm64, we may loose up to ~5% for 6 limbs - mul(&v, &v, &a) + v.Mul(&v, &a) } u.Set(x) // for correctness check @@ -1259,17 +1402,28 @@ func (z *Element) Inverse(x *Element) *Element { // correctness check v.Mul(&u, z) if !v.IsOne() && !u.IsZero() { - return z.inverseExp(&u) + return z.inverseExp(u) } return z } // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) -func (z *Element) inverseExp(x *Element) *Element { - qMinusTwo := Modulus() - qMinusTwo.Sub(qMinusTwo, big.NewInt(2)) - return z.Exp(*x, qMinusTwo) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z } // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits @@ -1368,7 +1522,7 @@ func (z *Element) montReduceSigned(x *Element, xHi uint64) { z[3], z[2] = madd2(m, q3, t[i+3], C) } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) diff --git a/ecc/bls12-378/fr/element_mul_adx_amd64.s b/ecc/bls12-378/fr/element_mul_adx_amd64.s deleted file mode 100644 index 1d430bb16..000000000 --- a/ecc/bls12-378/fr/element_mul_adx_amd64.s +++ /dev/null @@ -1,465 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x3291440000000001 -DATA q<>+8(SB)/8, $0xeae77f3da0940001 -DATA q<>+16(SB)/8, $0x87787fb4e3dbb0ff -DATA q<>+24(SB)/8, $0x20e7b9c8ef7b2eb1 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x329143ffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), NOSPLIT, $0-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -TEXT ·fromMont(SB), NOSPLIT, $0-8 - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET diff --git a/ecc/bls12-378/fr/element_mul_amd64.s b/ecc/bls12-378/fr/element_mul_amd64.s index 2b93dba2c..75908e04d 100644 --- a/ecc/bls12-378/fr/element_mul_amd64.s +++ b/ecc/bls12-378/fr/element_mul_amd64.s @@ -1,4 +1,4 @@ -// +build !amd64_adx +// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-378/fr/element_ops_amd64.go b/ecc/bls12-378/fr/element_ops_amd64.go index b37a251b6..e40a9caed 100644 --- a/ecc/bls12-378/fr/element_ops_amd64.go +++ b/ecc/bls12-378/fr/element_ops_amd64.go @@ -1,3 +1,6 @@ +//go:build !purego +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,7 +38,70 @@ func fromMont(res *Element) func reduce(res *Element) // Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) +// +// a = a + b (mod q) +// b = a - b (mod q) +// //go:noescape func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bls12-378/fr/element_ops_amd64.s b/ecc/bls12-378/fr/element_ops_amd64.s index a4182bdb0..a863389b1 100644 --- a/ecc/bls12-378/fr/element_ops_amd64.s +++ b/ecc/bls12-378/fr/element_ops_amd64.s @@ -1,3 +1,5 @@ +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bls12-378/fr/element_ops_noasm.go b/ecc/bls12-378/fr/element_ops_noasm.go deleted file mode 100644 index fc16770cd..000000000 --- a/ecc/bls12-378/fr/element_ops_noasm.go +++ /dev/null @@ -1,61 +0,0 @@ -//go:build !amd64 -// +build !amd64 - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 914279102867832731, - 5956798511920709511, - 10193226651174906632, - 329804807099814901, - } - x.Mul(x, &y) -} - -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} diff --git a/ecc/bls12-378/fr/element_ops_purego.go b/ecc/bls12-378/fr/element_ops_purego.go new file mode 100644 index 000000000..14036d3bd --- /dev/null +++ b/ecc/bls12-378/fr/element_ops_purego.go @@ -0,0 +1,443 @@ +//go:build !amd64 || purego +// +build !amd64 purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 914279102867832731, + 5956798511920709511, + 10193226651174906632, + 329804807099814901, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} diff --git a/ecc/bls12-378/fr/element_test.go b/ecc/bls12-378/fr/element_test.go index 765de28d3..1bd5c254f 100644 --- a/ecc/bls12-378/fr/element_test.go +++ b/ecc/bls12-378/fr/element_test.go @@ -23,7 +23,7 @@ import ( "math/big" "math/bits" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field" mrand "math/rand" "testing" @@ -182,17 +182,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -638,7 +631,7 @@ func TestElementBitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPairElement) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -753,7 +746,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -766,13 +759,13 @@ func TestElementAdd(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -801,17 +794,17 @@ func TestElementAdd(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -862,7 +855,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -875,13 +868,13 @@ func TestElementSub(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -910,17 +903,17 @@ func TestElementSub(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -971,7 +964,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -984,7 +977,7 @@ func TestElementMul(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -998,7 +991,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1038,11 +1031,11 @@ func TestElementMul(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -1055,7 +1048,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1107,7 +1100,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1120,14 +1113,14 @@ func TestElementDiv(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1156,18 +1149,18 @@ func TestElementDiv(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1218,7 +1211,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1231,13 +1224,13 @@ func TestElementExp(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1266,17 +1259,17 @@ func TestElementExp(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1321,7 +1314,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1342,14 +1335,14 @@ func TestElementSquare(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1393,7 +1386,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1414,14 +1407,14 @@ func TestElementInverse(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1465,7 +1458,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1486,14 +1479,14 @@ func TestElementSqrt(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1537,7 +1530,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1558,14 +1551,14 @@ func TestElementDouble(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1609,7 +1602,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1630,14 +1623,14 @@ func TestElementNeg(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -2025,7 +2018,7 @@ func TestElementNegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -2158,17 +2151,17 @@ func TestElementFromMont(t *testing.T) { func(a testPairElement) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPairElement) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -2247,7 +2240,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } @@ -2309,7 +2302,7 @@ func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { return field.BigIntMatchUint64Slice(&aIntMod, slice) } -//TODO: Phase out in favor of property based testing +// TODO: Phase out in favor of property based testing func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { if err := z.matchVeryBigInt(aHi, aInt); err != nil { @@ -2351,7 +2344,7 @@ func TestElementInversionCorrectionFactorFormula(t *testing.T) { inversionCorrectionFactorWord2, inversionCorrectionFactorWord3, } - inversionCorrectionFactor.ToBigInt(&refFactorInt) + inversionCorrectionFactor.toBigInt(&refFactorInt) if refFactorInt.Cmp(factorInt) != 0 { t.Error("mismatch") @@ -2393,7 +2386,7 @@ func TestElementInversionCorrectionFactor(t *testing.T) { if !oneInv.Equal(&one) { var i big.Int - oneInv.ToBigIntRegular(&i) // no montgomery + oneInv.BigInt(&i) // no montgomery i.ModInverse(&i, Modulus()) var fac Element fac.setBigInt(&i) // back to montgomery @@ -2538,7 +2531,7 @@ func TestElement0Inverse(t *testing.T) { } } -//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen func TestUpdateFactorSubtraction(t *testing.T) { for i := 0; i < 1000; i++ { @@ -2751,11 +2744,11 @@ func randomizeUpdateFactors() (int64, int64) { func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { var p1 big.Int - x.ToBigInt(&p1) + x.toBigInt(&p1) p1.Mul(&p1, big.NewInt(xC)) var p2 big.Int - y.ToBigInt(&p2) + y.toBigInt(&p2) p2.Mul(&p2, big.NewInt(yC)) p1.Add(&p1, &p2) @@ -2792,7 +2785,7 @@ func montReduce(res *big.Int, x *big.Int) { } func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { - z.ToBigInt(i) + z.toBigInt(i) var upperWord big.Int upperWord.SetUint64(xHi) upperWord.Lsh(&upperWord, Limbs*64) @@ -2810,7 +2803,7 @@ func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { var xInt big.Int - x.ToBigInt(&xInt) + x.toBigInt(&xInt) xInt.Mul(&xInt, big.NewInt(c)) @@ -2821,7 +2814,7 @@ func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, result func approximateRef(x *Element) uint64 { var asInt big.Int - x.ToBigInt(&asInt) + x.toBigInt(&asInt) n := x.BitLen() if n <= 64 { diff --git a/ecc/bls12-378/fr/fri/fri.go b/ecc/bls12-378/fr/fri/fri.go index 6985df67f..837b08862 100644 --- a/ecc/bls12-378/fr/fri/fri.go +++ b/ecc/bls12-378/fr/fri/fri.go @@ -121,7 +121,6 @@ type Round struct { // a function is d-close to a low degree polynomial. // // It is composed of a series of Interactions, emulated with Fiat Shamir, -// type ProofOfProximity struct { // ID unique ID attached to the proof of proximity. It's needed for diff --git a/ecc/bls12-378/fr/gkr/gkr.go b/ecc/bls12-378/fr/gkr/gkr.go new file mode 100644 index 000000000..c01ccec9c --- /dev/null +++ b/ecc/bls12-378/fr/gkr/gkr.go @@ -0,0 +1,774 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr/sumcheck" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// Gate must be a low-degree polynomial +type Gate interface { + Evaluate(...fr.Element) fr.Element + Degree() int +} + +type Wire struct { + Gate Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]fr.Element + claimedEvaluations []fr.Element + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]fr.Element) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]fr.Element // x in the paper + claimedEvaluations []fr.Element // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + newEq.Eq(c.evaluationPoints[k]) + + eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// computeValAndStep returns val : i ↦ m(1, i...) and step : i ↦ m(1, i...) - m(0, i...) +func computeValAndStep(m polynomial.MultiLin, p *polynomial.Pool) (val polynomial.MultiLin, step polynomial.MultiLin) { + val = p.Clone(m[len(m)/2:]) + step = p.Clone(m[:len(m)/2]) + + valAsPoly, stepAsPoly := polynomial.Polynomial(val), polynomial.Polynomial(step) + + stepAsPoly.Sub(valAsPoly, stepAsPoly) + return +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = g_{j-1}(r_{j-1}). By convention, g_0 is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() (gJ polynomial.Polynomial) { + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + EVal, EStep := computeValAndStep(c.eq, c.manager.memPool) + + puVal := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO: Make a two-dimensional array struct, and index it i-first rather than inputI first: would result in scanning memory access in the "d" loop and obviate the gateInput variable + puStep := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO, ctd: the greater degGJ, the more this would matter + + for i, puI := range c.inputPreprocessors { + puVal[i], puStep[i] = computeValAndStep(puI, c.manager.memPool) + } + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + gJ = make([]fr.Element, degGJ) + + parallel := len(EVal) >= 1024 //TODO: Experiment with threshold + + var gateInput [][]fr.Element + + if parallel { + gateInput = [][]fr.Element{c.manager.memPool.Make(len(c.inputPreprocessors)), + c.manager.memPool.Make(len(c.inputPreprocessors))} + } else { + gateInput = [][]fr.Element{c.manager.memPool.Make(len(c.inputPreprocessors))} + } + + var wg sync.WaitGroup + + for d := 0; d < degGJ; d++ { + + notLastIteration := d+1 < degGJ + + sumOverI := func(res *fr.Element, gateInput []fr.Element, start, end int) { + for i := start; i < end; i++ { + + for inputI := range puVal { + gateInput[inputI].Set(&puVal[inputI][i]) + if notLastIteration { + puVal[inputI][i].Add(&puVal[inputI][i], &puStep[inputI][i]) + } + } + + // gJAtDI = gJ(d, i...) + gJAtDI := c.wire.Gate.Evaluate(gateInput...) + gJAtDI.Mul(&gJAtDI, &EVal[i]) + + res.Add(res, &gJAtDI) + + if notLastIteration { + EVal[i].Add(&EVal[i], &EStep[i]) + } + } + wg.Done() + } + + if parallel { + var firstHalf, secondHalf fr.Element + wg.Add(2) + go sumOverI(&secondHalf, gateInput[1], len(EVal)/2, len(EVal)) + go sumOverI(&firstHalf, gateInput[0], 0, len(EVal)/2) + wg.Wait() + gJ[d].Add(&firstHalf, &secondHalf) + } else { + wg.Add(1) // formalities + sumOverI(&gJ[d], gateInput[0], 0, len(EVal)) + } + } + + c.manager.memPool.Dump(gateInput...) + c.manager.memPool.Dump(EVal, EStep) + + for inputI := range puVal { + c.manager.memPool.Dump(puVal[inputI], puStep[inputI]) + } + + return +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { + c.eq.Fold(element) + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, pool *polynomial.Pool) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = pool + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< b { + return a + } + return b +} + +func ChallengeNames(sorted []*Wire, logNbInstances int, prefix string) []string { + + // Pre-compute the size TODO: Consider not doing this and just grow the list by appending + size := logNbInstances // first challenge + + for _, w := range sorted { + if w.noProof() { // no proof, no challenge + continue + } + if w.NbClaims() > 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + + claims := newClaimsManager(c, assignment, o.pool) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + + claims := newClaimsManager(c, assignment, o.pool) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]fr.Element) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return fmt.Errorf("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return fmt.Errorf("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +type IdentityGate struct{} + +func (IdentityGate) Evaluate(input ...fr.Element) fr.Element { + return input[0] +} + +func (IdentityGate) Degree() int { + return 1 +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = IdentityGate{} + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + + numEvaluations := 0 + + for _, w := range sortedWires { + if !w.IsInput() { + if numEvaluations == 0 { + numEvaluations = len(a[w.Inputs[0]]) + } + evals := make([]fr.Element, numEvaluations) + ins := make([]fr.Element, len(w.Inputs)) + for k := 0; k < numEvaluations; k++ { + for inI, in := range w.Inputs { + ins[inI] = a[in][k] + } + evals[k] = w.Gate.Evaluate(ins...) + } + a[w] = evals + } + } + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} diff --git a/ecc/bls12-378/fr/gkr/gkr_test.go b/ecc/bls12-378/fr/gkr/gkr_test.go new file mode 100644 index 000000000..72dd8f49a --- /dev/null +++ b/ecc/bls12-378/fr/gkr/gkr_test.go @@ -0,0 +1,722 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr/sumcheck" + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + pool := polynomial.NewPool(256, 1<<11) + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, &pool) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandom(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/ConsenSys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandom(slice []fr.Element) { + for i := range slice { + slice[i].SetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, testCase.transcriptSetting()) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, testCase.transcriptSetting()) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(&test_vector_utils.MapHash{Map: testCase.Hash}, []byte{1})) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "../../../../internal/generator/gkr/test_vectors" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func BenchmarkGkrMimc(b *testing.B) { + const N = 1 << 19 + fmt.Println("creating circuit structure") + c := mimcCircuit(91) + + in0 := make([]fr.Element, N) + in1 := make([]fr.Element, N) + setRandom(in0) + setRandom(in1) + + fmt.Println("evaluating circuit") + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + + //b.ResetTimer() + fmt.Println("constructing proof") + Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate string `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = gates[c[i].Gate] + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +var gates map[string]Gate + +func init() { + gates = make(map[string]Gate) + gates["identity"] = IdentityGate{} + gates["mul"] = mulGate{} + gates["mimc"] = mimcCipherGate{} //TODO: Add ark + gates["select-input-3"] = _select(2) +} + +type mimcCipherGate struct { + ark fr.Element +} + +func (m mimcCipherGate) Evaluate(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]). + Add(&sum, &m.ark) + + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +func (m mimcCipherGate) Degree() int { + return 7 +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash *test_vector_utils.ElementMap + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash string `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash *test_vector_utils.ElementMap + if _hash, err = test_vector_utils.ElementMapFromFile(filepath.Join(dir, info.Hash)); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func (c *TestCase) transcriptSetting(initialChallenge ...[]byte) fiatshamir.Settings { + return fiatshamir.WithHash(&test_vector_utils.MapHash{Map: c.Hash}, initialChallenge...) +} + +type mulGate struct{} + +func (g mulGate) Evaluate(element ...fr.Element) (result fr.Element) { + result.Mul(&element[0], &element[1]) + return +} + +func (g mulGate) Degree() int { + return 2 +} + +type _select int + +func (g _select) Evaluate(in ...fr.Element) fr.Element { + return in[g] +} + +func (g _select) Degree() int { + return 1 +} diff --git a/ecc/bls12-378/fr/kzg/kzg.go b/ecc/bls12-378/fr/kzg/kzg.go index 638559099..4da124de6 100644 --- a/ecc/bls12-378/fr/kzg/kzg.go +++ b/ecc/bls12-378/fr/kzg/kzg.go @@ -84,9 +84,6 @@ func NewSRS(size uint64, bAlpha *big.Int) (*SRS, error) { for i := 1; i < len(alphas); i++ { alphas[i].Mul(&alphas[i-1], &alpha) } - for i := 0; i < len(alphas); i++ { - alphas[i].FromMont() - } g1s := bls12378.BatchScalarMultiplicationG1(&gen1Aff, alphas) copy(srs.G1[1:], g1s) @@ -125,7 +122,7 @@ func Commit(p []fr.Element, srs *SRS, nbTasks ...int) (Digest, error) { var res bls12378.G1Affine - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} if len(nbTasks) > 0 { config.NbTasks = nbTasks[0] } @@ -395,7 +392,7 @@ func BatchVerifyMultiPoints(digests []Digest, proofs []OpeningProof, points []fr for i := 0; i < len(randomNumbers); i++ { quotients[i].Set(&proofs[i].H) } - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} _, err := foldedQuotients.MultiExp(quotients, randomNumbers, config) if err != nil { return nil @@ -478,7 +475,7 @@ func fold(di []Digest, fai []fr.Element, ci []fr.Element) (Digest, fr.Element, e // fold the digests ∑ᵢ[cᵢ]([fᵢ(α)]G₁) var foldedDigests Digest - _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{ScalarsMont: true}) + _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{}) if err != nil { return foldedDigests, foldedEvaluations, err } diff --git a/ecc/bls12-378/fr/mimc/decompose.go b/ecc/bls12-378/fr/mimc/decompose.go new file mode 100644 index 000000000..50a124f54 --- /dev/null +++ b/ecc/bls12-378/fr/mimc/decompose.go @@ -0,0 +1,46 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" +) + +// Decompose interpret rawBytes as a bigInt x in big endian, +// and returns the digits of x (from LSB to MSB) when x is written +// in basis modulo. +func Decompose(rawBytes []byte) []fr.Element { + + rawBigInt := big.NewInt(0).SetBytes(rawBytes) + modulo := fr.Modulus() + + // maximum number of chunks that a function + maxNbChunks := len(rawBytes) / fr.Bytes + + res := make([]fr.Element, 0, maxNbChunks) + var tmp fr.Element + t := new(big.Int) + for rawBigInt.Sign() != 0 { + rawBigInt.DivMod(rawBigInt, modulo, t) + tmp.SetBigInt(t) + res = append(res, tmp) + } + + return res +} diff --git a/ecc/bls12-378/fr/mimc/decompose_test.go b/ecc/bls12-378/fr/mimc/decompose_test.go new file mode 100644 index 000000000..a63b57ff5 --- /dev/null +++ b/ecc/bls12-378/fr/mimc/decompose_test.go @@ -0,0 +1,53 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" +) + +func TestDecompose(t *testing.T) { + + // create 10 random digits in basis r + nbDigits := 10 + a := make([]fr.Element, nbDigits) + for i := 0; i < nbDigits; i++ { + a[i].SetRandom() + } + + // create a big int whose digits in basis r are a + m := fr.Modulus() + var b, tmp big.Int + for i := nbDigits - 1; i >= 0; i-- { + b.Mul(&b, m) + a[i].ToBigIntRegular(&tmp) + b.Add(&b, &tmp) + } + + // query the decomposition and compare to a + bb := b.Bytes() + d := Decompose(bb) + for i := 0; i < nbDigits; i++ { + if !d[i].Equal(&a[i]) { + t.Fatal("error decomposition") + } + } + +} diff --git a/ecc/bls12-378/fr/mimc/mimc.go b/ecc/bls12-378/fr/mimc/mimc.go index 7d788e8f4..e20b5ee79 100644 --- a/ecc/bls12-378/fr/mimc/mimc.go +++ b/ecc/bls12-378/fr/mimc/mimc.go @@ -17,6 +17,7 @@ package mimc import ( + "errors" "hash" "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" @@ -91,44 +92,44 @@ func (d *digest) BlockSize() int { } // Write (via the embedded io.Writer interface) adds more data to the running hash. -// It never returns an error. +// +// Each []byte block of size BlockSize represents a big endian fr.Element. +// +// If len(p) is not a multiple of BlockSize and any of the []byte in p represent an integer +// larger than fr.Modulus, this function returns an error. +// +// To hash arbitrary data ([]byte not representing canonical field elements) use Decompose +// function in this package. func (d *digest) Write(p []byte) (n int, err error) { n = len(p) + if n%BlockSize != 0 { + return 0, errors.New("invalid input length: must represent a list of field elements, expects a []byte of len m*BlockSize") + } + + // ensure each block represents a field element in canonical reduced form + for i := 0; i < n; i += BlockSize { + if _, err = fr.BigEndian.Element((*[BlockSize]byte)(p[i : i+BlockSize])); err != nil { + return 0, err + } + } + d.data = append(d.data, p...) return } -// Hash hash using Miyaguchi–Preneel: +// Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form func (d *digest) checksum() fr.Element { + // Write guarantees len(data) % BlockSize == 0 - var buffer [BlockSize]byte - var x fr.Element - - // if data size is not multiple of BlockSizes we padd: - // .. || 0xaf8 -> .. || 0x0000...0af8 - if len(d.data)%BlockSize != 0 { - q := len(d.data) / BlockSize - r := len(d.data) % BlockSize - sliceq := make([]byte, q*BlockSize) - copy(sliceq, d.data) - slicer := make([]byte, r) - copy(slicer, d.data[q*BlockSize:]) - sliceremainder := make([]byte, BlockSize-r) - d.data = append(sliceq, sliceremainder...) - d.data = append(d.data, slicer...) - } - + // TODO @ThomasPiellard shouldn't Sum() returns an error if there is no data? if len(d.data) == 0 { - d.data = make([]byte, 32) + d.data = make([]byte, BlockSize) } - nbChunks := len(d.data) / BlockSize - - for i := 0; i < nbChunks; i++ { - copy(buffer[:], d.data[i*BlockSize:(i+1)*BlockSize]) - x.SetBytes(buffer[:]) + for i := 0; i < len(d.data); i += BlockSize { + x, _ := fr.BigEndian.Element((*[BlockSize]byte)(d.data[i : i+BlockSize])) r := d.encrypt(x) d.h.Add(&r, &d.h).Add(&d.h, &x) } diff --git a/ecc/bls12-378/fr/pedersen/pedersen.go b/ecc/bls12-378/fr/pedersen/pedersen.go new file mode 100644 index 000000000..3508d51d0 --- /dev/null +++ b/ecc/bls12-378/fr/pedersen/pedersen.go @@ -0,0 +1,113 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package pedersen + +import ( + "crypto/rand" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls12-378" + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" + "math/big" +) + +// Key for proof and verification +type Key struct { + g bls12378.G2Affine // TODO @tabaie: does this really have to be randomized? + gRootSigmaNeg bls12378.G2Affine //gRootSigmaNeg = g^{-1/σ} + basis []bls12378.G1Affine + basisExpSigma []bls12378.G1Affine +} + +func randomOnG2() (bls12378.G2Affine, error) { // TODO: Add to G2.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return bls12378.G2Affine{}, err + } + return bls12378.HashToG2(gBytes, []byte("random on g2")) +} + +func Setup(basis []bls12378.G1Affine) (Key, error) { + var ( + k Key + err error + ) + + if k.g, err = randomOnG2(); err != nil { + return k, err + } + + var modMinusOne big.Int + modMinusOne.Sub(fr.Modulus(), big.NewInt(1)) + var sigma *big.Int + if sigma, err = rand.Int(rand.Reader, &modMinusOne); err != nil { + return k, err + } + sigma.Add(sigma, big.NewInt(1)) + + var sigmaInvNeg big.Int + sigmaInvNeg.ModInverse(sigma, fr.Modulus()) + sigmaInvNeg.Sub(fr.Modulus(), &sigmaInvNeg) + k.gRootSigmaNeg.ScalarMultiplication(&k.g, &sigmaInvNeg) + + k.basisExpSigma = make([]bls12378.G1Affine, len(basis)) + for i := range basis { + k.basisExpSigma[i].ScalarMultiplication(&basis[i], sigma) + } + + k.basis = basis + return k, err +} + +func (k *Key) Commit(values []fr.Element) (commitment bls12378.G1Affine, knowledgeProof bls12378.G1Affine, err error) { + + if len(values) != len(k.basis) { + err = fmt.Errorf("unexpected number of values") + return + } + + // TODO @gbotrel this will spawn more than one task, see + // https://github.com/ConsenSys/gnark-crypto/issues/269 + config := ecc.MultiExpConfig{ + NbTasks: 1, // TODO Experiment + } + + if _, err = commitment.MultiExp(k.basis, values, config); err != nil { + return + } + + _, err = knowledgeProof.MultiExp(k.basisExpSigma, values, config) + + return +} + +// VerifyKnowledgeProof checks if the proof of knowledge is valid +func (k *Key) VerifyKnowledgeProof(commitment bls12378.G1Affine, knowledgeProof bls12378.G1Affine) error { + + if !commitment.IsInSubGroup() || !knowledgeProof.IsInSubGroup() { + return fmt.Errorf("subgroup check failed") + } + + product, err := bls12378.Pair([]bls12378.G1Affine{commitment, knowledgeProof}, []bls12378.G2Affine{k.g, k.gRootSigmaNeg}) + if err != nil { + return err + } + if product.IsOne() { + return nil + } + return fmt.Errorf("proof rejected") +} diff --git a/ecc/bls12-378/fr/pedersen/pedersen_test.go b/ecc/bls12-378/fr/pedersen/pedersen_test.go new file mode 100644 index 000000000..745d06334 --- /dev/null +++ b/ecc/bls12-378/fr/pedersen/pedersen_test.go @@ -0,0 +1,90 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package pedersen + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-378" + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" + "github.com/stretchr/testify/assert" + "math/rand" + "testing" +) + +func interfaceSliceToFrSlice(t *testing.T, values ...interface{}) []fr.Element { + res := make([]fr.Element, len(values)) + for i, v := range values { + _, err := res[i].SetInterface(v) + assert.NoError(t, err) + } + return res +} + +func randomFrSlice(t *testing.T, size int) []interface{} { + res := make([]interface{}, size) + var err error + for i := range res { + var v fr.Element + res[i], err = v.SetRandom() + assert.NoError(t, err) + } + return res +} + +func randomOnG1() (bls12378.G1Affine, error) { // TODO: Add to G1.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return bls12378.G1Affine{}, err + } + return bls12378.HashToG1(gBytes, []byte("random on g2")) +} + +func testCommit(t *testing.T, values ...interface{}) { + + basis := make([]bls12378.G1Affine, len(values)) + for i := range basis { + var err error + basis[i], err = randomOnG1() + assert.NoError(t, err) + } + + var ( + key Key + err error + commitment, pok bls12378.G1Affine + ) + + key, err = Setup(basis) + assert.NoError(t, err) + commitment, pok, err = key.Commit(interfaceSliceToFrSlice(t, values...)) + assert.NoError(t, err) + assert.NoError(t, key.VerifyKnowledgeProof(commitment, pok)) + + pok.Neg(&pok) + assert.NotNil(t, key.VerifyKnowledgeProof(commitment, pok)) +} + +func TestCommitToOne(t *testing.T) { + testCommit(t, 1) +} + +func TestCommitSingle(t *testing.T) { + testCommit(t, randomFrSlice(t, 1)...) +} + +func TestCommitFiveElements(t *testing.T) { + testCommit(t, randomFrSlice(t, 5)...) +} diff --git a/ecc/bls12-378/fr/plookup/vector.go b/ecc/bls12-378/fr/plookup/vector.go index b6358ab70..1930e244a 100644 --- a/ecc/bls12-378/fr/plookup/vector.go +++ b/ecc/bls12-378/fr/plookup/vector.go @@ -125,7 +125,8 @@ func evaluateAccumulationPolynomial(lf, lt, lh1, lh2 []fr.Element, beta, gamma f // evaluateNumBitReversed computes the evaluation (shifted, bit reversed) of h where // h = (x-1)*z*(1+\beta)*(\gamma+f)*(\gamma(1+\beta) + t+ \beta*t(gX)) - -// (x-1)*z(gX)*(\gamma(1+\beta) + h_{1} + \beta*h_{1}(gX))*(\gamma(1+\beta) + h_{2} + \beta*h_{2}(gX) ) +// +// (x-1)*z(gX)*(\gamma(1+\beta) + h_{1} + \beta*h_{1}(gX))*(\gamma(1+\beta) + h_{2} + \beta*h_{2}(gX) ) // // * cz, ch1, ch2, ct, cf are the polynomials z, h1, h2, t, f in canonical basis // * _lz, _lh1, _lh2, _lt, _lf are the polynomials z, h1, h2, t, f in shifted Lagrange basis (domainBig) @@ -370,7 +371,6 @@ func computeQuotientCanonical(alpha fr.Element, lh, lh0, lhn, lh1h2 []fr.Element // before generating a lookup proof), the commitment needs to be done on the // table sorted. Otherwise the commitment in proof.t will not be the same as // the public commitment: it will contain the same values, but permuted. -// func ProveLookupVector(srs *kzg.SRS, f, t Table) (ProofLookupVector, error) { // res diff --git a/ecc/bls12-378/fr/polynomial/multilin.go b/ecc/bls12-378/fr/polynomial/multilin.go index ddfa19833..c419dfb99 100644 --- a/ecc/bls12-378/fr/polynomial/multilin.go +++ b/ecc/bls12-378/fr/polynomial/multilin.go @@ -18,6 +18,7 @@ package polynomial import ( "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" + "math/bits" ) // MultiLin tracks the values of a (dense i.e. not sparse) multilinear polynomial @@ -46,44 +47,62 @@ func (m *MultiLin) Fold(r fr.Element) { *m = (*m)[:mid] } +func (m MultiLin) Sum() fr.Element { + s := m[0] + for i := 1; i < len(m); i++ { + s.Add(&s, &m[i]) + } + return s +} + +func _clone(m MultiLin, p *Pool) MultiLin { + if p == nil { + return m.Clone() + } else { + return p.Clone(m) + } +} + +func _dump(m MultiLin, p *Pool) { + if p != nil { + p.Dump(m) + } +} + // Evaluate extrapolate the value of the multilinear polynomial corresponding to m // on the given coordinates -func (m MultiLin) Evaluate(coordinates []fr.Element) fr.Element { +func (m MultiLin) Evaluate(coordinates []fr.Element, p *Pool) fr.Element { // Folding is a mutating operation - bkCopy := m.Clone() + bkCopy := _clone(m, p) // Evaluate step by step through repeated folding (i.e. evaluation at the first remaining variable) for _, r := range coordinates { bkCopy.Fold(r) } - return bkCopy[0] + result := bkCopy[0] + + _dump(bkCopy, p) + return result } -// Clone creates a deep copy of a book-keeping table. +// Clone creates a deep copy of a bookkeeping table. // Both multilinear interpolation and sumcheck require folding an underlying // array, but folding changes the array. To do both one requires a deep copy -// of the book-keeping table. +// of the bookkeeping table. func (m MultiLin) Clone() MultiLin { - tableDeepCopy := Make(len(m)) - copy(tableDeepCopy, m) - return tableDeepCopy + res := make(MultiLin, len(m)) + copy(res, m) + return res } // Add two bookKeepingTables func (m *MultiLin) Add(left, right MultiLin) { size := len(left) // Check that left and right have the same size - if len(right) != size { - panic("Left and right do not have the right size") + if len(right) != size || len(*m) != size { + panic("left, right and destination must have the right size") } - // Reallocate the table if necessary - if cap(*m) < size { - *m = make([]fr.Element, size) - } - - // Resize the destination table - *m = (*m)[:size] // Add elementwise for i := 0; i < size; i++ { @@ -93,15 +112,17 @@ func (m *MultiLin) Add(left, right MultiLin) { // EvalEq computes Eq(q₁, ... , qₙ, h₁, ... , hₙ) = Π₁ⁿ Eq(qᵢ, hᵢ) // where Eq(x,y) = xy + (1-x)(1-y) = 1 - x - y + xy + xy interpolates -// _________________ -// | | | -// | 0 | 1 | -// |_______|_______| -// y | | | -// | 1 | 0 | -// |_______|_______| // -// x +// _________________ +// | | | +// | 0 | 1 | +// |_______|_______| +// y | | | +// | 1 | 0 | +// |_______|_______| +// +// x +// // In other words the polynomial evaluated here is the multilinear extrapolation of // one that evaluates to q' == h' for vectors q', h' of binary values func EvalEq(q, h []fr.Element) fr.Element { @@ -128,10 +149,7 @@ func (m *MultiLin) Eq(q []fr.Element) { n := len(q) if len(*m) != 1< 0 { - i.Sub(fr.Modulus(), &i) - i.Neg(&i) +func (p Polynomial) SetZero() { + for i := 0; i < len(p); i++ { + p[i].SetZero() } - return i } func (p Polynomial) Text(base int) string { @@ -171,12 +175,13 @@ func (p Polynomial) Text(base int) string { continue } - i := signedBigInt(&p[d]) + pD := p[d] + pDText := pD.Text(base) initialLen := builder.Len() - if i.Sign() < 1 { - i.Neg(&i) + if pDText[0] == '-' { + pDText = pDText[1:] if first { builder.WriteString("-") } else { @@ -188,13 +193,8 @@ func (p Polynomial) Text(base int) string { first = false - asInt64 := int64(0) - if i.IsInt64() { - asInt64 = i.Int64() - } - - if asInt64 != 1 || d == 0 { - builder.WriteString(i.Text(base)) + if !pD.IsOne() || d == 0 { + builder.WriteString(pDText) } if builder.Len()-initialLen > 10 { diff --git a/ecc/bls12-378/fr/polynomial/polynomial_test.go b/ecc/bls12-378/fr/polynomial/polynomial_test.go index 73994acd5..7cf5ee479 100644 --- a/ecc/bls12-378/fr/polynomial/polynomial_test.go +++ b/ecc/bls12-378/fr/polynomial/polynomial_test.go @@ -17,10 +17,10 @@ package polynomial import ( + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" + "github.com/stretchr/testify/assert" "math/big" "testing" - - "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" ) func TestPolynomialEval(t *testing.T) { @@ -206,3 +206,13 @@ func TestPolynomialAdd(t *testing.T) { t.Fatal("side effect, _f2 should not have been modified") } } + +func TestPolynomialText(t *testing.T) { + var one, negTwo fr.Element + one.SetOne() + negTwo.SetInt64(-2) + + p := Polynomial{one, negTwo, one} + + assert.Equal(t, "X² - 2X + 1", p.Text(10)) +} diff --git a/ecc/bls12-378/fr/polynomial/pool.go b/ecc/bls12-378/fr/polynomial/pool.go index 15c16d53f..e00970f67 100644 --- a/ecc/bls12-378/fr/polynomial/pool.go +++ b/ecc/bls12-378/fr/polynomial/pool.go @@ -17,114 +17,187 @@ package polynomial import ( + "encoding/json" "fmt" "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" "reflect" + "runtime" + "sort" "sync" "unsafe" ) // Memory management for polynomials -// Copied verbatim from gkr repo +// WARNING: This is not thread safe TODO: Make sure that is not a problem +// TODO: There is a lot of "unsafe" memory management here and needs to be vetted thoroughly -// Sets a maximum for the array size we keep in pool -const maxNForLargePool int = 1 << 24 -const maxNForSmallPool int = 256 +type sizedPool struct { + maxN int + pool sync.Pool + stats poolStats +} + +type inUseData struct { + allocatedFor []uintptr + pool *sizedPool +} -// Aliases because it is annoying to use arrays in all the places -type largeArr = [maxNForLargePool]fr.Element -type smallArr = [maxNForSmallPool]fr.Element +type Pool struct { + //lock sync.Mutex + inUse map[*fr.Element]inUseData + subPools []sizedPool +} -var rC = sync.Map{} +func (p *sizedPool) get(n int) *fr.Element { + p.stats.maake(n) + return p.pool.Get().(*fr.Element) +} -var ( - largePool = sync.Pool{ - New: func() interface{} { - var res largeArr - return &res - }, - } - smallPool = sync.Pool{ - New: func() interface{} { - var res smallArr - return &res - }, +func (p *sizedPool) put(ptr *fr.Element) { + p.stats.dump() + p.pool.Put(ptr) +} + +func NewPool(maxN ...int) (pool Pool) { + + sort.Ints(maxN) + pool = Pool{ + inUse: make(map[*fr.Element]inUseData), + subPools: make([]sizedPool, len(maxN)), } -) -// ClearPool Clears the pool completely, shields against memory leaks -// Eg: if we forgot to dump a polynomial at some point, this will ensure the value get dumped eventually -// Returns how many polynomials were cleared that way -func ClearPool() int { - res := 0 - rC.Range(func(k, _ interface{}) bool { - switch ptr := k.(type) { - case *largeArr: - largePool.Put(ptr) - case *smallArr: - smallPool.Put(ptr) - default: - panic(fmt.Sprintf("tried to clear %v", reflect.TypeOf(ptr))) + for i := range pool.subPools { + subPool := &pool.subPools[i] + subPool.maxN = maxN[i] + subPool.pool = sync.Pool{ + New: func() interface{} { + subPool.stats.Allocated++ + return getDataPointer(make([]fr.Element, 0, subPool.maxN)) + }, } - res++ - return true - }) - return res + } + return } -// CountPool Returns the number of elements in the pool without mutating it -func CountPool() int { - res := 0 - rC.Range(func(_, _ interface{}) bool { - res++ - return true - }) - return res +func (p *Pool) findCorrespondingPool(n int) *sizedPool { + poolI := 0 + for poolI < len(p.subPools) && n > p.subPools[poolI].maxN { + poolI++ + } + return &p.subPools[poolI] // out of bounds error here would mean that n is too large } -// Make tries to find a reusable polynomial or allocates a new one -func Make(n int) []fr.Element { - if n > maxNForLargePool { - panic(fmt.Sprintf("been provided with size of %v but the maximum is %v", n, maxNForLargePool)) +func (p *Pool) Make(n int) []fr.Element { + pool := p.findCorrespondingPool(n) + ptr := pool.get(n) + p.addInUse(ptr, pool) + return unsafe.Slice(ptr, n) +} + +// Dump dumps a set of polynomials into the pool +func (p *Pool) Dump(slices ...[]fr.Element) { + for _, slice := range slices { + ptr := getDataPointer(slice) + if metadata, ok := p.inUse[ptr]; ok { + delete(p.inUse, ptr) + metadata.pool.put(ptr) + } else { + panic("attempting to dump a slice not created by the pool") + } } +} + +func (p *Pool) addInUse(ptr *fr.Element, pool *sizedPool) { + pcs := make([]uintptr, 2) + n := runtime.Callers(3, pcs) - if n <= maxNForSmallPool { - ptr := smallPool.Get().(*smallArr) - rC.Store(ptr, struct{}{}) // registers the pointer being used - return (*ptr)[:n] + if prevPcs, ok := p.inUse[ptr]; ok { // TODO: remove if unnecessary for security + panic(fmt.Errorf("re-allocated non-dumped slice, previously allocated at %v", runtime.CallersFrames(prevPcs.allocatedFor))) } + p.inUse[ptr] = inUseData{ + allocatedFor: pcs[:n], + pool: pool, + } +} - ptr := largePool.Get().(*largeArr) - rC.Store(ptr, struct{}{}) // remember we allocated the pointer is being used - return (*ptr)[:n] +func printFrame(frame runtime.Frame) { + fmt.Printf("\t%s line %d, function %s\n", frame.File, frame.Line, frame.Function) } -// Dump dumps a set of polynomials into the pool -// Returns the number of deallocated polys -func Dump(arrs ...[]fr.Element) int { - cnt := 0 - for _, arr := range arrs { - ptr := ptr(arr) - pool := &smallPool - if len(arr) > maxNForSmallPool { - pool = &largePool - } - // If the rC did not register, then - // either the array was allocated somewhere else which can be ignored - // otherwise a double put which MUST be ignored - if _, ok := rC.Load(ptr); ok { - pool.Put(ptr) - // And deregisters the ptr - rC.Delete(ptr) - cnt++ +func (p *Pool) printInUse() { + fmt.Println("slices never dumped allocated at:") + for _, pcs := range p.inUse { + fmt.Println("-------------------------") + + var frame runtime.Frame + frames := runtime.CallersFrames(pcs.allocatedFor) + more := true + for more { + frame, more = frames.Next() + printFrame(frame) } } - return cnt } -func ptr(m []fr.Element) unsafe.Pointer { - if cap(m) != maxNForSmallPool && cap(m) != maxNForLargePool { - panic(fmt.Sprintf("can't cast to large or small array, the put array's is %v it should have capacity %v or %v", cap(m), maxNForLargePool, maxNForSmallPool)) +type poolStats struct { + Used int + Allocated int + ReuseRate float64 + InUse int + GreatestNUsed int + SmallestNUsed int +} + +type poolsStats struct { + SubPools []poolStats + InUse int +} + +func (s *poolStats) maake(n int) { + s.Used++ + s.InUse++ + if n > s.GreatestNUsed { + s.GreatestNUsed = n + } + if s.SmallestNUsed == 0 || s.SmallestNUsed > n { + s.SmallestNUsed = n + } +} + +func (s *poolStats) dump() { + s.InUse-- +} + +func (s *poolStats) finalize() { + s.ReuseRate = float64(s.Used) / float64(s.Allocated) +} + +func getDataPointer(slice []fr.Element) *fr.Element { + header := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) + return (*fr.Element)(unsafe.Pointer(header.Data)) +} + +func (p *Pool) PrintPoolStats() { + InUse := 0 + subStats := make([]poolStats, len(p.subPools)) + for i := range p.subPools { + subPool := &p.subPools[i] + subPool.stats.finalize() + subStats[i] = subPool.stats + InUse += subPool.stats.InUse } - return unsafe.Pointer(&m[0]) + + poolsStats := poolsStats{ + SubPools: subStats, + InUse: InUse, + } + serialized, _ := json.MarshalIndent(poolsStats, "", " ") + fmt.Println(string(serialized)) + p.printInUse() +} + +func (p *Pool) Clone(slice []fr.Element) []fr.Element { + res := p.Make(len(slice)) + copy(res, slice) + return res } diff --git a/ecc/bls12-378/fr/sumcheck/sumcheck.go b/ecc/bls12-378/fr/sumcheck/sumcheck.go new file mode 100644 index 000000000..5638cda10 --- /dev/null +++ b/ecc/bls12-378/fr/sumcheck/sumcheck.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = &transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return fmt.Errorf("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/ecc/bls12-378/fr/sumcheck/sumcheck_test.go b/ecc/bls12-378/fr/sumcheck/sumcheck_test.go new file mode 100644 index 000000000..6867729eb --- /dev/null +++ b/ecc/bls12-378/fr/sumcheck/sumcheck_test.go @@ -0,0 +1,161 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/ecc/bls12-378/fr/test_vector_utils/test_vector_utils.go b/ecc/bls12-378/fr/test_vector_utils/test_vector_utils.go new file mode 100644 index 000000000..de8684506 --- /dev/null +++ b/ecc/bls12-378/fr/test_vector_utils/test_vector_utils.go @@ -0,0 +1,429 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package test_vector_utils + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-378/fr/polynomial" + "hash" + + "os" + "path/filepath" + "reflect" + "sort" + "strconv" + "strings" +) + +type ElementTriplet struct { + key1 fr.Element + key2 fr.Element + key2Present bool + value fr.Element + used bool +} + +func (t *ElementTriplet) CmpKey(o *ElementTriplet) int { + if cmp1 := t.key1.Cmp(&o.key1); cmp1 != 0 { + return cmp1 + } + + if t.key2Present { + if o.key2Present { + return t.key2.Cmp(&o.key2) + } + return 1 + } else { + if o.key2Present { + return -1 + } + return 0 + } +} + +var MapCache = make(map[string]*ElementMap) + +func ElementMapFromFile(path string) (*ElementMap, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if h, ok := MapCache[path]; ok { + return h, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var asMap map[string]interface{} + if err = json.Unmarshal(bytes, &asMap); err != nil { + return nil, err + } + + var h ElementMap + if h, err = CreateElementMap(asMap); err == nil { + MapCache[path] = &h + } + + return &h, err + + } else { + return nil, err + } +} + +func CreateElementMap(rawMap map[string]interface{}) (ElementMap, error) { + res := make(ElementMap, 0, len(rawMap)) + + for k, v := range rawMap { + var entry ElementTriplet + if _, err := SetElement(&entry.value, v); err != nil { + return nil, err + } + + key := strings.Split(k, ",") + switch len(key) { + case 1: + entry.key2Present = false + case 2: + entry.key2Present = true + if _, err := SetElement(&entry.key2, key[1]); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("cannot parse %T as one or two field elements", v) + } + if _, err := SetElement(&entry.key1, key[0]); err != nil { + return nil, err + } + + res = append(res, &entry) + } + + res.sort() + return res, nil +} + +type ElementMap []*ElementTriplet + +type MapHash struct { + Map *ElementMap + state fr.Element + stateValid bool +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func (m *MapHash) Write(p []byte) (n int, err error) { + var x fr.Element + for i := 0; i < len(p); i += fr.Bytes { + x.SetBytes(p[i:min(len(p), i+fr.Bytes)]) + if err = m.write(x); err != nil { + return + } + } + n = len(p) + return +} + +func (m *MapHash) Sum(b []byte) []byte { + mP := *m + if _, err := mP.Write(b); err != nil { + panic(err) + } + bytes := mP.state.Bytes() + return bytes[:] +} + +func (m *MapHash) Reset() { + m.stateValid = false +} + +func (m *MapHash) Size() int { + return fr.Bytes +} + +func (m *MapHash) BlockSize() int { + return fr.Bytes +} + +func (m *MapHash) write(x fr.Element) error { + X := &x + Y := &m.state + if !m.stateValid { + Y = nil + } + var err error + if m.state, err = m.Map.FindPair(X, Y); err == nil { + m.stateValid = true + } + return err +} + +func (t *ElementTriplet) writeKey(sb *strings.Builder) { + sb.WriteRune('"') + sb.WriteString(t.key1.String()) + if t.key2Present { + sb.WriteRune(',') + sb.WriteString(t.key2.String()) + } + sb.WriteRune('"') +} +func (m *ElementMap) UnusedEntries() []interface{} { + unused := make([]interface{}, 0) + for _, v := range *m { + if !v.used { + var vInterface interface{} + if v.key2Present { + vInterface = []interface{}{ElementToInterface(&v.key1), ElementToInterface(&v.key2)} + } else { + vInterface = ElementToInterface(&v.key1) + } + unused = append(unused, vInterface) + } + } + return unused +} + +func (m *ElementMap) sort() { + sort.Slice(*m, func(i, j int) bool { + return (*m)[i].CmpKey((*m)[j]) <= 0 + }) +} + +func (m *ElementMap) find(toFind *ElementTriplet) (fr.Element, error) { + i := sort.Search(len(*m), func(i int) bool { return (*m)[i].CmpKey(toFind) >= 0 }) + + if i < len(*m) && (*m)[i].CmpKey(toFind) == 0 { + (*m)[i].used = true + return (*m)[i].value, nil + } + var sb strings.Builder + sb.WriteString("no value available for input ") + toFind.writeKey(&sb) + return fr.Element{}, fmt.Errorf(sb.String()) +} + +func (m *ElementMap) FindPair(x *fr.Element, y *fr.Element) (fr.Element, error) { + + toFind := ElementTriplet{ + key1: *x, + key2Present: y != nil, + } + + if y != nil { + toFind.key2 = *y + } + + return m.find(&toFind) +} + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type MessageCounter struct { + startState uint64 + state uint64 + step uint64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += uint64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + uint64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: uint64(startState), state: uint64(startState), step: uint64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + text := x.Text(10) + if len(text) < 10 && !strings.Contains(text, "/") { + if i, err := strconv.Atoi(text); err != nil { + panic(err.Error()) + } else { + return i + } + } + return text +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/ecc/bls12-378/g1.go b/ecc/bls12-378/g1.go index 6384c93df..cd9155b76 100644 --- a/ecc/bls12-378/g1.go +++ b/ecc/bls12-378/g1.go @@ -17,13 +17,12 @@ package bls12378 import ( - "math/big" - "runtime" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-378/fp" "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" + "runtime" ) // G1Affine point in affine coordinates @@ -36,7 +35,7 @@ type G1Jac struct { X, Y, Z fp.Element } -// g1JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +// g1JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) type g1JacExtended struct { X, Y, ZZ, ZZZ fp.Element } @@ -50,6 +49,13 @@ func (p *G1Affine) Set(a *G1Affine) *G1Affine { return p } +// setInfinity sets p to O +func (p *G1Affine) setInfinity() *G1Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + // ScalarMultiplication computes and returns p = a ⋅ s func (p *G1Affine) ScalarMultiplication(a *G1Affine, s *big.Int) *G1Affine { var _p G1Jac @@ -381,6 +387,7 @@ func (p *G1Jac) IsOnCurve() bool { func (p *G1Jac) IsInSubGroup() bool { var res G1Jac + res.phi(p). ScalarMultiplication(&res, &xGen). ScalarMultiplication(&res, &xGen). @@ -472,8 +479,8 @@ func (p *G1Jac) mulGLV(a *G1Jac, s *big.Int) *G1Jac { // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // we don't target constant-timeness so we check first if we increase the bounds or not maxBit := k1.BitLen() @@ -583,15 +590,15 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { return p } - var A, B, X1ZZ2, X2ZZ1, Y1ZZZ2, Y2ZZZ1 fp.Element + var A, B, U1, U2, S1, S2 fp.Element // p2: q, p1: p - X2ZZ1.Mul(&q.X, &p.ZZ) - X1ZZ2.Mul(&p.X, &q.ZZ) - A.Sub(&X2ZZ1, &X1ZZ2) - Y2ZZZ1.Mul(&q.Y, &p.ZZZ) - Y1ZZZ2.Mul(&p.Y, &q.ZZZ) - B.Sub(&Y2ZZZ1, &Y1ZZZ2) + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) if A.IsZero() { if B.IsZero() { @@ -603,11 +610,7 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { return p } - var U1, U2, S1, S2, P, R, PP, PPP, Q, V fp.Element - U1.Mul(&p.X, &q.ZZ) - U2.Mul(&q.X, &p.ZZ) - S1.Mul(&p.Y, &q.ZZZ) - S2.Mul(&q.Y, &p.ZZZ) + var P, R, PP, PPP, Q, V fp.Element P.Sub(&U2, &U1) R.Sub(&S2, &S1) PP.Square(&P) @@ -632,6 +635,8 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { // double point in Jacobian extended coordinates // http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well func (p *g1JacExtended) double(q *g1JacExtended) *g1JacExtended { var U, V, W, S, XX, M fp.Element @@ -874,95 +879,72 @@ func BatchJacobianToAffineG1(points []G1Jac) []G1Affine { // and return resulting points in affine coordinates // uses a simple windowed-NAF like exponentiation algorithm func BatchScalarMultiplicationG1(base *G1Affine, scalars []fr.Element) []G1Affine { - // approximate cost in group ops is // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) nbPoints := uint64(len(scalars)) min := ^uint64(0) bestC := 0 - for c := 2; c < 18; c++ { - cost := uint64(1 << (c - 1)) - nbChunks := uint64(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - cost += nbPoints * ((fr.Limbs * 64) + nbChunks) + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c - 1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add if cost < min { min = cost bestC = c } } c := uint64(bestC) // window size - nbChunks := int(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c } - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // precompute all powers of base for our window // note here that if performance is critical, we can implement as in the msmX methods // this allocation to be on the stack - baseTable := make([]G1Jac, (1 << (c - 1))) - baseTable[0].Set(&g1Infinity) - baseTable[0].AddMixed(base) + baseTable := make([]G1Jac, (1 << (maxC - 1))) + baseTable[0].FromAffine(base) for i := 1; i < len(baseTable); i++ { baseTable[i] = baseTable[i-1] baseTable[i].AddMixed(base) } - - pScalars, _ := partitionScalars(scalars, c, false, runtime.NumCPU()) - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := 0; chunk < nbChunks; chunk++ { - jc := uint64(uint64(chunk) * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = (64%c) != 0 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } // convert our base exp table into affine to use AddMixed baseTableAff := BatchJacobianToAffineG1(baseTable) toReturn := make([]G1Jac, len(scalars)) + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + // for each digit, take value in the base table, double it c time, voilà. - parallel.Execute(len(pScalars), func(start, end int) { + parallel.Execute(len(scalars), func(start, end int) { var p G1Jac for i := start; i < end; i++ { p.Set(&g1Infinity) for chunk := nbChunks - 1; chunk >= 0; chunk-- { - s := selectors[chunk] if chunk != nbChunks-1 { for j := uint64(0); j < c; j++ { p.DoubleAssign() } } + offset := chunk * len(scalars) + digit := digits[i+offset] - bits := (pScalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (pScalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { + if digit == 0 { continue } // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { + if digit&1 == 0 { // add - p.AddMixed(&baseTableAff[bits-1]) + p.AddMixed(&baseTableAff[(digit>>1)-1]) } else { // sub - t := baseTableAff[bits & ^msbWindow] + t := baseTableAff[digit>>1] t.Neg(&t) p.AddMixed(&t) } @@ -976,3 +958,54 @@ func BatchScalarMultiplicationG1(base *G1Affine, scalars []fr.Element) []G1Affin toReturnAff := BatchJacobianToAffineG1(toReturn) return toReturnAff } + +// batch add affine coordinates +// using batch inversion +// special cases (doubling, infinity) must be filtered out before this call +func batchAddG1Affine[TP pG1Affine, TPP ppG1Affine, TC cG1Affine](R *TPP, P *TP, batchSize int) { + var lambda, lambdain TC + + // add part + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // invert denominator using montgomery batch invert technique + { + var accumulator fp.Element + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var d fp.Element + var rr G1Affine + + // add part + for j := 0; j < batchSize; j++ { + // computa lambda + d.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &d) + + // compute X, Y + rr.X.Square(&lambda[j]) + rr.X.Sub(&rr.X, &(*R)[j].X) + rr.X.Sub(&rr.X, &(*P)[j].X) + d.Sub(&(*R)[j].X, &rr.X) + rr.Y.Mul(&lambda[j], &d) + rr.Y.Sub(&rr.Y, &(*R)[j].Y) + (*R)[j].Set(&rr) + } +} diff --git a/ecc/bls12-378/g1_test.go b/ecc/bls12-378/g1_test.go index dccaca15c..93f106197 100644 --- a/ecc/bls12-378/g1_test.go +++ b/ecc/bls12-378/g1_test.go @@ -19,6 +19,7 @@ package bls12378 import ( "fmt" "math/big" + "math/rand" "testing" "github.com/consensys/gnark-crypto/ecc/bls12-378/fp" @@ -338,7 +339,7 @@ func TestG1AffineOps(t *testing.T) { r := fr.Modulus() var g G1Jac - g.mulGLV(&g1Gen, r) + g.ScalarMultiplication(&g1Gen, r) var scalar, blindedScalar, rminusone big.Int var op1, op2, op3, gneg G1Jac @@ -458,8 +459,7 @@ func TestG1AffineBatchScalarMultiplication(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } result := BatchScalarMultiplicationG1(&g1GenAff, sampleScalars[:]) @@ -472,7 +472,7 @@ func TestG1AffineBatchScalarMultiplication(t *testing.T) { var expectedJac G1Jac var expected G1Affine var b big.Int - expectedJac.mulGLV(&g1Gen, sampleScalars[i].ToBigInt(&b)) + expectedJac.ScalarMultiplication(&g1Gen, sampleScalars[i].ToBigIntRegular(&b)) expected.FromJacobian(&expectedJac) if !result[i].Equal(&expected) { return false @@ -499,6 +499,33 @@ func BenchmarkG1JacIsInSubGroup(b *testing.B) { } +func BenchmarkBatchAddG1Affine(b *testing.B) { + + var P, R pG1AffineC16 + var RR ppG1AffineC16 + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBasesG1(P[:]) + fillBenchBasesG1(R[:]) + + for i := 0; i < len(ridx); i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAddG1Affine[pG1AffineC16, ppG1AffineC16, cG1AffineC16](&RR, &P, len(P)) + } +} + func BenchmarkG1AffineBatchScalarMultiplication(b *testing.B) { // ensure every words of the scalars are filled var mixer fr.Element @@ -511,8 +538,7 @@ func BenchmarkG1AffineBatchScalarMultiplication(b *testing.B) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } for i := 5; i <= pow; i++ { diff --git a/ecc/bls12-378/g2.go b/ecc/bls12-378/g2.go index 143474f59..b67b46779 100644 --- a/ecc/bls12-378/g2.go +++ b/ecc/bls12-378/g2.go @@ -17,13 +17,12 @@ package bls12378 import ( - "math/big" - "runtime" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" "github.com/consensys/gnark-crypto/ecc/bls12-378/internal/fptower" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" + "runtime" ) // G2Affine point in affine coordinates @@ -36,7 +35,7 @@ type G2Jac struct { X, Y, Z fptower.E2 } -// g2JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +// g2JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) type g2JacExtended struct { X, Y, ZZ, ZZZ fptower.E2 } @@ -55,6 +54,13 @@ func (p *G2Affine) Set(a *G2Affine) *G2Affine { return p } +// setInfinity sets p to O +func (p *G2Affine) setInfinity() *G2Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + // ScalarMultiplication computes and returns p = a ⋅ s func (p *G2Affine) ScalarMultiplication(a *G2Affine, s *big.Int) *G2Affine { var _p G2Jac @@ -370,7 +376,8 @@ func (p *G2Jac) IsOnCurve() bool { } // https://eprint.iacr.org/2021/1130.pdf, sec.4 -// ψ(p) = x₀ P +// and https://eprint.iacr.org/2022/352.pdf, sec. 4.2 +// ψ(p) = [x₀]P func (p *G2Jac) IsInSubGroup() bool { var res, tmp G2Jac tmp.psi(p) @@ -471,8 +478,8 @@ func (p *G2Jac) mulGLV(a *G2Jac, s *big.Int) *G2Jac { // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // we don't target constant-timeness so we check first if we increase the bounds or not maxBit := k1.BitLen() @@ -600,15 +607,15 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { return p } - var A, B, X1ZZ2, X2ZZ1, Y1ZZZ2, Y2ZZZ1 fptower.E2 + var A, B, U1, U2, S1, S2 fptower.E2 // p2: q, p1: p - X2ZZ1.Mul(&q.X, &p.ZZ) - X1ZZ2.Mul(&p.X, &q.ZZ) - A.Sub(&X2ZZ1, &X1ZZ2) - Y2ZZZ1.Mul(&q.Y, &p.ZZZ) - Y1ZZZ2.Mul(&p.Y, &q.ZZZ) - B.Sub(&Y2ZZZ1, &Y1ZZZ2) + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) if A.IsZero() { if B.IsZero() { @@ -620,11 +627,7 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { return p } - var U1, U2, S1, S2, P, R, PP, PPP, Q, V fptower.E2 - U1.Mul(&p.X, &q.ZZ) - U2.Mul(&q.X, &p.ZZ) - S1.Mul(&p.Y, &q.ZZZ) - S2.Mul(&q.Y, &p.ZZZ) + var P, R, PP, PPP, Q, V fptower.E2 P.Sub(&U2, &U1) R.Sub(&S2, &S1) PP.Square(&P) @@ -649,6 +652,8 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { // double point in Jacobian extended coordinates // http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well func (p *g2JacExtended) double(q *g2JacExtended) *g2JacExtended { var U, V, W, S, XX, M fptower.E2 @@ -872,93 +877,70 @@ func (p *g2Proj) FromAffine(Q *G2Affine) *g2Proj { // and return resulting points in affine coordinates // uses a simple windowed-NAF like exponentiation algorithm func BatchScalarMultiplicationG2(base *G2Affine, scalars []fr.Element) []G2Affine { - // approximate cost in group ops is // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) nbPoints := uint64(len(scalars)) min := ^uint64(0) bestC := 0 - for c := 2; c < 18; c++ { - cost := uint64(1 << (c - 1)) - nbChunks := uint64(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - cost += nbPoints * ((fr.Limbs * 64) + nbChunks) + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c - 1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add if cost < min { min = cost bestC = c } } c := uint64(bestC) // window size - nbChunks := int(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c } - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // precompute all powers of base for our window // note here that if performance is critical, we can implement as in the msmX methods // this allocation to be on the stack - baseTable := make([]G2Jac, (1 << (c - 1))) - baseTable[0].Set(&g2Infinity) - baseTable[0].AddMixed(base) + baseTable := make([]G2Jac, (1 << (maxC - 1))) + baseTable[0].FromAffine(base) for i := 1; i < len(baseTable); i++ { baseTable[i] = baseTable[i-1] baseTable[i].AddMixed(base) } - - pScalars, _ := partitionScalars(scalars, c, false, runtime.NumCPU()) - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := 0; chunk < nbChunks; chunk++ { - jc := uint64(uint64(chunk) * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = (64%c) != 0 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } toReturn := make([]G2Affine, len(scalars)) + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + // for each digit, take value in the base table, double it c time, voilà. - parallel.Execute(len(pScalars), func(start, end int) { + parallel.Execute(len(scalars), func(start, end int) { var p G2Jac for i := start; i < end; i++ { p.Set(&g2Infinity) for chunk := nbChunks - 1; chunk >= 0; chunk-- { - s := selectors[chunk] if chunk != nbChunks-1 { for j := uint64(0); j < c; j++ { p.DoubleAssign() } } + offset := chunk * len(scalars) + digit := digits[i+offset] - bits := (pScalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (pScalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { + if digit == 0 { continue } // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { + if digit&1 == 0 { // add - p.AddAssign(&baseTable[bits-1]) + p.AddAssign(&baseTable[(digit>>1)-1]) } else { // sub - t := baseTable[bits & ^msbWindow] + t := baseTable[digit>>1] t.Neg(&t) p.AddAssign(&t) } @@ -971,3 +953,54 @@ func BatchScalarMultiplicationG2(base *G2Affine, scalars []fr.Element) []G2Affin }) return toReturn } + +// batch add affine coordinates +// using batch inversion +// special cases (doubling, infinity) must be filtered out before this call +func batchAddG2Affine[TP pG2Affine, TPP ppG2Affine, TC cG2Affine](R *TPP, P *TP, batchSize int) { + var lambda, lambdain TC + + // add part + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // invert denominator using montgomery batch invert technique + { + var accumulator fptower.E2 + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var d fptower.E2 + var rr G2Affine + + // add part + for j := 0; j < batchSize; j++ { + // computa lambda + d.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &d) + + // compute X, Y + rr.X.Square(&lambda[j]) + rr.X.Sub(&rr.X, &(*R)[j].X) + rr.X.Sub(&rr.X, &(*P)[j].X) + d.Sub(&(*R)[j].X, &rr.X) + rr.Y.Mul(&lambda[j], &d) + rr.Y.Sub(&rr.Y, &(*R)[j].Y) + (*R)[j].Set(&rr) + } +} diff --git a/ecc/bls12-378/g2_test.go b/ecc/bls12-378/g2_test.go index 21e79d723..e2eaa4454 100644 --- a/ecc/bls12-378/g2_test.go +++ b/ecc/bls12-378/g2_test.go @@ -19,6 +19,7 @@ package bls12378 import ( "fmt" "math/big" + "math/rand" "testing" "github.com/consensys/gnark-crypto/ecc/bls12-378/internal/fptower" @@ -339,7 +340,7 @@ func TestG2AffineOps(t *testing.T) { r := fr.Modulus() var g G2Jac - g.mulGLV(&g2Gen, r) + g.ScalarMultiplication(&g2Gen, r) var scalar, blindedScalar, rminusone big.Int var op1, op2, op3, gneg G2Jac @@ -464,8 +465,7 @@ func TestG2AffineBatchScalarMultiplication(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } result := BatchScalarMultiplicationG2(&g2GenAff, sampleScalars[:]) @@ -478,7 +478,7 @@ func TestG2AffineBatchScalarMultiplication(t *testing.T) { var expectedJac G2Jac var expected G2Affine var b big.Int - expectedJac.mulGLV(&g2Gen, sampleScalars[i].ToBigInt(&b)) + expectedJac.ScalarMultiplication(&g2Gen, sampleScalars[i].ToBigIntRegular(&b)) expected.FromJacobian(&expectedJac) if !result[i].Equal(&expected) { return false @@ -505,6 +505,33 @@ func BenchmarkG2JacIsInSubGroup(b *testing.B) { } +func BenchmarkBatchAddG2Affine(b *testing.B) { + + var P, R pG2AffineC16 + var RR ppG2AffineC16 + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBasesG2(P[:]) + fillBenchBasesG2(R[:]) + + for i := 0; i < len(ridx); i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAddG2Affine[pG2AffineC16, ppG2AffineC16, cG2AffineC16](&RR, &P, len(P)) + } +} + func BenchmarkG2AffineBatchScalarMultiplication(b *testing.B) { // ensure every words of the scalars are filled var mixer fr.Element @@ -517,8 +544,7 @@ func BenchmarkG2AffineBatchScalarMultiplication(b *testing.B) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } for i := 5; i <= pow; i++ { diff --git a/ecc/bls12-378/hash_to_g1.go b/ecc/bls12-378/hash_to_g1.go index e67e9b40c..46986092d 100644 --- a/ecc/bls12-378/hash_to_g1.go +++ b/ecc/bls12-378/hash_to_g1.go @@ -17,7 +17,6 @@ package bls12378 import ( - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-378/fp" "math/big" @@ -256,35 +255,14 @@ func g1EvalPolynomial(z *fp.Element, monic bool, coefficients []fp.Element, x *f z.Set(&dst) } -// hashToFp hashes msg to count prime field elements. -// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 -func hashToFp(msg, dst []byte, count int) ([]fp.Element, error) { - // 128 bits of security - // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 - const Bytes = 1 + (fp.Bits-1)/8 - const L = 16 + Bytes - - lenInBytes := count * L - pseudoRandomBytes, err := ecc.ExpandMsgXmd(msg, dst, lenInBytes) - if err != nil { - return nil, err - } - - res := make([]fp.Element, count) - for i := 0; i < count; i++ { - res[i].SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) - } - return res, nil -} - // g1Sgn0 is an algebraic substitute for the notion of sign in ordered fields // Namely, every non-zero quadratic residue in a finite field of characteristic =/= 2 has exactly two square roots, one of each sign // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#name-the-sgn0-function // The sign of an element is not obviously related to that of its Montgomery form func g1Sgn0(z *fp.Element) uint64 { - nonMont := *z - nonMont.FromMont() + nonMont := z.Bits() + // m == 1 return nonMont[0] % 2 @@ -302,11 +280,11 @@ func MapToG1(u fp.Element) G1Affine { // EncodeToG1 hashes a message to a point on the G1 curve using the SSWU map. // It is faster than HashToG1, but the result is not uniformly distributed. Unsuitable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func EncodeToG1(msg, dst []byte) (G1Affine, error) { var res G1Affine - u, err := hashToFp(msg, dst, 1) + u, err := fp.Hash(msg, dst, 1) if err != nil { return res, err } @@ -322,9 +300,9 @@ func EncodeToG1(msg, dst []byte) (G1Affine, error) { // HashToG1 hashes a message to a point on the G1 curve using the SSWU map. // Slower than EncodeToG1, but usable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func HashToG1(msg, dst []byte) (G1Affine, error) { - u, err := hashToFp(msg, dst, 2*1) + u, err := fp.Hash(msg, dst, 2*1) if err != nil { return G1Affine{}, err } diff --git a/ecc/bls12-378/hash_to_g1_test.go b/ecc/bls12-378/hash_to_g1_test.go index 69019c4fe..2964a0869 100644 --- a/ecc/bls12-378/hash_to_g1_test.go +++ b/ecc/bls12-378/hash_to_g1_test.go @@ -62,7 +62,7 @@ func TestG1SqrtRatio(t *testing.T) { func TestHashToFpG1(t *testing.T) { for _, c := range encodeToG1Vector.cases { - elems, err := hashToFp([]byte(c.msg), encodeToG1Vector.dst, 1) + elems, err := fp.Hash([]byte(c.msg), encodeToG1Vector.dst, 1) if err != nil { t.Error(err) } @@ -70,7 +70,7 @@ func TestHashToFpG1(t *testing.T) { } for _, c := range hashToG1Vector.cases { - elems, err := hashToFp([]byte(c.msg), hashToG1Vector.dst, 2*1) + elems, err := fp.Hash([]byte(c.msg), hashToG1Vector.dst, 2*1) if err != nil { t.Error(err) } @@ -220,7 +220,7 @@ func BenchmarkHashToG1(b *testing.B) { } } -//TODO: Crude. Do something clever in Jacobian +// TODO: Crude. Do something clever in Jacobian func isOnE1Prime(p G1Affine) bool { var A, B fp.Element @@ -247,7 +247,7 @@ func isOnE1Prime(p G1Affine) bool { return LHS.Equal(&RHS) } -//Only works on simple extensions (two-story towers) +// Only works on simple extensions (two-story towers) func g1CoordSetString(z *fp.Element, s string) { z.SetString(s) } diff --git a/ecc/bls12-378/hash_to_g2.go b/ecc/bls12-378/hash_to_g2.go index 354fcc14d..d20ff1a77 100644 --- a/ecc/bls12-378/hash_to_g2.go +++ b/ecc/bls12-378/hash_to_g2.go @@ -100,7 +100,7 @@ func MapToG2(t fptower.E2) G2Affine { // https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-2.2.2 func EncodeToG2(msg, dst []byte) (G2Affine, error) { var res G2Affine - _t, err := hashToFp(msg, dst, 2) + _t, err := fp.Hash(msg, dst, 2) if err != nil { return res, err } @@ -115,7 +115,7 @@ func EncodeToG2(msg, dst []byte) (G2Affine, error) { // https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-3 func HashToG2(msg, dst []byte) (G2Affine, error) { var res G2Affine - u, err := hashToFp(msg, dst, 4) + u, err := fp.Hash(msg, dst, 4) if err != nil { return res, err } diff --git a/ecc/bls12-378/internal/fptower/e12.go b/ecc/bls12-378/internal/fptower/e12.go index f3dc8789f..39adbff08 100644 --- a/ecc/bls12-378/internal/fptower/e12.go +++ b/ecc/bls12-378/internal/fptower/e12.go @@ -17,7 +17,6 @@ package fptower import ( - "encoding/binary" "errors" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-378/fp" @@ -68,20 +67,6 @@ func (z *E12) SetOne() *E12 { return z } -// ToMont converts to Mont form -func (z *E12) ToMont() *E12 { - z.C0.ToMont() - z.C1.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E12) FromMont() *E12 { - z.C0.FromMont() - z.C1.FromMont() - return z -} - // Add set z=x+y in E12 and return z func (z *E12) Add(x, y *E12) *E12 { z.C0.Add(&x.C0, &y.C0) @@ -119,6 +104,10 @@ func (z *E12) IsZero() bool { return z.C0.IsZero() && z.C1.IsZero() } +func (z *E12) IsOne() bool { + return z.C0.IsOne() && z.C1.IsZero() +} + // Mul set z=x*y in E12 and return z func (z *E12) Mul(x, y *E12) *E12 { var a, b, c E6 @@ -226,9 +215,12 @@ func (z *E12) CyclotomicSquareCompressed(x *E12) *E12 { // DecompressKarabina Karabina's cyclotomic square result // if g3 != 0 -// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// +// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// // if g3 == 0 -// g4 = 2g1g5/g2 +// +// g4 = 2g1g5/g2 // // if g3=g2=0 then g4=g5=g1=0 and g0=1 (x=1) // Theorem 3.1 is well-defined for all x in Gϕₙ\{1} @@ -289,9 +281,12 @@ func (z *E12) DecompressKarabina(x *E12) *E12 { // BatchDecompressKarabina multiple Karabina's cyclotomic square results // if g3 != 0 -// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// +// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// // if g3 == 0 -// g4 = 2g1g5/g2 +// +// g4 = 2g1g5/g2 // // if g3=g2=0 then g4=g5=g1=0 and g0=1 (x=1) // Theorem 3.1 is well-defined for all x in Gϕₙ\{1} @@ -602,8 +597,8 @@ func (z *E12) ExpGLV(x E12, k *big.Int) *E12 { table[14].Mul(&table[11], &table[2]) // bounds on the lattice base vectors guarantee that s1, s2 are len(r)/2 bits long max - s1.SetBigInt(&s[0]).FromMont() - s2.SetBigInt(&s[1]).FromMont() + s1 = s1.SetBigInt(&s[0]).Bits() + s2 = s2.SetBigInt(&s[1]).Bits() // loop starts from len(s1)/2 due to the bounds for i := len(s1) / 2; i >= 0; i-- { @@ -652,93 +647,20 @@ func (z *E12) Unmarshal(buf []byte) error { // Bytes returns the regular (non montgomery) value // of z as a big-endian byte array. -// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... +// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... func (z *E12) Bytes() (r [SizeOfGT]byte) { - _z := *z - _z.FromMont() - binary.BigEndian.PutUint64(r[568:576], _z.C0.B0.A0[0]) - binary.BigEndian.PutUint64(r[560:568], _z.C0.B0.A0[1]) - binary.BigEndian.PutUint64(r[552:560], _z.C0.B0.A0[2]) - binary.BigEndian.PutUint64(r[544:552], _z.C0.B0.A0[3]) - binary.BigEndian.PutUint64(r[536:544], _z.C0.B0.A0[4]) - binary.BigEndian.PutUint64(r[528:536], _z.C0.B0.A0[5]) - - binary.BigEndian.PutUint64(r[520:528], _z.C0.B0.A1[0]) - binary.BigEndian.PutUint64(r[512:520], _z.C0.B0.A1[1]) - binary.BigEndian.PutUint64(r[504:512], _z.C0.B0.A1[2]) - binary.BigEndian.PutUint64(r[496:504], _z.C0.B0.A1[3]) - binary.BigEndian.PutUint64(r[488:496], _z.C0.B0.A1[4]) - binary.BigEndian.PutUint64(r[480:488], _z.C0.B0.A1[5]) - - binary.BigEndian.PutUint64(r[472:480], _z.C0.B1.A0[0]) - binary.BigEndian.PutUint64(r[464:472], _z.C0.B1.A0[1]) - binary.BigEndian.PutUint64(r[456:464], _z.C0.B1.A0[2]) - binary.BigEndian.PutUint64(r[448:456], _z.C0.B1.A0[3]) - binary.BigEndian.PutUint64(r[440:448], _z.C0.B1.A0[4]) - binary.BigEndian.PutUint64(r[432:440], _z.C0.B1.A0[5]) - - binary.BigEndian.PutUint64(r[424:432], _z.C0.B1.A1[0]) - binary.BigEndian.PutUint64(r[416:424], _z.C0.B1.A1[1]) - binary.BigEndian.PutUint64(r[408:416], _z.C0.B1.A1[2]) - binary.BigEndian.PutUint64(r[400:408], _z.C0.B1.A1[3]) - binary.BigEndian.PutUint64(r[392:400], _z.C0.B1.A1[4]) - binary.BigEndian.PutUint64(r[384:392], _z.C0.B1.A1[5]) - - binary.BigEndian.PutUint64(r[376:384], _z.C0.B2.A0[0]) - binary.BigEndian.PutUint64(r[368:376], _z.C0.B2.A0[1]) - binary.BigEndian.PutUint64(r[360:368], _z.C0.B2.A0[2]) - binary.BigEndian.PutUint64(r[352:360], _z.C0.B2.A0[3]) - binary.BigEndian.PutUint64(r[344:352], _z.C0.B2.A0[4]) - binary.BigEndian.PutUint64(r[336:344], _z.C0.B2.A0[5]) - - binary.BigEndian.PutUint64(r[328:336], _z.C0.B2.A1[0]) - binary.BigEndian.PutUint64(r[320:328], _z.C0.B2.A1[1]) - binary.BigEndian.PutUint64(r[312:320], _z.C0.B2.A1[2]) - binary.BigEndian.PutUint64(r[304:312], _z.C0.B2.A1[3]) - binary.BigEndian.PutUint64(r[296:304], _z.C0.B2.A1[4]) - binary.BigEndian.PutUint64(r[288:296], _z.C0.B2.A1[5]) - - binary.BigEndian.PutUint64(r[280:288], _z.C1.B0.A0[0]) - binary.BigEndian.PutUint64(r[272:280], _z.C1.B0.A0[1]) - binary.BigEndian.PutUint64(r[264:272], _z.C1.B0.A0[2]) - binary.BigEndian.PutUint64(r[256:264], _z.C1.B0.A0[3]) - binary.BigEndian.PutUint64(r[248:256], _z.C1.B0.A0[4]) - binary.BigEndian.PutUint64(r[240:248], _z.C1.B0.A0[5]) - - binary.BigEndian.PutUint64(r[232:240], _z.C1.B0.A1[0]) - binary.BigEndian.PutUint64(r[224:232], _z.C1.B0.A1[1]) - binary.BigEndian.PutUint64(r[216:224], _z.C1.B0.A1[2]) - binary.BigEndian.PutUint64(r[208:216], _z.C1.B0.A1[3]) - binary.BigEndian.PutUint64(r[200:208], _z.C1.B0.A1[4]) - binary.BigEndian.PutUint64(r[192:200], _z.C1.B0.A1[5]) - - binary.BigEndian.PutUint64(r[184:192], _z.C1.B1.A0[0]) - binary.BigEndian.PutUint64(r[176:184], _z.C1.B1.A0[1]) - binary.BigEndian.PutUint64(r[168:176], _z.C1.B1.A0[2]) - binary.BigEndian.PutUint64(r[160:168], _z.C1.B1.A0[3]) - binary.BigEndian.PutUint64(r[152:160], _z.C1.B1.A0[4]) - binary.BigEndian.PutUint64(r[144:152], _z.C1.B1.A0[5]) - - binary.BigEndian.PutUint64(r[136:144], _z.C1.B1.A1[0]) - binary.BigEndian.PutUint64(r[128:136], _z.C1.B1.A1[1]) - binary.BigEndian.PutUint64(r[120:128], _z.C1.B1.A1[2]) - binary.BigEndian.PutUint64(r[112:120], _z.C1.B1.A1[3]) - binary.BigEndian.PutUint64(r[104:112], _z.C1.B1.A1[4]) - binary.BigEndian.PutUint64(r[96:104], _z.C1.B1.A1[5]) - - binary.BigEndian.PutUint64(r[88:96], _z.C1.B2.A0[0]) - binary.BigEndian.PutUint64(r[80:88], _z.C1.B2.A0[1]) - binary.BigEndian.PutUint64(r[72:80], _z.C1.B2.A0[2]) - binary.BigEndian.PutUint64(r[64:72], _z.C1.B2.A0[3]) - binary.BigEndian.PutUint64(r[56:64], _z.C1.B2.A0[4]) - binary.BigEndian.PutUint64(r[48:56], _z.C1.B2.A0[5]) - - binary.BigEndian.PutUint64(r[40:48], _z.C1.B2.A1[0]) - binary.BigEndian.PutUint64(r[32:40], _z.C1.B2.A1[1]) - binary.BigEndian.PutUint64(r[24:32], _z.C1.B2.A1[2]) - binary.BigEndian.PutUint64(r[16:24], _z.C1.B2.A1[3]) - binary.BigEndian.PutUint64(r[8:16], _z.C1.B2.A1[4]) - binary.BigEndian.PutUint64(r[0:8], _z.C1.B2.A1[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[528:528+fp.Bytes]), z.C0.B0.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[480:480+fp.Bytes]), z.C0.B0.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[432:432+fp.Bytes]), z.C0.B1.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[384:384+fp.Bytes]), z.C0.B1.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[336:336+fp.Bytes]), z.C0.B2.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[288:288+fp.Bytes]), z.C0.B2.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[240:240+fp.Bytes]), z.C1.B0.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[192:192+fp.Bytes]), z.C1.B0.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[144:144+fp.Bytes]), z.C1.B1.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[96:96+fp.Bytes]), z.C1.B1.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[48:48+fp.Bytes]), z.C1.B2.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[0:0+fp.Bytes]), z.C1.B2.A1) return } @@ -746,34 +668,47 @@ func (z *E12) Bytes() (r [SizeOfGT]byte) { // SetBytes interprets e as the bytes of a big-endian GT // sets z to that value (in Montgomery form), and returns z. // size(e) == 48 * 12 -// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... +// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... func (z *E12) SetBytes(e []byte) error { if len(e) != SizeOfGT { return errors.New("invalid buffer size") } - z.C0.B0.A0.SetBytes(e[528 : 528+fp.Bytes]) - - z.C0.B0.A1.SetBytes(e[480 : 480+fp.Bytes]) - - z.C0.B1.A0.SetBytes(e[432 : 432+fp.Bytes]) - - z.C0.B1.A1.SetBytes(e[384 : 384+fp.Bytes]) - - z.C0.B2.A0.SetBytes(e[336 : 336+fp.Bytes]) - - z.C0.B2.A1.SetBytes(e[288 : 288+fp.Bytes]) - - z.C1.B0.A0.SetBytes(e[240 : 240+fp.Bytes]) - - z.C1.B0.A1.SetBytes(e[192 : 192+fp.Bytes]) - - z.C1.B1.A0.SetBytes(e[144 : 144+fp.Bytes]) - - z.C1.B1.A1.SetBytes(e[96 : 96+fp.Bytes]) - - z.C1.B2.A0.SetBytes(e[48 : 48+fp.Bytes]) - - z.C1.B2.A1.SetBytes(e[0 : 0+fp.Bytes]) + if err := z.C0.B0.A0.SetBytesCanonical(e[528 : 528+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B0.A1.SetBytesCanonical(e[480 : 480+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B1.A0.SetBytesCanonical(e[432 : 432+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B1.A1.SetBytesCanonical(e[384 : 384+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B2.A0.SetBytesCanonical(e[336 : 336+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B2.A1.SetBytesCanonical(e[288 : 288+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B0.A0.SetBytesCanonical(e[240 : 240+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B0.A1.SetBytesCanonical(e[192 : 192+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B1.A0.SetBytesCanonical(e[144 : 144+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B1.A1.SetBytesCanonical(e[96 : 96+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B2.A0.SetBytesCanonical(e[48 : 48+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B2.A1.SetBytesCanonical(e[0 : 0+fp.Bytes]); err != nil { + return err + } return nil } diff --git a/ecc/bls12-378/internal/fptower/e2.go b/ecc/bls12-378/internal/fptower/e2.go index 4de9ceea9..0279a9d66 100644 --- a/ecc/bls12-378/internal/fptower/e2.go +++ b/ecc/bls12-378/internal/fptower/e2.go @@ -31,12 +31,20 @@ func (z *E2) Equal(x *E2) bool { return z.A0.Equal(&x.A0) && z.A1.Equal(&x.A1) } +// Bits +// TODO @gbotrel fixme this shouldn't return a E2 +func (z *E2) Bits() E2 { + r := E2{} + r.A0 = z.A0.Bits() + r.A1 = z.A1.Bits() + return r +} + // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *E2) Cmp(x *E2) int { if a1 := z.A1.Cmp(&x.A1); a1 != 0 { return a1 @@ -98,6 +106,10 @@ func (z *E2) IsZero() bool { return z.A0.IsZero() && z.A1.IsZero() } +func (z *E2) IsOne() bool { + return z.A0.IsOne() && z.A1.IsZero() +} + // Add adds two elements of E2 func (z *E2) Add(x, y *E2) *E2 { addE2(z, x, y) @@ -127,20 +139,6 @@ func (z *E2) String() string { return z.A0.String() + "+" + z.A1.String() + "*u" } -// ToMont converts to mont form -func (z *E2) ToMont() *E2 { - z.A0.ToMont() - z.A1.ToMont() - return z -} - -// FromMont converts from mont form -func (z *E2) FromMont() *E2 { - z.A0.FromMont() - z.A1.FromMont() - return z -} - // MulByElement multiplies an element in E2 by an element in fp func (z *E2) MulByElement(x *E2, y *fp.Element) *E2 { var yCopy fp.Element diff --git a/ecc/bls12-378/internal/fptower/e6.go b/ecc/bls12-378/internal/fptower/e6.go index 4da093f5f..8ae7216ec 100644 --- a/ecc/bls12-378/internal/fptower/e6.go +++ b/ecc/bls12-378/internal/fptower/e6.go @@ -63,25 +63,13 @@ func (z *E6) SetRandom() (*E6, error) { return z, nil } -// IsZero returns true if the two elements are equal, fasle otherwise +// IsZero returns true if the two elements are equal, false otherwise func (z *E6) IsZero() bool { return z.B0.IsZero() && z.B1.IsZero() && z.B2.IsZero() } -// ToMont converts to Mont form -func (z *E6) ToMont() *E6 { - z.B0.ToMont() - z.B1.ToMont() - z.B2.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E6) FromMont() *E6 { - z.B0.FromMont() - z.B1.FromMont() - z.B2.FromMont() - return z +func (z *E6) IsOne() bool { + return z.B0.IsOne() && z.B1.IsZero() && z.B2.IsZero() } // Add adds two elements of E6 diff --git a/ecc/bls12-378/marshal.go b/ecc/bls12-378/marshal.go index 1ac5c64e2..1120fa850 100644 --- a/ecc/bls12-378/marshal.go +++ b/ecc/bls12-378/marshal.go @@ -100,7 +100,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fr.Bytes]) + err = t.SetBytesCanonical(buf[:fr.Bytes]) return case *fp.Element: read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) @@ -108,7 +108,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fp.Bytes]) + err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: var sliceLen uint32 @@ -126,7 +126,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fr.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { + return + } } return case *[]fp.Element: @@ -145,7 +147,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fp.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return + } } return case *G1Affine: @@ -221,7 +225,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -276,7 +284,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -643,9 +655,6 @@ func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - msbMask := mCompressedSmallest // compressed, we need to know if Y is lexicographically bigger than -Y // if p.Y ">" -p.Y @@ -654,14 +663,7 @@ func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { } // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[40:48], tmp[0]) - binary.BigEndian.PutUint64(res[32:40], tmp[1]) - binary.BigEndian.PutUint64(res[24:32], tmp[2]) - binary.BigEndian.PutUint64(res[16:24], tmp[3]) - binary.BigEndian.PutUint64(res[8:16], tmp[4]) - binary.BigEndian.PutUint64(res[0:8], tmp[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= msbMask @@ -680,29 +682,12 @@ func (p *G1Affine) RawBytes() (res [SizeOfG1AffineUncompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - // not compressed // we store the Y coordinate - tmp = p.Y - tmp.FromMont() - binary.BigEndian.PutUint64(res[88:96], tmp[0]) - binary.BigEndian.PutUint64(res[80:88], tmp[1]) - binary.BigEndian.PutUint64(res[72:80], tmp[2]) - binary.BigEndian.PutUint64(res[64:72], tmp[3]) - binary.BigEndian.PutUint64(res[56:64], tmp[4]) - binary.BigEndian.PutUint64(res[48:56], tmp[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[48:48+fp.Bytes]), p.Y) // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[40:48], tmp[0]) - binary.BigEndian.PutUint64(res[32:40], tmp[1]) - binary.BigEndian.PutUint64(res[24:32], tmp[2]) - binary.BigEndian.PutUint64(res[16:24], tmp[3]) - binary.BigEndian.PutUint64(res[8:16], tmp[4]) - binary.BigEndian.PutUint64(res[0:8], tmp[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= mUncompressed @@ -753,8 +738,12 @@ func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { // uncompressed point if mData == mUncompressed { // read X and Y coordinates - p.X.SetBytes(buf[:fp.Bytes]) - p.Y.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.Y.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } // subgroup check if subGroupCheck && !p.IsInSubGroup() { @@ -774,7 +763,9 @@ func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return 0, err + } var YSquared, Y fp.Element @@ -848,7 +839,7 @@ func (p *G1Affine) unsafeComputeY(subGroupCheck bool) error { // assumes buf[:8] mask is set to compressed // returns true if point is infinity and need no further processing // it sets X coordinate and uses Y for scratch space to store decompression metadata -func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { +func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { // read the most significant byte mData := buf[0] & mMask @@ -857,7 +848,7 @@ func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { p.X.SetZero() p.Y.SetZero() isInfinity = true - return + return isInfinity, nil } // we need to copy the input buffer (to keep this method thread safe) @@ -866,12 +857,14 @@ func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return false, err + } // store mData in p.Y[0] p.Y[0] = uint64(mData) // recomputing Y will be done asynchronously - return + return isInfinity, nil } // SizeOfG2AffineCompressed represents the size in bytes that a G2Affine need in binary form, compressed @@ -909,9 +902,6 @@ func (p *G2Affine) Bytes() (res [SizeOfG2AffineCompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - msbMask := mCompressedSmallest // compressed, we need to know if Y is lexicographically bigger than -Y // if p.Y ">" -p.Y @@ -921,23 +911,8 @@ func (p *G2Affine) Bytes() (res [SizeOfG2AffineCompressed]byte) { // we store X and mask the most significant word with our metadata mask // p.X.A1 | p.X.A0 - tmp = p.X.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[88:96], tmp[0]) - binary.BigEndian.PutUint64(res[80:88], tmp[1]) - binary.BigEndian.PutUint64(res[72:80], tmp[2]) - binary.BigEndian.PutUint64(res[64:72], tmp[3]) - binary.BigEndian.PutUint64(res[56:64], tmp[4]) - binary.BigEndian.PutUint64(res[48:56], tmp[5]) - - tmp = p.X.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[40:48], tmp[0]) - binary.BigEndian.PutUint64(res[32:40], tmp[1]) - binary.BigEndian.PutUint64(res[24:32], tmp[2]) - binary.BigEndian.PutUint64(res[16:24], tmp[3]) - binary.BigEndian.PutUint64(res[8:16], tmp[4]) - binary.BigEndian.PutUint64(res[0:8], tmp[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[48:48+fp.Bytes]), p.X.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X.A1) res[0] |= msbMask @@ -956,49 +931,16 @@ func (p *G2Affine) RawBytes() (res [SizeOfG2AffineUncompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - // not compressed // we store the Y coordinate // p.Y.A1 | p.Y.A0 - tmp = p.Y.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[184:192], tmp[0]) - binary.BigEndian.PutUint64(res[176:184], tmp[1]) - binary.BigEndian.PutUint64(res[168:176], tmp[2]) - binary.BigEndian.PutUint64(res[160:168], tmp[3]) - binary.BigEndian.PutUint64(res[152:160], tmp[4]) - binary.BigEndian.PutUint64(res[144:152], tmp[5]) - - tmp = p.Y.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[136:144], tmp[0]) - binary.BigEndian.PutUint64(res[128:136], tmp[1]) - binary.BigEndian.PutUint64(res[120:128], tmp[2]) - binary.BigEndian.PutUint64(res[112:120], tmp[3]) - binary.BigEndian.PutUint64(res[104:112], tmp[4]) - binary.BigEndian.PutUint64(res[96:104], tmp[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[144:144+fp.Bytes]), p.Y.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[96:96+fp.Bytes]), p.Y.A1) // we store X and mask the most significant word with our metadata mask // p.X.A1 | p.X.A0 - tmp = p.X.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[40:48], tmp[0]) - binary.BigEndian.PutUint64(res[32:40], tmp[1]) - binary.BigEndian.PutUint64(res[24:32], tmp[2]) - binary.BigEndian.PutUint64(res[16:24], tmp[3]) - binary.BigEndian.PutUint64(res[8:16], tmp[4]) - binary.BigEndian.PutUint64(res[0:8], tmp[5]) - - tmp = p.X.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[88:96], tmp[0]) - binary.BigEndian.PutUint64(res[80:88], tmp[1]) - binary.BigEndian.PutUint64(res[72:80], tmp[2]) - binary.BigEndian.PutUint64(res[64:72], tmp[3]) - binary.BigEndian.PutUint64(res[56:64], tmp[4]) - binary.BigEndian.PutUint64(res[48:56], tmp[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[48:48+fp.Bytes]), p.X.A0) res[0] |= mUncompressed @@ -1050,11 +992,19 @@ func (p *G2Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { if mData == mUncompressed { // read X and Y coordinates // p.X.A1 | p.X.A0 - p.X.A1.SetBytes(buf[:fp.Bytes]) - p.X.A0.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.A1.SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.X.A0.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } // p.Y.A1 | p.Y.A0 - p.Y.A1.SetBytes(buf[fp.Bytes*2 : fp.Bytes*3]) - p.Y.A0.SetBytes(buf[fp.Bytes*3 : fp.Bytes*4]) + if err := p.Y.A1.SetBytesCanonical(buf[fp.Bytes*2 : fp.Bytes*3]); err != nil { + return 0, err + } + if err := p.Y.A0.SetBytesCanonical(buf[fp.Bytes*3 : fp.Bytes*4]); err != nil { + return 0, err + } // subgroup check if subGroupCheck && !p.IsInSubGroup() { @@ -1075,8 +1025,12 @@ func (p *G2Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { // read X coordinate // p.X.A1 | p.X.A0 - p.X.A1.SetBytes(bufX[:fp.Bytes]) - p.X.A0.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.A1.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.X.A0.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } var YSquared, Y fptower.E2 @@ -1152,7 +1106,7 @@ func (p *G2Affine) unsafeComputeY(subGroupCheck bool) error { // assumes buf[:8] mask is set to compressed // returns true if point is infinity and need no further processing // it sets X coordinate and uses Y for scratch space to store decompression metadata -func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { +func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { // read the most significant byte mData := buf[0] & mMask @@ -1161,7 +1115,7 @@ func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { p.X.SetZero() p.Y.SetZero() isInfinity = true - return + return isInfinity, nil } // we need to copy the input buffer (to keep this method thread safe) @@ -1171,12 +1125,16 @@ func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { // read X coordinate // p.X.A1 | p.X.A0 - p.X.A1.SetBytes(bufX[:fp.Bytes]) - p.X.A0.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.A1.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return false, err + } + if err := p.X.A0.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return false, err + } // store mData in p.Y.A0[0] p.Y.A0[0] = uint64(mData) // recomputing Y will be done asynchronously - return + return isInfinity, nil } diff --git a/ecc/bls12-378/multiexp.go b/ecc/bls12-378/multiexp.go index b42536f1b..48e136e41 100644 --- a/ecc/bls12-378/multiexp.go +++ b/ecc/bls12-378/multiexp.go @@ -25,143 +25,6 @@ import ( "runtime" ) -// selector stores the index, mask and shifts needed to select bits from a scalar -// it is used during the multiExp algorithm or the batch scalar multiplication -type selector struct { - index uint64 // index in the multi-word scalar to select bits from - mask uint64 // mask (c-bit wide) - shift uint64 // shift needed to get our bits on low positions - - multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) - maskHigh uint64 // same than mask, for index+1 - shiftHigh uint64 // same than shift, for index+1 -} - -// partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits -// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract -// 2^{c} to the current digit, making it negative. -// negative digits can be processed in a later step as adding -G into the bucket instead of G -// (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) -// scalarsMont indicates wheter the provided scalars are in montgomery form -// returns smallValues, which represent the number of scalars which meets the following condition -// 0 < scalar < 2^c (in other words, scalars where only the c-least significant bits are non zero) -func partitionScalars(scalars []fr.Element, c uint64, scalarsMont bool, nbTasks int) ([]fr.Element, int) { - toReturn := make([]fr.Element, len(scalars)) - - // number of c-bit radixes in a scalar - nbChunks := fr.Limbs * 64 / c - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // msb of the c-bit window - max := int(1 << (c - 1)) // max value we want for our digits - cDivides64 := (64 % c) == 0 // if c doesn't divide 64, we may need to select over multiple words - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := uint64(0); chunk < nbChunks; chunk++ { - jc := uint64(chunk * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } - - // for each chunk, we could track the number of non-zeros points we will need to process - // this way, if a chunk has more work to do than others, we can spawn off more go routines - // (at the cost of more buckets allocated) - // a simplified approach is to track the small values where only the first word is set - // if this number represent a significant number of points, then we will split first chunk - // processing in the msm in 2, to ensure all go routines finish at ~same time - // /!\ nbTasks is enough as parallel.Execute is not going to spawn more than nbTasks go routine - // if it does, though, this will deadlocK. - chSmallValues := make(chan int, nbTasks) - - parallel.Execute(len(scalars), func(start, end int) { - smallValues := 0 - for i := start; i < end; i++ { - var carry int - - scalar := scalars[i] - if scalarsMont { - scalar.FromMont() - } - if scalar.FitsOnOneWord() { - // everything is 0, no need to process this scalar - if scalar[0] == 0 { - continue - } - // low c-bits are 1 in mask - if scalar[0]&mask == scalar[0] { - smallValues++ - } - } - - // for each chunk in the scalar, compute the current digit, and an eventual carry - for chunk := uint64(0); chunk < nbChunks; chunk++ { - s := selectors[chunk] - - // init with carry if any - digit := carry - carry = 0 - - // digit = value of the c-bit window - digit += int((scalar[s.index] & s.mask) >> s.shift) - - if s.multiWordSelect { - // we are selecting bits over 2 words - digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh - } - - // if digit is zero, no impact on result - if digit == 0 { - continue - } - - // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract - // 2^{c} to the current digit, making it negative. - if digit >= max { - digit -= (1 << c) - carry = 1 - } - - var bits uint64 - if digit >= 0 { - bits = uint64(digit) - } else { - bits = uint64(-digit-1) | msbWindow - } - - toReturn[i][s.index] |= (bits << s.shift) - if s.multiWordSelect { - toReturn[i][s.index+1] |= (bits >> s.shiftHigh) - } - - } - } - - chSmallValues <- smallValues - - }, nbTasks) - - // aggregate small values - close(chSmallValues) - smallValues := 0 - for o := range chSmallValues { - smallValues += o - } - return toReturn, smallValues -} - // MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf // // This call return an error if len(scalars) != len(points) or if provided config is invalid. @@ -221,7 +84,7 @@ func (p *G1Jac) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.Mul // we split recursively until nbChunks(c) >= nbTasks, bestC := func(nbPoints int) uint64 { // implemented msmC methods (the c we use must be in this slice) - implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21} + implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} var C uint64 // approximate cost (in group operations) // cost = bits/c * (nbPoints + 2^{c}) @@ -229,118 +92,179 @@ func (p *G1Jac) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.Mul // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results min := math.MaxFloat64 for _, c := range implementedCs { - cc := fr.Limbs * 64 * (nbPoints + (1 << (c))) + cc := (fr.Bits + 1) * (nbPoints + (1 << c)) cost := float64(cc) / float64(c) if cost < min { min = cost C = c } } - // empirical, needs to be tuned. - // if C > 16 && nbPoints < 1 << 23 { - // C = 16 - // } return C } - var C uint64 - nbSplits := 1 - nbChunks := 0 - for nbChunks < config.NbTasks { - C = bestC(nbPoints) - nbChunks = int(fr.Limbs * 64 / C) // number of c-bit radixes in a scalar - if (fr.Limbs*64)%C != 0 { - nbChunks++ - } - nbChunks *= nbSplits - if nbChunks < config.NbTasks { - nbSplits <<= 1 - nbPoints >>= 1 + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // if we don't utilise all the tasks (CPU in the default case) that we could, let's see if it's worth it to split + if config.NbTasks > 1 && nbChunks < config.NbTasks { + // before spliting, let's see if we endup with more tasks than thread; + cSplit := bestC(nbPoints / 2) + nbChunksPostSplit := int(computeNbChunks(cSplit)) + nbTasksPostSplit := nbChunksPostSplit * 2 + if (nbTasksPostSplit <= config.NbTasks/2) || (nbTasksPostSplit-config.NbTasks/2) <= (config.NbTasks-nbChunks) { + // if postSplit we still have less tasks than available CPU + // or if we have more tasks BUT the difference of CPU usage is in our favor, we split. + config.NbTasks /= 2 + var _p G1Jac + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil } } + _innerMsmG1(p, C, points, scalars, config) + + return p, nil +} + +func _innerMsmG1(p *G1Jac, c uint64, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { // partition the scalars - // note: we do that before the actual chunk processing, as for each c-bit window (starting from LSW) - // if it's larger than 2^{c-1}, we have a carry we need to propagate up to the higher window - var smallValues int - scalars, smallValues = partitionScalars(scalars, C, config.ScalarsMont, config.NbTasks) - - // if we have more than 10% of small values, we split the processing of the first chunk in 2 - // we may want to do that in msmInnerG1Jac , but that would incur a cost of looping through all scalars one more time - splitFirstChunk := (float64(smallValues) / float64(len(scalars))) >= 0.1 - - // we have nbSplits intermediate results that we must sum together. - _p := make([]G1Jac, nbSplits-1) - chDone := make(chan int, nbSplits-1) - for i := 0; i < nbSplits-1; i++ { - start := i * nbPoints - end := start + nbPoints - go func(start, end, i int) { - msmInnerG1Jac(&_p[i], int(C), points[start:end], scalars[start:end], splitFirstChunk) - chDone <- i - }(start, end, i) + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) } - msmInnerG1Jac(p, int(C), points[(nbSplits-1)*nbPoints:], scalars[(nbSplits-1)*nbPoints:], splitFirstChunk) - for i := 0; i < nbSplits-1; i++ { - done := <-chDone - p.AddAssign(&_p[done]) + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessorG1(c, chunkStats[j]) + if j == int(nbChunks-1) { + processChunk = getChunkProcessorG1(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan g1JacExtended, 2) + split := n / 2 + go processChunk(uint64(j), chSplit, c, points[:split], digits[j*n:(j*n)+split]) + go processChunk(uint64(j), chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n]) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n]) } - close(chDone) - return p, nil -} -func msmInnerG1Jac(p *G1Jac, c int, points []G1Affine, scalars []fr.Element, splitFirstChunk bool) { + return msmReduceChunkG1Affine(p, int(c), chChunks[:]) +} +// getChunkProcessorG1 decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessorG1(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- g1JacExtended, c uint64, points []G1Affine, digits []uint16) { switch c { + case 2: + return processChunkG1Jacobian[bucketg1JacExtendedC2] + case 3: + return processChunkG1Jacobian[bucketg1JacExtendedC3] case 4: - p.msmC4(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC4] case 5: - p.msmC5(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC5] case 6: - p.msmC6(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC6] case 7: - p.msmC7(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC7] case 8: - p.msmC8(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC8] case 9: - p.msmC9(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC9] case 10: - p.msmC10(points, scalars, splitFirstChunk) - + const batchSize = 80 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC10] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC10, bucketG1AffineC10, bitSetC10, pG1AffineC10, ppG1AffineC10, qG1AffineC10, cG1AffineC10] case 11: - p.msmC11(points, scalars, splitFirstChunk) - + const batchSize = 150 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC11] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC11, bucketG1AffineC11, bitSetC11, pG1AffineC11, ppG1AffineC11, qG1AffineC11, cG1AffineC11] case 12: - p.msmC12(points, scalars, splitFirstChunk) - + const batchSize = 200 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC12] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC12, bucketG1AffineC12, bitSetC12, pG1AffineC12, ppG1AffineC12, qG1AffineC12, cG1AffineC12] case 13: - p.msmC13(points, scalars, splitFirstChunk) - + const batchSize = 350 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC13] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC13, bucketG1AffineC13, bitSetC13, pG1AffineC13, ppG1AffineC13, qG1AffineC13, cG1AffineC13] case 14: - p.msmC14(points, scalars, splitFirstChunk) - + const batchSize = 400 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC14] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC14, bucketG1AffineC14, bitSetC14, pG1AffineC14, ppG1AffineC14, qG1AffineC14, cG1AffineC14] case 15: - p.msmC15(points, scalars, splitFirstChunk) - + const batchSize = 500 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC15] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC15, bucketG1AffineC15, bitSetC15, pG1AffineC15, ppG1AffineC15, qG1AffineC15, cG1AffineC15] case 16: - p.msmC16(points, scalars, splitFirstChunk) - - case 20: - p.msmC20(points, scalars, splitFirstChunk) - - case 21: - p.msmC21(points, scalars, splitFirstChunk) - + const batchSize = 640 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC16] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC16, bucketG1AffineC16, bitSetC16, pG1AffineC16, ppG1AffineC16, qG1AffineC16, cG1AffineC16] default: - panic("not implemented") + // panic("will not happen c != previous values is not generated by templates") + return processChunkG1Jacobian[bucketg1JacExtendedC16] } } @@ -360,1846 +284,447 @@ func msmReduceChunkG1Affine(p *G1Jac, c int, chChunks []chan g1JacExtended) *G1J return p.unsafeFromJacExtended(&_p) } -func msmProcessChunkG1Affine(chunk uint64, - chRes chan<- g1JacExtended, - buckets []g1JacExtended, - c uint64, - points []G1Affine, - scalars []fr.Element) { +// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf +// +// This call return an error if len(scalars) != len(points) or if provided config is invalid. +func (p *G2Affine) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Affine, error) { + var _p G2Jac + if _, err := _p.MultiExp(points, scalars, config); err != nil { + return nil, err + } + p.FromJacobian(&_p) + return p, nil +} - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) +// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf +// +// This call return an error if len(scalars) != len(points) or if provided config is invalid. +func (p *G2Jac) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Jac, error) { + // note: + // each of the msmCX method is the same, except for the c constant it declares + // duplicating (through template generation) these methods allows to declare the buckets on the stack + // the choice of c needs to be improved: + // there is a theoritical value that gives optimal asymptotics + // but in practice, other factors come into play, including: + // * if c doesn't divide 64, the word size, then we're bound to select bits over 2 words of our scalars, instead of 1 + // * number of CPUs + // * cache friendliness (which depends on the host, G1 or G2... ) + // --> for example, on BN254, a G1 point fits into one cache line of 64bytes, but a G2 point don't. - for i := 0; i < len(buckets); i++ { - buckets[i].setInfinity() - } + // for each msmCX + // step 1 + // we compute, for each scalars over c-bit wide windows, nbChunk digits + // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract + // 2^{c} to the current digit, making it negative. + // negative digits will be processed in the next step as adding -G into the bucket instead of G + // (computing -G is cheap, and this saves us half of the buckets) + // step 2 + // buckets are declared on the stack + // notice that we have 2^{c-1} buckets instead of 2^{c} (see step1) + // we use jacobian extended formulas here as they are faster than mixed addition + // msmProcessChunk places points into buckets base on their selector and return the weighted bucket sum in given channel + // step 3 + // reduce the buckets weigthed sums into our result (msmReduceChunk) - jc := uint64(chunk * c) - s := selector{} - s.index = jc / 64 - s.shift = jc - (s.index * 64) - s.mask = mask << s.shift - s.multiWordSelect = (64%c) != 0 && s.shift > (64-c) && s.index < (fr.Limbs-1) - if s.multiWordSelect { - nbBitsHigh := s.shift - uint64(64-c) - s.maskHigh = (1 << nbBitsHigh) - 1 - s.shiftHigh = (c - nbBitsHigh) + // ensure len(points) == len(scalars) + nbPoints := len(points) + if nbPoints != len(scalars) { + return nil, errors.New("len(points) != len(scalars)") } - // for each scalars, get the digit corresponding to the chunk we're processing. - for i := 0; i < len(scalars); i++ { - bits := (scalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (scalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { - continue - } + // if nbTasks is not set, use all available CPUs + if config.NbTasks <= 0 { + config.NbTasks = runtime.NumCPU() + } else if config.NbTasks > 1024 { + return nil, errors.New("invalid config: config.NbTasks > 1024") + } - // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { - // add - buckets[bits-1].addMixed(&points[i]) - } else { - // sub - buckets[bits & ^msbWindow].subMixed(&points[i]) + // here, we compute the best C for nbPoints + // we split recursively until nbChunks(c) >= nbTasks, + bestC := func(nbPoints int) uint64 { + // implemented msmC methods (the c we use must be in this slice) + implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + var C uint64 + // approximate cost (in group operations) + // cost = bits/c * (nbPoints + 2^{c}) + // this needs to be verified empirically. + // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results + min := math.MaxFloat64 + for _, c := range implementedCs { + cc := (fr.Bits + 1) * (nbPoints + (1 << c)) + cost := float64(cc) / float64(c) + if cost < min { + min = cost + C = c + } } + return C } - // reduce buckets into total - // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] - - var runningSum, total g1JacExtended - runningSum.setInfinity() - total.setInfinity() - for k := len(buckets) - 1; k >= 0; k-- { - if !buckets[k].ZZ.IsZero() { - runningSum.add(&buckets[k]) + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // if we don't utilise all the tasks (CPU in the default case) that we could, let's see if it's worth it to split + if config.NbTasks > 1 && nbChunks < config.NbTasks { + // before spliting, let's see if we endup with more tasks than thread; + cSplit := bestC(nbPoints / 2) + nbChunksPostSplit := int(computeNbChunks(cSplit)) + nbTasksPostSplit := nbChunksPostSplit * 2 + if (nbTasksPostSplit <= config.NbTasks/2) || (nbTasksPostSplit-config.NbTasks/2) <= (config.NbTasks-nbChunks) { + // if postSplit we still have less tasks than available CPU + // or if we have more tasks BUT the difference of CPU usage is in our favor, we split. + config.NbTasks /= 2 + var _p G2Jac + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil } - total.add(&runningSum) } - chRes <- total + _innerMsmG2(p, C, points, scalars, config) + return p, nil } -func (p *G1Jac) msmC4(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 4 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) +func _innerMsmG2(p *G2Jac, c uint64, points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G2Jac { + // partition the scalars + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) // for each chunk, spawn one go routine that'll loop through all the scalars in the // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + // note that buckets is an array allocated on the stack and this is critical for performance // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended + chChunks := make([]chan g2JacExtended, nbChunks) for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) + chChunks[i] = make(chan g2JacExtended, 1) } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessorG2(c, chunkStats[j]) + if j == int(nbChunks-1) { + processChunk = getChunkProcessorG2(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan g2JacExtended, 2) + split := n / 2 + go processChunk(uint64(j), chSplit, c, points[:split], digits[j*n:(j*n)+split]) + go processChunk(uint64(j), chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n]) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n]) } - return msmReduceChunkG1Affine(p, c, chChunks[:]) + return msmReduceChunkG2Affine(p, int(c), chChunks[:]) } -func (p *G1Jac) msmC5(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 5 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } +// getChunkProcessorG2 decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessorG2(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- g2JacExtended, c uint64, points []G2Affine, digits []uint16) { + switch c { - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + case 2: + return processChunkG2Jacobian[bucketg2JacExtendedC2] + case 3: + return processChunkG2Jacobian[bucketg2JacExtendedC3] + case 4: + return processChunkG2Jacobian[bucketg2JacExtendedC4] + case 5: + return processChunkG2Jacobian[bucketg2JacExtendedC5] + case 6: + return processChunkG2Jacobian[bucketg2JacExtendedC6] + case 7: + return processChunkG2Jacobian[bucketg2JacExtendedC7] + case 8: + return processChunkG2Jacobian[bucketg2JacExtendedC8] + case 9: + return processChunkG2Jacobian[bucketg2JacExtendedC9] + case 10: + const batchSize = 80 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC10] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC10, bucketG2AffineC10, bitSetC10, pG2AffineC10, ppG2AffineC10, qG2AffineC10, cG2AffineC10] + case 11: + const batchSize = 150 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC11] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC11, bucketG2AffineC11, bitSetC11, pG2AffineC11, ppG2AffineC11, qG2AffineC11, cG2AffineC11] + case 12: + const batchSize = 200 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC12] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC12, bucketG2AffineC12, bitSetC12, pG2AffineC12, ppG2AffineC12, qG2AffineC12, cG2AffineC12] + case 13: + const batchSize = 350 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC13] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC13, bucketG2AffineC13, bitSetC13, pG2AffineC13, ppG2AffineC13, qG2AffineC13, cG2AffineC13] + case 14: + const batchSize = 400 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC14] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC14, bucketG2AffineC14, bitSetC14, pG2AffineC14, ppG2AffineC14, qG2AffineC14, cG2AffineC14] + case 15: + const batchSize = 500 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC15] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC15, bucketG2AffineC15, bitSetC15, pG2AffineC15, ppG2AffineC15, qG2AffineC15, cG2AffineC15] + case 16: + const batchSize = 640 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC16] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC16, bucketG2AffineC16, bitSetC16, pG2AffineC16, ppG2AffineC16, qG2AffineC16, cG2AffineC16] + default: + // panic("will not happen c != previous values is not generated by templates") + return processChunkG2Jacobian[bucketg2JacExtendedC16] } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) } -func (p *G1Jac) msmC6(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 6 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC7(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 7 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC8(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 8 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC9(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 9 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC10(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 10 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC11(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 11 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC12(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 12 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC13(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 13 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC14(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 14 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC15(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 15 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC16(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 16 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC20(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 20 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC21(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 21 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf -// -// This call return an error if len(scalars) != len(points) or if provided config is invalid. -func (p *G2Affine) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Affine, error) { - var _p G2Jac - if _, err := _p.MultiExp(points, scalars, config); err != nil { - return nil, err - } - p.FromJacobian(&_p) - return p, nil -} - -// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf -// -// This call return an error if len(scalars) != len(points) or if provided config is invalid. -func (p *G2Jac) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Jac, error) { - // note: - // each of the msmCX method is the same, except for the c constant it declares - // duplicating (through template generation) these methods allows to declare the buckets on the stack - // the choice of c needs to be improved: - // there is a theoritical value that gives optimal asymptotics - // but in practice, other factors come into play, including: - // * if c doesn't divide 64, the word size, then we're bound to select bits over 2 words of our scalars, instead of 1 - // * number of CPUs - // * cache friendliness (which depends on the host, G1 or G2... ) - // --> for example, on BN254, a G1 point fits into one cache line of 64bytes, but a G2 point don't. - - // for each msmCX - // step 1 - // we compute, for each scalars over c-bit wide windows, nbChunk digits - // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract - // 2^{c} to the current digit, making it negative. - // negative digits will be processed in the next step as adding -G into the bucket instead of G - // (computing -G is cheap, and this saves us half of the buckets) - // step 2 - // buckets are declared on the stack - // notice that we have 2^{c-1} buckets instead of 2^{c} (see step1) - // we use jacobian extended formulas here as they are faster than mixed addition - // msmProcessChunk places points into buckets base on their selector and return the weighted bucket sum in given channel - // step 3 - // reduce the buckets weigthed sums into our result (msmReduceChunk) - - // ensure len(points) == len(scalars) - nbPoints := len(points) - if nbPoints != len(scalars) { - return nil, errors.New("len(points) != len(scalars)") - } - - // if nbTasks is not set, use all available CPUs - if config.NbTasks <= 0 { - config.NbTasks = runtime.NumCPU() - } else if config.NbTasks > 1024 { - return nil, errors.New("invalid config: config.NbTasks > 1024") - } - - // here, we compute the best C for nbPoints - // we split recursively until nbChunks(c) >= nbTasks, - bestC := func(nbPoints int) uint64 { - // implemented msmC methods (the c we use must be in this slice) - implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21} - var C uint64 - // approximate cost (in group operations) - // cost = bits/c * (nbPoints + 2^{c}) - // this needs to be verified empirically. - // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results - min := math.MaxFloat64 - for _, c := range implementedCs { - cc := fr.Limbs * 64 * (nbPoints + (1 << (c))) - cost := float64(cc) / float64(c) - if cost < min { - min = cost - C = c - } - } - // empirical, needs to be tuned. - // if C > 16 && nbPoints < 1 << 23 { - // C = 16 - // } - return C - } - - var C uint64 - nbSplits := 1 - nbChunks := 0 - for nbChunks < config.NbTasks { - C = bestC(nbPoints) - nbChunks = int(fr.Limbs * 64 / C) // number of c-bit radixes in a scalar - if (fr.Limbs*64)%C != 0 { - nbChunks++ - } - nbChunks *= nbSplits - if nbChunks < config.NbTasks { - nbSplits <<= 1 - nbPoints >>= 1 - } - } - - // partition the scalars - // note: we do that before the actual chunk processing, as for each c-bit window (starting from LSW) - // if it's larger than 2^{c-1}, we have a carry we need to propagate up to the higher window - var smallValues int - scalars, smallValues = partitionScalars(scalars, C, config.ScalarsMont, config.NbTasks) - - // if we have more than 10% of small values, we split the processing of the first chunk in 2 - // we may want to do that in msmInnerG2Jac , but that would incur a cost of looping through all scalars one more time - splitFirstChunk := (float64(smallValues) / float64(len(scalars))) >= 0.1 - - // we have nbSplits intermediate results that we must sum together. - _p := make([]G2Jac, nbSplits-1) - chDone := make(chan int, nbSplits-1) - for i := 0; i < nbSplits-1; i++ { - start := i * nbPoints - end := start + nbPoints - go func(start, end, i int) { - msmInnerG2Jac(&_p[i], int(C), points[start:end], scalars[start:end], splitFirstChunk) - chDone <- i - }(start, end, i) - } - - msmInnerG2Jac(p, int(C), points[(nbSplits-1)*nbPoints:], scalars[(nbSplits-1)*nbPoints:], splitFirstChunk) - for i := 0; i < nbSplits-1; i++ { - done := <-chDone - p.AddAssign(&_p[done]) - } - close(chDone) - return p, nil -} - -func msmInnerG2Jac(p *G2Jac, c int, points []G2Affine, scalars []fr.Element, splitFirstChunk bool) { - - switch c { - - case 4: - p.msmC4(points, scalars, splitFirstChunk) - - case 5: - p.msmC5(points, scalars, splitFirstChunk) - - case 6: - p.msmC6(points, scalars, splitFirstChunk) - - case 7: - p.msmC7(points, scalars, splitFirstChunk) - - case 8: - p.msmC8(points, scalars, splitFirstChunk) - - case 9: - p.msmC9(points, scalars, splitFirstChunk) - - case 10: - p.msmC10(points, scalars, splitFirstChunk) - - case 11: - p.msmC11(points, scalars, splitFirstChunk) - - case 12: - p.msmC12(points, scalars, splitFirstChunk) - - case 13: - p.msmC13(points, scalars, splitFirstChunk) - - case 14: - p.msmC14(points, scalars, splitFirstChunk) - - case 15: - p.msmC15(points, scalars, splitFirstChunk) - - case 16: - p.msmC16(points, scalars, splitFirstChunk) - - case 20: - p.msmC20(points, scalars, splitFirstChunk) - - case 21: - p.msmC21(points, scalars, splitFirstChunk) - - default: - panic("not implemented") - } -} - -// msmReduceChunkG2Affine reduces the weighted sum of the buckets into the result of the multiExp -func msmReduceChunkG2Affine(p *G2Jac, c int, chChunks []chan g2JacExtended) *G2Jac { - var _p g2JacExtended - totalj := <-chChunks[len(chChunks)-1] - _p.Set(&totalj) - for j := len(chChunks) - 2; j >= 0; j-- { - for l := 0; l < c; l++ { - _p.double(&_p) - } - totalj := <-chChunks[j] - _p.add(&totalj) - } - - return p.unsafeFromJacExtended(&_p) -} - -func msmProcessChunkG2Affine(chunk uint64, - chRes chan<- g2JacExtended, - buckets []g2JacExtended, - c uint64, - points []G2Affine, - scalars []fr.Element) { - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) - - for i := 0; i < len(buckets); i++ { - buckets[i].setInfinity() - } - - jc := uint64(chunk * c) - s := selector{} - s.index = jc / 64 - s.shift = jc - (s.index * 64) - s.mask = mask << s.shift - s.multiWordSelect = (64%c) != 0 && s.shift > (64-c) && s.index < (fr.Limbs-1) - if s.multiWordSelect { - nbBitsHigh := s.shift - uint64(64-c) - s.maskHigh = (1 << nbBitsHigh) - 1 - s.shiftHigh = (c - nbBitsHigh) - } - - // for each scalars, get the digit corresponding to the chunk we're processing. - for i := 0; i < len(scalars); i++ { - bits := (scalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (scalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { - continue - } - - // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { - // add - buckets[bits-1].addMixed(&points[i]) - } else { - // sub - buckets[bits & ^msbWindow].subMixed(&points[i]) - } - } - - // reduce buckets into total - // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] - - var runningSum, total g2JacExtended - runningSum.setInfinity() - total.setInfinity() - for k := len(buckets) - 1; k >= 0; k-- { - if !buckets[k].ZZ.IsZero() { - runningSum.add(&buckets[k]) - } - total.add(&runningSum) - } - - chRes <- total - -} - -func (p *G2Jac) msmC4(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 4 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC5(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 5 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC6(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 6 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC7(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 7 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC8(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 8 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC9(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 9 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC10(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 10 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC11(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 11 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() +// msmReduceChunkG2Affine reduces the weighted sum of the buckets into the result of the multiExp +func msmReduceChunkG2Affine(p *G2Jac, c int, chChunks []chan g2JacExtended) *G2Jac { + var _p g2JacExtended + totalj := <-chChunks[len(chChunks)-1] + _p.Set(&totalj) + for j := len(chChunks) - 2; j >= 0; j-- { + for l := 0; l < c; l++ { + _p.double(&_p) + } + totalj := <-chChunks[j] + _p.add(&totalj) } - return msmReduceChunkG2Affine(p, c, chChunks[:]) + return p.unsafeFromJacExtended(&_p) } -func (p *G2Jac) msmC12(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 12 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } +// selector stores the index, mask and shifts needed to select bits from a scalar +// it is used during the multiExp algorithm or the batch scalar multiplication +type selector struct { + index uint64 // index in the multi-word scalar to select bits from + mask uint64 // mask (c-bit wide) + shift uint64 // shift needed to get our bits on low positions - return msmReduceChunkG2Affine(p, c, chChunks[:]) + multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) + maskHigh uint64 // same than mask, for index+1 + shiftHigh uint64 // same than shift, for index+1 } -func (p *G2Jac) msmC13(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 13 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) +// return number of chunks for a given window size c +// the last chunk may be bigger to accomodate a potential carry from the NAF decomposition +func computeNbChunks(c uint64) uint64 { + return (fr.Bits + c - 1) / c } -func (p *G2Jac) msmC14(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 14 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) +// return the last window size for a scalar; +// this last window should accomodate a carry (from the NAF decomposition) +// it can be == c if we have 1 available bit +// it can be > c if we have 0 available bit +// it can be < c if we have 2+ available bits +func lastC(c uint64) uint64 { + nbAvailableBits := (computeNbChunks(c) * c) - fr.Bits + return c + 1 - nbAvailableBits } -func (p *G2Jac) msmC15(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 15 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } +type chunkStat struct { + // relative weight of work compared to other chunks. 100.0 -> nominal weight. + weight float32 - return msmReduceChunkG2Affine(p, c, chChunks[:]) + // percentage of bucket filled in the window; + ppBucketFilled float32 + nbBucketFilled int } -func (p *G2Jac) msmC16(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 16 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } +// partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits +// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract +// 2^{c} to the current digit, making it negative. +// negative digits can be processed in a later step as adding -G into the bucket instead of G +// (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) +func partitionScalars(scalars []fr.Element, c uint64, nbTasks int) ([]uint16, []chunkStat) { + // number of c-bit radixes in a scalar + nbChunks := computeNbChunks(c) - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + digits := make([]uint16, len(scalars)*int(nbChunks)) - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } + mask := uint64((1 << c) - 1) // low c bits are 1 + max := int(1<<(c-1)) - 1 // max value (inclusive) we want for our digits + cDivides64 := (64 % c) == 0 // if c doesn't divide 64, we may need to select over multiple words - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + // compute offset and word selector / shift to select the right bits of our windows + selectors := make([]selector, nbChunks) + for chunk := uint64(0); chunk < nbChunks; chunk++ { + jc := uint64(chunk * c) + d := selector{} + d.index = jc / 64 + d.shift = jc - (d.index * 64) + d.mask = mask << d.shift + d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs-1) + if d.multiWordSelect { + nbBitsHigh := d.shift - uint64(64-c) + d.maskHigh = (1 << nbBitsHigh) - 1 + d.shiftHigh = (c - nbBitsHigh) + } + selectors[chunk] = d } - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} + parallel.Execute(len(scalars), func(start, end int) { + for i := start; i < end; i++ { + if scalars[i].IsZero() { + // everything is 0, no need to process this scalar + continue + } + scalar := scalars[i].Bits() -func (p *G2Jac) msmC20(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 20 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) + var carry int - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + // for each chunk in the scalar, compute the current digit, and an eventual carry + for chunk := uint64(0); chunk < nbChunks-1; chunk++ { + s := selectors[chunk] - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } + // init with carry if any + digit := carry + carry = 0 - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } + // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract + // 2^{c} to the current digit, making it negative. + if digit > max { + digit -= (1 << c) + carry = 1 + } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } + // if digit is zero, no impact on result + if digit == 0 { + continue + } - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} + var bits uint16 + if digit > 0 { + bits = uint16(digit) << 1 + } else { + bits = (uint16(-digit-1) << 1) + 1 + } + digits[int(chunk)*len(scalars)+i] = bits + } -func (p *G2Jac) msmC21(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 21 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) + // for the last chunk, we don't want to borrow from a next window + // (but may have a larger max value) + chunk := nbChunks - 1 + s := selectors[chunk] + // init with carry if any + digit := carry + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } + digits[int(chunk)*len(scalars)+i] = uint16(digit) << 1 + } - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + }, nbTasks) - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) + // aggregate chunk stats + chunkStats := make([]chunkStat, nbChunks) + if c <= 9 { + // no need to compute stats for small window sizes + return digits, chunkStats } + parallel.Execute(len(chunkStats), func(start, end int) { + // for each chunk compute the statistics + for chunkID := start; chunkID < end; chunkID++ { + // indicates if a bucket is hit. + var b bitSetC16 - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) + // digits for the chunk + chunkDigits := digits[chunkID*len(scalars) : (chunkID+1)*len(scalars)] - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + totalOps := 0 + nz := 0 // non zero buckets count + for _, digit := range chunkDigits { + if digit == 0 { + continue + } + totalOps++ + bucketID := digit >> 1 + if digit&1 == 0 { + bucketID -= 1 + } + if !b[bucketID] { + nz++ + b[bucketID] = true + } + } + chunkStats[chunkID].weight = float32(totalOps) // count number of ops for now, we will compute the weight after + chunkStats[chunkID].ppBucketFilled = (float32(nz) * 100.0) / float32(int(1<<(c-1))) + chunkStats[chunkID].nbBucketFilled = nz + } + }, nbTasks) - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) + totalOps := float32(0.0) + for _, stat := range chunkStats { + totalOps += stat.weight } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + target := totalOps / float32(nbChunks) + if target != 0.0 { + // if target == 0, it means all the scalars are 0 everywhere, there is no work to be done. + for i := 0; i < len(chunkStats); i++ { + chunkStats[i].weight = (chunkStats[i].weight * 100.0) / target + } } - return msmReduceChunkG2Affine(p, c, chChunks[:]) + return digits, chunkStats } diff --git a/ecc/bls12-378/multiexp_affine.go b/ecc/bls12-378/multiexp_affine.go new file mode 100644 index 000000000..28301102c --- /dev/null +++ b/ecc/bls12-378/multiexp_affine.go @@ -0,0 +1,688 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package bls12378 + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-378/fp" + "github.com/consensys/gnark-crypto/ecc/bls12-378/internal/fptower" +) + +type batchOpG1Affine struct { + bucketID uint16 + point G1Affine +} + +// processChunkG1BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/ConsenSys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunkG1BatchAffine[BJE ibg1JacExtended, B ibG1Affine, BS bitSet, TP pG1Affine, TPP ppG1Affine, TQ qOpsG1Affine, TC cG1Affine]( + chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16) { + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // g1JacExtended coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediatly + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in G1Affine used with the batch affine additions + // 1 in g1JacExtended used in case the queue of conflicting points + var buckets B + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + bucketsJE[i].setInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize } + + executeAndReset := func() { + batchAddG1Affine[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOpG1Affine) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.setInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *G1Affine, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.setInfinity() + } + return + } + if isAdd { + BK.setInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func() { + for i := 0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func() { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit >> 1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID -= 1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue)-1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total g1JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].ZZ.IsZero() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketG1AffineC10 [512]G1Affine +type bucketG1AffineC11 [1024]G1Affine +type bucketG1AffineC12 [2048]G1Affine +type bucketG1AffineC13 [4096]G1Affine +type bucketG1AffineC14 [8192]G1Affine +type bucketG1AffineC15 [16384]G1Affine +type bucketG1AffineC16 [32768]G1Affine + +// buckets: array of G1Affine points of size 1 << (c-1) +type ibG1Affine interface { + bucketG1AffineC10 | + bucketG1AffineC11 | + bucketG1AffineC12 | + bucketG1AffineC13 | + bucketG1AffineC14 | + bucketG1AffineC15 | + bucketG1AffineC16 +} + +// array of coordinates fp.Element +type cG1Affine interface { + cG1AffineC10 | + cG1AffineC11 | + cG1AffineC12 | + cG1AffineC13 | + cG1AffineC14 | + cG1AffineC15 | + cG1AffineC16 +} + +// buckets: array of G1Affine points (for the batch addition) +type pG1Affine interface { + pG1AffineC10 | + pG1AffineC11 | + pG1AffineC12 | + pG1AffineC13 | + pG1AffineC14 | + pG1AffineC15 | + pG1AffineC16 +} + +// buckets: array of *G1Affine points (for the batch addition) +type ppG1Affine interface { + ppG1AffineC10 | + ppG1AffineC11 | + ppG1AffineC12 | + ppG1AffineC13 | + ppG1AffineC14 | + ppG1AffineC15 | + ppG1AffineC16 +} + +// buckets: array of G1Affine queue operations (for the batch addition) +type qOpsG1Affine interface { + qG1AffineC10 | + qG1AffineC11 | + qG1AffineC12 | + qG1AffineC13 | + qG1AffineC14 | + qG1AffineC15 | + qG1AffineC16 +} + +// batch size 80 when c = 10 +type cG1AffineC10 [80]fp.Element +type pG1AffineC10 [80]G1Affine +type ppG1AffineC10 [80]*G1Affine +type qG1AffineC10 [80]batchOpG1Affine + +// batch size 150 when c = 11 +type cG1AffineC11 [150]fp.Element +type pG1AffineC11 [150]G1Affine +type ppG1AffineC11 [150]*G1Affine +type qG1AffineC11 [150]batchOpG1Affine + +// batch size 200 when c = 12 +type cG1AffineC12 [200]fp.Element +type pG1AffineC12 [200]G1Affine +type ppG1AffineC12 [200]*G1Affine +type qG1AffineC12 [200]batchOpG1Affine + +// batch size 350 when c = 13 +type cG1AffineC13 [350]fp.Element +type pG1AffineC13 [350]G1Affine +type ppG1AffineC13 [350]*G1Affine +type qG1AffineC13 [350]batchOpG1Affine + +// batch size 400 when c = 14 +type cG1AffineC14 [400]fp.Element +type pG1AffineC14 [400]G1Affine +type ppG1AffineC14 [400]*G1Affine +type qG1AffineC14 [400]batchOpG1Affine + +// batch size 500 when c = 15 +type cG1AffineC15 [500]fp.Element +type pG1AffineC15 [500]G1Affine +type ppG1AffineC15 [500]*G1Affine +type qG1AffineC15 [500]batchOpG1Affine + +// batch size 640 when c = 16 +type cG1AffineC16 [640]fp.Element +type pG1AffineC16 [640]G1Affine +type ppG1AffineC16 [640]*G1Affine +type qG1AffineC16 [640]batchOpG1Affine + +type batchOpG2Affine struct { + bucketID uint16 + point G2Affine +} + +// processChunkG2BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/ConsenSys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunkG2BatchAffine[BJE ibg2JacExtended, B ibG2Affine, BS bitSet, TP pG2Affine, TPP ppG2Affine, TQ qOpsG2Affine, TC cG2Affine]( + chunk uint64, + chRes chan<- g2JacExtended, + c uint64, + points []G2Affine, + digits []uint16) { + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // g2JacExtended coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediatly + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in G2Affine used with the batch affine additions + // 1 in g2JacExtended used in case the queue of conflicting points + var buckets B + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + bucketsJE[i].setInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize } + + executeAndReset := func() { + batchAddG2Affine[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOpG2Affine) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.setInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *G2Affine, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.setInfinity() + } + return + } + if isAdd { + BK.setInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func() { + for i := 0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func() { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit >> 1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID -= 1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue)-1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total g2JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].ZZ.IsZero() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketG2AffineC10 [512]G2Affine +type bucketG2AffineC11 [1024]G2Affine +type bucketG2AffineC12 [2048]G2Affine +type bucketG2AffineC13 [4096]G2Affine +type bucketG2AffineC14 [8192]G2Affine +type bucketG2AffineC15 [16384]G2Affine +type bucketG2AffineC16 [32768]G2Affine + +// buckets: array of G2Affine points of size 1 << (c-1) +type ibG2Affine interface { + bucketG2AffineC10 | + bucketG2AffineC11 | + bucketG2AffineC12 | + bucketG2AffineC13 | + bucketG2AffineC14 | + bucketG2AffineC15 | + bucketG2AffineC16 +} + +// array of coordinates fptower.E2 +type cG2Affine interface { + cG2AffineC10 | + cG2AffineC11 | + cG2AffineC12 | + cG2AffineC13 | + cG2AffineC14 | + cG2AffineC15 | + cG2AffineC16 +} + +// buckets: array of G2Affine points (for the batch addition) +type pG2Affine interface { + pG2AffineC10 | + pG2AffineC11 | + pG2AffineC12 | + pG2AffineC13 | + pG2AffineC14 | + pG2AffineC15 | + pG2AffineC16 +} + +// buckets: array of *G2Affine points (for the batch addition) +type ppG2Affine interface { + ppG2AffineC10 | + ppG2AffineC11 | + ppG2AffineC12 | + ppG2AffineC13 | + ppG2AffineC14 | + ppG2AffineC15 | + ppG2AffineC16 +} + +// buckets: array of G2Affine queue operations (for the batch addition) +type qOpsG2Affine interface { + qG2AffineC10 | + qG2AffineC11 | + qG2AffineC12 | + qG2AffineC13 | + qG2AffineC14 | + qG2AffineC15 | + qG2AffineC16 +} + +// batch size 80 when c = 10 +type cG2AffineC10 [80]fptower.E2 +type pG2AffineC10 [80]G2Affine +type ppG2AffineC10 [80]*G2Affine +type qG2AffineC10 [80]batchOpG2Affine + +// batch size 150 when c = 11 +type cG2AffineC11 [150]fptower.E2 +type pG2AffineC11 [150]G2Affine +type ppG2AffineC11 [150]*G2Affine +type qG2AffineC11 [150]batchOpG2Affine + +// batch size 200 when c = 12 +type cG2AffineC12 [200]fptower.E2 +type pG2AffineC12 [200]G2Affine +type ppG2AffineC12 [200]*G2Affine +type qG2AffineC12 [200]batchOpG2Affine + +// batch size 350 when c = 13 +type cG2AffineC13 [350]fptower.E2 +type pG2AffineC13 [350]G2Affine +type ppG2AffineC13 [350]*G2Affine +type qG2AffineC13 [350]batchOpG2Affine + +// batch size 400 when c = 14 +type cG2AffineC14 [400]fptower.E2 +type pG2AffineC14 [400]G2Affine +type ppG2AffineC14 [400]*G2Affine +type qG2AffineC14 [400]batchOpG2Affine + +// batch size 500 when c = 15 +type cG2AffineC15 [500]fptower.E2 +type pG2AffineC15 [500]G2Affine +type ppG2AffineC15 [500]*G2Affine +type qG2AffineC15 [500]batchOpG2Affine + +// batch size 640 when c = 16 +type cG2AffineC16 [640]fptower.E2 +type pG2AffineC16 [640]G2Affine +type ppG2AffineC16 [640]*G2Affine +type qG2AffineC16 [640]batchOpG2Affine + +type bitSetC2 [2]bool +type bitSetC3 [4]bool +type bitSetC4 [8]bool +type bitSetC5 [16]bool +type bitSetC6 [32]bool +type bitSetC7 [64]bool +type bitSetC8 [128]bool +type bitSetC9 [256]bool +type bitSetC10 [512]bool +type bitSetC11 [1024]bool +type bitSetC12 [2048]bool +type bitSetC13 [4096]bool +type bitSetC14 [8192]bool +type bitSetC15 [16384]bool +type bitSetC16 [32768]bool + +type bitSet interface { + bitSetC2 | + bitSetC3 | + bitSetC4 | + bitSetC5 | + bitSetC6 | + bitSetC7 | + bitSetC8 | + bitSetC9 | + bitSetC10 | + bitSetC11 | + bitSetC12 | + bitSetC13 | + bitSetC14 | + bitSetC15 | + bitSetC16 +} diff --git a/ecc/bls12-378/multiexp_jacobian.go b/ecc/bls12-378/multiexp_jacobian.go new file mode 100644 index 000000000..0e9d572e7 --- /dev/null +++ b/ecc/bls12-378/multiexp_jacobian.go @@ -0,0 +1,175 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package bls12378 + +func processChunkG1Jacobian[B ibg1JacExtended](chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16) { + + var buckets B + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to substract + if digit&1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit >> 1)].subMixed(&points[i]) + } + } + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total g1JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].ZZ.IsZero() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketg1JacExtendedC2 [2]g1JacExtended +type bucketg1JacExtendedC3 [4]g1JacExtended +type bucketg1JacExtendedC4 [8]g1JacExtended +type bucketg1JacExtendedC5 [16]g1JacExtended +type bucketg1JacExtendedC6 [32]g1JacExtended +type bucketg1JacExtendedC7 [64]g1JacExtended +type bucketg1JacExtendedC8 [128]g1JacExtended +type bucketg1JacExtendedC9 [256]g1JacExtended +type bucketg1JacExtendedC10 [512]g1JacExtended +type bucketg1JacExtendedC11 [1024]g1JacExtended +type bucketg1JacExtendedC12 [2048]g1JacExtended +type bucketg1JacExtendedC13 [4096]g1JacExtended +type bucketg1JacExtendedC14 [8192]g1JacExtended +type bucketg1JacExtendedC15 [16384]g1JacExtended +type bucketg1JacExtendedC16 [32768]g1JacExtended + +type ibg1JacExtended interface { + bucketg1JacExtendedC2 | + bucketg1JacExtendedC3 | + bucketg1JacExtendedC4 | + bucketg1JacExtendedC5 | + bucketg1JacExtendedC6 | + bucketg1JacExtendedC7 | + bucketg1JacExtendedC8 | + bucketg1JacExtendedC9 | + bucketg1JacExtendedC10 | + bucketg1JacExtendedC11 | + bucketg1JacExtendedC12 | + bucketg1JacExtendedC13 | + bucketg1JacExtendedC14 | + bucketg1JacExtendedC15 | + bucketg1JacExtendedC16 +} + +func processChunkG2Jacobian[B ibg2JacExtended](chunk uint64, + chRes chan<- g2JacExtended, + c uint64, + points []G2Affine, + digits []uint16) { + + var buckets B + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to substract + if digit&1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit >> 1)].subMixed(&points[i]) + } + } + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total g2JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].ZZ.IsZero() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketg2JacExtendedC2 [2]g2JacExtended +type bucketg2JacExtendedC3 [4]g2JacExtended +type bucketg2JacExtendedC4 [8]g2JacExtended +type bucketg2JacExtendedC5 [16]g2JacExtended +type bucketg2JacExtendedC6 [32]g2JacExtended +type bucketg2JacExtendedC7 [64]g2JacExtended +type bucketg2JacExtendedC8 [128]g2JacExtended +type bucketg2JacExtendedC9 [256]g2JacExtended +type bucketg2JacExtendedC10 [512]g2JacExtended +type bucketg2JacExtendedC11 [1024]g2JacExtended +type bucketg2JacExtendedC12 [2048]g2JacExtended +type bucketg2JacExtendedC13 [4096]g2JacExtended +type bucketg2JacExtendedC14 [8192]g2JacExtended +type bucketg2JacExtendedC15 [16384]g2JacExtended +type bucketg2JacExtendedC16 [32768]g2JacExtended + +type ibg2JacExtended interface { + bucketg2JacExtendedC2 | + bucketg2JacExtendedC3 | + bucketg2JacExtendedC4 | + bucketg2JacExtendedC5 | + bucketg2JacExtendedC6 | + bucketg2JacExtendedC7 | + bucketg2JacExtendedC8 | + bucketg2JacExtendedC9 | + bucketg2JacExtendedC10 | + bucketg2JacExtendedC11 | + bucketg2JacExtendedC12 | + bucketg2JacExtendedC13 | + bucketg2JacExtendedC14 | + bucketg2JacExtendedC15 | + bucketg2JacExtendedC16 +} diff --git a/ecc/bls12-378/multiexp_test.go b/ecc/bls12-378/multiexp_test.go index 39c7f5dd6..7e8cdd73a 100644 --- a/ecc/bls12-378/multiexp_test.go +++ b/ecc/bls12-378/multiexp_test.go @@ -20,9 +20,11 @@ import ( "fmt" "math/big" "math/bits" + "math/rand" "runtime" "sync" "testing" + "time" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-378/fr" @@ -34,9 +36,9 @@ func TestMultiExpG1(t *testing.T) { parameters := gopter.DefaultTestParameters() if testing.Short() { - parameters.MinSuccessfulTests = 2 + parameters.MinSuccessfulTests = 3 } else { - parameters.MinSuccessfulTests = nbFuzzShort + parameters.MinSuccessfulTests = nbFuzzShort * 2 } properties := gopter.NewProperties(parameters) @@ -55,6 +57,13 @@ func TestMultiExpG1(t *testing.T) { g.AddAssign(&g1Gen) } + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + // final scalar to use in double and add method (without mixer factor) // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) var scalar big.Int @@ -78,13 +87,10 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - scalars16, _ := partitionScalars(sampleScalars[:], 16, false, runtime.NumCPU()) - r16.msmC16(samplePoints[:], scalars16, true) - + r16.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) return r16.Equal(&splitted1) && r16.Equal(&splitted2) @@ -93,7 +99,7 @@ func TestMultiExpG1(t *testing.T) { )) // cRange is generated from template and contains the available parameters for the multiexp window size - cRange := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21} + cRange := []uint64{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} if testing.Short() { // test only "odd" and "even" (ie windows size divide word size vs not) cRange = []uint64{5, 16} @@ -114,21 +120,72 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - results := make([]G1Jac, len(cRange)+1) + results := make([]G1Jac, len(cRange)) for i, c := range cRange { - scalars, _ := partitionScalars(sampleScalars[:], c, false, runtime.NumCPU()) - msmInnerG1Jac(&results[i], int(c), samplePoints[:], scalars, false) - if c == 16 { - // split the first chunk - msmInnerG1Jac(&results[len(results)-1], 16, samplePoints[:], scalars, true) - } + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) } for i := 1; i < len(results); i++ { if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1], cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]G1Affine + + var expected G1Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) + expected.ScalarMultiplication(&g1Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].setInfinity() + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) return false } } @@ -151,8 +208,7 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= 30; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) samplePoints[i-1].FromJacobian(&g) g.AddAssign(&g1Gen) } @@ -175,6 +231,87 @@ func TestMultiExpG1(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestCrossMultiExpG1(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]G1Affine + var g G1Jac + g.Set(&g1Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g1Gen) + } + + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i := 10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + cRange := []uint64{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + if testing.Short() { + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange = []uint64{5, 16} + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r G1Jac + _innerMsmG1Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got G1Affine + expected.FromJacobian(&r) + + for i := 0; i < len(results); i++ { + got.FromJacobian(&results[i]) + if !expected.Equal(&got) { + t.Fatalf("cross msm failed with c=%d", cRange[i]) + } + } + +} + +// _innerMsmG1Reference always do ext jacobian with c == 16 +func _innerMsmG1Reference(p *G1Jac, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { + // partition the scalars + digits, _ := partitionScalars(scalars, 16, config.NbTasks) + + nbChunks := computeNbChunks(16) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := processChunkG1Jacobian[bucketg1JacExtendedC16] + go processChunk(uint64(j), chChunks[j], 16, points, digits[j*n:(j+1)*n]) + } + + return msmReduceChunkG1Affine(p, int(16), chChunks[:]) +} + func BenchmarkMultiExpG1(b *testing.B) { const ( @@ -183,11 +320,33 @@ func BenchmarkMultiExpG1(b *testing.B) { ) var ( - samplePoints [nbSamples]G1Affine - sampleScalars [nbSamples]fr.Element + samplePoints [nbSamples]G1Affine + sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element ) fillBenchScalars(sampleScalars[:]) + copy(sampleScalarsSmallValues[:], sampleScalars[:]) + copy(sampleScalarsRedundant[:], sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i := 0; i < len(sampleScalarsSmallValues); i++ { + if i%5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i := 0; i < len(sampleScalarsRedundant); i += 100 { + for j := i + 1; j < i+100 && j < len(sampleScalarsRedundant); j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + fillBenchBasesG1(samplePoints[:]) var testPoint G1Affine @@ -201,6 +360,20 @@ func BenchmarkMultiExpG1(b *testing.B) { testPoint.MultiExp(samplePoints[:using], sampleScalars[:using], ecc.MultiExpConfig{}) } }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using], ecc.MultiExpConfig{}) + } + }) } } @@ -279,9 +452,9 @@ func TestMultiExpG2(t *testing.T) { parameters := gopter.DefaultTestParameters() if testing.Short() { - parameters.MinSuccessfulTests = 2 + parameters.MinSuccessfulTests = 3 } else { - parameters.MinSuccessfulTests = nbFuzzShort + parameters.MinSuccessfulTests = nbFuzzShort * 2 } properties := gopter.NewProperties(parameters) @@ -300,6 +473,13 @@ func TestMultiExpG2(t *testing.T) { g.AddAssign(&g2Gen) } + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + // final scalar to use in double and add method (without mixer factor) // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) var scalar big.Int @@ -323,13 +503,10 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - scalars16, _ := partitionScalars(sampleScalars[:], 16, false, runtime.NumCPU()) - r16.msmC16(samplePoints[:], scalars16, true) - + r16.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) return r16.Equal(&splitted1) && r16.Equal(&splitted2) @@ -357,21 +534,72 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - results := make([]G2Jac, len(cRange)+1) + results := make([]G2Jac, len(cRange)) for i, c := range cRange { - scalars, _ := partitionScalars(sampleScalars[:], c, false, runtime.NumCPU()) - msmInnerG2Jac(&results[i], int(c), samplePoints[:], scalars, false) - if c == 16 { - // split the first chunk - msmInnerG2Jac(&results[len(results)-1], 16, samplePoints[:], scalars, true) - } + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) } for i := 1; i < len(results); i++ { if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1], cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G2] Multi exponentation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]G2Affine + + var expected G2Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) + expected.ScalarMultiplication(&g2Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].setInfinity() + } + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G2] Multi exponentation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) return false } } @@ -394,8 +622,7 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= 30; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) samplePoints[i-1].FromJacobian(&g) g.AddAssign(&g2Gen) } @@ -418,6 +645,85 @@ func TestMultiExpG2(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestCrossMultiExpG2(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]G2Affine + var g G2Jac + g.Set(&g2Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g2Gen) + } + + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i := 10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + // for g2, CI suffers with large c size since it needs to allocate a lot of memory for the buckets. + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange := []uint64{5, 16} + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r G2Jac + _innerMsmG2Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got G2Affine + expected.FromJacobian(&r) + + for i := 0; i < len(results); i++ { + got.FromJacobian(&results[i]) + if !expected.Equal(&got) { + t.Fatalf("cross msm failed with c=%d", cRange[i]) + } + } + +} + +// _innerMsmG2Reference always do ext jacobian with c == 16 +func _innerMsmG2Reference(p *G2Jac, points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G2Jac { + // partition the scalars + digits, _ := partitionScalars(scalars, 16, config.NbTasks) + + nbChunks := computeNbChunks(16) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g2JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g2JacExtended, 1) + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := processChunkG2Jacobian[bucketg2JacExtendedC16] + go processChunk(uint64(j), chChunks[j], 16, points, digits[j*n:(j+1)*n]) + } + + return msmReduceChunkG2Affine(p, int(16), chChunks[:]) +} + func BenchmarkMultiExpG2(b *testing.B) { const ( @@ -426,11 +732,33 @@ func BenchmarkMultiExpG2(b *testing.B) { ) var ( - samplePoints [nbSamples]G2Affine - sampleScalars [nbSamples]fr.Element + samplePoints [nbSamples]G2Affine + sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element ) fillBenchScalars(sampleScalars[:]) + copy(sampleScalarsSmallValues[:], sampleScalars[:]) + copy(sampleScalarsRedundant[:], sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i := 0; i < len(sampleScalarsSmallValues); i++ { + if i%5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i := 0; i < len(sampleScalarsRedundant); i += 100 { + for j := i + 1; j < i+100 && j < len(sampleScalarsRedundant); j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + fillBenchBasesG2(samplePoints[:]) var testPoint G2Affine @@ -444,6 +772,20 @@ func BenchmarkMultiExpG2(b *testing.B) { testPoint.MultiExp(samplePoints[:using], sampleScalars[:using], ecc.MultiExpConfig{}) } }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using], ecc.MultiExpConfig{}) + } + }) } } @@ -520,11 +862,7 @@ func fillBenchBasesG2(samplePoints []G2Affine) { func fillBenchScalars(sampleScalars []fr.Element) { // ensure every words of the scalars are filled - var mixer fr.Element - mixer.SetString("7716837800905789770901243404444209691916730933998574719964609384059111546487") - for i := 1; i <= len(sampleScalars); i++ { - sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + for i := 0; i < len(sampleScalars); i++ { + sampleScalars[i].SetRandom() } } diff --git a/ecc/bls12-378/twistededwards/eddsa/doc.go b/ecc/bls12-378/twistededwards/eddsa/doc.go index e19c483f7..af8a37a6c 100644 --- a/ecc/bls12-378/twistededwards/eddsa/doc.go +++ b/ecc/bls12-378/twistededwards/eddsa/doc.go @@ -16,7 +16,7 @@ // Package eddsa provides EdDSA signature scheme on bls12-378's twisted edwards curve. // -// See also +// # See also // // https://en.wikipedia.org/wiki/EdDSA package eddsa diff --git a/ecc/bls12-378/twistededwards/eddsa/eddsa_test.go b/ecc/bls12-378/twistededwards/eddsa/eddsa_test.go index 5cfc0927f..f7e095669 100644 --- a/ecc/bls12-378/twistededwards/eddsa/eddsa_test.go +++ b/ecc/bls12-378/twistededwards/eddsa/eddsa_test.go @@ -37,8 +37,10 @@ func Example() { privateKey, _ := GenerateKey(crand.Reader) publicKey := privateKey.PublicKey - // note that the message is on 4 bytes - msg := []byte{0xde, 0xad, 0xf0, 0x0d} + // generate a message (the size must be a multiple of the size of Fr) + var _msg fr.Element + _msg.SetRandom() + msg := _msg.Marshal() // sign the message signature, _ := privateKey.Sign(msg, hFunc) diff --git a/ecc/bls12-378/twistededwards/eddsa/marshal.go b/ecc/bls12-378/twistededwards/eddsa/marshal.go index c68129087..485e9b710 100644 --- a/ecc/bls12-378/twistededwards/eddsa/marshal.go +++ b/ecc/bls12-378/twistededwards/eddsa/marshal.go @@ -94,11 +94,11 @@ func (privKey *PrivateKey) SetBytes(buf []byte) (int, error) { // Bytes returns the binary representation of sig // as a byte array of size 3*sizeFr x||y||s where -// * x, y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x, y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) func (sig *Signature) Bytes() []byte { var res [sizeSignature]byte sigRBin := sig.R.Bytes() @@ -109,11 +109,12 @@ func (sig *Signature) Bytes() []byte { // SetBytes sets sig from a buffer in binary. // buf is read interpreted as x||y||s where -// * x,y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x,y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) +// // It returns the number of bytes read from buf. func (sig *Signature) SetBytes(buf []byte) (int, error) { n := 0 diff --git a/ecc/bls12-378/twistededwards/point.go b/ecc/bls12-378/twistededwards/point.go index f7846fb6d..7db08b4cb 100644 --- a/ecc/bls12-378/twistededwards/point.go +++ b/ecc/bls12-378/twistededwards/point.go @@ -49,7 +49,7 @@ const ( mUnmask = 0x7f // size in byte of a compressed point (point.Y --> fr.Element) - sizePointCompressed = fr.Limbs * 8 + sizePointCompressed = fr.Bytes ) // Bytes returns the compressed point as a byte array diff --git a/ecc/bls12-381/bandersnatch/eddsa/doc.go b/ecc/bls12-381/bandersnatch/eddsa/doc.go index fb4712633..af3fe1f93 100644 --- a/ecc/bls12-381/bandersnatch/eddsa/doc.go +++ b/ecc/bls12-381/bandersnatch/eddsa/doc.go @@ -16,7 +16,7 @@ // Package eddsa provides EdDSA signature scheme on bls12-381's twisted edwards curve. // -// See also +// # See also // // https://en.wikipedia.org/wiki/EdDSA package eddsa diff --git a/ecc/bls12-381/bandersnatch/eddsa/eddsa_test.go b/ecc/bls12-381/bandersnatch/eddsa/eddsa_test.go index 967aac1db..85bd27ca0 100644 --- a/ecc/bls12-381/bandersnatch/eddsa/eddsa_test.go +++ b/ecc/bls12-381/bandersnatch/eddsa/eddsa_test.go @@ -37,8 +37,10 @@ func Example() { privateKey, _ := GenerateKey(crand.Reader) publicKey := privateKey.PublicKey - // note that the message is on 4 bytes - msg := []byte{0xde, 0xad, 0xf0, 0x0d} + // generate a message (the size must be a multiple of the size of Fr) + var _msg fr.Element + _msg.SetRandom() + msg := _msg.Marshal() // sign the message signature, _ := privateKey.Sign(msg, hFunc) diff --git a/ecc/bls12-381/bandersnatch/eddsa/marshal.go b/ecc/bls12-381/bandersnatch/eddsa/marshal.go index c68129087..485e9b710 100644 --- a/ecc/bls12-381/bandersnatch/eddsa/marshal.go +++ b/ecc/bls12-381/bandersnatch/eddsa/marshal.go @@ -94,11 +94,11 @@ func (privKey *PrivateKey) SetBytes(buf []byte) (int, error) { // Bytes returns the binary representation of sig // as a byte array of size 3*sizeFr x||y||s where -// * x, y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x, y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) func (sig *Signature) Bytes() []byte { var res [sizeSignature]byte sigRBin := sig.R.Bytes() @@ -109,11 +109,12 @@ func (sig *Signature) Bytes() []byte { // SetBytes sets sig from a buffer in binary. // buf is read interpreted as x||y||s where -// * x,y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x,y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) +// // It returns the number of bytes read from buf. func (sig *Signature) SetBytes(buf []byte) (int, error) { n := 0 diff --git a/ecc/bls12-381/bandersnatch/endomorpism.go b/ecc/bls12-381/bandersnatch/endomorpism.go index 5c6aa7a0b..d68fdfc4e 100644 --- a/ecc/bls12-381/bandersnatch/endomorpism.go +++ b/ecc/bls12-381/bandersnatch/endomorpism.go @@ -76,8 +76,8 @@ func (p *PointProj) scalarMulGLV(p1 *PointProj, scalar *big.Int) *PointProj { table[14].Set(&table[11]).Add(&table[14], &table[2]) // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 bits long max - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // loop starts from len(k1)/2 due to the bounds // fr.Limbs == Order.limbs @@ -166,8 +166,8 @@ func (p *PointExtended) scalarMulGLV(p1 *PointExtended, scalar *big.Int) *PointE table[14].Set(&table[11]).Add(&table[14], &table[2]) // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 bits long max - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // loop starts from len(k1)/2 due to the bounds // fr.Limbs == Order.limbs diff --git a/ecc/bls12-381/bandersnatch/point.go b/ecc/bls12-381/bandersnatch/point.go index 4e030bc1f..3c2db7816 100644 --- a/ecc/bls12-381/bandersnatch/point.go +++ b/ecc/bls12-381/bandersnatch/point.go @@ -48,7 +48,7 @@ const ( mUnmask = 0x7f // size in byte of a compressed point (point.Y --> fr.Element) - sizePointCompressed = fr.Limbs * 8 + sizePointCompressed = fr.Bytes ) // Bytes returns the compressed point as a byte array diff --git a/ecc/bls12-381/bls12-381.go b/ecc/bls12-381/bls12-381.go index dbfb11edc..796369ed6 100644 --- a/ecc/bls12-381/bls12-381.go +++ b/ecc/bls12-381/bls12-381.go @@ -1,23 +1,29 @@ // Package bls12381 efficient elliptic curve, pairing and hash to curve implementation for bls12-381. // // bls12-381: A Barreto--Lynn--Scott curve -// embedding degree k=12 -// seed x₀=-15132376222941642752 -// 𝔽r: r=52435875175126190479447740508185965837690552500527637822603658699938581184513 (x₀⁴-x₀²+1) -// 𝔽p: p=4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787 ((x₀-1)² ⋅ r(x₀)/3+x₀) -// (E/𝔽p): Y²=X³+4 -// (Eₜ/𝔽p²): Y² = X³+4(u+1) (M-type twist) -// r ∣ #E(Fp) and r ∣ #Eₜ(𝔽p²) +// +// embedding degree k=12 +// seed x₀=-15132376222941642752 +// 𝔽r: r=52435875175126190479447740508185965837690552500527637822603658699938581184513 (x₀⁴-x₀²+1) +// 𝔽p: p=4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787 ((x₀-1)² ⋅ r(x₀)/3+x₀) +// (E/𝔽p): Y²=X³+4 +// (Eₜ/𝔽p²): Y² = X³+4(u+1) (M-type twist) +// r ∣ #E(Fp) and r ∣ #Eₜ(𝔽p²) +// // Extension fields tower: -// 𝔽p²[u] = 𝔽p/u²+1 -// 𝔽p⁶[v] = 𝔽p²/v³-1-u -// 𝔽p¹²[w] = 𝔽p⁶/w²-v +// +// 𝔽p²[u] = 𝔽p/u²+1 +// 𝔽p⁶[v] = 𝔽p²/v³-1-u +// 𝔽p¹²[w] = 𝔽p⁶/w²-v +// // optimal Ate loop size: -// x₀ +// +// x₀ +// // Security: estimated 126-bit level following [https://eprint.iacr.org/2019/885.pdf] // (r is 255 bits and p¹² is 4569 bits) // -// Warning +// # Warning // // This code has been partially audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package bls12381 diff --git a/ecc/bls12-381/fp/doc.go b/ecc/bls12-381/fp/doc.go index 03d9919b4..bf5b82c11 100644 --- a/ecc/bls12-381/fp/doc.go +++ b/ecc/bls12-381/fp/doc.go @@ -21,30 +21,33 @@ // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [6]uint64 // -// Usage +// type Element [6]uint64 +// +// # Usage // // Example API signature: -// // Mul z = x * y (mod q) -// func (z *Element) Mul(x, y *Element) *Element +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) // // Modulus q = // -// q[base10] = 4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787 -// q[base16] = 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab +// q[base10] = 4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787 +// q[base16] = 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package fp diff --git a/ecc/bls12-381/fp/element.go b/ecc/bls12-381/fp/element.go index 6c9aad2bc..e995024e0 100644 --- a/ecc/bls12-381/fp/element.go +++ b/ecc/bls12-381/fp/element.go @@ -20,13 +20,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "github.com/consensys/gnark-crypto/field" "io" "math/big" "math/bits" "reflect" "strconv" "strings" - "sync" ) // Element represents a field element stored on 6 words (uint64) @@ -35,18 +35,18 @@ import ( // // Modulus q = // -// q[base10] = 4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787 -// q[base16] = 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab +// q[base10] = 4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787 +// q[base16] = 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [6]uint64 const ( - Limbs = 6 // number of 64 bits words needed to represent a Element - Bits = 381 // number of bits needed to represent a Element - Bytes = Limbs * 8 // number of bytes needed to represent a Element + Limbs = 6 // number of 64 bits words needed to represent a Element + Bits = 381 // number of bits needed to represent a Element + Bytes = 48 // number of bytes needed to represent a Element ) // Field modulus q @@ -72,8 +72,8 @@ var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int // -// q[base10] = 4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787 -// q[base16] = 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab +// q[base10] = 4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787 +// q[base16] = 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } @@ -82,12 +82,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 9940570264628428797 -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab", 16) } @@ -95,8 +89,9 @@ func init() { // NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v Element -// v.SetUint64(...) +// +// var v Element +// v.SetUint64(...) func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) @@ -107,7 +102,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -139,14 +134,15 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported // supported types: -// Element -// *Element -// uint64 -// int -// string (see SetString for valid formats) -// *big.Int -// big.Int -// []byte +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { if i1 == nil { return nil, errors.New("can't set fp.Element with ") @@ -260,15 +256,13 @@ func (z *Element) IsOne() bool { // IsUint64 reports whether z can be represented as an uint64. func (z *Element) IsUint64() bool { zz := *z - zz.FromMont() + zz.fromMont() return zz.FitsOnOneWord() } // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -280,15 +274,12 @@ func (z *Element) FitsOnOneWord() bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[5] > _x[5] { return 1 } else if _z[5] < _x[5] { @@ -329,8 +320,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 15924587544893707606, 0) @@ -429,67 +419,9 @@ func (z *Element) Halve() { } -// Mul z = x * y (mod q) -// -// x and y must be strictly inferior to q -func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be strictly inferior to q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } @@ -505,7 +437,7 @@ func (z *Element) Add(x, y *Element) *Element { z[4], carry = bits.Add64(x[4], y[4], carry) z[5], _ = bits.Add64(x[5], y[5], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -529,7 +461,7 @@ func (z *Element) Double(x *Element) *Element { z[4], carry = bits.Add64(x[4], x[4], carry) z[5], _ = bits.Add64(x[5], x[5], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -592,115 +524,219 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // see Mul for algorithm documentation - var t [6]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd1(v, y[4], c[1]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd1(v, y[5], c[1]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 4 - v := x[4] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 5 - v := x[5] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], z[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], z[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - z[5], z[4] = madd3(m, q5, c[0], c[2], c[1]) - } - - // if z >= q → z -= q + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [7]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + C, t[4] = madd1(y[0], x[4], C) + C, t[5] = madd1(y[0], x[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + C, t[4] = madd2(y[1], x[4], t[4], C) + C, t[5] = madd2(y[1], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + C, t[4] = madd2(y[2], x[4], t[4], C) + C, t[5] = madd2(y[2], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + C, t[4] = madd2(y[3], x[4], t[4], C) + C, t[5] = madd2(y[3], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[4], x[0], t[0]) + C, t[1] = madd2(y[4], x[1], t[1], C) + C, t[2] = madd2(y[4], x[2], t[2], C) + C, t[3] = madd2(y[4], x[3], t[3], C) + C, t[4] = madd2(y[4], x[4], t[4], C) + C, t[5] = madd2(y[4], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[5], x[0], t[0]) + C, t[1] = madd2(y[5], x[1], t[1], C) + C, t[2] = madd2(y[5], x[2], t[2], C) + C, t[3] = madd2(y[5], x[3], t[3], C) + C, t[4] = madd2(y[5], x[4], t[4], C) + C, t[5] = madd2(y[5], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + + if t[6] != 0 { + // we need to reduce, we have a result on 7 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], b = bits.Sub64(t[3], q3, b) + z[4], b = bits.Sub64(t[4], q4, b) + z[5], _ = bits.Sub64(t[5], q5, b) + return + } + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + z[4] = t[4] + z[5] = t[5] + + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -710,7 +746,6 @@ func _mulGeneric(z, x, y *Element) { z[4], b = bits.Sub64(z[4], q4, b) z[5], _ = bits.Sub64(z[5], q5, b) } - } func _fromMontGeneric(z *Element) { @@ -784,7 +819,7 @@ func _fromMontGeneric(z *Element) { z[5] = C } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -798,7 +833,7 @@ func _fromMontGeneric(z *Element) { func _reduceGeneric(z *Element) { - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -870,6 +905,35 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + // Exp z = xᵏ (mod q) func (z *Element) Exp(x Element, k *big.Int) *Element { if k.IsUint64() && k.Uint64() == 0 { @@ -884,8 +948,8 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } @@ -913,23 +977,31 @@ var rSquare = Element{ 1267921511277847466, } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *Element) ToMont() *Element { +func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *Element) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[40:48], z[0]) + binary.BigEndian.PutUint64(b[32:40], z[1]) + binary.BigEndian.PutUint64(b[24:32], z[2]) + binary.BigEndian.PutUint64(b[16:24], z[3]) + binary.BigEndian.PutUint64(b[8:16], z[4]) + binary.BigEndian.PutUint64(b[0:8], z[5]) + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -948,51 +1020,49 @@ func (z *Element) Text(base int) string { if base == 10 { var zzNeg Element zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } zz := *z - zz.FromMont() + zz.fromMont() if zz.FitsOnOneWord() { return strconv.FormatUint(zz[0], base) } - vv := bigIntPool.Get().(*big.Int) - r := zz.ToBigInt(vv).Text(base) - bigIntPool.Put(vv) + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) return r } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[40:48], z[0]) - binary.BigEndian.PutUint64(b[32:40], z[1]) - binary.BigEndian.PutUint64(b[24:32], z[2]) - binary.BigEndian.PutUint64(b[16:24], z[3]) - binary.BigEndian.PutUint64(b[8:16], z[4]) - binary.BigEndian.PutUint64(b[0:8], z[5]) - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[40:48], _z[0]) - binary.BigEndian.PutUint64(res[32:40], _z[1]) - binary.BigEndian.PutUint64(res[24:32], _z[2]) - binary.BigEndian.PutUint64(res[16:24], _z[3]) - binary.BigEndian.PutUint64(res[8:16], _z[4]) - binary.BigEndian.PutUint64(res[0:8], _z[5]) +// Bits provides access to z by returning its value as a little-endian [6]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [6]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -1005,19 +1075,44 @@ func (z *Element) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } +// SetBytesCanonical interprets e as the bytes of a big-endian 48-byte integer. +// If e is not a 48-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fp.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -1035,17 +1130,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -1067,20 +1161,20 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z // // The number prefix determines the actual base: A prefix of -// ''0b'' or ''0B'' selects base 2, ''0'', ''0o'' or ''0O'' selects base 8, -// and ''0x'' or ''0X'' selects base 16. Otherwise, the selected base is 10 +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For base 16, lower and upper case letters are considered the same: // The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. // -// An underscore character ''_'' may appear between a base +// An underscore character ”_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as a panic if there @@ -1089,7 +1183,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -1098,7 +1192,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -1138,7 +1232,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -1147,10 +1241,87 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 48-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[40:48]) + z[1] = binary.BigEndian.Uint64((*b)[32:40]) + z[2] = binary.BigEndian.Uint64((*b)[24:32]) + z[3] = binary.BigEndian.Uint64((*b)[16:24]) + z[4] = binary.BigEndian.Uint64((*b)[8:16]) + z[5] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[40:48], e[0]) + binary.BigEndian.PutUint64((*b)[32:40], e[1]) + binary.BigEndian.PutUint64((*b)[24:32], e[2]) + binary.BigEndian.PutUint64((*b)[16:24], e[3]) + binary.BigEndian.PutUint64((*b)[8:16], e[4]) + binary.BigEndian.PutUint64((*b)[0:8], e[5]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + z[4] = binary.LittleEndian.Uint64((*b)[32:40]) + z[5] = binary.LittleEndian.Uint64((*b)[40:48]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) + binary.LittleEndian.PutUint64((*b)[32:40], e[4]) + binary.LittleEndian.PutUint64((*b)[40:48], e[5]) +} + +func (littleEndian) String() string { return "LittleEndian" } + // Legendre returns the Legendre symbol of z (either +1, -1, or 0.) func (z *Element) Legendre() int { var l Element @@ -1342,7 +1513,7 @@ func (z *Element) Inverse(x *Element) *Element { // we would multiply by pSq up to 13times; // on x86, the assembly routine outperforms generic code for mul by word // on arm64, we may loose up to ~5% for 6 limbs - mul(&v, &v, &a) + v.Mul(&v, &a) } u.Set(x) // for correctness check @@ -1359,17 +1530,28 @@ func (z *Element) Inverse(x *Element) *Element { // correctness check v.Mul(&u, z) if !v.IsOne() && !u.IsZero() { - return z.inverseExp(&u) + return z.inverseExp(u) } return z } // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) -func (z *Element) inverseExp(x *Element) *Element { - qMinusTwo := Modulus() - qMinusTwo.Sub(qMinusTwo, big.NewInt(2)) - return z.Exp(*x, qMinusTwo) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z } // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits @@ -1502,7 +1684,7 @@ func (z *Element) montReduceSigned(x *Element, xHi uint64) { z[5], z[4] = madd2(m, q5, t[i+5], C) } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) diff --git a/ecc/bls12-381/fp/element_mul_adx_amd64.s b/ecc/bls12-381/fp/element_mul_adx_amd64.s deleted file mode 100644 index 325ae42f7..000000000 --- a/ecc/bls12-381/fp/element_mul_adx_amd64.s +++ /dev/null @@ -1,835 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xb9feffffffffaaab -DATA q<>+8(SB)/8, $0x1eabfffeb153ffff -DATA q<>+16(SB)/8, $0x6730d2a0f6b0f624 -DATA q<>+24(SB)/8, $0x64774b84f38512bf -DATA q<>+32(SB)/8, $0x4b1ba7b6434bacd7 -DATA q<>+40(SB)/8, $0x1a0111ea397fe69a -GLOBL q<>(SB), (RODATA+NOPTR), $48 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x89f3fffcfffcfffd -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), NOSPLIT, $0-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - MOVQ x+8(FP), R8 - - // x[0] -> R10 - // x[1] -> R11 - // x[2] -> R12 - MOVQ 0(R8), R10 - MOVQ 8(R8), R11 - MOVQ 16(R8), R12 - MOVQ y+16(FP), R13 - - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // t[5] -> DI - // clear the flags - XORQ AX, AX - MOVQ 0(R13), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R10, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R11, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R12, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(R8), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(R8), AX, DI - ADOXQ AX, SI - - // (A,t[5]) := x[5]*y[0] + A - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 8(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 16(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 24(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 32(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 40(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R9,R8,R13,R10,R11,R12) - REDUCE(R14,R15,CX,BX,SI,DI,R9,R8,R13,R10,R11,R12) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET - -TEXT ·fromMont(SB), NOSPLIT, $0-8 - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - MOVQ 40(DX), DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12,R13) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET diff --git a/ecc/bls12-381/fp/element_mul_amd64.s b/ecc/bls12-381/fp/element_mul_amd64.s index de478c793..9e03b1c0a 100644 --- a/ecc/bls12-381/fp/element_mul_amd64.s +++ b/ecc/bls12-381/fp/element_mul_amd64.s @@ -1,4 +1,4 @@ -// +build !amd64_adx +// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-381/fp/element_ops_amd64.go b/ecc/bls12-381/fp/element_ops_amd64.go index a3c830471..83bba45ae 100644 --- a/ecc/bls12-381/fp/element_ops_amd64.go +++ b/ecc/bls12-381/fp/element_ops_amd64.go @@ -1,3 +1,6 @@ +//go:build !purego +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,7 +38,70 @@ func fromMont(res *Element) func reduce(res *Element) // Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) +// +// a = a + b (mod q) +// b = a - b (mod q) +// //go:noescape func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bls12-381/fp/element_ops_amd64.s b/ecc/bls12-381/fp/element_ops_amd64.s index 2aec95be4..830b2dd63 100644 --- a/ecc/bls12-381/fp/element_ops_amd64.s +++ b/ecc/bls12-381/fp/element_ops_amd64.s @@ -1,3 +1,5 @@ +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bls12-381/fp/element_ops_noasm.go b/ecc/bls12-381/fp/element_ops_noasm.go deleted file mode 100644 index 50ba5e32e..000000000 --- a/ecc/bls12-381/fp/element_ops_noasm.go +++ /dev/null @@ -1,63 +0,0 @@ -//go:build !amd64 -// +build !amd64 - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fp - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 13438459813099623723, - 14459933216667336738, - 14900020990258308116, - 2941282712809091851, - 13639094935183769893, - 1835248516986607988, - } - x.Mul(x, &y) -} - -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} diff --git a/ecc/bls12-381/fp/element_ops_purego.go b/ecc/bls12-381/fp/element_ops_purego.go new file mode 100644 index 000000000..fc10b3df3 --- /dev/null +++ b/ecc/bls12-381/fp/element_ops_purego.go @@ -0,0 +1,745 @@ +//go:build !amd64 || purego +// +build !amd64 purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 13438459813099623723, + 14459933216667336738, + 14900020990258308116, + 2941282712809091851, + 13639094935183769893, + 1835248516986607988, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + var t0, t1, t2, t3, t4, t5 uint64 + var u0, u1, u2, u3, u4, u5 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + u4, t4 = bits.Mul64(v, y[4]) + u5, t5 = bits.Mul64(v, y[5]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3, t4, t5 uint64 + var u0, u1, u2, u3, u4, u5 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + u4, t4 = bits.Mul64(v, x[4]) + u5, t5 = bits.Mul64(v, x[5]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} diff --git a/ecc/bls12-381/fp/element_test.go b/ecc/bls12-381/fp/element_test.go index 22588929e..c2a5fc32a 100644 --- a/ecc/bls12-381/fp/element_test.go +++ b/ecc/bls12-381/fp/element_test.go @@ -23,7 +23,7 @@ import ( "math/big" "math/bits" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field" mrand "math/rand" "testing" @@ -182,17 +182,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -642,7 +635,7 @@ func TestElementBitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPairElement) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -757,7 +750,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -770,13 +763,13 @@ func TestElementAdd(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -805,17 +798,17 @@ func TestElementAdd(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -866,7 +859,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -879,13 +872,13 @@ func TestElementSub(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -914,17 +907,17 @@ func TestElementSub(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -975,7 +968,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -988,7 +981,7 @@ func TestElementMul(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -1002,7 +995,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1042,11 +1035,11 @@ func TestElementMul(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -1059,7 +1052,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1111,7 +1104,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1124,14 +1117,14 @@ func TestElementDiv(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1160,18 +1153,18 @@ func TestElementDiv(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1222,7 +1215,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1235,13 +1228,13 @@ func TestElementExp(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1270,17 +1263,17 @@ func TestElementExp(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1325,7 +1318,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1346,14 +1339,14 @@ func TestElementSquare(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1397,7 +1390,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1418,14 +1411,14 @@ func TestElementInverse(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1469,7 +1462,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1490,14 +1483,14 @@ func TestElementSqrt(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1541,7 +1534,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1562,14 +1555,14 @@ func TestElementDouble(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1613,7 +1606,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1634,14 +1627,14 @@ func TestElementNeg(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -2029,7 +2022,7 @@ func TestElementNegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -2162,17 +2155,17 @@ func TestElementFromMont(t *testing.T) { func(a testPairElement) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPairElement) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -2255,7 +2248,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } @@ -2323,7 +2316,7 @@ func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { return field.BigIntMatchUint64Slice(&aIntMod, slice) } -//TODO: Phase out in favor of property based testing +// TODO: Phase out in favor of property based testing func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { if err := z.matchVeryBigInt(aHi, aInt); err != nil { @@ -2367,7 +2360,7 @@ func TestElementInversionCorrectionFactorFormula(t *testing.T) { inversionCorrectionFactorWord4, inversionCorrectionFactorWord5, } - inversionCorrectionFactor.ToBigInt(&refFactorInt) + inversionCorrectionFactor.toBigInt(&refFactorInt) if refFactorInt.Cmp(factorInt) != 0 { t.Error("mismatch") @@ -2409,7 +2402,7 @@ func TestElementInversionCorrectionFactor(t *testing.T) { if !oneInv.Equal(&one) { var i big.Int - oneInv.ToBigIntRegular(&i) // no montgomery + oneInv.BigInt(&i) // no montgomery i.ModInverse(&i, Modulus()) var fac Element fac.setBigInt(&i) // back to montgomery @@ -2558,7 +2551,7 @@ func TestElement0Inverse(t *testing.T) { } } -//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen func TestUpdateFactorSubtraction(t *testing.T) { for i := 0; i < 1000; i++ { @@ -2771,11 +2764,11 @@ func randomizeUpdateFactors() (int64, int64) { func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { var p1 big.Int - x.ToBigInt(&p1) + x.toBigInt(&p1) p1.Mul(&p1, big.NewInt(xC)) var p2 big.Int - y.ToBigInt(&p2) + y.toBigInt(&p2) p2.Mul(&p2, big.NewInt(yC)) p1.Add(&p1, &p2) @@ -2812,7 +2805,7 @@ func montReduce(res *big.Int, x *big.Int) { } func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { - z.ToBigInt(i) + z.toBigInt(i) var upperWord big.Int upperWord.SetUint64(xHi) upperWord.Lsh(&upperWord, Limbs*64) @@ -2830,7 +2823,7 @@ func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { var xInt big.Int - x.ToBigInt(&xInt) + x.toBigInt(&xInt) xInt.Mul(&xInt, big.NewInt(c)) @@ -2841,7 +2834,7 @@ func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, result func approximateRef(x *Element) uint64 { var asInt big.Int - x.ToBigInt(&asInt) + x.toBigInt(&asInt) n := x.BitLen() if n <= 64 { diff --git a/ecc/bls12-381/fr/doc.go b/ecc/bls12-381/fr/doc.go index 530ab29ed..71ea8a2dc 100644 --- a/ecc/bls12-381/fr/doc.go +++ b/ecc/bls12-381/fr/doc.go @@ -21,30 +21,33 @@ // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [4]uint64 // -// Usage +// type Element [4]uint64 +// +// # Usage // // Example API signature: -// // Mul z = x * y (mod q) -// func (z *Element) Mul(x, y *Element) *Element +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) // // Modulus q = // -// q[base10] = 52435875175126190479447740508185965837690552500527637822603658699938581184513 -// q[base16] = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001 +// q[base10] = 52435875175126190479447740508185965837690552500527637822603658699938581184513 +// q[base16] = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package fr diff --git a/ecc/bls12-381/fr/element.go b/ecc/bls12-381/fr/element.go index f825c8768..7b753d7b5 100644 --- a/ecc/bls12-381/fr/element.go +++ b/ecc/bls12-381/fr/element.go @@ -20,13 +20,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "github.com/consensys/gnark-crypto/field" "io" "math/big" "math/bits" "reflect" "strconv" "strings" - "sync" ) // Element represents a field element stored on 4 words (uint64) @@ -35,18 +35,18 @@ import ( // // Modulus q = // -// q[base10] = 52435875175126190479447740508185965837690552500527637822603658699938581184513 -// q[base16] = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001 +// q[base10] = 52435875175126190479447740508185965837690552500527637822603658699938581184513 +// q[base16] = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [4]uint64 const ( - Limbs = 4 // number of 64 bits words needed to represent a Element - Bits = 255 // number of bits needed to represent a Element - Bytes = Limbs * 8 // number of bytes needed to represent a Element + Limbs = 4 // number of 64 bits words needed to represent a Element + Bits = 255 // number of bits needed to represent a Element + Bytes = 32 // number of bytes needed to represent a Element ) // Field modulus q @@ -68,8 +68,8 @@ var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int // -// q[base10] = 52435875175126190479447740508185965837690552500527637822603658699938581184513 -// q[base16] = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001 +// q[base10] = 52435875175126190479447740508185965837690552500527637822603658699938581184513 +// q[base16] = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001 func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } @@ -78,12 +78,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 18446744069414584319 -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001", 16) } @@ -91,8 +85,9 @@ func init() { // NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v Element -// v.SetUint64(...) +// +// var v Element +// v.SetUint64(...) func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) @@ -103,7 +98,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -133,14 +128,15 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported // supported types: -// Element -// *Element -// uint64 -// int -// string (see SetString for valid formats) -// *big.Int -// big.Int -// []byte +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { if i1 == nil { return nil, errors.New("can't set fr.Element with ") @@ -250,15 +246,13 @@ func (z *Element) IsOne() bool { // IsUint64 reports whether z can be represented as an uint64. func (z *Element) IsUint64() bool { zz := *z - zz.FromMont() + zz.fromMont() return zz.FitsOnOneWord() } // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -270,15 +264,12 @@ func (z *Element) FitsOnOneWord() bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[3] > _x[3] { return 1 } else if _z[3] < _x[3] { @@ -309,8 +300,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 9223372034707292161, 0) @@ -401,67 +391,9 @@ func (z *Element) Halve() { } -// Mul z = x * y (mod q) -// -// x and y must be strictly inferior to q -func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be strictly inferior to q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } @@ -475,7 +407,7 @@ func (z *Element) Add(x, y *Element) *Element { z[2], carry = bits.Add64(x[2], y[2], carry) z[3], _ = bits.Add64(x[3], y[3], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -495,7 +427,7 @@ func (z *Element) Double(x *Element) *Element { z[2], carry = bits.Add64(x[2], x[2], carry) z[3], _ = bits.Add64(x[3], x[3], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -548,65 +480,147 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // see Mul for algorithm documentation - var t [4]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [5]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + + if t[4] != 0 { + // we need to reduce, we have a result on 5 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], _ = bits.Sub64(t[3], q3, b) + return } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - z[3], z[2] = madd3(m, q3, c[0], c[2], c[1]) - } - - // if z >= q → z -= q + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -614,7 +628,6 @@ func _mulGeneric(z, x, y *Element) { z[2], b = bits.Sub64(z[2], q2, b) z[3], _ = bits.Sub64(z[3], q3, b) } - } func _fromMontGeneric(z *Element) { @@ -658,7 +671,7 @@ func _fromMontGeneric(z *Element) { z[3] = C } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -670,7 +683,7 @@ func _fromMontGeneric(z *Element) { func _reduceGeneric(z *Element) { - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -734,6 +747,35 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + // Exp z = xᵏ (mod q) func (z *Element) Exp(x Element, k *big.Int) *Element { if k.IsUint64() && k.Uint64() == 0 { @@ -748,8 +790,8 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } @@ -775,23 +817,29 @@ var rSquare = Element{ 524908885293268753, } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *Element) ToMont() *Element { +func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *Element) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[24:32], z[0]) + binary.BigEndian.PutUint64(b[16:24], z[1]) + binary.BigEndian.PutUint64(b[8:16], z[2]) + binary.BigEndian.PutUint64(b[0:8], z[3]) + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -810,47 +858,49 @@ func (z *Element) Text(base int) string { if base == 10 { var zzNeg Element zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } zz := *z - zz.FromMont() + zz.fromMont() if zz.FitsOnOneWord() { return strconv.FormatUint(zz[0], base) } - vv := bigIntPool.Get().(*big.Int) - r := zz.ToBigInt(vv).Text(base) - bigIntPool.Put(vv) + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) return r } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[24:32], z[0]) - binary.BigEndian.PutUint64(b[16:24], z[1]) - binary.BigEndian.PutUint64(b[8:16], z[2]) - binary.BigEndian.PutUint64(b[0:8], z[3]) - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[24:32], _z[0]) - binary.BigEndian.PutUint64(res[16:24], _z[1]) - binary.BigEndian.PutUint64(res[8:16], _z[2]) - binary.BigEndian.PutUint64(res[0:8], _z[3]) +// Bits provides access to z by returning its value as a little-endian [4]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [4]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -863,19 +913,44 @@ func (z *Element) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } +// SetBytesCanonical interprets e as the bytes of a big-endian 32-byte integer. +// If e is not a 32-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fr.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -893,17 +968,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -925,20 +999,20 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z // // The number prefix determines the actual base: A prefix of -// ''0b'' or ''0B'' selects base 2, ''0'', ''0o'' or ''0O'' selects base 8, -// and ''0x'' or ''0X'' selects base 16. Otherwise, the selected base is 10 +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For base 16, lower and upper case letters are considered the same: // The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. // -// An underscore character ''_'' may appear between a base +// An underscore character ”_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as a panic if there @@ -947,7 +1021,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -956,7 +1030,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -996,7 +1070,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -1005,10 +1079,79 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 32-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[24:32]) + z[1] = binary.BigEndian.Uint64((*b)[16:24]) + z[2] = binary.BigEndian.Uint64((*b)[8:16]) + z[3] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[24:32], e[0]) + binary.BigEndian.PutUint64((*b)[16:24], e[1]) + binary.BigEndian.PutUint64((*b)[8:16], e[2]) + binary.BigEndian.PutUint64((*b)[0:8], e[3]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) +} + +func (littleEndian) String() string { return "LittleEndian" } + // Legendre returns the Legendre symbol of z (either +1, -1, or 0.) func (z *Element) Legendre() int { var l Element @@ -1041,7 +1184,7 @@ func (z *Element) Sqrt(x *Element) *Element { // y = x^((s+1)/2)) = w * x y.Mul(x, &w) - // b = x^s = w * w * x = y * x + // b = xˢ = w * w * x = y * x b.Mul(&w, &y) // g = nonResidue ^ s @@ -1054,7 +1197,7 @@ func (z *Element) Sqrt(x *Element) *Element { r := uint64(32) // compute legendre symbol - // t = x^((q-1)/2) = r-1 squaring of x^s + // t = x^((q-1)/2) = r-1 squaring of xˢ t = b for i := uint64(0); i < r-1; i++ { t.Square(&t) @@ -1244,7 +1387,7 @@ func (z *Element) Inverse(x *Element) *Element { // we would multiply by pSq up to 13times; // on x86, the assembly routine outperforms generic code for mul by word // on arm64, we may loose up to ~5% for 6 limbs - mul(&v, &v, &a) + v.Mul(&v, &a) } u.Set(x) // for correctness check @@ -1259,17 +1402,28 @@ func (z *Element) Inverse(x *Element) *Element { // correctness check v.Mul(&u, z) if !v.IsOne() && !u.IsZero() { - return z.inverseExp(&u) + return z.inverseExp(u) } return z } // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) -func (z *Element) inverseExp(x *Element) *Element { - qMinusTwo := Modulus() - qMinusTwo.Sub(qMinusTwo, big.NewInt(2)) - return z.Exp(*x, qMinusTwo) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z } // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits @@ -1368,7 +1522,7 @@ func (z *Element) montReduceSigned(x *Element, xHi uint64) { z[3], z[2] = madd2(m, q3, t[i+3], C) } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) diff --git a/ecc/bls12-381/fr/element_mul_adx_amd64.s b/ecc/bls12-381/fr/element_mul_adx_amd64.s deleted file mode 100644 index 5068db70e..000000000 --- a/ecc/bls12-381/fr/element_mul_adx_amd64.s +++ /dev/null @@ -1,465 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xffffffff00000001 -DATA q<>+8(SB)/8, $0x53bda402fffe5bfe -DATA q<>+16(SB)/8, $0x3339d80809a1d805 -DATA q<>+24(SB)/8, $0x73eda753299d7d48 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xfffffffeffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), NOSPLIT, $0-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -TEXT ·fromMont(SB), NOSPLIT, $0-8 - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET diff --git a/ecc/bls12-381/fr/element_mul_amd64.s b/ecc/bls12-381/fr/element_mul_amd64.s index 467a39f8f..ef89cc5df 100644 --- a/ecc/bls12-381/fr/element_mul_amd64.s +++ b/ecc/bls12-381/fr/element_mul_amd64.s @@ -1,4 +1,4 @@ -// +build !amd64_adx +// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls12-381/fr/element_ops_amd64.go b/ecc/bls12-381/fr/element_ops_amd64.go index b37a251b6..e40a9caed 100644 --- a/ecc/bls12-381/fr/element_ops_amd64.go +++ b/ecc/bls12-381/fr/element_ops_amd64.go @@ -1,3 +1,6 @@ +//go:build !purego +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,7 +38,70 @@ func fromMont(res *Element) func reduce(res *Element) // Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) +// +// a = a + b (mod q) +// b = a - b (mod q) +// //go:noescape func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bls12-381/fr/element_ops_amd64.s b/ecc/bls12-381/fr/element_ops_amd64.s index daf46847a..dde381328 100644 --- a/ecc/bls12-381/fr/element_ops_amd64.s +++ b/ecc/bls12-381/fr/element_ops_amd64.s @@ -1,3 +1,5 @@ +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bls12-381/fr/element_ops_purego.go b/ecc/bls12-381/fr/element_ops_purego.go new file mode 100644 index 000000000..258157ab7 --- /dev/null +++ b/ecc/bls12-381/fr/element_ops_purego.go @@ -0,0 +1,443 @@ +//go:build !amd64 || purego +// +build !amd64 purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 120259084260, + 15510977298029211676, + 7326335280343703402, + 5909200893219589146, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} diff --git a/ecc/bls12-381/fr/element_test.go b/ecc/bls12-381/fr/element_test.go index 84d026fce..2ad111aac 100644 --- a/ecc/bls12-381/fr/element_test.go +++ b/ecc/bls12-381/fr/element_test.go @@ -23,7 +23,7 @@ import ( "math/big" "math/bits" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field" mrand "math/rand" "testing" @@ -182,17 +182,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -638,7 +631,7 @@ func TestElementBitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPairElement) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -753,7 +746,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -766,13 +759,13 @@ func TestElementAdd(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -801,17 +794,17 @@ func TestElementAdd(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -862,7 +855,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -875,13 +868,13 @@ func TestElementSub(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -910,17 +903,17 @@ func TestElementSub(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -971,7 +964,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -984,7 +977,7 @@ func TestElementMul(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -998,7 +991,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1038,11 +1031,11 @@ func TestElementMul(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -1055,7 +1048,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1107,7 +1100,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1120,14 +1113,14 @@ func TestElementDiv(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1156,18 +1149,18 @@ func TestElementDiv(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1218,7 +1211,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1231,13 +1224,13 @@ func TestElementExp(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1266,17 +1259,17 @@ func TestElementExp(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1321,7 +1314,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1342,14 +1335,14 @@ func TestElementSquare(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1393,7 +1386,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1414,14 +1407,14 @@ func TestElementInverse(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1465,7 +1458,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1486,14 +1479,14 @@ func TestElementSqrt(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1537,7 +1530,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1558,14 +1551,14 @@ func TestElementDouble(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1609,7 +1602,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1630,14 +1623,14 @@ func TestElementNeg(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -2025,7 +2018,7 @@ func TestElementNegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -2158,17 +2151,17 @@ func TestElementFromMont(t *testing.T) { func(a testPairElement) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPairElement) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -2247,7 +2240,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } @@ -2309,7 +2302,7 @@ func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { return field.BigIntMatchUint64Slice(&aIntMod, slice) } -//TODO: Phase out in favor of property based testing +// TODO: Phase out in favor of property based testing func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { if err := z.matchVeryBigInt(aHi, aInt); err != nil { @@ -2351,7 +2344,7 @@ func TestElementInversionCorrectionFactorFormula(t *testing.T) { inversionCorrectionFactorWord2, inversionCorrectionFactorWord3, } - inversionCorrectionFactor.ToBigInt(&refFactorInt) + inversionCorrectionFactor.toBigInt(&refFactorInt) if refFactorInt.Cmp(factorInt) != 0 { t.Error("mismatch") @@ -2393,7 +2386,7 @@ func TestElementInversionCorrectionFactor(t *testing.T) { if !oneInv.Equal(&one) { var i big.Int - oneInv.ToBigIntRegular(&i) // no montgomery + oneInv.BigInt(&i) // no montgomery i.ModInverse(&i, Modulus()) var fac Element fac.setBigInt(&i) // back to montgomery @@ -2538,7 +2531,7 @@ func TestElement0Inverse(t *testing.T) { } } -//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen func TestUpdateFactorSubtraction(t *testing.T) { for i := 0; i < 1000; i++ { @@ -2751,11 +2744,11 @@ func randomizeUpdateFactors() (int64, int64) { func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { var p1 big.Int - x.ToBigInt(&p1) + x.toBigInt(&p1) p1.Mul(&p1, big.NewInt(xC)) var p2 big.Int - y.ToBigInt(&p2) + y.toBigInt(&p2) p2.Mul(&p2, big.NewInt(yC)) p1.Add(&p1, &p2) @@ -2792,7 +2785,7 @@ func montReduce(res *big.Int, x *big.Int) { } func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { - z.ToBigInt(i) + z.toBigInt(i) var upperWord big.Int upperWord.SetUint64(xHi) upperWord.Lsh(&upperWord, Limbs*64) @@ -2810,7 +2803,7 @@ func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { var xInt big.Int - x.ToBigInt(&xInt) + x.toBigInt(&xInt) xInt.Mul(&xInt, big.NewInt(c)) @@ -2821,7 +2814,7 @@ func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, result func approximateRef(x *Element) uint64 { var asInt big.Int - x.ToBigInt(&asInt) + x.toBigInt(&asInt) n := x.BitLen() if n <= 64 { diff --git a/ecc/bls12-381/fr/fri/fri.go b/ecc/bls12-381/fr/fri/fri.go index 85fc69bc6..b526bcf38 100644 --- a/ecc/bls12-381/fr/fri/fri.go +++ b/ecc/bls12-381/fr/fri/fri.go @@ -121,7 +121,6 @@ type Round struct { // a function is d-close to a low degree polynomial. // // It is composed of a series of Interactions, emulated with Fiat Shamir, -// type ProofOfProximity struct { // ID unique ID attached to the proof of proximity. It's needed for diff --git a/ecc/bls12-381/fr/gkr/gkr.go b/ecc/bls12-381/fr/gkr/gkr.go new file mode 100644 index 000000000..185b4455e --- /dev/null +++ b/ecc/bls12-381/fr/gkr/gkr.go @@ -0,0 +1,774 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/sumcheck" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// Gate must be a low-degree polynomial +type Gate interface { + Evaluate(...fr.Element) fr.Element + Degree() int +} + +type Wire struct { + Gate Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]fr.Element + claimedEvaluations []fr.Element + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]fr.Element) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]fr.Element // x in the paper + claimedEvaluations []fr.Element // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + newEq.Eq(c.evaluationPoints[k]) + + eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// computeValAndStep returns val : i ↦ m(1, i...) and step : i ↦ m(1, i...) - m(0, i...) +func computeValAndStep(m polynomial.MultiLin, p *polynomial.Pool) (val polynomial.MultiLin, step polynomial.MultiLin) { + val = p.Clone(m[len(m)/2:]) + step = p.Clone(m[:len(m)/2]) + + valAsPoly, stepAsPoly := polynomial.Polynomial(val), polynomial.Polynomial(step) + + stepAsPoly.Sub(valAsPoly, stepAsPoly) + return +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = g_{j-1}(r_{j-1}). By convention, g_0 is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() (gJ polynomial.Polynomial) { + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + EVal, EStep := computeValAndStep(c.eq, c.manager.memPool) + + puVal := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO: Make a two-dimensional array struct, and index it i-first rather than inputI first: would result in scanning memory access in the "d" loop and obviate the gateInput variable + puStep := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO, ctd: the greater degGJ, the more this would matter + + for i, puI := range c.inputPreprocessors { + puVal[i], puStep[i] = computeValAndStep(puI, c.manager.memPool) + } + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + gJ = make([]fr.Element, degGJ) + + parallel := len(EVal) >= 1024 //TODO: Experiment with threshold + + var gateInput [][]fr.Element + + if parallel { + gateInput = [][]fr.Element{c.manager.memPool.Make(len(c.inputPreprocessors)), + c.manager.memPool.Make(len(c.inputPreprocessors))} + } else { + gateInput = [][]fr.Element{c.manager.memPool.Make(len(c.inputPreprocessors))} + } + + var wg sync.WaitGroup + + for d := 0; d < degGJ; d++ { + + notLastIteration := d+1 < degGJ + + sumOverI := func(res *fr.Element, gateInput []fr.Element, start, end int) { + for i := start; i < end; i++ { + + for inputI := range puVal { + gateInput[inputI].Set(&puVal[inputI][i]) + if notLastIteration { + puVal[inputI][i].Add(&puVal[inputI][i], &puStep[inputI][i]) + } + } + + // gJAtDI = gJ(d, i...) + gJAtDI := c.wire.Gate.Evaluate(gateInput...) + gJAtDI.Mul(&gJAtDI, &EVal[i]) + + res.Add(res, &gJAtDI) + + if notLastIteration { + EVal[i].Add(&EVal[i], &EStep[i]) + } + } + wg.Done() + } + + if parallel { + var firstHalf, secondHalf fr.Element + wg.Add(2) + go sumOverI(&secondHalf, gateInput[1], len(EVal)/2, len(EVal)) + go sumOverI(&firstHalf, gateInput[0], 0, len(EVal)/2) + wg.Wait() + gJ[d].Add(&firstHalf, &secondHalf) + } else { + wg.Add(1) // formalities + sumOverI(&gJ[d], gateInput[0], 0, len(EVal)) + } + } + + c.manager.memPool.Dump(gateInput...) + c.manager.memPool.Dump(EVal, EStep) + + for inputI := range puVal { + c.manager.memPool.Dump(puVal[inputI], puStep[inputI]) + } + + return +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { + c.eq.Fold(element) + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, pool *polynomial.Pool) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = pool + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< b { + return a + } + return b +} + +func ChallengeNames(sorted []*Wire, logNbInstances int, prefix string) []string { + + // Pre-compute the size TODO: Consider not doing this and just grow the list by appending + size := logNbInstances // first challenge + + for _, w := range sorted { + if w.noProof() { // no proof, no challenge + continue + } + if w.NbClaims() > 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + + claims := newClaimsManager(c, assignment, o.pool) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + + claims := newClaimsManager(c, assignment, o.pool) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]fr.Element) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return fmt.Errorf("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return fmt.Errorf("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +type IdentityGate struct{} + +func (IdentityGate) Evaluate(input ...fr.Element) fr.Element { + return input[0] +} + +func (IdentityGate) Degree() int { + return 1 +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = IdentityGate{} + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + + numEvaluations := 0 + + for _, w := range sortedWires { + if !w.IsInput() { + if numEvaluations == 0 { + numEvaluations = len(a[w.Inputs[0]]) + } + evals := make([]fr.Element, numEvaluations) + ins := make([]fr.Element, len(w.Inputs)) + for k := 0; k < numEvaluations; k++ { + for inI, in := range w.Inputs { + ins[inI] = a[in][k] + } + evals[k] = w.Gate.Evaluate(ins...) + } + a[w] = evals + } + } + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} diff --git a/ecc/bls12-381/fr/gkr/gkr_test.go b/ecc/bls12-381/fr/gkr/gkr_test.go new file mode 100644 index 000000000..2dbc0e90a --- /dev/null +++ b/ecc/bls12-381/fr/gkr/gkr_test.go @@ -0,0 +1,722 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/sumcheck" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + pool := polynomial.NewPool(256, 1<<11) + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, &pool) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandom(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/ConsenSys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandom(slice []fr.Element) { + for i := range slice { + slice[i].SetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, testCase.transcriptSetting()) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, testCase.transcriptSetting()) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(&test_vector_utils.MapHash{Map: testCase.Hash}, []byte{1})) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "../../../../internal/generator/gkr/test_vectors" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func BenchmarkGkrMimc(b *testing.B) { + const N = 1 << 19 + fmt.Println("creating circuit structure") + c := mimcCircuit(91) + + in0 := make([]fr.Element, N) + in1 := make([]fr.Element, N) + setRandom(in0) + setRandom(in1) + + fmt.Println("evaluating circuit") + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + + //b.ResetTimer() + fmt.Println("constructing proof") + Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate string `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = gates[c[i].Gate] + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +var gates map[string]Gate + +func init() { + gates = make(map[string]Gate) + gates["identity"] = IdentityGate{} + gates["mul"] = mulGate{} + gates["mimc"] = mimcCipherGate{} //TODO: Add ark + gates["select-input-3"] = _select(2) +} + +type mimcCipherGate struct { + ark fr.Element +} + +func (m mimcCipherGate) Evaluate(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]). + Add(&sum, &m.ark) + + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +func (m mimcCipherGate) Degree() int { + return 7 +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash *test_vector_utils.ElementMap + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash string `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash *test_vector_utils.ElementMap + if _hash, err = test_vector_utils.ElementMapFromFile(filepath.Join(dir, info.Hash)); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func (c *TestCase) transcriptSetting(initialChallenge ...[]byte) fiatshamir.Settings { + return fiatshamir.WithHash(&test_vector_utils.MapHash{Map: c.Hash}, initialChallenge...) +} + +type mulGate struct{} + +func (g mulGate) Evaluate(element ...fr.Element) (result fr.Element) { + result.Mul(&element[0], &element[1]) + return +} + +func (g mulGate) Degree() int { + return 2 +} + +type _select int + +func (g _select) Evaluate(in ...fr.Element) fr.Element { + return in[g] +} + +func (g _select) Degree() int { + return 1 +} diff --git a/ecc/bls12-381/fr/kzg/kzg.go b/ecc/bls12-381/fr/kzg/kzg.go index 0a7a712d3..e7e4d3d3b 100644 --- a/ecc/bls12-381/fr/kzg/kzg.go +++ b/ecc/bls12-381/fr/kzg/kzg.go @@ -84,9 +84,6 @@ func NewSRS(size uint64, bAlpha *big.Int) (*SRS, error) { for i := 1; i < len(alphas); i++ { alphas[i].Mul(&alphas[i-1], &alpha) } - for i := 0; i < len(alphas); i++ { - alphas[i].FromMont() - } g1s := bls12381.BatchScalarMultiplicationG1(&gen1Aff, alphas) copy(srs.G1[1:], g1s) @@ -125,7 +122,7 @@ func Commit(p []fr.Element, srs *SRS, nbTasks ...int) (Digest, error) { var res bls12381.G1Affine - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} if len(nbTasks) > 0 { config.NbTasks = nbTasks[0] } @@ -395,7 +392,7 @@ func BatchVerifyMultiPoints(digests []Digest, proofs []OpeningProof, points []fr for i := 0; i < len(randomNumbers); i++ { quotients[i].Set(&proofs[i].H) } - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} _, err := foldedQuotients.MultiExp(quotients, randomNumbers, config) if err != nil { return nil @@ -478,7 +475,7 @@ func fold(di []Digest, fai []fr.Element, ci []fr.Element) (Digest, fr.Element, e // fold the digests ∑ᵢ[cᵢ]([fᵢ(α)]G₁) var foldedDigests Digest - _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{ScalarsMont: true}) + _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{}) if err != nil { return foldedDigests, foldedEvaluations, err } diff --git a/ecc/bls12-381/fr/mimc/decompose.go b/ecc/bls12-381/fr/mimc/decompose.go new file mode 100644 index 000000000..925d67932 --- /dev/null +++ b/ecc/bls12-381/fr/mimc/decompose.go @@ -0,0 +1,46 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" +) + +// Decompose interpret rawBytes as a bigInt x in big endian, +// and returns the digits of x (from LSB to MSB) when x is written +// in basis modulo. +func Decompose(rawBytes []byte) []fr.Element { + + rawBigInt := big.NewInt(0).SetBytes(rawBytes) + modulo := fr.Modulus() + + // maximum number of chunks that a function + maxNbChunks := len(rawBytes) / fr.Bytes + + res := make([]fr.Element, 0, maxNbChunks) + var tmp fr.Element + t := new(big.Int) + for rawBigInt.Sign() != 0 { + rawBigInt.DivMod(rawBigInt, modulo, t) + tmp.SetBigInt(t) + res = append(res, tmp) + } + + return res +} diff --git a/ecc/bls12-381/fr/mimc/decompose_test.go b/ecc/bls12-381/fr/mimc/decompose_test.go new file mode 100644 index 000000000..36809a2aa --- /dev/null +++ b/ecc/bls12-381/fr/mimc/decompose_test.go @@ -0,0 +1,53 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" +) + +func TestDecompose(t *testing.T) { + + // create 10 random digits in basis r + nbDigits := 10 + a := make([]fr.Element, nbDigits) + for i := 0; i < nbDigits; i++ { + a[i].SetRandom() + } + + // create a big int whose digits in basis r are a + m := fr.Modulus() + var b, tmp big.Int + for i := nbDigits - 1; i >= 0; i-- { + b.Mul(&b, m) + a[i].ToBigIntRegular(&tmp) + b.Add(&b, &tmp) + } + + // query the decomposition and compare to a + bb := b.Bytes() + d := Decompose(bb) + for i := 0; i < nbDigits; i++ { + if !d[i].Equal(&a[i]) { + t.Fatal("error decomposition") + } + } + +} diff --git a/ecc/bls12-381/fr/mimc/mimc.go b/ecc/bls12-381/fr/mimc/mimc.go index 89287dd06..e704e01f7 100644 --- a/ecc/bls12-381/fr/mimc/mimc.go +++ b/ecc/bls12-381/fr/mimc/mimc.go @@ -17,6 +17,7 @@ package mimc import ( + "errors" "hash" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" @@ -91,44 +92,44 @@ func (d *digest) BlockSize() int { } // Write (via the embedded io.Writer interface) adds more data to the running hash. -// It never returns an error. +// +// Each []byte block of size BlockSize represents a big endian fr.Element. +// +// If len(p) is not a multiple of BlockSize and any of the []byte in p represent an integer +// larger than fr.Modulus, this function returns an error. +// +// To hash arbitrary data ([]byte not representing canonical field elements) use Decompose +// function in this package. func (d *digest) Write(p []byte) (n int, err error) { n = len(p) + if n%BlockSize != 0 { + return 0, errors.New("invalid input length: must represent a list of field elements, expects a []byte of len m*BlockSize") + } + + // ensure each block represents a field element in canonical reduced form + for i := 0; i < n; i += BlockSize { + if _, err = fr.BigEndian.Element((*[BlockSize]byte)(p[i : i+BlockSize])); err != nil { + return 0, err + } + } + d.data = append(d.data, p...) return } -// Hash hash using Miyaguchi–Preneel: +// Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form func (d *digest) checksum() fr.Element { + // Write guarantees len(data) % BlockSize == 0 - var buffer [BlockSize]byte - var x fr.Element - - // if data size is not multiple of BlockSizes we padd: - // .. || 0xaf8 -> .. || 0x0000...0af8 - if len(d.data)%BlockSize != 0 { - q := len(d.data) / BlockSize - r := len(d.data) % BlockSize - sliceq := make([]byte, q*BlockSize) - copy(sliceq, d.data) - slicer := make([]byte, r) - copy(slicer, d.data[q*BlockSize:]) - sliceremainder := make([]byte, BlockSize-r) - d.data = append(sliceq, sliceremainder...) - d.data = append(d.data, slicer...) - } - + // TODO @ThomasPiellard shouldn't Sum() returns an error if there is no data? if len(d.data) == 0 { - d.data = make([]byte, 32) + d.data = make([]byte, BlockSize) } - nbChunks := len(d.data) / BlockSize - - for i := 0; i < nbChunks; i++ { - copy(buffer[:], d.data[i*BlockSize:(i+1)*BlockSize]) - x.SetBytes(buffer[:]) + for i := 0; i < len(d.data); i += BlockSize { + x, _ := fr.BigEndian.Element((*[BlockSize]byte)(d.data[i : i+BlockSize])) r := d.encrypt(x) d.h.Add(&r, &d.h).Add(&d.h, &x) } diff --git a/ecc/bls12-381/fr/pedersen/pedersen.go b/ecc/bls12-381/fr/pedersen/pedersen.go new file mode 100644 index 000000000..38cc4d32c --- /dev/null +++ b/ecc/bls12-381/fr/pedersen/pedersen.go @@ -0,0 +1,113 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package pedersen + +import ( + "crypto/rand" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "math/big" +) + +// Key for proof and verification +type Key struct { + g bls12381.G2Affine // TODO @tabaie: does this really have to be randomized? + gRootSigmaNeg bls12381.G2Affine //gRootSigmaNeg = g^{-1/σ} + basis []bls12381.G1Affine + basisExpSigma []bls12381.G1Affine +} + +func randomOnG2() (bls12381.G2Affine, error) { // TODO: Add to G2.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return bls12381.G2Affine{}, err + } + return bls12381.HashToG2(gBytes, []byte("random on g2")) +} + +func Setup(basis []bls12381.G1Affine) (Key, error) { + var ( + k Key + err error + ) + + if k.g, err = randomOnG2(); err != nil { + return k, err + } + + var modMinusOne big.Int + modMinusOne.Sub(fr.Modulus(), big.NewInt(1)) + var sigma *big.Int + if sigma, err = rand.Int(rand.Reader, &modMinusOne); err != nil { + return k, err + } + sigma.Add(sigma, big.NewInt(1)) + + var sigmaInvNeg big.Int + sigmaInvNeg.ModInverse(sigma, fr.Modulus()) + sigmaInvNeg.Sub(fr.Modulus(), &sigmaInvNeg) + k.gRootSigmaNeg.ScalarMultiplication(&k.g, &sigmaInvNeg) + + k.basisExpSigma = make([]bls12381.G1Affine, len(basis)) + for i := range basis { + k.basisExpSigma[i].ScalarMultiplication(&basis[i], sigma) + } + + k.basis = basis + return k, err +} + +func (k *Key) Commit(values []fr.Element) (commitment bls12381.G1Affine, knowledgeProof bls12381.G1Affine, err error) { + + if len(values) != len(k.basis) { + err = fmt.Errorf("unexpected number of values") + return + } + + // TODO @gbotrel this will spawn more than one task, see + // https://github.com/ConsenSys/gnark-crypto/issues/269 + config := ecc.MultiExpConfig{ + NbTasks: 1, // TODO Experiment + } + + if _, err = commitment.MultiExp(k.basis, values, config); err != nil { + return + } + + _, err = knowledgeProof.MultiExp(k.basisExpSigma, values, config) + + return +} + +// VerifyKnowledgeProof checks if the proof of knowledge is valid +func (k *Key) VerifyKnowledgeProof(commitment bls12381.G1Affine, knowledgeProof bls12381.G1Affine) error { + + if !commitment.IsInSubGroup() || !knowledgeProof.IsInSubGroup() { + return fmt.Errorf("subgroup check failed") + } + + product, err := bls12381.Pair([]bls12381.G1Affine{commitment, knowledgeProof}, []bls12381.G2Affine{k.g, k.gRootSigmaNeg}) + if err != nil { + return err + } + if product.IsOne() { + return nil + } + return fmt.Errorf("proof rejected") +} diff --git a/ecc/bls12-381/fr/pedersen/pedersen_test.go b/ecc/bls12-381/fr/pedersen/pedersen_test.go new file mode 100644 index 000000000..111f63590 --- /dev/null +++ b/ecc/bls12-381/fr/pedersen/pedersen_test.go @@ -0,0 +1,90 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package pedersen + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-381" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/stretchr/testify/assert" + "math/rand" + "testing" +) + +func interfaceSliceToFrSlice(t *testing.T, values ...interface{}) []fr.Element { + res := make([]fr.Element, len(values)) + for i, v := range values { + _, err := res[i].SetInterface(v) + assert.NoError(t, err) + } + return res +} + +func randomFrSlice(t *testing.T, size int) []interface{} { + res := make([]interface{}, size) + var err error + for i := range res { + var v fr.Element + res[i], err = v.SetRandom() + assert.NoError(t, err) + } + return res +} + +func randomOnG1() (bls12381.G1Affine, error) { // TODO: Add to G1.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return bls12381.G1Affine{}, err + } + return bls12381.HashToG1(gBytes, []byte("random on g2")) +} + +func testCommit(t *testing.T, values ...interface{}) { + + basis := make([]bls12381.G1Affine, len(values)) + for i := range basis { + var err error + basis[i], err = randomOnG1() + assert.NoError(t, err) + } + + var ( + key Key + err error + commitment, pok bls12381.G1Affine + ) + + key, err = Setup(basis) + assert.NoError(t, err) + commitment, pok, err = key.Commit(interfaceSliceToFrSlice(t, values...)) + assert.NoError(t, err) + assert.NoError(t, key.VerifyKnowledgeProof(commitment, pok)) + + pok.Neg(&pok) + assert.NotNil(t, key.VerifyKnowledgeProof(commitment, pok)) +} + +func TestCommitToOne(t *testing.T) { + testCommit(t, 1) +} + +func TestCommitSingle(t *testing.T) { + testCommit(t, randomFrSlice(t, 1)...) +} + +func TestCommitFiveElements(t *testing.T) { + testCommit(t, randomFrSlice(t, 5)...) +} diff --git a/ecc/bls12-381/fr/plookup/vector.go b/ecc/bls12-381/fr/plookup/vector.go index 3e9be2c81..07cc44655 100644 --- a/ecc/bls12-381/fr/plookup/vector.go +++ b/ecc/bls12-381/fr/plookup/vector.go @@ -125,7 +125,8 @@ func evaluateAccumulationPolynomial(lf, lt, lh1, lh2 []fr.Element, beta, gamma f // evaluateNumBitReversed computes the evaluation (shifted, bit reversed) of h where // h = (x-1)*z*(1+\beta)*(\gamma+f)*(\gamma(1+\beta) + t+ \beta*t(gX)) - -// (x-1)*z(gX)*(\gamma(1+\beta) + h_{1} + \beta*h_{1}(gX))*(\gamma(1+\beta) + h_{2} + \beta*h_{2}(gX) ) +// +// (x-1)*z(gX)*(\gamma(1+\beta) + h_{1} + \beta*h_{1}(gX))*(\gamma(1+\beta) + h_{2} + \beta*h_{2}(gX) ) // // * cz, ch1, ch2, ct, cf are the polynomials z, h1, h2, t, f in canonical basis // * _lz, _lh1, _lh2, _lt, _lf are the polynomials z, h1, h2, t, f in shifted Lagrange basis (domainBig) @@ -370,7 +371,6 @@ func computeQuotientCanonical(alpha fr.Element, lh, lh0, lhn, lh1h2 []fr.Element // before generating a lookup proof), the commitment needs to be done on the // table sorted. Otherwise the commitment in proof.t will not be the same as // the public commitment: it will contain the same values, but permuted. -// func ProveLookupVector(srs *kzg.SRS, f, t Table) (ProofLookupVector, error) { // res diff --git a/ecc/bls12-381/fr/polynomial/multilin.go b/ecc/bls12-381/fr/polynomial/multilin.go index 087ef65e7..f668c3898 100644 --- a/ecc/bls12-381/fr/polynomial/multilin.go +++ b/ecc/bls12-381/fr/polynomial/multilin.go @@ -18,6 +18,7 @@ package polynomial import ( "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "math/bits" ) // MultiLin tracks the values of a (dense i.e. not sparse) multilinear polynomial @@ -46,44 +47,62 @@ func (m *MultiLin) Fold(r fr.Element) { *m = (*m)[:mid] } +func (m MultiLin) Sum() fr.Element { + s := m[0] + for i := 1; i < len(m); i++ { + s.Add(&s, &m[i]) + } + return s +} + +func _clone(m MultiLin, p *Pool) MultiLin { + if p == nil { + return m.Clone() + } else { + return p.Clone(m) + } +} + +func _dump(m MultiLin, p *Pool) { + if p != nil { + p.Dump(m) + } +} + // Evaluate extrapolate the value of the multilinear polynomial corresponding to m // on the given coordinates -func (m MultiLin) Evaluate(coordinates []fr.Element) fr.Element { +func (m MultiLin) Evaluate(coordinates []fr.Element, p *Pool) fr.Element { // Folding is a mutating operation - bkCopy := m.Clone() + bkCopy := _clone(m, p) // Evaluate step by step through repeated folding (i.e. evaluation at the first remaining variable) for _, r := range coordinates { bkCopy.Fold(r) } - return bkCopy[0] + result := bkCopy[0] + + _dump(bkCopy, p) + return result } -// Clone creates a deep copy of a book-keeping table. +// Clone creates a deep copy of a bookkeeping table. // Both multilinear interpolation and sumcheck require folding an underlying // array, but folding changes the array. To do both one requires a deep copy -// of the book-keeping table. +// of the bookkeeping table. func (m MultiLin) Clone() MultiLin { - tableDeepCopy := Make(len(m)) - copy(tableDeepCopy, m) - return tableDeepCopy + res := make(MultiLin, len(m)) + copy(res, m) + return res } // Add two bookKeepingTables func (m *MultiLin) Add(left, right MultiLin) { size := len(left) // Check that left and right have the same size - if len(right) != size { - panic("Left and right do not have the right size") + if len(right) != size || len(*m) != size { + panic("left, right and destination must have the right size") } - // Reallocate the table if necessary - if cap(*m) < size { - *m = make([]fr.Element, size) - } - - // Resize the destination table - *m = (*m)[:size] // Add elementwise for i := 0; i < size; i++ { @@ -93,15 +112,17 @@ func (m *MultiLin) Add(left, right MultiLin) { // EvalEq computes Eq(q₁, ... , qₙ, h₁, ... , hₙ) = Π₁ⁿ Eq(qᵢ, hᵢ) // where Eq(x,y) = xy + (1-x)(1-y) = 1 - x - y + xy + xy interpolates -// _________________ -// | | | -// | 0 | 1 | -// |_______|_______| -// y | | | -// | 1 | 0 | -// |_______|_______| // -// x +// _________________ +// | | | +// | 0 | 1 | +// |_______|_______| +// y | | | +// | 1 | 0 | +// |_______|_______| +// +// x +// // In other words the polynomial evaluated here is the multilinear extrapolation of // one that evaluates to q' == h' for vectors q', h' of binary values func EvalEq(q, h []fr.Element) fr.Element { @@ -128,10 +149,7 @@ func (m *MultiLin) Eq(q []fr.Element) { n := len(q) if len(*m) != 1< 0 { - i.Sub(fr.Modulus(), &i) - i.Neg(&i) +func (p Polynomial) SetZero() { + for i := 0; i < len(p); i++ { + p[i].SetZero() } - return i } func (p Polynomial) Text(base int) string { @@ -171,12 +175,13 @@ func (p Polynomial) Text(base int) string { continue } - i := signedBigInt(&p[d]) + pD := p[d] + pDText := pD.Text(base) initialLen := builder.Len() - if i.Sign() < 1 { - i.Neg(&i) + if pDText[0] == '-' { + pDText = pDText[1:] if first { builder.WriteString("-") } else { @@ -188,13 +193,8 @@ func (p Polynomial) Text(base int) string { first = false - asInt64 := int64(0) - if i.IsInt64() { - asInt64 = i.Int64() - } - - if asInt64 != 1 || d == 0 { - builder.WriteString(i.Text(base)) + if !pD.IsOne() || d == 0 { + builder.WriteString(pDText) } if builder.Len()-initialLen > 10 { diff --git a/ecc/bls12-381/fr/polynomial/polynomial_test.go b/ecc/bls12-381/fr/polynomial/polynomial_test.go index 0376d294c..5df4aebae 100644 --- a/ecc/bls12-381/fr/polynomial/polynomial_test.go +++ b/ecc/bls12-381/fr/polynomial/polynomial_test.go @@ -17,10 +17,10 @@ package polynomial import ( + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/stretchr/testify/assert" "math/big" "testing" - - "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" ) func TestPolynomialEval(t *testing.T) { @@ -206,3 +206,13 @@ func TestPolynomialAdd(t *testing.T) { t.Fatal("side effect, _f2 should not have been modified") } } + +func TestPolynomialText(t *testing.T) { + var one, negTwo fr.Element + one.SetOne() + negTwo.SetInt64(-2) + + p := Polynomial{one, negTwo, one} + + assert.Equal(t, "X² - 2X + 1", p.Text(10)) +} diff --git a/ecc/bls12-381/fr/polynomial/pool.go b/ecc/bls12-381/fr/polynomial/pool.go index a364c152b..81132603f 100644 --- a/ecc/bls12-381/fr/polynomial/pool.go +++ b/ecc/bls12-381/fr/polynomial/pool.go @@ -17,114 +17,187 @@ package polynomial import ( + "encoding/json" "fmt" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "reflect" + "runtime" + "sort" "sync" "unsafe" ) // Memory management for polynomials -// Copied verbatim from gkr repo +// WARNING: This is not thread safe TODO: Make sure that is not a problem +// TODO: There is a lot of "unsafe" memory management here and needs to be vetted thoroughly -// Sets a maximum for the array size we keep in pool -const maxNForLargePool int = 1 << 24 -const maxNForSmallPool int = 256 +type sizedPool struct { + maxN int + pool sync.Pool + stats poolStats +} + +type inUseData struct { + allocatedFor []uintptr + pool *sizedPool +} -// Aliases because it is annoying to use arrays in all the places -type largeArr = [maxNForLargePool]fr.Element -type smallArr = [maxNForSmallPool]fr.Element +type Pool struct { + //lock sync.Mutex + inUse map[*fr.Element]inUseData + subPools []sizedPool +} -var rC = sync.Map{} +func (p *sizedPool) get(n int) *fr.Element { + p.stats.maake(n) + return p.pool.Get().(*fr.Element) +} -var ( - largePool = sync.Pool{ - New: func() interface{} { - var res largeArr - return &res - }, - } - smallPool = sync.Pool{ - New: func() interface{} { - var res smallArr - return &res - }, +func (p *sizedPool) put(ptr *fr.Element) { + p.stats.dump() + p.pool.Put(ptr) +} + +func NewPool(maxN ...int) (pool Pool) { + + sort.Ints(maxN) + pool = Pool{ + inUse: make(map[*fr.Element]inUseData), + subPools: make([]sizedPool, len(maxN)), } -) -// ClearPool Clears the pool completely, shields against memory leaks -// Eg: if we forgot to dump a polynomial at some point, this will ensure the value get dumped eventually -// Returns how many polynomials were cleared that way -func ClearPool() int { - res := 0 - rC.Range(func(k, _ interface{}) bool { - switch ptr := k.(type) { - case *largeArr: - largePool.Put(ptr) - case *smallArr: - smallPool.Put(ptr) - default: - panic(fmt.Sprintf("tried to clear %v", reflect.TypeOf(ptr))) + for i := range pool.subPools { + subPool := &pool.subPools[i] + subPool.maxN = maxN[i] + subPool.pool = sync.Pool{ + New: func() interface{} { + subPool.stats.Allocated++ + return getDataPointer(make([]fr.Element, 0, subPool.maxN)) + }, } - res++ - return true - }) - return res + } + return } -// CountPool Returns the number of elements in the pool without mutating it -func CountPool() int { - res := 0 - rC.Range(func(_, _ interface{}) bool { - res++ - return true - }) - return res +func (p *Pool) findCorrespondingPool(n int) *sizedPool { + poolI := 0 + for poolI < len(p.subPools) && n > p.subPools[poolI].maxN { + poolI++ + } + return &p.subPools[poolI] // out of bounds error here would mean that n is too large } -// Make tries to find a reusable polynomial or allocates a new one -func Make(n int) []fr.Element { - if n > maxNForLargePool { - panic(fmt.Sprintf("been provided with size of %v but the maximum is %v", n, maxNForLargePool)) +func (p *Pool) Make(n int) []fr.Element { + pool := p.findCorrespondingPool(n) + ptr := pool.get(n) + p.addInUse(ptr, pool) + return unsafe.Slice(ptr, n) +} + +// Dump dumps a set of polynomials into the pool +func (p *Pool) Dump(slices ...[]fr.Element) { + for _, slice := range slices { + ptr := getDataPointer(slice) + if metadata, ok := p.inUse[ptr]; ok { + delete(p.inUse, ptr) + metadata.pool.put(ptr) + } else { + panic("attempting to dump a slice not created by the pool") + } } +} + +func (p *Pool) addInUse(ptr *fr.Element, pool *sizedPool) { + pcs := make([]uintptr, 2) + n := runtime.Callers(3, pcs) - if n <= maxNForSmallPool { - ptr := smallPool.Get().(*smallArr) - rC.Store(ptr, struct{}{}) // registers the pointer being used - return (*ptr)[:n] + if prevPcs, ok := p.inUse[ptr]; ok { // TODO: remove if unnecessary for security + panic(fmt.Errorf("re-allocated non-dumped slice, previously allocated at %v", runtime.CallersFrames(prevPcs.allocatedFor))) } + p.inUse[ptr] = inUseData{ + allocatedFor: pcs[:n], + pool: pool, + } +} - ptr := largePool.Get().(*largeArr) - rC.Store(ptr, struct{}{}) // remember we allocated the pointer is being used - return (*ptr)[:n] +func printFrame(frame runtime.Frame) { + fmt.Printf("\t%s line %d, function %s\n", frame.File, frame.Line, frame.Function) } -// Dump dumps a set of polynomials into the pool -// Returns the number of deallocated polys -func Dump(arrs ...[]fr.Element) int { - cnt := 0 - for _, arr := range arrs { - ptr := ptr(arr) - pool := &smallPool - if len(arr) > maxNForSmallPool { - pool = &largePool - } - // If the rC did not register, then - // either the array was allocated somewhere else which can be ignored - // otherwise a double put which MUST be ignored - if _, ok := rC.Load(ptr); ok { - pool.Put(ptr) - // And deregisters the ptr - rC.Delete(ptr) - cnt++ +func (p *Pool) printInUse() { + fmt.Println("slices never dumped allocated at:") + for _, pcs := range p.inUse { + fmt.Println("-------------------------") + + var frame runtime.Frame + frames := runtime.CallersFrames(pcs.allocatedFor) + more := true + for more { + frame, more = frames.Next() + printFrame(frame) } } - return cnt } -func ptr(m []fr.Element) unsafe.Pointer { - if cap(m) != maxNForSmallPool && cap(m) != maxNForLargePool { - panic(fmt.Sprintf("can't cast to large or small array, the put array's is %v it should have capacity %v or %v", cap(m), maxNForLargePool, maxNForSmallPool)) +type poolStats struct { + Used int + Allocated int + ReuseRate float64 + InUse int + GreatestNUsed int + SmallestNUsed int +} + +type poolsStats struct { + SubPools []poolStats + InUse int +} + +func (s *poolStats) maake(n int) { + s.Used++ + s.InUse++ + if n > s.GreatestNUsed { + s.GreatestNUsed = n + } + if s.SmallestNUsed == 0 || s.SmallestNUsed > n { + s.SmallestNUsed = n + } +} + +func (s *poolStats) dump() { + s.InUse-- +} + +func (s *poolStats) finalize() { + s.ReuseRate = float64(s.Used) / float64(s.Allocated) +} + +func getDataPointer(slice []fr.Element) *fr.Element { + header := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) + return (*fr.Element)(unsafe.Pointer(header.Data)) +} + +func (p *Pool) PrintPoolStats() { + InUse := 0 + subStats := make([]poolStats, len(p.subPools)) + for i := range p.subPools { + subPool := &p.subPools[i] + subPool.stats.finalize() + subStats[i] = subPool.stats + InUse += subPool.stats.InUse } - return unsafe.Pointer(&m[0]) + + poolsStats := poolsStats{ + SubPools: subStats, + InUse: InUse, + } + serialized, _ := json.MarshalIndent(poolsStats, "", " ") + fmt.Println(string(serialized)) + p.printInUse() +} + +func (p *Pool) Clone(slice []fr.Element) []fr.Element { + res := p.Make(len(slice)) + copy(res, slice) + return res } diff --git a/ecc/bls12-381/fr/sumcheck/sumcheck.go b/ecc/bls12-381/fr/sumcheck/sumcheck.go new file mode 100644 index 000000000..a39dc48aa --- /dev/null +++ b/ecc/bls12-381/fr/sumcheck/sumcheck.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = &transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return fmt.Errorf("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/ecc/bls12-381/fr/sumcheck/sumcheck_test.go b/ecc/bls12-381/fr/sumcheck/sumcheck_test.go new file mode 100644 index 000000000..b62b8a915 --- /dev/null +++ b/ecc/bls12-381/fr/sumcheck/sumcheck_test.go @@ -0,0 +1,161 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/ecc/bls12-381/fr/test_vector_utils/test_vector_utils.go b/ecc/bls12-381/fr/test_vector_utils/test_vector_utils.go new file mode 100644 index 000000000..4bfd2a790 --- /dev/null +++ b/ecc/bls12-381/fr/test_vector_utils/test_vector_utils.go @@ -0,0 +1,429 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package test_vector_utils + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" + "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/polynomial" + "hash" + + "os" + "path/filepath" + "reflect" + "sort" + "strconv" + "strings" +) + +type ElementTriplet struct { + key1 fr.Element + key2 fr.Element + key2Present bool + value fr.Element + used bool +} + +func (t *ElementTriplet) CmpKey(o *ElementTriplet) int { + if cmp1 := t.key1.Cmp(&o.key1); cmp1 != 0 { + return cmp1 + } + + if t.key2Present { + if o.key2Present { + return t.key2.Cmp(&o.key2) + } + return 1 + } else { + if o.key2Present { + return -1 + } + return 0 + } +} + +var MapCache = make(map[string]*ElementMap) + +func ElementMapFromFile(path string) (*ElementMap, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if h, ok := MapCache[path]; ok { + return h, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var asMap map[string]interface{} + if err = json.Unmarshal(bytes, &asMap); err != nil { + return nil, err + } + + var h ElementMap + if h, err = CreateElementMap(asMap); err == nil { + MapCache[path] = &h + } + + return &h, err + + } else { + return nil, err + } +} + +func CreateElementMap(rawMap map[string]interface{}) (ElementMap, error) { + res := make(ElementMap, 0, len(rawMap)) + + for k, v := range rawMap { + var entry ElementTriplet + if _, err := SetElement(&entry.value, v); err != nil { + return nil, err + } + + key := strings.Split(k, ",") + switch len(key) { + case 1: + entry.key2Present = false + case 2: + entry.key2Present = true + if _, err := SetElement(&entry.key2, key[1]); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("cannot parse %T as one or two field elements", v) + } + if _, err := SetElement(&entry.key1, key[0]); err != nil { + return nil, err + } + + res = append(res, &entry) + } + + res.sort() + return res, nil +} + +type ElementMap []*ElementTriplet + +type MapHash struct { + Map *ElementMap + state fr.Element + stateValid bool +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func (m *MapHash) Write(p []byte) (n int, err error) { + var x fr.Element + for i := 0; i < len(p); i += fr.Bytes { + x.SetBytes(p[i:min(len(p), i+fr.Bytes)]) + if err = m.write(x); err != nil { + return + } + } + n = len(p) + return +} + +func (m *MapHash) Sum(b []byte) []byte { + mP := *m + if _, err := mP.Write(b); err != nil { + panic(err) + } + bytes := mP.state.Bytes() + return bytes[:] +} + +func (m *MapHash) Reset() { + m.stateValid = false +} + +func (m *MapHash) Size() int { + return fr.Bytes +} + +func (m *MapHash) BlockSize() int { + return fr.Bytes +} + +func (m *MapHash) write(x fr.Element) error { + X := &x + Y := &m.state + if !m.stateValid { + Y = nil + } + var err error + if m.state, err = m.Map.FindPair(X, Y); err == nil { + m.stateValid = true + } + return err +} + +func (t *ElementTriplet) writeKey(sb *strings.Builder) { + sb.WriteRune('"') + sb.WriteString(t.key1.String()) + if t.key2Present { + sb.WriteRune(',') + sb.WriteString(t.key2.String()) + } + sb.WriteRune('"') +} +func (m *ElementMap) UnusedEntries() []interface{} { + unused := make([]interface{}, 0) + for _, v := range *m { + if !v.used { + var vInterface interface{} + if v.key2Present { + vInterface = []interface{}{ElementToInterface(&v.key1), ElementToInterface(&v.key2)} + } else { + vInterface = ElementToInterface(&v.key1) + } + unused = append(unused, vInterface) + } + } + return unused +} + +func (m *ElementMap) sort() { + sort.Slice(*m, func(i, j int) bool { + return (*m)[i].CmpKey((*m)[j]) <= 0 + }) +} + +func (m *ElementMap) find(toFind *ElementTriplet) (fr.Element, error) { + i := sort.Search(len(*m), func(i int) bool { return (*m)[i].CmpKey(toFind) >= 0 }) + + if i < len(*m) && (*m)[i].CmpKey(toFind) == 0 { + (*m)[i].used = true + return (*m)[i].value, nil + } + var sb strings.Builder + sb.WriteString("no value available for input ") + toFind.writeKey(&sb) + return fr.Element{}, fmt.Errorf(sb.String()) +} + +func (m *ElementMap) FindPair(x *fr.Element, y *fr.Element) (fr.Element, error) { + + toFind := ElementTriplet{ + key1: *x, + key2Present: y != nil, + } + + if y != nil { + toFind.key2 = *y + } + + return m.find(&toFind) +} + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type MessageCounter struct { + startState uint64 + state uint64 + step uint64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += uint64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + uint64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: uint64(startState), state: uint64(startState), step: uint64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + text := x.Text(10) + if len(text) < 10 && !strings.Contains(text, "/") { + if i, err := strconv.Atoi(text); err != nil { + panic(err.Error()) + } else { + return i + } + } + return text +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/ecc/bls12-381/g1.go b/ecc/bls12-381/g1.go index 26e21d85a..a6e2f174c 100644 --- a/ecc/bls12-381/g1.go +++ b/ecc/bls12-381/g1.go @@ -17,13 +17,12 @@ package bls12381 import ( - "math/big" - "runtime" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" + "runtime" ) // G1Affine point in affine coordinates @@ -36,7 +35,7 @@ type G1Jac struct { X, Y, Z fp.Element } -// g1JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +// g1JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) type g1JacExtended struct { X, Y, ZZ, ZZZ fp.Element } @@ -50,6 +49,13 @@ func (p *G1Affine) Set(a *G1Affine) *G1Affine { return p } +// setInfinity sets p to O +func (p *G1Affine) setInfinity() *G1Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + // ScalarMultiplication computes and returns p = a ⋅ s func (p *G1Affine) ScalarMultiplication(a *G1Affine, s *big.Int) *G1Affine { var _p G1Jac @@ -381,6 +387,7 @@ func (p *G1Jac) IsOnCurve() bool { func (p *G1Jac) IsInSubGroup() bool { var res G1Jac + res.phi(p). ScalarMultiplication(&res, &xGen). ScalarMultiplication(&res, &xGen). @@ -472,8 +479,8 @@ func (p *G1Jac) mulGLV(a *G1Jac, s *big.Int) *G1Jac { // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // we don't target constant-timeness so we check first if we increase the bounds or not maxBit := k1.BitLen() @@ -583,15 +590,15 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { return p } - var A, B, X1ZZ2, X2ZZ1, Y1ZZZ2, Y2ZZZ1 fp.Element + var A, B, U1, U2, S1, S2 fp.Element // p2: q, p1: p - X2ZZ1.Mul(&q.X, &p.ZZ) - X1ZZ2.Mul(&p.X, &q.ZZ) - A.Sub(&X2ZZ1, &X1ZZ2) - Y2ZZZ1.Mul(&q.Y, &p.ZZZ) - Y1ZZZ2.Mul(&p.Y, &q.ZZZ) - B.Sub(&Y2ZZZ1, &Y1ZZZ2) + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) if A.IsZero() { if B.IsZero() { @@ -603,11 +610,7 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { return p } - var U1, U2, S1, S2, P, R, PP, PPP, Q, V fp.Element - U1.Mul(&p.X, &q.ZZ) - U2.Mul(&q.X, &p.ZZ) - S1.Mul(&p.Y, &q.ZZZ) - S2.Mul(&q.Y, &p.ZZZ) + var P, R, PP, PPP, Q, V fp.Element P.Sub(&U2, &U1) R.Sub(&S2, &S1) PP.Square(&P) @@ -632,6 +635,8 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { // double point in Jacobian extended coordinates // http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well func (p *g1JacExtended) double(q *g1JacExtended) *g1JacExtended { var U, V, W, S, XX, M fp.Element @@ -874,95 +879,72 @@ func BatchJacobianToAffineG1(points []G1Jac) []G1Affine { // and return resulting points in affine coordinates // uses a simple windowed-NAF like exponentiation algorithm func BatchScalarMultiplicationG1(base *G1Affine, scalars []fr.Element) []G1Affine { - // approximate cost in group ops is // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) nbPoints := uint64(len(scalars)) min := ^uint64(0) bestC := 0 - for c := 2; c < 18; c++ { - cost := uint64(1 << (c - 1)) - nbChunks := uint64(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - cost += nbPoints * ((fr.Limbs * 64) + nbChunks) + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c - 1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add if cost < min { min = cost bestC = c } } c := uint64(bestC) // window size - nbChunks := int(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c } - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // precompute all powers of base for our window // note here that if performance is critical, we can implement as in the msmX methods // this allocation to be on the stack - baseTable := make([]G1Jac, (1 << (c - 1))) - baseTable[0].Set(&g1Infinity) - baseTable[0].AddMixed(base) + baseTable := make([]G1Jac, (1 << (maxC - 1))) + baseTable[0].FromAffine(base) for i := 1; i < len(baseTable); i++ { baseTable[i] = baseTable[i-1] baseTable[i].AddMixed(base) } - - pScalars, _ := partitionScalars(scalars, c, false, runtime.NumCPU()) - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := 0; chunk < nbChunks; chunk++ { - jc := uint64(uint64(chunk) * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = (64%c) != 0 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } // convert our base exp table into affine to use AddMixed baseTableAff := BatchJacobianToAffineG1(baseTable) toReturn := make([]G1Jac, len(scalars)) + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + // for each digit, take value in the base table, double it c time, voilà. - parallel.Execute(len(pScalars), func(start, end int) { + parallel.Execute(len(scalars), func(start, end int) { var p G1Jac for i := start; i < end; i++ { p.Set(&g1Infinity) for chunk := nbChunks - 1; chunk >= 0; chunk-- { - s := selectors[chunk] if chunk != nbChunks-1 { for j := uint64(0); j < c; j++ { p.DoubleAssign() } } + offset := chunk * len(scalars) + digit := digits[i+offset] - bits := (pScalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (pScalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { + if digit == 0 { continue } // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { + if digit&1 == 0 { // add - p.AddMixed(&baseTableAff[bits-1]) + p.AddMixed(&baseTableAff[(digit>>1)-1]) } else { // sub - t := baseTableAff[bits & ^msbWindow] + t := baseTableAff[digit>>1] t.Neg(&t) p.AddMixed(&t) } @@ -976,3 +958,54 @@ func BatchScalarMultiplicationG1(base *G1Affine, scalars []fr.Element) []G1Affin toReturnAff := BatchJacobianToAffineG1(toReturn) return toReturnAff } + +// batch add affine coordinates +// using batch inversion +// special cases (doubling, infinity) must be filtered out before this call +func batchAddG1Affine[TP pG1Affine, TPP ppG1Affine, TC cG1Affine](R *TPP, P *TP, batchSize int) { + var lambda, lambdain TC + + // add part + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // invert denominator using montgomery batch invert technique + { + var accumulator fp.Element + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var d fp.Element + var rr G1Affine + + // add part + for j := 0; j < batchSize; j++ { + // computa lambda + d.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &d) + + // compute X, Y + rr.X.Square(&lambda[j]) + rr.X.Sub(&rr.X, &(*R)[j].X) + rr.X.Sub(&rr.X, &(*P)[j].X) + d.Sub(&(*R)[j].X, &rr.X) + rr.Y.Mul(&lambda[j], &d) + rr.Y.Sub(&rr.Y, &(*R)[j].Y) + (*R)[j].Set(&rr) + } +} diff --git a/ecc/bls12-381/g1_test.go b/ecc/bls12-381/g1_test.go index 9aa3311f0..7e25d4d4e 100644 --- a/ecc/bls12-381/g1_test.go +++ b/ecc/bls12-381/g1_test.go @@ -19,6 +19,7 @@ package bls12381 import ( "fmt" "math/big" + "math/rand" "testing" "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" @@ -338,7 +339,7 @@ func TestG1AffineOps(t *testing.T) { r := fr.Modulus() var g G1Jac - g.mulGLV(&g1Gen, r) + g.ScalarMultiplication(&g1Gen, r) var scalar, blindedScalar, rminusone big.Int var op1, op2, op3, gneg G1Jac @@ -458,8 +459,7 @@ func TestG1AffineBatchScalarMultiplication(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } result := BatchScalarMultiplicationG1(&g1GenAff, sampleScalars[:]) @@ -472,7 +472,7 @@ func TestG1AffineBatchScalarMultiplication(t *testing.T) { var expectedJac G1Jac var expected G1Affine var b big.Int - expectedJac.mulGLV(&g1Gen, sampleScalars[i].ToBigInt(&b)) + expectedJac.ScalarMultiplication(&g1Gen, sampleScalars[i].ToBigIntRegular(&b)) expected.FromJacobian(&expectedJac) if !result[i].Equal(&expected) { return false @@ -499,6 +499,33 @@ func BenchmarkG1JacIsInSubGroup(b *testing.B) { } +func BenchmarkBatchAddG1Affine(b *testing.B) { + + var P, R pG1AffineC16 + var RR ppG1AffineC16 + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBasesG1(P[:]) + fillBenchBasesG1(R[:]) + + for i := 0; i < len(ridx); i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAddG1Affine[pG1AffineC16, ppG1AffineC16, cG1AffineC16](&RR, &P, len(P)) + } +} + func BenchmarkG1AffineBatchScalarMultiplication(b *testing.B) { // ensure every words of the scalars are filled var mixer fr.Element @@ -511,8 +538,7 @@ func BenchmarkG1AffineBatchScalarMultiplication(b *testing.B) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } for i := 5; i <= pow; i++ { diff --git a/ecc/bls12-381/g2.go b/ecc/bls12-381/g2.go index d2e3a4b91..d30c7b34b 100644 --- a/ecc/bls12-381/g2.go +++ b/ecc/bls12-381/g2.go @@ -17,13 +17,12 @@ package bls12381 import ( - "math/big" - "runtime" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" "github.com/consensys/gnark-crypto/ecc/bls12-381/internal/fptower" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" + "runtime" ) // G2Affine point in affine coordinates @@ -36,7 +35,7 @@ type G2Jac struct { X, Y, Z fptower.E2 } -// g2JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +// g2JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) type g2JacExtended struct { X, Y, ZZ, ZZZ fptower.E2 } @@ -55,6 +54,13 @@ func (p *G2Affine) Set(a *G2Affine) *G2Affine { return p } +// setInfinity sets p to O +func (p *G2Affine) setInfinity() *G2Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + // ScalarMultiplication computes and returns p = a ⋅ s func (p *G2Affine) ScalarMultiplication(a *G2Affine, s *big.Int) *G2Affine { var _p G2Jac @@ -371,7 +377,8 @@ func (p *G2Jac) IsOnCurve() bool { // IsInSubGroup returns true if p is on the r-torsion, false otherwise. // https://eprint.iacr.org/2021/1130.pdf, sec.4 -// ψ(p) = x₀ P +// and https://eprint.iacr.org/2022/352.pdf, sec. 4.2 +// ψ(p) = [x₀]P func (p *G2Jac) IsInSubGroup() bool { var res, tmp G2Jac tmp.psi(p) @@ -472,8 +479,8 @@ func (p *G2Jac) mulGLV(a *G2Jac, s *big.Int) *G2Jac { // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // we don't target constant-timeness so we check first if we increase the bounds or not maxBit := k1.BitLen() @@ -601,15 +608,15 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { return p } - var A, B, X1ZZ2, X2ZZ1, Y1ZZZ2, Y2ZZZ1 fptower.E2 + var A, B, U1, U2, S1, S2 fptower.E2 // p2: q, p1: p - X2ZZ1.Mul(&q.X, &p.ZZ) - X1ZZ2.Mul(&p.X, &q.ZZ) - A.Sub(&X2ZZ1, &X1ZZ2) - Y2ZZZ1.Mul(&q.Y, &p.ZZZ) - Y1ZZZ2.Mul(&p.Y, &q.ZZZ) - B.Sub(&Y2ZZZ1, &Y1ZZZ2) + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) if A.IsZero() { if B.IsZero() { @@ -621,11 +628,7 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { return p } - var U1, U2, S1, S2, P, R, PP, PPP, Q, V fptower.E2 - U1.Mul(&p.X, &q.ZZ) - U2.Mul(&q.X, &p.ZZ) - S1.Mul(&p.Y, &q.ZZZ) - S2.Mul(&q.Y, &p.ZZZ) + var P, R, PP, PPP, Q, V fptower.E2 P.Sub(&U2, &U1) R.Sub(&S2, &S1) PP.Square(&P) @@ -650,6 +653,8 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { // double point in Jacobian extended coordinates // http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well func (p *g2JacExtended) double(q *g2JacExtended) *g2JacExtended { var U, V, W, S, XX, M fptower.E2 @@ -873,93 +878,70 @@ func (p *g2Proj) FromAffine(Q *G2Affine) *g2Proj { // and return resulting points in affine coordinates // uses a simple windowed-NAF like exponentiation algorithm func BatchScalarMultiplicationG2(base *G2Affine, scalars []fr.Element) []G2Affine { - // approximate cost in group ops is // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) nbPoints := uint64(len(scalars)) min := ^uint64(0) bestC := 0 - for c := 2; c < 18; c++ { - cost := uint64(1 << (c - 1)) - nbChunks := uint64(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - cost += nbPoints * ((fr.Limbs * 64) + nbChunks) + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c - 1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add if cost < min { min = cost bestC = c } } c := uint64(bestC) // window size - nbChunks := int(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c } - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // precompute all powers of base for our window // note here that if performance is critical, we can implement as in the msmX methods // this allocation to be on the stack - baseTable := make([]G2Jac, (1 << (c - 1))) - baseTable[0].Set(&g2Infinity) - baseTable[0].AddMixed(base) + baseTable := make([]G2Jac, (1 << (maxC - 1))) + baseTable[0].FromAffine(base) for i := 1; i < len(baseTable); i++ { baseTable[i] = baseTable[i-1] baseTable[i].AddMixed(base) } - - pScalars, _ := partitionScalars(scalars, c, false, runtime.NumCPU()) - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := 0; chunk < nbChunks; chunk++ { - jc := uint64(uint64(chunk) * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = (64%c) != 0 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } toReturn := make([]G2Affine, len(scalars)) + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + // for each digit, take value in the base table, double it c time, voilà. - parallel.Execute(len(pScalars), func(start, end int) { + parallel.Execute(len(scalars), func(start, end int) { var p G2Jac for i := start; i < end; i++ { p.Set(&g2Infinity) for chunk := nbChunks - 1; chunk >= 0; chunk-- { - s := selectors[chunk] if chunk != nbChunks-1 { for j := uint64(0); j < c; j++ { p.DoubleAssign() } } + offset := chunk * len(scalars) + digit := digits[i+offset] - bits := (pScalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (pScalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { + if digit == 0 { continue } // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { + if digit&1 == 0 { // add - p.AddAssign(&baseTable[bits-1]) + p.AddAssign(&baseTable[(digit>>1)-1]) } else { // sub - t := baseTable[bits & ^msbWindow] + t := baseTable[digit>>1] t.Neg(&t) p.AddAssign(&t) } @@ -972,3 +954,54 @@ func BatchScalarMultiplicationG2(base *G2Affine, scalars []fr.Element) []G2Affin }) return toReturn } + +// batch add affine coordinates +// using batch inversion +// special cases (doubling, infinity) must be filtered out before this call +func batchAddG2Affine[TP pG2Affine, TPP ppG2Affine, TC cG2Affine](R *TPP, P *TP, batchSize int) { + var lambda, lambdain TC + + // add part + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // invert denominator using montgomery batch invert technique + { + var accumulator fptower.E2 + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var d fptower.E2 + var rr G2Affine + + // add part + for j := 0; j < batchSize; j++ { + // computa lambda + d.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &d) + + // compute X, Y + rr.X.Square(&lambda[j]) + rr.X.Sub(&rr.X, &(*R)[j].X) + rr.X.Sub(&rr.X, &(*P)[j].X) + d.Sub(&(*R)[j].X, &rr.X) + rr.Y.Mul(&lambda[j], &d) + rr.Y.Sub(&rr.Y, &(*R)[j].Y) + (*R)[j].Set(&rr) + } +} diff --git a/ecc/bls12-381/g2_test.go b/ecc/bls12-381/g2_test.go index c25960662..026c06e4c 100644 --- a/ecc/bls12-381/g2_test.go +++ b/ecc/bls12-381/g2_test.go @@ -19,6 +19,7 @@ package bls12381 import ( "fmt" "math/big" + "math/rand" "testing" "github.com/consensys/gnark-crypto/ecc/bls12-381/internal/fptower" @@ -339,7 +340,7 @@ func TestG2AffineOps(t *testing.T) { r := fr.Modulus() var g G2Jac - g.mulGLV(&g2Gen, r) + g.ScalarMultiplication(&g2Gen, r) var scalar, blindedScalar, rminusone big.Int var op1, op2, op3, gneg G2Jac @@ -464,8 +465,7 @@ func TestG2AffineBatchScalarMultiplication(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } result := BatchScalarMultiplicationG2(&g2GenAff, sampleScalars[:]) @@ -478,7 +478,7 @@ func TestG2AffineBatchScalarMultiplication(t *testing.T) { var expectedJac G2Jac var expected G2Affine var b big.Int - expectedJac.mulGLV(&g2Gen, sampleScalars[i].ToBigInt(&b)) + expectedJac.ScalarMultiplication(&g2Gen, sampleScalars[i].ToBigIntRegular(&b)) expected.FromJacobian(&expectedJac) if !result[i].Equal(&expected) { return false @@ -505,6 +505,33 @@ func BenchmarkG2JacIsInSubGroup(b *testing.B) { } +func BenchmarkBatchAddG2Affine(b *testing.B) { + + var P, R pG2AffineC16 + var RR ppG2AffineC16 + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBasesG2(P[:]) + fillBenchBasesG2(R[:]) + + for i := 0; i < len(ridx); i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAddG2Affine[pG2AffineC16, ppG2AffineC16, cG2AffineC16](&RR, &P, len(P)) + } +} + func BenchmarkG2AffineBatchScalarMultiplication(b *testing.B) { // ensure every words of the scalars are filled var mixer fr.Element @@ -517,8 +544,7 @@ func BenchmarkG2AffineBatchScalarMultiplication(b *testing.B) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } for i := 5; i <= pow; i++ { diff --git a/ecc/bls12-381/hash_to_g1.go b/ecc/bls12-381/hash_to_g1.go index 23c108b94..3f57d725a 100644 --- a/ecc/bls12-381/hash_to_g1.go +++ b/ecc/bls12-381/hash_to_g1.go @@ -17,7 +17,6 @@ package bls12381 import ( - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" "math/big" @@ -265,35 +264,14 @@ func g1EvalPolynomial(z *fp.Element, monic bool, coefficients []fp.Element, x *f z.Set(&dst) } -// hashToFp hashes msg to count prime field elements. -// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 -func hashToFp(msg, dst []byte, count int) ([]fp.Element, error) { - // 128 bits of security - // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 - const Bytes = 1 + (fp.Bits-1)/8 - const L = 16 + Bytes - - lenInBytes := count * L - pseudoRandomBytes, err := ecc.ExpandMsgXmd(msg, dst, lenInBytes) - if err != nil { - return nil, err - } - - res := make([]fp.Element, count) - for i := 0; i < count; i++ { - res[i].SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) - } - return res, nil -} - // g1Sgn0 is an algebraic substitute for the notion of sign in ordered fields // Namely, every non-zero quadratic residue in a finite field of characteristic =/= 2 has exactly two square roots, one of each sign // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#name-the-sgn0-function // The sign of an element is not obviously related to that of its Montgomery form func g1Sgn0(z *fp.Element) uint64 { - nonMont := *z - nonMont.FromMont() + nonMont := z.Bits() + // m == 1 return nonMont[0] % 2 @@ -311,11 +289,11 @@ func MapToG1(u fp.Element) G1Affine { // EncodeToG1 hashes a message to a point on the G1 curve using the SSWU map. // It is faster than HashToG1, but the result is not uniformly distributed. Unsuitable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func EncodeToG1(msg, dst []byte) (G1Affine, error) { var res G1Affine - u, err := hashToFp(msg, dst, 1) + u, err := fp.Hash(msg, dst, 1) if err != nil { return res, err } @@ -331,9 +309,9 @@ func EncodeToG1(msg, dst []byte) (G1Affine, error) { // HashToG1 hashes a message to a point on the G1 curve using the SSWU map. // Slower than EncodeToG1, but usable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func HashToG1(msg, dst []byte) (G1Affine, error) { - u, err := hashToFp(msg, dst, 2*1) + u, err := fp.Hash(msg, dst, 2*1) if err != nil { return G1Affine{}, err } diff --git a/ecc/bls12-381/hash_to_g1_test.go b/ecc/bls12-381/hash_to_g1_test.go index 8185fb0ef..bcadce17a 100644 --- a/ecc/bls12-381/hash_to_g1_test.go +++ b/ecc/bls12-381/hash_to_g1_test.go @@ -62,7 +62,7 @@ func TestG1SqrtRatio(t *testing.T) { func TestHashToFpG1(t *testing.T) { for _, c := range encodeToG1Vector.cases { - elems, err := hashToFp([]byte(c.msg), encodeToG1Vector.dst, 1) + elems, err := fp.Hash([]byte(c.msg), encodeToG1Vector.dst, 1) if err != nil { t.Error(err) } @@ -70,7 +70,7 @@ func TestHashToFpG1(t *testing.T) { } for _, c := range hashToG1Vector.cases { - elems, err := hashToFp([]byte(c.msg), hashToG1Vector.dst, 2*1) + elems, err := fp.Hash([]byte(c.msg), hashToG1Vector.dst, 2*1) if err != nil { t.Error(err) } @@ -220,7 +220,7 @@ func BenchmarkHashToG1(b *testing.B) { } } -//TODO: Crude. Do something clever in Jacobian +// TODO: Crude. Do something clever in Jacobian func isOnE1Prime(p G1Affine) bool { var A, B fp.Element @@ -247,7 +247,7 @@ func isOnE1Prime(p G1Affine) bool { return LHS.Equal(&RHS) } -//Only works on simple extensions (two-story towers) +// Only works on simple extensions (two-story towers) func g1CoordSetString(z *fp.Element, s string) { z.SetString(s) } diff --git a/ecc/bls12-381/hash_to_g2.go b/ecc/bls12-381/hash_to_g2.go index 19f8fbfa8..ace9c207a 100644 --- a/ecc/bls12-381/hash_to_g2.go +++ b/ecc/bls12-381/hash_to_g2.go @@ -315,8 +315,7 @@ func g2EvalPolynomial(z *fptower.E2, monic bool, coefficients []fptower.E2, x *f // The sign of an element is not obviously related to that of its Montgomery form func g2Sgn0(z *fptower.E2) uint64 { - nonMont := *z - nonMont.FromMont() + nonMont := z.Bits() sign := uint64(0) // 1. sign = 0 zero := uint64(1) // 2. zero = 1 @@ -350,11 +349,11 @@ func MapToG2(u fptower.E2) G2Affine { // EncodeToG2 hashes a message to a point on the G2 curve using the SSWU map. // It is faster than HashToG2, but the result is not uniformly distributed. Unsuitable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func EncodeToG2(msg, dst []byte) (G2Affine, error) { var res G2Affine - u, err := hashToFp(msg, dst, 2) + u, err := fp.Hash(msg, dst, 2) if err != nil { return res, err } @@ -373,9 +372,9 @@ func EncodeToG2(msg, dst []byte) (G2Affine, error) { // HashToG2 hashes a message to a point on the G2 curve using the SSWU map. // Slower than EncodeToG2, but usable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func HashToG2(msg, dst []byte) (G2Affine, error) { - u, err := hashToFp(msg, dst, 2*2) + u, err := fp.Hash(msg, dst, 2*2) if err != nil { return G2Affine{}, err } diff --git a/ecc/bls12-381/hash_to_g2_test.go b/ecc/bls12-381/hash_to_g2_test.go index cc428eb89..b48b87ff2 100644 --- a/ecc/bls12-381/hash_to_g2_test.go +++ b/ecc/bls12-381/hash_to_g2_test.go @@ -64,7 +64,7 @@ func TestG2SqrtRatio(t *testing.T) { func TestHashToFpG2(t *testing.T) { for _, c := range encodeToG2Vector.cases { - elems, err := hashToFp([]byte(c.msg), encodeToG2Vector.dst, 2) + elems, err := fp.Hash([]byte(c.msg), encodeToG2Vector.dst, 2) if err != nil { t.Error(err) } @@ -72,7 +72,7 @@ func TestHashToFpG2(t *testing.T) { } for _, c := range hashToG2Vector.cases { - elems, err := hashToFp([]byte(c.msg), hashToG2Vector.dst, 2*2) + elems, err := fp.Hash([]byte(c.msg), hashToG2Vector.dst, 2*2) if err != nil { t.Error(err) } @@ -222,7 +222,7 @@ func BenchmarkHashToG2(b *testing.B) { } } -//TODO: Crude. Do something clever in Jacobian +// TODO: Crude. Do something clever in Jacobian func isOnE2Prime(p G2Affine) bool { var A, B fptower.E2 @@ -251,7 +251,7 @@ func isOnE2Prime(p G2Affine) bool { return LHS.Equal(&RHS) } -//Only works on simple extensions (two-story towers) +// Only works on simple extensions (two-story towers) func g2CoordSetString(z *fptower.E2, s string) { ssplit := strings.Split(s, ",") if len(ssplit) != 2 { diff --git a/ecc/bls12-381/hash_vectors_test.go b/ecc/bls12-381/hash_vectors_test.go index 0678e7e59..49e9b2344 100644 --- a/ecc/bls12-381/hash_vectors_test.go +++ b/ecc/bls12-381/hash_vectors_test.go @@ -9,6 +9,7 @@ import ( func TestG1IsogenyVectors(t *testing.T) { t.Parallel() + // TODO @gbotrel fix me test vectors shouldn't set words directly p := G1Affine{ fp.Element{ 3660217524291093078, 10096673235325531916, 228883846699980880, 13273309082988818590, 5645112663858216297, 1475745906155504807, @@ -17,8 +18,8 @@ func TestG1IsogenyVectors(t *testing.T) { 7179819451626801451, 8122998708501415251, 10493900036512999567, 8666325578439571587, 1547096619901497872, 644447436619416978, }, } - p.X.ToMont() - p.Y.ToMont() + toMont(&p.X) + toMont(&p.Y) ref := G1Affine{ fp.Element{ @@ -29,8 +30,8 @@ func TestG1IsogenyVectors(t *testing.T) { }, } - ref.X.ToMont() - ref.Y.ToMont() + toMont(&ref.X) + toMont(&ref.Y) g1Isogeny(&p) @@ -220,3 +221,18 @@ func init() { }, } } + +var rSquare = fp.Element{ + 17644856173732828998, + 754043588434789617, + 10224657059481499349, + 7488229067341005760, + 11130996698012816685, + 1267921511277847466, +} + +// toMont converts z to Montgomery form +// sets and returns z = z * r² +func toMont(z *fp.Element) { + z.Mul(z, &rSquare) +} diff --git a/ecc/bls12-381/internal/fptower/e12.go b/ecc/bls12-381/internal/fptower/e12.go index 50bb505e5..35a43eb6d 100644 --- a/ecc/bls12-381/internal/fptower/e12.go +++ b/ecc/bls12-381/internal/fptower/e12.go @@ -17,7 +17,6 @@ package fptower import ( - "encoding/binary" "errors" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" @@ -68,20 +67,6 @@ func (z *E12) SetOne() *E12 { return z } -// ToMont converts to Mont form -func (z *E12) ToMont() *E12 { - z.C0.ToMont() - z.C1.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E12) FromMont() *E12 { - z.C0.FromMont() - z.C1.FromMont() - return z -} - // Add set z=x+y in E12 and return z func (z *E12) Add(x, y *E12) *E12 { z.C0.Add(&x.C0, &y.C0) @@ -119,6 +104,10 @@ func (z *E12) IsZero() bool { return z.C0.IsZero() && z.C1.IsZero() } +func (z *E12) IsOne() bool { + return z.C0.IsOne() && z.C1.IsZero() +} + // Mul set z=x*y in E12 and return z func (z *E12) Mul(x, y *E12) *E12 { var a, b, c E6 @@ -226,9 +215,12 @@ func (z *E12) CyclotomicSquareCompressed(x *E12) *E12 { // DecompressKarabina Karabina's cyclotomic square result // if g3 != 0 -// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// +// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// // if g3 == 0 -// g4 = 2g1g5/g2 +// +// g4 = 2g1g5/g2 // // if g3=g2=0 then g4=g5=g1=0 and g0=1 (x=1) // Theorem 3.1 is well-defined for all x in Gϕₙ\{1} @@ -289,9 +281,12 @@ func (z *E12) DecompressKarabina(x *E12) *E12 { // BatchDecompressKarabina multiple Karabina's cyclotomic square results // if g3 != 0 -// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// +// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// // if g3 == 0 -// g4 = 2g1g5/g2 +// +// g4 = 2g1g5/g2 // // if g3=g2=0 then g4=g5=g1=0 and g0=1 (x=1) // Theorem 3.1 is well-defined for all x in Gϕₙ\{1} @@ -602,8 +597,8 @@ func (z *E12) ExpGLV(x E12, k *big.Int) *E12 { table[14].Mul(&table[11], &table[2]) // bounds on the lattice base vectors guarantee that s1, s2 are len(r)/2 bits long max - s1.SetBigInt(&s[0]).FromMont() - s2.SetBigInt(&s[1]).FromMont() + s1 = s1.SetBigInt(&s[0]).Bits() + s2 = s2.SetBigInt(&s[1]).Bits() // loop starts from len(s1)/2 due to the bounds for i := len(s1) / 2; i >= 0; i-- { @@ -652,93 +647,20 @@ func (z *E12) Unmarshal(buf []byte) error { // Bytes returns the regular (non montgomery) value // of z as a big-endian byte array. -// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... +// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... func (z *E12) Bytes() (r [SizeOfGT]byte) { - _z := *z - _z.FromMont() - binary.BigEndian.PutUint64(r[568:576], _z.C0.B0.A0[0]) - binary.BigEndian.PutUint64(r[560:568], _z.C0.B0.A0[1]) - binary.BigEndian.PutUint64(r[552:560], _z.C0.B0.A0[2]) - binary.BigEndian.PutUint64(r[544:552], _z.C0.B0.A0[3]) - binary.BigEndian.PutUint64(r[536:544], _z.C0.B0.A0[4]) - binary.BigEndian.PutUint64(r[528:536], _z.C0.B0.A0[5]) - - binary.BigEndian.PutUint64(r[520:528], _z.C0.B0.A1[0]) - binary.BigEndian.PutUint64(r[512:520], _z.C0.B0.A1[1]) - binary.BigEndian.PutUint64(r[504:512], _z.C0.B0.A1[2]) - binary.BigEndian.PutUint64(r[496:504], _z.C0.B0.A1[3]) - binary.BigEndian.PutUint64(r[488:496], _z.C0.B0.A1[4]) - binary.BigEndian.PutUint64(r[480:488], _z.C0.B0.A1[5]) - - binary.BigEndian.PutUint64(r[472:480], _z.C0.B1.A0[0]) - binary.BigEndian.PutUint64(r[464:472], _z.C0.B1.A0[1]) - binary.BigEndian.PutUint64(r[456:464], _z.C0.B1.A0[2]) - binary.BigEndian.PutUint64(r[448:456], _z.C0.B1.A0[3]) - binary.BigEndian.PutUint64(r[440:448], _z.C0.B1.A0[4]) - binary.BigEndian.PutUint64(r[432:440], _z.C0.B1.A0[5]) - - binary.BigEndian.PutUint64(r[424:432], _z.C0.B1.A1[0]) - binary.BigEndian.PutUint64(r[416:424], _z.C0.B1.A1[1]) - binary.BigEndian.PutUint64(r[408:416], _z.C0.B1.A1[2]) - binary.BigEndian.PutUint64(r[400:408], _z.C0.B1.A1[3]) - binary.BigEndian.PutUint64(r[392:400], _z.C0.B1.A1[4]) - binary.BigEndian.PutUint64(r[384:392], _z.C0.B1.A1[5]) - - binary.BigEndian.PutUint64(r[376:384], _z.C0.B2.A0[0]) - binary.BigEndian.PutUint64(r[368:376], _z.C0.B2.A0[1]) - binary.BigEndian.PutUint64(r[360:368], _z.C0.B2.A0[2]) - binary.BigEndian.PutUint64(r[352:360], _z.C0.B2.A0[3]) - binary.BigEndian.PutUint64(r[344:352], _z.C0.B2.A0[4]) - binary.BigEndian.PutUint64(r[336:344], _z.C0.B2.A0[5]) - - binary.BigEndian.PutUint64(r[328:336], _z.C0.B2.A1[0]) - binary.BigEndian.PutUint64(r[320:328], _z.C0.B2.A1[1]) - binary.BigEndian.PutUint64(r[312:320], _z.C0.B2.A1[2]) - binary.BigEndian.PutUint64(r[304:312], _z.C0.B2.A1[3]) - binary.BigEndian.PutUint64(r[296:304], _z.C0.B2.A1[4]) - binary.BigEndian.PutUint64(r[288:296], _z.C0.B2.A1[5]) - - binary.BigEndian.PutUint64(r[280:288], _z.C1.B0.A0[0]) - binary.BigEndian.PutUint64(r[272:280], _z.C1.B0.A0[1]) - binary.BigEndian.PutUint64(r[264:272], _z.C1.B0.A0[2]) - binary.BigEndian.PutUint64(r[256:264], _z.C1.B0.A0[3]) - binary.BigEndian.PutUint64(r[248:256], _z.C1.B0.A0[4]) - binary.BigEndian.PutUint64(r[240:248], _z.C1.B0.A0[5]) - - binary.BigEndian.PutUint64(r[232:240], _z.C1.B0.A1[0]) - binary.BigEndian.PutUint64(r[224:232], _z.C1.B0.A1[1]) - binary.BigEndian.PutUint64(r[216:224], _z.C1.B0.A1[2]) - binary.BigEndian.PutUint64(r[208:216], _z.C1.B0.A1[3]) - binary.BigEndian.PutUint64(r[200:208], _z.C1.B0.A1[4]) - binary.BigEndian.PutUint64(r[192:200], _z.C1.B0.A1[5]) - - binary.BigEndian.PutUint64(r[184:192], _z.C1.B1.A0[0]) - binary.BigEndian.PutUint64(r[176:184], _z.C1.B1.A0[1]) - binary.BigEndian.PutUint64(r[168:176], _z.C1.B1.A0[2]) - binary.BigEndian.PutUint64(r[160:168], _z.C1.B1.A0[3]) - binary.BigEndian.PutUint64(r[152:160], _z.C1.B1.A0[4]) - binary.BigEndian.PutUint64(r[144:152], _z.C1.B1.A0[5]) - - binary.BigEndian.PutUint64(r[136:144], _z.C1.B1.A1[0]) - binary.BigEndian.PutUint64(r[128:136], _z.C1.B1.A1[1]) - binary.BigEndian.PutUint64(r[120:128], _z.C1.B1.A1[2]) - binary.BigEndian.PutUint64(r[112:120], _z.C1.B1.A1[3]) - binary.BigEndian.PutUint64(r[104:112], _z.C1.B1.A1[4]) - binary.BigEndian.PutUint64(r[96:104], _z.C1.B1.A1[5]) - - binary.BigEndian.PutUint64(r[88:96], _z.C1.B2.A0[0]) - binary.BigEndian.PutUint64(r[80:88], _z.C1.B2.A0[1]) - binary.BigEndian.PutUint64(r[72:80], _z.C1.B2.A0[2]) - binary.BigEndian.PutUint64(r[64:72], _z.C1.B2.A0[3]) - binary.BigEndian.PutUint64(r[56:64], _z.C1.B2.A0[4]) - binary.BigEndian.PutUint64(r[48:56], _z.C1.B2.A0[5]) - - binary.BigEndian.PutUint64(r[40:48], _z.C1.B2.A1[0]) - binary.BigEndian.PutUint64(r[32:40], _z.C1.B2.A1[1]) - binary.BigEndian.PutUint64(r[24:32], _z.C1.B2.A1[2]) - binary.BigEndian.PutUint64(r[16:24], _z.C1.B2.A1[3]) - binary.BigEndian.PutUint64(r[8:16], _z.C1.B2.A1[4]) - binary.BigEndian.PutUint64(r[0:8], _z.C1.B2.A1[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[528:528+fp.Bytes]), z.C0.B0.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[480:480+fp.Bytes]), z.C0.B0.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[432:432+fp.Bytes]), z.C0.B1.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[384:384+fp.Bytes]), z.C0.B1.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[336:336+fp.Bytes]), z.C0.B2.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[288:288+fp.Bytes]), z.C0.B2.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[240:240+fp.Bytes]), z.C1.B0.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[192:192+fp.Bytes]), z.C1.B0.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[144:144+fp.Bytes]), z.C1.B1.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[96:96+fp.Bytes]), z.C1.B1.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[48:48+fp.Bytes]), z.C1.B2.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[0:0+fp.Bytes]), z.C1.B2.A1) return } @@ -746,34 +668,47 @@ func (z *E12) Bytes() (r [SizeOfGT]byte) { // SetBytes interprets e as the bytes of a big-endian GT // sets z to that value (in Montgomery form), and returns z. // size(e) == 48 * 12 -// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... +// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... func (z *E12) SetBytes(e []byte) error { if len(e) != SizeOfGT { return errors.New("invalid buffer size") } - z.C0.B0.A0.SetBytes(e[528 : 528+fp.Bytes]) - - z.C0.B0.A1.SetBytes(e[480 : 480+fp.Bytes]) - - z.C0.B1.A0.SetBytes(e[432 : 432+fp.Bytes]) - - z.C0.B1.A1.SetBytes(e[384 : 384+fp.Bytes]) - - z.C0.B2.A0.SetBytes(e[336 : 336+fp.Bytes]) - - z.C0.B2.A1.SetBytes(e[288 : 288+fp.Bytes]) - - z.C1.B0.A0.SetBytes(e[240 : 240+fp.Bytes]) - - z.C1.B0.A1.SetBytes(e[192 : 192+fp.Bytes]) - - z.C1.B1.A0.SetBytes(e[144 : 144+fp.Bytes]) - - z.C1.B1.A1.SetBytes(e[96 : 96+fp.Bytes]) - - z.C1.B2.A0.SetBytes(e[48 : 48+fp.Bytes]) - - z.C1.B2.A1.SetBytes(e[0 : 0+fp.Bytes]) + if err := z.C0.B0.A0.SetBytesCanonical(e[528 : 528+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B0.A1.SetBytesCanonical(e[480 : 480+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B1.A0.SetBytesCanonical(e[432 : 432+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B1.A1.SetBytesCanonical(e[384 : 384+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B2.A0.SetBytesCanonical(e[336 : 336+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B2.A1.SetBytesCanonical(e[288 : 288+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B0.A0.SetBytesCanonical(e[240 : 240+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B0.A1.SetBytesCanonical(e[192 : 192+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B1.A0.SetBytesCanonical(e[144 : 144+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B1.A1.SetBytesCanonical(e[96 : 96+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B2.A0.SetBytesCanonical(e[48 : 48+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B2.A1.SetBytesCanonical(e[0 : 0+fp.Bytes]); err != nil { + return err + } return nil } diff --git a/ecc/bls12-381/internal/fptower/e2.go b/ecc/bls12-381/internal/fptower/e2.go index f15a70d85..45338f189 100644 --- a/ecc/bls12-381/internal/fptower/e2.go +++ b/ecc/bls12-381/internal/fptower/e2.go @@ -31,12 +31,20 @@ func (z *E2) Equal(x *E2) bool { return z.A0.Equal(&x.A0) && z.A1.Equal(&x.A1) } +// Bits +// TODO @gbotrel fixme this shouldn't return a E2 +func (z *E2) Bits() E2 { + r := E2{} + r.A0 = z.A0.Bits() + r.A1 = z.A1.Bits() + return r +} + // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *E2) Cmp(x *E2) int { if a1 := z.A1.Cmp(&x.A1); a1 != 0 { return a1 @@ -98,6 +106,10 @@ func (z *E2) IsZero() bool { return z.A0.IsZero() && z.A1.IsZero() } +func (z *E2) IsOne() bool { + return z.A0.IsOne() && z.A1.IsZero() +} + // Add adds two elements of E2 func (z *E2) Add(x, y *E2) *E2 { addE2(z, x, y) @@ -127,20 +139,6 @@ func (z *E2) String() string { return z.A0.String() + "+" + z.A1.String() + "*u" } -// ToMont converts to mont form -func (z *E2) ToMont() *E2 { - z.A0.ToMont() - z.A1.ToMont() - return z -} - -// FromMont converts from mont form -func (z *E2) FromMont() *E2 { - z.A0.FromMont() - z.A1.FromMont() - return z -} - // MulByElement multiplies an element in E2 by an element in fp func (z *E2) MulByElement(x *E2, y *fp.Element) *E2 { var yCopy fp.Element diff --git a/ecc/bls12-381/internal/fptower/e2_adx_amd64.s b/ecc/bls12-381/internal/fptower/e2_adx_amd64.s deleted file mode 100644 index 3396a8901..000000000 --- a/ecc/bls12-381/internal/fptower/e2_adx_amd64.s +++ /dev/null @@ -1,3170 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xb9feffffffffaaab -DATA q<>+8(SB)/8, $0x1eabfffeb153ffff -DATA q<>+16(SB)/8, $0x6730d2a0f6b0f624 -DATA q<>+24(SB)/8, $0x64774b84f38512bf -DATA q<>+32(SB)/8, $0x4b1ba7b6434bacd7 -DATA q<>+40(SB)/8, $0x1a0111ea397fe69a -GLOBL q<>(SB), (RODATA+NOPTR), $48 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x89f3fffcfffcfffd -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - -TEXT ·addE2(SB), NOSPLIT, $0-24 - MOVQ x+8(FP), AX - MOVQ 0(AX), BX - MOVQ 8(AX), SI - MOVQ 16(AX), DI - MOVQ 24(AX), R8 - MOVQ 32(AX), R9 - MOVQ 40(AX), R10 - MOVQ y+16(FP), DX - ADDQ 0(DX), BX - ADCQ 8(DX), SI - ADCQ 16(DX), DI - ADCQ 24(DX), R8 - ADCQ 32(DX), R9 - ADCQ 40(DX), R10 - - // reduce element(BX,SI,DI,R8,R9,R10) using temp registers (R11,R12,R13,R14,R15,s0-8(SP)) - REDUCE(BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP)) - - MOVQ res+0(FP), CX - MOVQ BX, 0(CX) - MOVQ SI, 8(CX) - MOVQ DI, 16(CX) - MOVQ R8, 24(CX) - MOVQ R9, 32(CX) - MOVQ R10, 40(CX) - MOVQ 48(AX), BX - MOVQ 56(AX), SI - MOVQ 64(AX), DI - MOVQ 72(AX), R8 - MOVQ 80(AX), R9 - MOVQ 88(AX), R10 - ADDQ 48(DX), BX - ADCQ 56(DX), SI - ADCQ 64(DX), DI - ADCQ 72(DX), R8 - ADCQ 80(DX), R9 - ADCQ 88(DX), R10 - - // reduce element(BX,SI,DI,R8,R9,R10) using temp registers (R11,R12,R13,R14,R15,s0-8(SP)) - REDUCE(BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15,s0-8(SP)) - - MOVQ BX, 48(CX) - MOVQ SI, 56(CX) - MOVQ DI, 64(CX) - MOVQ R8, 72(CX) - MOVQ R9, 80(CX) - MOVQ R10, 88(CX) - RET - -TEXT ·doubleE2(SB), NOSPLIT, $0-16 - MOVQ res+0(FP), DX - MOVQ x+8(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ 32(AX), R8 - MOVQ 40(AX), R9 - ADDQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - - // reduce element(CX,BX,SI,DI,R8,R9) using temp registers (R10,R11,R12,R13,R14,R15) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15) - - MOVQ CX, 0(DX) - MOVQ BX, 8(DX) - MOVQ SI, 16(DX) - MOVQ DI, 24(DX) - MOVQ R8, 32(DX) - MOVQ R9, 40(DX) - MOVQ 48(AX), CX - MOVQ 56(AX), BX - MOVQ 64(AX), SI - MOVQ 72(AX), DI - MOVQ 80(AX), R8 - MOVQ 88(AX), R9 - ADDQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - - // reduce element(CX,BX,SI,DI,R8,R9) using temp registers (R10,R11,R12,R13,R14,R15) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,R14,R15) - - MOVQ CX, 48(DX) - MOVQ BX, 56(DX) - MOVQ SI, 64(DX) - MOVQ DI, 72(DX) - MOVQ R8, 80(DX) - MOVQ R9, 88(DX) - RET - -TEXT ·subE2(SB), NOSPLIT, $0-24 - XORQ R9, R9 - MOVQ x+8(FP), R8 - MOVQ 0(R8), AX - MOVQ 8(R8), DX - MOVQ 16(R8), CX - MOVQ 24(R8), BX - MOVQ 32(R8), SI - MOVQ 40(R8), DI - MOVQ y+16(FP), R8 - SUBQ 0(R8), AX - SBBQ 8(R8), DX - SBBQ 16(R8), CX - SBBQ 24(R8), BX - SBBQ 32(R8), SI - SBBQ 40(R8), DI - MOVQ x+8(FP), R8 - MOVQ $0xb9feffffffffaaab, R10 - MOVQ $0x1eabfffeb153ffff, R11 - MOVQ $0x6730d2a0f6b0f624, R12 - MOVQ $0x64774b84f38512bf, R13 - MOVQ $0x4b1ba7b6434bacd7, R14 - MOVQ $0x1a0111ea397fe69a, R15 - CMOVQCC R9, R10 - CMOVQCC R9, R11 - CMOVQCC R9, R12 - CMOVQCC R9, R13 - CMOVQCC R9, R14 - CMOVQCC R9, R15 - ADDQ R10, AX - ADCQ R11, DX - ADCQ R12, CX - ADCQ R13, BX - ADCQ R14, SI - ADCQ R15, DI - MOVQ res+0(FP), R10 - MOVQ AX, 0(R10) - MOVQ DX, 8(R10) - MOVQ CX, 16(R10) - MOVQ BX, 24(R10) - MOVQ SI, 32(R10) - MOVQ DI, 40(R10) - MOVQ 48(R8), AX - MOVQ 56(R8), DX - MOVQ 64(R8), CX - MOVQ 72(R8), BX - MOVQ 80(R8), SI - MOVQ 88(R8), DI - MOVQ y+16(FP), R8 - SUBQ 48(R8), AX - SBBQ 56(R8), DX - SBBQ 64(R8), CX - SBBQ 72(R8), BX - SBBQ 80(R8), SI - SBBQ 88(R8), DI - MOVQ $0xb9feffffffffaaab, R11 - MOVQ $0x1eabfffeb153ffff, R12 - MOVQ $0x6730d2a0f6b0f624, R13 - MOVQ $0x64774b84f38512bf, R14 - MOVQ $0x4b1ba7b6434bacd7, R15 - MOVQ $0x1a0111ea397fe69a, R10 - CMOVQCC R9, R11 - CMOVQCC R9, R12 - CMOVQCC R9, R13 - CMOVQCC R9, R14 - CMOVQCC R9, R15 - CMOVQCC R9, R10 - ADDQ R11, AX - ADCQ R12, DX - ADCQ R13, CX - ADCQ R14, BX - ADCQ R15, SI - ADCQ R10, DI - MOVQ res+0(FP), R8 - MOVQ AX, 48(R8) - MOVQ DX, 56(R8) - MOVQ CX, 64(R8) - MOVQ BX, 72(R8) - MOVQ SI, 80(R8) - MOVQ DI, 88(R8) - RET - -TEXT ·negE2(SB), NOSPLIT, $0-16 - MOVQ res+0(FP), DX - MOVQ x+8(FP), AX - MOVQ 0(AX), BX - MOVQ 8(AX), SI - MOVQ 16(AX), DI - MOVQ 24(AX), R8 - MOVQ 32(AX), R9 - MOVQ 40(AX), R10 - MOVQ BX, AX - ORQ SI, AX - ORQ DI, AX - ORQ R8, AX - ORQ R9, AX - ORQ R10, AX - TESTQ AX, AX - JNE l1 - MOVQ AX, 0(DX) - MOVQ AX, 8(DX) - MOVQ AX, 16(DX) - MOVQ AX, 24(DX) - MOVQ AX, 32(DX) - MOVQ AX, 40(DX) - JMP l3 - -l1: - MOVQ $0xb9feffffffffaaab, CX - SUBQ BX, CX - MOVQ CX, 0(DX) - MOVQ $0x1eabfffeb153ffff, CX - SBBQ SI, CX - MOVQ CX, 8(DX) - MOVQ $0x6730d2a0f6b0f624, CX - SBBQ DI, CX - MOVQ CX, 16(DX) - MOVQ $0x64774b84f38512bf, CX - SBBQ R8, CX - MOVQ CX, 24(DX) - MOVQ $0x4b1ba7b6434bacd7, CX - SBBQ R9, CX - MOVQ CX, 32(DX) - MOVQ $0x1a0111ea397fe69a, CX - SBBQ R10, CX - MOVQ CX, 40(DX) - -l3: - MOVQ x+8(FP), AX - MOVQ 48(AX), BX - MOVQ 56(AX), SI - MOVQ 64(AX), DI - MOVQ 72(AX), R8 - MOVQ 80(AX), R9 - MOVQ 88(AX), R10 - MOVQ BX, AX - ORQ SI, AX - ORQ DI, AX - ORQ R8, AX - ORQ R9, AX - ORQ R10, AX - TESTQ AX, AX - JNE l2 - MOVQ AX, 48(DX) - MOVQ AX, 56(DX) - MOVQ AX, 64(DX) - MOVQ AX, 72(DX) - MOVQ AX, 80(DX) - MOVQ AX, 88(DX) - RET - -l2: - MOVQ $0xb9feffffffffaaab, CX - SUBQ BX, CX - MOVQ CX, 48(DX) - MOVQ $0x1eabfffeb153ffff, CX - SBBQ SI, CX - MOVQ CX, 56(DX) - MOVQ $0x6730d2a0f6b0f624, CX - SBBQ DI, CX - MOVQ CX, 64(DX) - MOVQ $0x64774b84f38512bf, CX - SBBQ R8, CX - MOVQ CX, 72(DX) - MOVQ $0x4b1ba7b6434bacd7, CX - SBBQ R9, CX - MOVQ CX, 80(DX) - MOVQ $0x1a0111ea397fe69a, CX - SBBQ R10, CX - MOVQ CX, 88(DX) - RET - -TEXT ·mulNonResE2(SB), NOSPLIT, $0-16 - XORQ R15, R15 - MOVQ x+8(FP), R14 - MOVQ 0(R14), AX - MOVQ 8(R14), DX - MOVQ 16(R14), CX - MOVQ 24(R14), BX - MOVQ 32(R14), SI - MOVQ 40(R14), DI - SUBQ 48(R14), AX - SBBQ 56(R14), DX - SBBQ 64(R14), CX - SBBQ 72(R14), BX - SBBQ 80(R14), SI - SBBQ 88(R14), DI - MOVQ $0xb9feffffffffaaab, R8 - MOVQ $0x1eabfffeb153ffff, R9 - MOVQ $0x6730d2a0f6b0f624, R10 - MOVQ $0x64774b84f38512bf, R11 - MOVQ $0x4b1ba7b6434bacd7, R12 - MOVQ $0x1a0111ea397fe69a, R13 - CMOVQCC R15, R8 - CMOVQCC R15, R9 - CMOVQCC R15, R10 - CMOVQCC R15, R11 - CMOVQCC R15, R12 - CMOVQCC R15, R13 - ADDQ R8, AX - ADCQ R9, DX - ADCQ R10, CX - ADCQ R11, BX - ADCQ R12, SI - ADCQ R13, DI - MOVQ 48(R14), R8 - MOVQ 56(R14), R9 - MOVQ 64(R14), R10 - MOVQ 72(R14), R11 - MOVQ 80(R14), R12 - MOVQ 88(R14), R13 - ADDQ 0(R14), R8 - ADCQ 8(R14), R9 - ADCQ 16(R14), R10 - ADCQ 24(R14), R11 - ADCQ 32(R14), R12 - ADCQ 40(R14), R13 - MOVQ res+0(FP), R15 - MOVQ AX, 0(R15) - MOVQ DX, 8(R15) - MOVQ CX, 16(R15) - MOVQ BX, 24(R15) - MOVQ SI, 32(R15) - MOVQ DI, 40(R15) - - // reduce element(R8,R9,R10,R11,R12,R13) using temp registers (AX,DX,CX,BX,SI,DI) - REDUCE(R8,R9,R10,R11,R12,R13,AX,DX,CX,BX,SI,DI) - - MOVQ R8, 48(R15) - MOVQ R9, 56(R15) - MOVQ R10, 64(R15) - MOVQ R11, 72(R15) - MOVQ R12, 80(R15) - MOVQ R13, 88(R15) - RET - -TEXT ·squareAdxE2(SB), $48-16 - NO_LOCAL_POINTERS - - // z.A0 = (x.A0 + x.A1) * (x.A0 - x.A1) - // z.A1 = 2 * x.A0 * x.A1 - - // 2 * x.A0 * x.A1 - MOVQ x+8(FP), AX - - // 2 * x.A1[0] -> R14 - // 2 * x.A1[1] -> R15 - // 2 * x.A1[2] -> CX - // 2 * x.A1[3] -> BX - // 2 * x.A1[4] -> SI - // 2 * x.A1[5] -> DI - MOVQ 48(AX), R14 - MOVQ 56(AX), R15 - MOVQ 64(AX), CX - MOVQ 72(AX), BX - MOVQ 80(AX), SI - MOVQ 88(AX), DI - ADDQ R14, R14 - ADCQ R15, R15 - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // A -> BP - // t[0] -> R8 - // t[1] -> R9 - // t[2] -> R10 - // t[3] -> R11 - // t[4] -> R12 - // t[5] -> R13 - // clear the flags - XORQ AX, AX - MOVQ x+8(FP), DX - MOVQ 0(DX), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R14, R8, R9 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R15, AX, R10 - ADOXQ AX, R9 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ CX, AX, R11 - ADOXQ AX, R10 - - // (A,t[3]) := x[3]*y[0] + A - MULXQ BX, AX, R12 - ADOXQ AX, R11 - - // (A,t[4]) := x[4]*y[0] + A - MULXQ SI, AX, R13 - ADOXQ AX, R12 - - // (A,t[5]) := x[5]*y[0] + A - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ x+8(FP), DX - MOVQ 8(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ x+8(FP), DX - MOVQ 16(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ x+8(FP), DX - MOVQ 24(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ x+8(FP), DX - MOVQ 32(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ x+8(FP), DX - MOVQ 40(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // reduce element(R8,R9,R10,R11,R12,R13) using temp registers (R14,R15,CX,BX,SI,DI) - REDUCE(R8,R9,R10,R11,R12,R13,R14,R15,CX,BX,SI,DI) - - MOVQ x+8(FP), AX - - // x.A1[0] -> R14 - // x.A1[1] -> R15 - // x.A1[2] -> CX - // x.A1[3] -> BX - // x.A1[4] -> SI - // x.A1[5] -> DI - MOVQ 48(AX), R14 - MOVQ 56(AX), R15 - MOVQ 64(AX), CX - MOVQ 72(AX), BX - MOVQ 80(AX), SI - MOVQ 88(AX), DI - MOVQ res+0(FP), DX - MOVQ R8, 48(DX) - MOVQ R9, 56(DX) - MOVQ R10, 64(DX) - MOVQ R11, 72(DX) - MOVQ R12, 80(DX) - MOVQ R13, 88(DX) - MOVQ R14, R8 - MOVQ R15, R9 - MOVQ CX, R10 - MOVQ BX, R11 - MOVQ SI, R12 - MOVQ DI, R13 - - // Add(&x.A0, &x.A1) - ADDQ 0(AX), R14 - ADCQ 8(AX), R15 - ADCQ 16(AX), CX - ADCQ 24(AX), BX - ADCQ 32(AX), SI - ADCQ 40(AX), DI - MOVQ R14, s0-8(SP) - MOVQ R15, s1-16(SP) - MOVQ CX, s2-24(SP) - MOVQ BX, s3-32(SP) - MOVQ SI, s4-40(SP) - MOVQ DI, s5-48(SP) - XORQ BP, BP - - // Sub(&x.A0, &x.A1) - MOVQ 0(AX), R14 - MOVQ 8(AX), R15 - MOVQ 16(AX), CX - MOVQ 24(AX), BX - MOVQ 32(AX), SI - MOVQ 40(AX), DI - SUBQ R8, R14 - SBBQ R9, R15 - SBBQ R10, CX - SBBQ R11, BX - SBBQ R12, SI - SBBQ R13, DI - MOVQ $0xb9feffffffffaaab, R8 - MOVQ $0x1eabfffeb153ffff, R9 - MOVQ $0x6730d2a0f6b0f624, R10 - MOVQ $0x64774b84f38512bf, R11 - MOVQ $0x4b1ba7b6434bacd7, R12 - MOVQ $0x1a0111ea397fe69a, R13 - CMOVQCC BP, R8 - CMOVQCC BP, R9 - CMOVQCC BP, R10 - CMOVQCC BP, R11 - CMOVQCC BP, R12 - CMOVQCC BP, R13 - ADDQ R8, R14 - ADCQ R9, R15 - ADCQ R10, CX - ADCQ R11, BX - ADCQ R12, SI - ADCQ R13, DI - - // A -> BP - // t[0] -> R8 - // t[1] -> R9 - // t[2] -> R10 - // t[3] -> R11 - // t[4] -> R12 - // t[5] -> R13 - // clear the flags - XORQ AX, AX - MOVQ s0-8(SP), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R14, R8, R9 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R15, AX, R10 - ADOXQ AX, R9 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ CX, AX, R11 - ADOXQ AX, R10 - - // (A,t[3]) := x[3]*y[0] + A - MULXQ BX, AX, R12 - ADOXQ AX, R11 - - // (A,t[4]) := x[4]*y[0] + A - MULXQ SI, AX, R13 - ADOXQ AX, R12 - - // (A,t[5]) := x[5]*y[0] + A - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ s1-16(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ s2-24(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ s3-32(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ s4-40(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ s5-48(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // reduce element(R8,R9,R10,R11,R12,R13) using temp registers (R14,R15,CX,BX,SI,DI) - REDUCE(R8,R9,R10,R11,R12,R13,R14,R15,CX,BX,SI,DI) - - MOVQ res+0(FP), AX - MOVQ R8, 0(AX) - MOVQ R9, 8(AX) - MOVQ R10, 16(AX) - MOVQ R11, 24(AX) - MOVQ R12, 32(AX) - MOVQ R13, 40(AX) - RET - -TEXT ·mulAdxE2(SB), $96-24 - NO_LOCAL_POINTERS - - // var a, b, c fp.Element - // a.Add(&x.A0, &x.A1) - // b.Add(&y.A0, &y.A1) - // a.Mul(&a, &b) - // b.Mul(&x.A0, &y.A0) - // c.Mul(&x.A1, &y.A1) - // z.A1.Sub(&a, &b).Sub(&z.A1, &c) - // z.A0.Sub(&b, &c) - - MOVQ x+8(FP), AX - MOVQ 48(AX), R14 - MOVQ 56(AX), R15 - MOVQ 64(AX), CX - MOVQ 72(AX), BX - MOVQ 80(AX), SI - MOVQ 88(AX), DI - - // A -> BP - // t[0] -> R8 - // t[1] -> R9 - // t[2] -> R10 - // t[3] -> R11 - // t[4] -> R12 - // t[5] -> R13 - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 48(DX), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R14, R8, R9 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R15, AX, R10 - ADOXQ AX, R9 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ CX, AX, R11 - ADOXQ AX, R10 - - // (A,t[3]) := x[3]*y[0] + A - MULXQ BX, AX, R12 - ADOXQ AX, R11 - - // (A,t[4]) := x[4]*y[0] + A - MULXQ SI, AX, R13 - ADOXQ AX, R12 - - // (A,t[5]) := x[5]*y[0] + A - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 56(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 64(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 72(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 80(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 88(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // reduce element(R8,R9,R10,R11,R12,R13) using temp registers (R14,R15,CX,BX,SI,DI) - REDUCE(R8,R9,R10,R11,R12,R13,R14,R15,CX,BX,SI,DI) - - MOVQ R8, s6-56(SP) - MOVQ R9, s7-64(SP) - MOVQ R10, s8-72(SP) - MOVQ R11, s9-80(SP) - MOVQ R12, s10-88(SP) - MOVQ R13, s11-96(SP) - MOVQ x+8(FP), AX - MOVQ y+16(FP), DX - MOVQ 48(AX), R14 - MOVQ 56(AX), R15 - MOVQ 64(AX), CX - MOVQ 72(AX), BX - MOVQ 80(AX), SI - MOVQ 88(AX), DI - ADDQ 0(AX), R14 - ADCQ 8(AX), R15 - ADCQ 16(AX), CX - ADCQ 24(AX), BX - ADCQ 32(AX), SI - ADCQ 40(AX), DI - MOVQ R14, s0-8(SP) - MOVQ R15, s1-16(SP) - MOVQ CX, s2-24(SP) - MOVQ BX, s3-32(SP) - MOVQ SI, s4-40(SP) - MOVQ DI, s5-48(SP) - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - MOVQ 40(DX), DI - ADDQ 48(DX), R14 - ADCQ 56(DX), R15 - ADCQ 64(DX), CX - ADCQ 72(DX), BX - ADCQ 80(DX), SI - ADCQ 88(DX), DI - - // A -> BP - // t[0] -> R8 - // t[1] -> R9 - // t[2] -> R10 - // t[3] -> R11 - // t[4] -> R12 - // t[5] -> R13 - // clear the flags - XORQ AX, AX - MOVQ s0-8(SP), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R14, R8, R9 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R15, AX, R10 - ADOXQ AX, R9 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ CX, AX, R11 - ADOXQ AX, R10 - - // (A,t[3]) := x[3]*y[0] + A - MULXQ BX, AX, R12 - ADOXQ AX, R11 - - // (A,t[4]) := x[4]*y[0] + A - MULXQ SI, AX, R13 - ADOXQ AX, R12 - - // (A,t[5]) := x[5]*y[0] + A - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ s1-16(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ s2-24(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ s3-32(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ s4-40(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ s5-48(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // reduce element(R8,R9,R10,R11,R12,R13) using temp registers (R14,R15,CX,BX,SI,DI) - REDUCE(R8,R9,R10,R11,R12,R13,R14,R15,CX,BX,SI,DI) - - MOVQ R8, s0-8(SP) - MOVQ R9, s1-16(SP) - MOVQ R10, s2-24(SP) - MOVQ R11, s3-32(SP) - MOVQ R12, s4-40(SP) - MOVQ R13, s5-48(SP) - MOVQ x+8(FP), AX - MOVQ 0(AX), R14 - MOVQ 8(AX), R15 - MOVQ 16(AX), CX - MOVQ 24(AX), BX - MOVQ 32(AX), SI - MOVQ 40(AX), DI - - // A -> BP - // t[0] -> R8 - // t[1] -> R9 - // t[2] -> R10 - // t[3] -> R11 - // t[4] -> R12 - // t[5] -> R13 - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 0(DX), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R14, R8, R9 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R15, AX, R10 - ADOXQ AX, R9 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ CX, AX, R11 - ADOXQ AX, R10 - - // (A,t[3]) := x[3]*y[0] + A - MULXQ BX, AX, R12 - ADOXQ AX, R11 - - // (A,t[4]) := x[4]*y[0] + A - MULXQ SI, AX, R13 - ADOXQ AX, R12 - - // (A,t[5]) := x[5]*y[0] + A - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 8(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 16(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 24(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 32(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 40(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R14, AX, BP - ADOXQ AX, R8 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R9 - MULXQ R15, AX, BP - ADOXQ AX, R9 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, R10 - MULXQ CX, AX, BP - ADOXQ AX, R10 - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, R11 - MULXQ BX, AX, BP - ADOXQ AX, R11 - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, R12 - MULXQ SI, AX, BP - ADOXQ AX, R12 - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, R13 - MULXQ DI, AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R8, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R8, AX - MOVQ BP, R8 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R9, R8 - MULXQ q<>+8(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R10, R9 - MULXQ q<>+16(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R11, R10 - MULXQ q<>+24(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R12, R11 - MULXQ q<>+32(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ R13, R12 - MULXQ q<>+40(SB), AX, R13 - ADOXQ AX, R12 - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // reduce element(R8,R9,R10,R11,R12,R13) using temp registers (R14,R15,CX,BX,SI,DI) - REDUCE(R8,R9,R10,R11,R12,R13,R14,R15,CX,BX,SI,DI) - - XORQ DX, DX - MOVQ s0-8(SP), R14 - MOVQ s1-16(SP), R15 - MOVQ s2-24(SP), CX - MOVQ s3-32(SP), BX - MOVQ s4-40(SP), SI - MOVQ s5-48(SP), DI - SUBQ R8, R14 - SBBQ R9, R15 - SBBQ R10, CX - SBBQ R11, BX - SBBQ R12, SI - SBBQ R13, DI - MOVQ R8, s0-8(SP) - MOVQ R9, s1-16(SP) - MOVQ R10, s2-24(SP) - MOVQ R11, s3-32(SP) - MOVQ R12, s4-40(SP) - MOVQ R13, s5-48(SP) - MOVQ $0xb9feffffffffaaab, R8 - MOVQ $0x1eabfffeb153ffff, R9 - MOVQ $0x6730d2a0f6b0f624, R10 - MOVQ $0x64774b84f38512bf, R11 - MOVQ $0x4b1ba7b6434bacd7, R12 - MOVQ $0x1a0111ea397fe69a, R13 - CMOVQCC DX, R8 - CMOVQCC DX, R9 - CMOVQCC DX, R10 - CMOVQCC DX, R11 - CMOVQCC DX, R12 - CMOVQCC DX, R13 - ADDQ R8, R14 - ADCQ R9, R15 - ADCQ R10, CX - ADCQ R11, BX - ADCQ R12, SI - ADCQ R13, DI - SUBQ s6-56(SP), R14 - SBBQ s7-64(SP), R15 - SBBQ s8-72(SP), CX - SBBQ s9-80(SP), BX - SBBQ s10-88(SP), SI - SBBQ s11-96(SP), DI - MOVQ $0xb9feffffffffaaab, R8 - MOVQ $0x1eabfffeb153ffff, R9 - MOVQ $0x6730d2a0f6b0f624, R10 - MOVQ $0x64774b84f38512bf, R11 - MOVQ $0x4b1ba7b6434bacd7, R12 - MOVQ $0x1a0111ea397fe69a, R13 - CMOVQCC DX, R8 - CMOVQCC DX, R9 - CMOVQCC DX, R10 - CMOVQCC DX, R11 - CMOVQCC DX, R12 - CMOVQCC DX, R13 - ADDQ R8, R14 - ADCQ R9, R15 - ADCQ R10, CX - ADCQ R11, BX - ADCQ R12, SI - ADCQ R13, DI - MOVQ z+0(FP), AX - MOVQ R14, 48(AX) - MOVQ R15, 56(AX) - MOVQ CX, 64(AX) - MOVQ BX, 72(AX) - MOVQ SI, 80(AX) - MOVQ DI, 88(AX) - MOVQ s0-8(SP), R8 - MOVQ s1-16(SP), R9 - MOVQ s2-24(SP), R10 - MOVQ s3-32(SP), R11 - MOVQ s4-40(SP), R12 - MOVQ s5-48(SP), R13 - SUBQ s6-56(SP), R8 - SBBQ s7-64(SP), R9 - SBBQ s8-72(SP), R10 - SBBQ s9-80(SP), R11 - SBBQ s10-88(SP), R12 - SBBQ s11-96(SP), R13 - MOVQ $0xb9feffffffffaaab, R14 - MOVQ $0x1eabfffeb153ffff, R15 - MOVQ $0x6730d2a0f6b0f624, CX - MOVQ $0x64774b84f38512bf, BX - MOVQ $0x4b1ba7b6434bacd7, SI - MOVQ $0x1a0111ea397fe69a, DI - CMOVQCC DX, R14 - CMOVQCC DX, R15 - CMOVQCC DX, CX - CMOVQCC DX, BX - CMOVQCC DX, SI - CMOVQCC DX, DI - ADDQ R14, R8 - ADCQ R15, R9 - ADCQ CX, R10 - ADCQ BX, R11 - ADCQ SI, R12 - ADCQ DI, R13 - MOVQ R8, 0(AX) - MOVQ R9, 8(AX) - MOVQ R10, 16(AX) - MOVQ R11, 24(AX) - MOVQ R12, 32(AX) - MOVQ R13, 40(AX) - RET diff --git a/ecc/bls12-381/internal/fptower/e2_amd64.s b/ecc/bls12-381/internal/fptower/e2_amd64.s index 5717dc695..7fc53f463 100644 --- a/ecc/bls12-381/internal/fptower/e2_amd64.s +++ b/ecc/bls12-381/internal/fptower/e2_amd64.s @@ -1,5 +1,3 @@ -// +build !amd64_adx - // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bls12-381/internal/fptower/e6.go b/ecc/bls12-381/internal/fptower/e6.go index 4da093f5f..8ae7216ec 100644 --- a/ecc/bls12-381/internal/fptower/e6.go +++ b/ecc/bls12-381/internal/fptower/e6.go @@ -63,25 +63,13 @@ func (z *E6) SetRandom() (*E6, error) { return z, nil } -// IsZero returns true if the two elements are equal, fasle otherwise +// IsZero returns true if the two elements are equal, false otherwise func (z *E6) IsZero() bool { return z.B0.IsZero() && z.B1.IsZero() && z.B2.IsZero() } -// ToMont converts to Mont form -func (z *E6) ToMont() *E6 { - z.B0.ToMont() - z.B1.ToMont() - z.B2.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E6) FromMont() *E6 { - z.B0.FromMont() - z.B1.FromMont() - z.B2.FromMont() - return z +func (z *E6) IsOne() bool { + return z.B0.IsOne() && z.B1.IsZero() && z.B2.IsZero() } // Add adds two elements of E6 diff --git a/ecc/bls12-381/marshal.go b/ecc/bls12-381/marshal.go index bcaf588bb..2b0e584f4 100644 --- a/ecc/bls12-381/marshal.go +++ b/ecc/bls12-381/marshal.go @@ -100,7 +100,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fr.Bytes]) + err = t.SetBytesCanonical(buf[:fr.Bytes]) return case *fp.Element: read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) @@ -108,7 +108,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fp.Bytes]) + err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: var sliceLen uint32 @@ -126,7 +126,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fr.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { + return + } } return case *[]fp.Element: @@ -145,7 +147,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fp.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return + } } return case *G1Affine: @@ -221,7 +225,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -276,7 +284,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -643,9 +655,6 @@ func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - msbMask := mCompressedSmallest // compressed, we need to know if Y is lexicographically bigger than -Y // if p.Y ">" -p.Y @@ -654,14 +663,7 @@ func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { } // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[40:48], tmp[0]) - binary.BigEndian.PutUint64(res[32:40], tmp[1]) - binary.BigEndian.PutUint64(res[24:32], tmp[2]) - binary.BigEndian.PutUint64(res[16:24], tmp[3]) - binary.BigEndian.PutUint64(res[8:16], tmp[4]) - binary.BigEndian.PutUint64(res[0:8], tmp[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= msbMask @@ -680,29 +682,12 @@ func (p *G1Affine) RawBytes() (res [SizeOfG1AffineUncompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - // not compressed // we store the Y coordinate - tmp = p.Y - tmp.FromMont() - binary.BigEndian.PutUint64(res[88:96], tmp[0]) - binary.BigEndian.PutUint64(res[80:88], tmp[1]) - binary.BigEndian.PutUint64(res[72:80], tmp[2]) - binary.BigEndian.PutUint64(res[64:72], tmp[3]) - binary.BigEndian.PutUint64(res[56:64], tmp[4]) - binary.BigEndian.PutUint64(res[48:56], tmp[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[48:48+fp.Bytes]), p.Y) // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[40:48], tmp[0]) - binary.BigEndian.PutUint64(res[32:40], tmp[1]) - binary.BigEndian.PutUint64(res[24:32], tmp[2]) - binary.BigEndian.PutUint64(res[16:24], tmp[3]) - binary.BigEndian.PutUint64(res[8:16], tmp[4]) - binary.BigEndian.PutUint64(res[0:8], tmp[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= mUncompressed @@ -753,8 +738,12 @@ func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { // uncompressed point if mData == mUncompressed { // read X and Y coordinates - p.X.SetBytes(buf[:fp.Bytes]) - p.Y.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.Y.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } // subgroup check if subGroupCheck && !p.IsInSubGroup() { @@ -774,7 +763,9 @@ func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return 0, err + } var YSquared, Y fp.Element @@ -848,7 +839,7 @@ func (p *G1Affine) unsafeComputeY(subGroupCheck bool) error { // assumes buf[:8] mask is set to compressed // returns true if point is infinity and need no further processing // it sets X coordinate and uses Y for scratch space to store decompression metadata -func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { +func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { // read the most significant byte mData := buf[0] & mMask @@ -857,7 +848,7 @@ func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { p.X.SetZero() p.Y.SetZero() isInfinity = true - return + return isInfinity, nil } // we need to copy the input buffer (to keep this method thread safe) @@ -866,12 +857,14 @@ func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return false, err + } // store mData in p.Y[0] p.Y[0] = uint64(mData) // recomputing Y will be done asynchronously - return + return isInfinity, nil } // SizeOfG2AffineCompressed represents the size in bytes that a G2Affine need in binary form, compressed @@ -909,9 +902,6 @@ func (p *G2Affine) Bytes() (res [SizeOfG2AffineCompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - msbMask := mCompressedSmallest // compressed, we need to know if Y is lexicographically bigger than -Y // if p.Y ">" -p.Y @@ -921,23 +911,8 @@ func (p *G2Affine) Bytes() (res [SizeOfG2AffineCompressed]byte) { // we store X and mask the most significant word with our metadata mask // p.X.A1 | p.X.A0 - tmp = p.X.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[88:96], tmp[0]) - binary.BigEndian.PutUint64(res[80:88], tmp[1]) - binary.BigEndian.PutUint64(res[72:80], tmp[2]) - binary.BigEndian.PutUint64(res[64:72], tmp[3]) - binary.BigEndian.PutUint64(res[56:64], tmp[4]) - binary.BigEndian.PutUint64(res[48:56], tmp[5]) - - tmp = p.X.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[40:48], tmp[0]) - binary.BigEndian.PutUint64(res[32:40], tmp[1]) - binary.BigEndian.PutUint64(res[24:32], tmp[2]) - binary.BigEndian.PutUint64(res[16:24], tmp[3]) - binary.BigEndian.PutUint64(res[8:16], tmp[4]) - binary.BigEndian.PutUint64(res[0:8], tmp[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[48:48+fp.Bytes]), p.X.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X.A1) res[0] |= msbMask @@ -956,49 +931,16 @@ func (p *G2Affine) RawBytes() (res [SizeOfG2AffineUncompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - // not compressed // we store the Y coordinate // p.Y.A1 | p.Y.A0 - tmp = p.Y.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[184:192], tmp[0]) - binary.BigEndian.PutUint64(res[176:184], tmp[1]) - binary.BigEndian.PutUint64(res[168:176], tmp[2]) - binary.BigEndian.PutUint64(res[160:168], tmp[3]) - binary.BigEndian.PutUint64(res[152:160], tmp[4]) - binary.BigEndian.PutUint64(res[144:152], tmp[5]) - - tmp = p.Y.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[136:144], tmp[0]) - binary.BigEndian.PutUint64(res[128:136], tmp[1]) - binary.BigEndian.PutUint64(res[120:128], tmp[2]) - binary.BigEndian.PutUint64(res[112:120], tmp[3]) - binary.BigEndian.PutUint64(res[104:112], tmp[4]) - binary.BigEndian.PutUint64(res[96:104], tmp[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[144:144+fp.Bytes]), p.Y.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[96:96+fp.Bytes]), p.Y.A1) // we store X and mask the most significant word with our metadata mask // p.X.A1 | p.X.A0 - tmp = p.X.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[40:48], tmp[0]) - binary.BigEndian.PutUint64(res[32:40], tmp[1]) - binary.BigEndian.PutUint64(res[24:32], tmp[2]) - binary.BigEndian.PutUint64(res[16:24], tmp[3]) - binary.BigEndian.PutUint64(res[8:16], tmp[4]) - binary.BigEndian.PutUint64(res[0:8], tmp[5]) - - tmp = p.X.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[88:96], tmp[0]) - binary.BigEndian.PutUint64(res[80:88], tmp[1]) - binary.BigEndian.PutUint64(res[72:80], tmp[2]) - binary.BigEndian.PutUint64(res[64:72], tmp[3]) - binary.BigEndian.PutUint64(res[56:64], tmp[4]) - binary.BigEndian.PutUint64(res[48:56], tmp[5]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[48:48+fp.Bytes]), p.X.A0) res[0] |= mUncompressed @@ -1050,11 +992,19 @@ func (p *G2Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { if mData == mUncompressed { // read X and Y coordinates // p.X.A1 | p.X.A0 - p.X.A1.SetBytes(buf[:fp.Bytes]) - p.X.A0.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.A1.SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.X.A0.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } // p.Y.A1 | p.Y.A0 - p.Y.A1.SetBytes(buf[fp.Bytes*2 : fp.Bytes*3]) - p.Y.A0.SetBytes(buf[fp.Bytes*3 : fp.Bytes*4]) + if err := p.Y.A1.SetBytesCanonical(buf[fp.Bytes*2 : fp.Bytes*3]); err != nil { + return 0, err + } + if err := p.Y.A0.SetBytesCanonical(buf[fp.Bytes*3 : fp.Bytes*4]); err != nil { + return 0, err + } // subgroup check if subGroupCheck && !p.IsInSubGroup() { @@ -1075,8 +1025,12 @@ func (p *G2Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { // read X coordinate // p.X.A1 | p.X.A0 - p.X.A1.SetBytes(bufX[:fp.Bytes]) - p.X.A0.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.A1.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.X.A0.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } var YSquared, Y fptower.E2 @@ -1152,7 +1106,7 @@ func (p *G2Affine) unsafeComputeY(subGroupCheck bool) error { // assumes buf[:8] mask is set to compressed // returns true if point is infinity and need no further processing // it sets X coordinate and uses Y for scratch space to store decompression metadata -func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { +func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { // read the most significant byte mData := buf[0] & mMask @@ -1161,7 +1115,7 @@ func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { p.X.SetZero() p.Y.SetZero() isInfinity = true - return + return isInfinity, nil } // we need to copy the input buffer (to keep this method thread safe) @@ -1171,12 +1125,16 @@ func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { // read X coordinate // p.X.A1 | p.X.A0 - p.X.A1.SetBytes(bufX[:fp.Bytes]) - p.X.A0.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.A1.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return false, err + } + if err := p.X.A0.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return false, err + } // store mData in p.Y.A0[0] p.Y.A0[0] = uint64(mData) // recomputing Y will be done asynchronously - return + return isInfinity, nil } diff --git a/ecc/bls12-381/multiexp.go b/ecc/bls12-381/multiexp.go index 25a18a945..68e8b38ea 100644 --- a/ecc/bls12-381/multiexp.go +++ b/ecc/bls12-381/multiexp.go @@ -25,143 +25,6 @@ import ( "runtime" ) -// selector stores the index, mask and shifts needed to select bits from a scalar -// it is used during the multiExp algorithm or the batch scalar multiplication -type selector struct { - index uint64 // index in the multi-word scalar to select bits from - mask uint64 // mask (c-bit wide) - shift uint64 // shift needed to get our bits on low positions - - multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) - maskHigh uint64 // same than mask, for index+1 - shiftHigh uint64 // same than shift, for index+1 -} - -// partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits -// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract -// 2^{c} to the current digit, making it negative. -// negative digits can be processed in a later step as adding -G into the bucket instead of G -// (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) -// scalarsMont indicates wheter the provided scalars are in montgomery form -// returns smallValues, which represent the number of scalars which meets the following condition -// 0 < scalar < 2^c (in other words, scalars where only the c-least significant bits are non zero) -func partitionScalars(scalars []fr.Element, c uint64, scalarsMont bool, nbTasks int) ([]fr.Element, int) { - toReturn := make([]fr.Element, len(scalars)) - - // number of c-bit radixes in a scalar - nbChunks := fr.Limbs * 64 / c - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // msb of the c-bit window - max := int(1 << (c - 1)) // max value we want for our digits - cDivides64 := (64 % c) == 0 // if c doesn't divide 64, we may need to select over multiple words - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := uint64(0); chunk < nbChunks; chunk++ { - jc := uint64(chunk * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } - - // for each chunk, we could track the number of non-zeros points we will need to process - // this way, if a chunk has more work to do than others, we can spawn off more go routines - // (at the cost of more buckets allocated) - // a simplified approach is to track the small values where only the first word is set - // if this number represent a significant number of points, then we will split first chunk - // processing in the msm in 2, to ensure all go routines finish at ~same time - // /!\ nbTasks is enough as parallel.Execute is not going to spawn more than nbTasks go routine - // if it does, though, this will deadlocK. - chSmallValues := make(chan int, nbTasks) - - parallel.Execute(len(scalars), func(start, end int) { - smallValues := 0 - for i := start; i < end; i++ { - var carry int - - scalar := scalars[i] - if scalarsMont { - scalar.FromMont() - } - if scalar.FitsOnOneWord() { - // everything is 0, no need to process this scalar - if scalar[0] == 0 { - continue - } - // low c-bits are 1 in mask - if scalar[0]&mask == scalar[0] { - smallValues++ - } - } - - // for each chunk in the scalar, compute the current digit, and an eventual carry - for chunk := uint64(0); chunk < nbChunks; chunk++ { - s := selectors[chunk] - - // init with carry if any - digit := carry - carry = 0 - - // digit = value of the c-bit window - digit += int((scalar[s.index] & s.mask) >> s.shift) - - if s.multiWordSelect { - // we are selecting bits over 2 words - digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh - } - - // if digit is zero, no impact on result - if digit == 0 { - continue - } - - // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract - // 2^{c} to the current digit, making it negative. - if digit >= max { - digit -= (1 << c) - carry = 1 - } - - var bits uint64 - if digit >= 0 { - bits = uint64(digit) - } else { - bits = uint64(-digit-1) | msbWindow - } - - toReturn[i][s.index] |= (bits << s.shift) - if s.multiWordSelect { - toReturn[i][s.index+1] |= (bits >> s.shiftHigh) - } - - } - } - - chSmallValues <- smallValues - - }, nbTasks) - - // aggregate small values - close(chSmallValues) - smallValues := 0 - for o := range chSmallValues { - smallValues += o - } - return toReturn, smallValues -} - // MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf // // This call return an error if len(scalars) != len(points) or if provided config is invalid. @@ -221,7 +84,7 @@ func (p *G1Jac) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.Mul // we split recursively until nbChunks(c) >= nbTasks, bestC := func(nbPoints int) uint64 { // implemented msmC methods (the c we use must be in this slice) - implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21} + implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} var C uint64 // approximate cost (in group operations) // cost = bits/c * (nbPoints + 2^{c}) @@ -229,118 +92,177 @@ func (p *G1Jac) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.Mul // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results min := math.MaxFloat64 for _, c := range implementedCs { - cc := fr.Limbs * 64 * (nbPoints + (1 << (c))) + cc := (fr.Bits + 1) * (nbPoints + (1 << c)) cost := float64(cc) / float64(c) if cost < min { min = cost C = c } } - // empirical, needs to be tuned. - // if C > 16 && nbPoints < 1 << 23 { - // C = 16 - // } return C } - var C uint64 - nbSplits := 1 - nbChunks := 0 - for nbChunks < config.NbTasks { - C = bestC(nbPoints) - nbChunks = int(fr.Limbs * 64 / C) // number of c-bit radixes in a scalar - if (fr.Limbs*64)%C != 0 { - nbChunks++ - } - nbChunks *= nbSplits - if nbChunks < config.NbTasks { - nbSplits <<= 1 - nbPoints >>= 1 + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // if we don't utilise all the tasks (CPU in the default case) that we could, let's see if it's worth it to split + if config.NbTasks > 1 && nbChunks < config.NbTasks { + // before spliting, let's see if we endup with more tasks than thread; + cSplit := bestC(nbPoints / 2) + nbChunksPostSplit := int(computeNbChunks(cSplit)) + nbTasksPostSplit := nbChunksPostSplit * 2 + if (nbTasksPostSplit <= config.NbTasks/2) || (nbTasksPostSplit-config.NbTasks/2) <= (config.NbTasks-nbChunks) { + // if postSplit we still have less tasks than available CPU + // or if we have more tasks BUT the difference of CPU usage is in our favor, we split. + config.NbTasks /= 2 + var _p G1Jac + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil } } + _innerMsmG1(p, C, points, scalars, config) + + return p, nil +} + +func _innerMsmG1(p *G1Jac, c uint64, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { // partition the scalars - // note: we do that before the actual chunk processing, as for each c-bit window (starting from LSW) - // if it's larger than 2^{c-1}, we have a carry we need to propagate up to the higher window - var smallValues int - scalars, smallValues = partitionScalars(scalars, C, config.ScalarsMont, config.NbTasks) - - // if we have more than 10% of small values, we split the processing of the first chunk in 2 - // we may want to do that in msmInnerG1Jac , but that would incur a cost of looping through all scalars one more time - splitFirstChunk := (float64(smallValues) / float64(len(scalars))) >= 0.1 - - // we have nbSplits intermediate results that we must sum together. - _p := make([]G1Jac, nbSplits-1) - chDone := make(chan int, nbSplits-1) - for i := 0; i < nbSplits-1; i++ { - start := i * nbPoints - end := start + nbPoints - go func(start, end, i int) { - msmInnerG1Jac(&_p[i], int(C), points[start:end], scalars[start:end], splitFirstChunk) - chDone <- i - }(start, end, i) + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) } - msmInnerG1Jac(p, int(C), points[(nbSplits-1)*nbPoints:], scalars[(nbSplits-1)*nbPoints:], splitFirstChunk) - for i := 0; i < nbSplits-1; i++ { - done := <-chDone - p.AddAssign(&_p[done]) + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessorG1(c, chunkStats[j]) + if j == int(nbChunks-1) { + processChunk = getChunkProcessorG1(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan g1JacExtended, 2) + split := n / 2 + go processChunk(uint64(j), chSplit, c, points[:split], digits[j*n:(j*n)+split]) + go processChunk(uint64(j), chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n]) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n]) } - close(chDone) - return p, nil -} -func msmInnerG1Jac(p *G1Jac, c int, points []G1Affine, scalars []fr.Element, splitFirstChunk bool) { + return msmReduceChunkG1Affine(p, int(c), chChunks[:]) +} +// getChunkProcessorG1 decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessorG1(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- g1JacExtended, c uint64, points []G1Affine, digits []uint16) { switch c { + case 3: + return processChunkG1Jacobian[bucketg1JacExtendedC3] case 4: - p.msmC4(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC4] case 5: - p.msmC5(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC5] case 6: - p.msmC6(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC6] case 7: - p.msmC7(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC7] case 8: - p.msmC8(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC8] case 9: - p.msmC9(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC9] case 10: - p.msmC10(points, scalars, splitFirstChunk) - + const batchSize = 80 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC10] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC10, bucketG1AffineC10, bitSetC10, pG1AffineC10, ppG1AffineC10, qG1AffineC10, cG1AffineC10] case 11: - p.msmC11(points, scalars, splitFirstChunk) - + const batchSize = 150 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC11] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC11, bucketG1AffineC11, bitSetC11, pG1AffineC11, ppG1AffineC11, qG1AffineC11, cG1AffineC11] case 12: - p.msmC12(points, scalars, splitFirstChunk) - + const batchSize = 200 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC12] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC12, bucketG1AffineC12, bitSetC12, pG1AffineC12, ppG1AffineC12, qG1AffineC12, cG1AffineC12] case 13: - p.msmC13(points, scalars, splitFirstChunk) - + const batchSize = 350 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC13] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC13, bucketG1AffineC13, bitSetC13, pG1AffineC13, ppG1AffineC13, qG1AffineC13, cG1AffineC13] case 14: - p.msmC14(points, scalars, splitFirstChunk) - + const batchSize = 400 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC14] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC14, bucketG1AffineC14, bitSetC14, pG1AffineC14, ppG1AffineC14, qG1AffineC14, cG1AffineC14] case 15: - p.msmC15(points, scalars, splitFirstChunk) - + const batchSize = 500 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC15] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC15, bucketG1AffineC15, bitSetC15, pG1AffineC15, ppG1AffineC15, qG1AffineC15, cG1AffineC15] case 16: - p.msmC16(points, scalars, splitFirstChunk) - - case 20: - p.msmC20(points, scalars, splitFirstChunk) - - case 21: - p.msmC21(points, scalars, splitFirstChunk) - + const batchSize = 640 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC16] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC16, bucketG1AffineC16, bitSetC16, pG1AffineC16, ppG1AffineC16, qG1AffineC16, cG1AffineC16] default: - panic("not implemented") + // panic("will not happen c != previous values is not generated by templates") + return processChunkG1Jacobian[bucketg1JacExtendedC16] } } @@ -360,1846 +282,445 @@ func msmReduceChunkG1Affine(p *G1Jac, c int, chChunks []chan g1JacExtended) *G1J return p.unsafeFromJacExtended(&_p) } -func msmProcessChunkG1Affine(chunk uint64, - chRes chan<- g1JacExtended, - buckets []g1JacExtended, - c uint64, - points []G1Affine, - scalars []fr.Element) { +// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf +// +// This call return an error if len(scalars) != len(points) or if provided config is invalid. +func (p *G2Affine) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Affine, error) { + var _p G2Jac + if _, err := _p.MultiExp(points, scalars, config); err != nil { + return nil, err + } + p.FromJacobian(&_p) + return p, nil +} - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) +// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf +// +// This call return an error if len(scalars) != len(points) or if provided config is invalid. +func (p *G2Jac) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Jac, error) { + // note: + // each of the msmCX method is the same, except for the c constant it declares + // duplicating (through template generation) these methods allows to declare the buckets on the stack + // the choice of c needs to be improved: + // there is a theoritical value that gives optimal asymptotics + // but in practice, other factors come into play, including: + // * if c doesn't divide 64, the word size, then we're bound to select bits over 2 words of our scalars, instead of 1 + // * number of CPUs + // * cache friendliness (which depends on the host, G1 or G2... ) + // --> for example, on BN254, a G1 point fits into one cache line of 64bytes, but a G2 point don't. - for i := 0; i < len(buckets); i++ { - buckets[i].setInfinity() - } + // for each msmCX + // step 1 + // we compute, for each scalars over c-bit wide windows, nbChunk digits + // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract + // 2^{c} to the current digit, making it negative. + // negative digits will be processed in the next step as adding -G into the bucket instead of G + // (computing -G is cheap, and this saves us half of the buckets) + // step 2 + // buckets are declared on the stack + // notice that we have 2^{c-1} buckets instead of 2^{c} (see step1) + // we use jacobian extended formulas here as they are faster than mixed addition + // msmProcessChunk places points into buckets base on their selector and return the weighted bucket sum in given channel + // step 3 + // reduce the buckets weigthed sums into our result (msmReduceChunk) - jc := uint64(chunk * c) - s := selector{} - s.index = jc / 64 - s.shift = jc - (s.index * 64) - s.mask = mask << s.shift - s.multiWordSelect = (64%c) != 0 && s.shift > (64-c) && s.index < (fr.Limbs-1) - if s.multiWordSelect { - nbBitsHigh := s.shift - uint64(64-c) - s.maskHigh = (1 << nbBitsHigh) - 1 - s.shiftHigh = (c - nbBitsHigh) + // ensure len(points) == len(scalars) + nbPoints := len(points) + if nbPoints != len(scalars) { + return nil, errors.New("len(points) != len(scalars)") } - // for each scalars, get the digit corresponding to the chunk we're processing. - for i := 0; i < len(scalars); i++ { - bits := (scalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (scalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { - continue - } + // if nbTasks is not set, use all available CPUs + if config.NbTasks <= 0 { + config.NbTasks = runtime.NumCPU() + } else if config.NbTasks > 1024 { + return nil, errors.New("invalid config: config.NbTasks > 1024") + } - // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { - // add - buckets[bits-1].addMixed(&points[i]) - } else { - // sub - buckets[bits & ^msbWindow].subMixed(&points[i]) + // here, we compute the best C for nbPoints + // we split recursively until nbChunks(c) >= nbTasks, + bestC := func(nbPoints int) uint64 { + // implemented msmC methods (the c we use must be in this slice) + implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + var C uint64 + // approximate cost (in group operations) + // cost = bits/c * (nbPoints + 2^{c}) + // this needs to be verified empirically. + // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results + min := math.MaxFloat64 + for _, c := range implementedCs { + cc := (fr.Bits + 1) * (nbPoints + (1 << c)) + cost := float64(cc) / float64(c) + if cost < min { + min = cost + C = c + } } + return C } - // reduce buckets into total - // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] - - var runningSum, total g1JacExtended - runningSum.setInfinity() - total.setInfinity() - for k := len(buckets) - 1; k >= 0; k-- { - if !buckets[k].ZZ.IsZero() { - runningSum.add(&buckets[k]) + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // if we don't utilise all the tasks (CPU in the default case) that we could, let's see if it's worth it to split + if config.NbTasks > 1 && nbChunks < config.NbTasks { + // before spliting, let's see if we endup with more tasks than thread; + cSplit := bestC(nbPoints / 2) + nbChunksPostSplit := int(computeNbChunks(cSplit)) + nbTasksPostSplit := nbChunksPostSplit * 2 + if (nbTasksPostSplit <= config.NbTasks/2) || (nbTasksPostSplit-config.NbTasks/2) <= (config.NbTasks-nbChunks) { + // if postSplit we still have less tasks than available CPU + // or if we have more tasks BUT the difference of CPU usage is in our favor, we split. + config.NbTasks /= 2 + var _p G2Jac + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil } - total.add(&runningSum) } - chRes <- total + _innerMsmG2(p, C, points, scalars, config) + return p, nil } -func (p *G1Jac) msmC4(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 4 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) +func _innerMsmG2(p *G2Jac, c uint64, points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G2Jac { + // partition the scalars + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) // for each chunk, spawn one go routine that'll loop through all the scalars in the // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + // note that buckets is an array allocated on the stack and this is critical for performance // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended + chChunks := make([]chan g2JacExtended, nbChunks) for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) + chChunks[i] = make(chan g2JacExtended, 1) } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessorG2(c, chunkStats[j]) + if j == int(nbChunks-1) { + processChunk = getChunkProcessorG2(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan g2JacExtended, 2) + split := n / 2 + go processChunk(uint64(j), chSplit, c, points[:split], digits[j*n:(j*n)+split]) + go processChunk(uint64(j), chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n]) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n]) } - return msmReduceChunkG1Affine(p, c, chChunks[:]) + return msmReduceChunkG2Affine(p, int(c), chChunks[:]) } -func (p *G1Jac) msmC5(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 5 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } +// getChunkProcessorG2 decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessorG2(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- g2JacExtended, c uint64, points []G2Affine, digits []uint16) { + switch c { - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + case 3: + return processChunkG2Jacobian[bucketg2JacExtendedC3] + case 4: + return processChunkG2Jacobian[bucketg2JacExtendedC4] + case 5: + return processChunkG2Jacobian[bucketg2JacExtendedC5] + case 6: + return processChunkG2Jacobian[bucketg2JacExtendedC6] + case 7: + return processChunkG2Jacobian[bucketg2JacExtendedC7] + case 8: + return processChunkG2Jacobian[bucketg2JacExtendedC8] + case 9: + return processChunkG2Jacobian[bucketg2JacExtendedC9] + case 10: + const batchSize = 80 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC10] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC10, bucketG2AffineC10, bitSetC10, pG2AffineC10, ppG2AffineC10, qG2AffineC10, cG2AffineC10] + case 11: + const batchSize = 150 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC11] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC11, bucketG2AffineC11, bitSetC11, pG2AffineC11, ppG2AffineC11, qG2AffineC11, cG2AffineC11] + case 12: + const batchSize = 200 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC12] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC12, bucketG2AffineC12, bitSetC12, pG2AffineC12, ppG2AffineC12, qG2AffineC12, cG2AffineC12] + case 13: + const batchSize = 350 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC13] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC13, bucketG2AffineC13, bitSetC13, pG2AffineC13, ppG2AffineC13, qG2AffineC13, cG2AffineC13] + case 14: + const batchSize = 400 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC14] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC14, bucketG2AffineC14, bitSetC14, pG2AffineC14, ppG2AffineC14, qG2AffineC14, cG2AffineC14] + case 15: + const batchSize = 500 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC15] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC15, bucketG2AffineC15, bitSetC15, pG2AffineC15, ppG2AffineC15, qG2AffineC15, cG2AffineC15] + case 16: + const batchSize = 640 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC16] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC16, bucketG2AffineC16, bitSetC16, pG2AffineC16, ppG2AffineC16, qG2AffineC16, cG2AffineC16] + default: + // panic("will not happen c != previous values is not generated by templates") + return processChunkG2Jacobian[bucketg2JacExtendedC16] } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) } -func (p *G1Jac) msmC6(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 6 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC7(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 7 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC8(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 8 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC9(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 9 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC10(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 10 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC11(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 11 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC12(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 12 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC13(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 13 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC14(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 14 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC15(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 15 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC16(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 16 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC20(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 20 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC21(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 21 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf -// -// This call return an error if len(scalars) != len(points) or if provided config is invalid. -func (p *G2Affine) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Affine, error) { - var _p G2Jac - if _, err := _p.MultiExp(points, scalars, config); err != nil { - return nil, err - } - p.FromJacobian(&_p) - return p, nil -} - -// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf -// -// This call return an error if len(scalars) != len(points) or if provided config is invalid. -func (p *G2Jac) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Jac, error) { - // note: - // each of the msmCX method is the same, except for the c constant it declares - // duplicating (through template generation) these methods allows to declare the buckets on the stack - // the choice of c needs to be improved: - // there is a theoritical value that gives optimal asymptotics - // but in practice, other factors come into play, including: - // * if c doesn't divide 64, the word size, then we're bound to select bits over 2 words of our scalars, instead of 1 - // * number of CPUs - // * cache friendliness (which depends on the host, G1 or G2... ) - // --> for example, on BN254, a G1 point fits into one cache line of 64bytes, but a G2 point don't. - - // for each msmCX - // step 1 - // we compute, for each scalars over c-bit wide windows, nbChunk digits - // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract - // 2^{c} to the current digit, making it negative. - // negative digits will be processed in the next step as adding -G into the bucket instead of G - // (computing -G is cheap, and this saves us half of the buckets) - // step 2 - // buckets are declared on the stack - // notice that we have 2^{c-1} buckets instead of 2^{c} (see step1) - // we use jacobian extended formulas here as they are faster than mixed addition - // msmProcessChunk places points into buckets base on their selector and return the weighted bucket sum in given channel - // step 3 - // reduce the buckets weigthed sums into our result (msmReduceChunk) - - // ensure len(points) == len(scalars) - nbPoints := len(points) - if nbPoints != len(scalars) { - return nil, errors.New("len(points) != len(scalars)") - } - - // if nbTasks is not set, use all available CPUs - if config.NbTasks <= 0 { - config.NbTasks = runtime.NumCPU() - } else if config.NbTasks > 1024 { - return nil, errors.New("invalid config: config.NbTasks > 1024") - } - - // here, we compute the best C for nbPoints - // we split recursively until nbChunks(c) >= nbTasks, - bestC := func(nbPoints int) uint64 { - // implemented msmC methods (the c we use must be in this slice) - implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21} - var C uint64 - // approximate cost (in group operations) - // cost = bits/c * (nbPoints + 2^{c}) - // this needs to be verified empirically. - // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results - min := math.MaxFloat64 - for _, c := range implementedCs { - cc := fr.Limbs * 64 * (nbPoints + (1 << (c))) - cost := float64(cc) / float64(c) - if cost < min { - min = cost - C = c - } - } - // empirical, needs to be tuned. - // if C > 16 && nbPoints < 1 << 23 { - // C = 16 - // } - return C - } - - var C uint64 - nbSplits := 1 - nbChunks := 0 - for nbChunks < config.NbTasks { - C = bestC(nbPoints) - nbChunks = int(fr.Limbs * 64 / C) // number of c-bit radixes in a scalar - if (fr.Limbs*64)%C != 0 { - nbChunks++ - } - nbChunks *= nbSplits - if nbChunks < config.NbTasks { - nbSplits <<= 1 - nbPoints >>= 1 - } - } - - // partition the scalars - // note: we do that before the actual chunk processing, as for each c-bit window (starting from LSW) - // if it's larger than 2^{c-1}, we have a carry we need to propagate up to the higher window - var smallValues int - scalars, smallValues = partitionScalars(scalars, C, config.ScalarsMont, config.NbTasks) - - // if we have more than 10% of small values, we split the processing of the first chunk in 2 - // we may want to do that in msmInnerG2Jac , but that would incur a cost of looping through all scalars one more time - splitFirstChunk := (float64(smallValues) / float64(len(scalars))) >= 0.1 - - // we have nbSplits intermediate results that we must sum together. - _p := make([]G2Jac, nbSplits-1) - chDone := make(chan int, nbSplits-1) - for i := 0; i < nbSplits-1; i++ { - start := i * nbPoints - end := start + nbPoints - go func(start, end, i int) { - msmInnerG2Jac(&_p[i], int(C), points[start:end], scalars[start:end], splitFirstChunk) - chDone <- i - }(start, end, i) - } - - msmInnerG2Jac(p, int(C), points[(nbSplits-1)*nbPoints:], scalars[(nbSplits-1)*nbPoints:], splitFirstChunk) - for i := 0; i < nbSplits-1; i++ { - done := <-chDone - p.AddAssign(&_p[done]) - } - close(chDone) - return p, nil -} - -func msmInnerG2Jac(p *G2Jac, c int, points []G2Affine, scalars []fr.Element, splitFirstChunk bool) { - - switch c { - - case 4: - p.msmC4(points, scalars, splitFirstChunk) - - case 5: - p.msmC5(points, scalars, splitFirstChunk) - - case 6: - p.msmC6(points, scalars, splitFirstChunk) - - case 7: - p.msmC7(points, scalars, splitFirstChunk) - - case 8: - p.msmC8(points, scalars, splitFirstChunk) - - case 9: - p.msmC9(points, scalars, splitFirstChunk) - - case 10: - p.msmC10(points, scalars, splitFirstChunk) - - case 11: - p.msmC11(points, scalars, splitFirstChunk) - - case 12: - p.msmC12(points, scalars, splitFirstChunk) - - case 13: - p.msmC13(points, scalars, splitFirstChunk) - - case 14: - p.msmC14(points, scalars, splitFirstChunk) - - case 15: - p.msmC15(points, scalars, splitFirstChunk) - - case 16: - p.msmC16(points, scalars, splitFirstChunk) - - case 20: - p.msmC20(points, scalars, splitFirstChunk) - - case 21: - p.msmC21(points, scalars, splitFirstChunk) - - default: - panic("not implemented") - } -} - -// msmReduceChunkG2Affine reduces the weighted sum of the buckets into the result of the multiExp -func msmReduceChunkG2Affine(p *G2Jac, c int, chChunks []chan g2JacExtended) *G2Jac { - var _p g2JacExtended - totalj := <-chChunks[len(chChunks)-1] - _p.Set(&totalj) - for j := len(chChunks) - 2; j >= 0; j-- { - for l := 0; l < c; l++ { - _p.double(&_p) - } - totalj := <-chChunks[j] - _p.add(&totalj) - } - - return p.unsafeFromJacExtended(&_p) -} - -func msmProcessChunkG2Affine(chunk uint64, - chRes chan<- g2JacExtended, - buckets []g2JacExtended, - c uint64, - points []G2Affine, - scalars []fr.Element) { - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) - - for i := 0; i < len(buckets); i++ { - buckets[i].setInfinity() - } - - jc := uint64(chunk * c) - s := selector{} - s.index = jc / 64 - s.shift = jc - (s.index * 64) - s.mask = mask << s.shift - s.multiWordSelect = (64%c) != 0 && s.shift > (64-c) && s.index < (fr.Limbs-1) - if s.multiWordSelect { - nbBitsHigh := s.shift - uint64(64-c) - s.maskHigh = (1 << nbBitsHigh) - 1 - s.shiftHigh = (c - nbBitsHigh) - } - - // for each scalars, get the digit corresponding to the chunk we're processing. - for i := 0; i < len(scalars); i++ { - bits := (scalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (scalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { - continue - } - - // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { - // add - buckets[bits-1].addMixed(&points[i]) - } else { - // sub - buckets[bits & ^msbWindow].subMixed(&points[i]) - } - } - - // reduce buckets into total - // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] - - var runningSum, total g2JacExtended - runningSum.setInfinity() - total.setInfinity() - for k := len(buckets) - 1; k >= 0; k-- { - if !buckets[k].ZZ.IsZero() { - runningSum.add(&buckets[k]) - } - total.add(&runningSum) - } - - chRes <- total - -} - -func (p *G2Jac) msmC4(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 4 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC5(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 5 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC6(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 6 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC7(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 7 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC8(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 8 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC9(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 9 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC10(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 10 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC11(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 11 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() +// msmReduceChunkG2Affine reduces the weighted sum of the buckets into the result of the multiExp +func msmReduceChunkG2Affine(p *G2Jac, c int, chChunks []chan g2JacExtended) *G2Jac { + var _p g2JacExtended + totalj := <-chChunks[len(chChunks)-1] + _p.Set(&totalj) + for j := len(chChunks) - 2; j >= 0; j-- { + for l := 0; l < c; l++ { + _p.double(&_p) + } + totalj := <-chChunks[j] + _p.add(&totalj) } - return msmReduceChunkG2Affine(p, c, chChunks[:]) + return p.unsafeFromJacExtended(&_p) } -func (p *G2Jac) msmC12(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 12 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } +// selector stores the index, mask and shifts needed to select bits from a scalar +// it is used during the multiExp algorithm or the batch scalar multiplication +type selector struct { + index uint64 // index in the multi-word scalar to select bits from + mask uint64 // mask (c-bit wide) + shift uint64 // shift needed to get our bits on low positions - return msmReduceChunkG2Affine(p, c, chChunks[:]) + multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) + maskHigh uint64 // same than mask, for index+1 + shiftHigh uint64 // same than shift, for index+1 } -func (p *G2Jac) msmC13(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 13 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) +// return number of chunks for a given window size c +// the last chunk may be bigger to accomodate a potential carry from the NAF decomposition +func computeNbChunks(c uint64) uint64 { + return (fr.Bits + c - 1) / c } -func (p *G2Jac) msmC14(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 14 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) +// return the last window size for a scalar; +// this last window should accomodate a carry (from the NAF decomposition) +// it can be == c if we have 1 available bit +// it can be > c if we have 0 available bit +// it can be < c if we have 2+ available bits +func lastC(c uint64) uint64 { + nbAvailableBits := (computeNbChunks(c) * c) - fr.Bits + return c + 1 - nbAvailableBits } -func (p *G2Jac) msmC15(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 15 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } +type chunkStat struct { + // relative weight of work compared to other chunks. 100.0 -> nominal weight. + weight float32 - return msmReduceChunkG2Affine(p, c, chChunks[:]) + // percentage of bucket filled in the window; + ppBucketFilled float32 + nbBucketFilled int } -func (p *G2Jac) msmC16(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 16 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } +// partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits +// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract +// 2^{c} to the current digit, making it negative. +// negative digits can be processed in a later step as adding -G into the bucket instead of G +// (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) +func partitionScalars(scalars []fr.Element, c uint64, nbTasks int) ([]uint16, []chunkStat) { + // number of c-bit radixes in a scalar + nbChunks := computeNbChunks(c) - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + digits := make([]uint16, len(scalars)*int(nbChunks)) - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } + mask := uint64((1 << c) - 1) // low c bits are 1 + max := int(1<<(c-1)) - 1 // max value (inclusive) we want for our digits + cDivides64 := (64 % c) == 0 // if c doesn't divide 64, we may need to select over multiple words - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + // compute offset and word selector / shift to select the right bits of our windows + selectors := make([]selector, nbChunks) + for chunk := uint64(0); chunk < nbChunks; chunk++ { + jc := uint64(chunk * c) + d := selector{} + d.index = jc / 64 + d.shift = jc - (d.index * 64) + d.mask = mask << d.shift + d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs-1) + if d.multiWordSelect { + nbBitsHigh := d.shift - uint64(64-c) + d.maskHigh = (1 << nbBitsHigh) - 1 + d.shiftHigh = (c - nbBitsHigh) + } + selectors[chunk] = d } - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} + parallel.Execute(len(scalars), func(start, end int) { + for i := start; i < end; i++ { + if scalars[i].IsZero() { + // everything is 0, no need to process this scalar + continue + } + scalar := scalars[i].Bits() -func (p *G2Jac) msmC20(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 20 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) + var carry int - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + // for each chunk in the scalar, compute the current digit, and an eventual carry + for chunk := uint64(0); chunk < nbChunks-1; chunk++ { + s := selectors[chunk] - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } + // init with carry if any + digit := carry + carry = 0 - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } + // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract + // 2^{c} to the current digit, making it negative. + if digit > max { + digit -= (1 << c) + carry = 1 + } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } + // if digit is zero, no impact on result + if digit == 0 { + continue + } - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} + var bits uint16 + if digit > 0 { + bits = uint16(digit) << 1 + } else { + bits = (uint16(-digit-1) << 1) + 1 + } + digits[int(chunk)*len(scalars)+i] = bits + } -func (p *G2Jac) msmC21(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 21 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) + // for the last chunk, we don't want to borrow from a next window + // (but may have a larger max value) + chunk := nbChunks - 1 + s := selectors[chunk] + // init with carry if any + digit := carry + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } + digits[int(chunk)*len(scalars)+i] = uint16(digit) << 1 + } - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + }, nbTasks) - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) + // aggregate chunk stats + chunkStats := make([]chunkStat, nbChunks) + if c <= 9 { + // no need to compute stats for small window sizes + return digits, chunkStats } + parallel.Execute(len(chunkStats), func(start, end int) { + // for each chunk compute the statistics + for chunkID := start; chunkID < end; chunkID++ { + // indicates if a bucket is hit. + var b bitSetC16 - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) + // digits for the chunk + chunkDigits := digits[chunkID*len(scalars) : (chunkID+1)*len(scalars)] - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + totalOps := 0 + nz := 0 // non zero buckets count + for _, digit := range chunkDigits { + if digit == 0 { + continue + } + totalOps++ + bucketID := digit >> 1 + if digit&1 == 0 { + bucketID -= 1 + } + if !b[bucketID] { + nz++ + b[bucketID] = true + } + } + chunkStats[chunkID].weight = float32(totalOps) // count number of ops for now, we will compute the weight after + chunkStats[chunkID].ppBucketFilled = (float32(nz) * 100.0) / float32(int(1<<(c-1))) + chunkStats[chunkID].nbBucketFilled = nz + } + }, nbTasks) - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) + totalOps := float32(0.0) + for _, stat := range chunkStats { + totalOps += stat.weight } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + target := totalOps / float32(nbChunks) + if target != 0.0 { + // if target == 0, it means all the scalars are 0 everywhere, there is no work to be done. + for i := 0; i < len(chunkStats); i++ { + chunkStats[i].weight = (chunkStats[i].weight * 100.0) / target + } } - return msmReduceChunkG2Affine(p, c, chChunks[:]) + return digits, chunkStats } diff --git a/ecc/bls12-381/multiexp_affine.go b/ecc/bls12-381/multiexp_affine.go new file mode 100644 index 000000000..f2fcc0573 --- /dev/null +++ b/ecc/bls12-381/multiexp_affine.go @@ -0,0 +1,686 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package bls12381 + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-381/fp" + "github.com/consensys/gnark-crypto/ecc/bls12-381/internal/fptower" +) + +type batchOpG1Affine struct { + bucketID uint16 + point G1Affine +} + +// processChunkG1BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/ConsenSys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunkG1BatchAffine[BJE ibg1JacExtended, B ibG1Affine, BS bitSet, TP pG1Affine, TPP ppG1Affine, TQ qOpsG1Affine, TC cG1Affine]( + chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16) { + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // g1JacExtended coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediatly + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in G1Affine used with the batch affine additions + // 1 in g1JacExtended used in case the queue of conflicting points + var buckets B + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + bucketsJE[i].setInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize } + + executeAndReset := func() { + batchAddG1Affine[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOpG1Affine) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.setInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *G1Affine, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.setInfinity() + } + return + } + if isAdd { + BK.setInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func() { + for i := 0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func() { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit >> 1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID -= 1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue)-1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total g1JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].ZZ.IsZero() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketG1AffineC10 [512]G1Affine +type bucketG1AffineC11 [1024]G1Affine +type bucketG1AffineC12 [2048]G1Affine +type bucketG1AffineC13 [4096]G1Affine +type bucketG1AffineC14 [8192]G1Affine +type bucketG1AffineC15 [16384]G1Affine +type bucketG1AffineC16 [32768]G1Affine + +// buckets: array of G1Affine points of size 1 << (c-1) +type ibG1Affine interface { + bucketG1AffineC10 | + bucketG1AffineC11 | + bucketG1AffineC12 | + bucketG1AffineC13 | + bucketG1AffineC14 | + bucketG1AffineC15 | + bucketG1AffineC16 +} + +// array of coordinates fp.Element +type cG1Affine interface { + cG1AffineC10 | + cG1AffineC11 | + cG1AffineC12 | + cG1AffineC13 | + cG1AffineC14 | + cG1AffineC15 | + cG1AffineC16 +} + +// buckets: array of G1Affine points (for the batch addition) +type pG1Affine interface { + pG1AffineC10 | + pG1AffineC11 | + pG1AffineC12 | + pG1AffineC13 | + pG1AffineC14 | + pG1AffineC15 | + pG1AffineC16 +} + +// buckets: array of *G1Affine points (for the batch addition) +type ppG1Affine interface { + ppG1AffineC10 | + ppG1AffineC11 | + ppG1AffineC12 | + ppG1AffineC13 | + ppG1AffineC14 | + ppG1AffineC15 | + ppG1AffineC16 +} + +// buckets: array of G1Affine queue operations (for the batch addition) +type qOpsG1Affine interface { + qG1AffineC10 | + qG1AffineC11 | + qG1AffineC12 | + qG1AffineC13 | + qG1AffineC14 | + qG1AffineC15 | + qG1AffineC16 +} + +// batch size 80 when c = 10 +type cG1AffineC10 [80]fp.Element +type pG1AffineC10 [80]G1Affine +type ppG1AffineC10 [80]*G1Affine +type qG1AffineC10 [80]batchOpG1Affine + +// batch size 150 when c = 11 +type cG1AffineC11 [150]fp.Element +type pG1AffineC11 [150]G1Affine +type ppG1AffineC11 [150]*G1Affine +type qG1AffineC11 [150]batchOpG1Affine + +// batch size 200 when c = 12 +type cG1AffineC12 [200]fp.Element +type pG1AffineC12 [200]G1Affine +type ppG1AffineC12 [200]*G1Affine +type qG1AffineC12 [200]batchOpG1Affine + +// batch size 350 when c = 13 +type cG1AffineC13 [350]fp.Element +type pG1AffineC13 [350]G1Affine +type ppG1AffineC13 [350]*G1Affine +type qG1AffineC13 [350]batchOpG1Affine + +// batch size 400 when c = 14 +type cG1AffineC14 [400]fp.Element +type pG1AffineC14 [400]G1Affine +type ppG1AffineC14 [400]*G1Affine +type qG1AffineC14 [400]batchOpG1Affine + +// batch size 500 when c = 15 +type cG1AffineC15 [500]fp.Element +type pG1AffineC15 [500]G1Affine +type ppG1AffineC15 [500]*G1Affine +type qG1AffineC15 [500]batchOpG1Affine + +// batch size 640 when c = 16 +type cG1AffineC16 [640]fp.Element +type pG1AffineC16 [640]G1Affine +type ppG1AffineC16 [640]*G1Affine +type qG1AffineC16 [640]batchOpG1Affine + +type batchOpG2Affine struct { + bucketID uint16 + point G2Affine +} + +// processChunkG2BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/ConsenSys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunkG2BatchAffine[BJE ibg2JacExtended, B ibG2Affine, BS bitSet, TP pG2Affine, TPP ppG2Affine, TQ qOpsG2Affine, TC cG2Affine]( + chunk uint64, + chRes chan<- g2JacExtended, + c uint64, + points []G2Affine, + digits []uint16) { + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // g2JacExtended coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediatly + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in G2Affine used with the batch affine additions + // 1 in g2JacExtended used in case the queue of conflicting points + var buckets B + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + bucketsJE[i].setInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize } + + executeAndReset := func() { + batchAddG2Affine[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOpG2Affine) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.setInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *G2Affine, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.setInfinity() + } + return + } + if isAdd { + BK.setInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func() { + for i := 0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func() { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit >> 1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID -= 1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue)-1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total g2JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].ZZ.IsZero() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketG2AffineC10 [512]G2Affine +type bucketG2AffineC11 [1024]G2Affine +type bucketG2AffineC12 [2048]G2Affine +type bucketG2AffineC13 [4096]G2Affine +type bucketG2AffineC14 [8192]G2Affine +type bucketG2AffineC15 [16384]G2Affine +type bucketG2AffineC16 [32768]G2Affine + +// buckets: array of G2Affine points of size 1 << (c-1) +type ibG2Affine interface { + bucketG2AffineC10 | + bucketG2AffineC11 | + bucketG2AffineC12 | + bucketG2AffineC13 | + bucketG2AffineC14 | + bucketG2AffineC15 | + bucketG2AffineC16 +} + +// array of coordinates fptower.E2 +type cG2Affine interface { + cG2AffineC10 | + cG2AffineC11 | + cG2AffineC12 | + cG2AffineC13 | + cG2AffineC14 | + cG2AffineC15 | + cG2AffineC16 +} + +// buckets: array of G2Affine points (for the batch addition) +type pG2Affine interface { + pG2AffineC10 | + pG2AffineC11 | + pG2AffineC12 | + pG2AffineC13 | + pG2AffineC14 | + pG2AffineC15 | + pG2AffineC16 +} + +// buckets: array of *G2Affine points (for the batch addition) +type ppG2Affine interface { + ppG2AffineC10 | + ppG2AffineC11 | + ppG2AffineC12 | + ppG2AffineC13 | + ppG2AffineC14 | + ppG2AffineC15 | + ppG2AffineC16 +} + +// buckets: array of G2Affine queue operations (for the batch addition) +type qOpsG2Affine interface { + qG2AffineC10 | + qG2AffineC11 | + qG2AffineC12 | + qG2AffineC13 | + qG2AffineC14 | + qG2AffineC15 | + qG2AffineC16 +} + +// batch size 80 when c = 10 +type cG2AffineC10 [80]fptower.E2 +type pG2AffineC10 [80]G2Affine +type ppG2AffineC10 [80]*G2Affine +type qG2AffineC10 [80]batchOpG2Affine + +// batch size 150 when c = 11 +type cG2AffineC11 [150]fptower.E2 +type pG2AffineC11 [150]G2Affine +type ppG2AffineC11 [150]*G2Affine +type qG2AffineC11 [150]batchOpG2Affine + +// batch size 200 when c = 12 +type cG2AffineC12 [200]fptower.E2 +type pG2AffineC12 [200]G2Affine +type ppG2AffineC12 [200]*G2Affine +type qG2AffineC12 [200]batchOpG2Affine + +// batch size 350 when c = 13 +type cG2AffineC13 [350]fptower.E2 +type pG2AffineC13 [350]G2Affine +type ppG2AffineC13 [350]*G2Affine +type qG2AffineC13 [350]batchOpG2Affine + +// batch size 400 when c = 14 +type cG2AffineC14 [400]fptower.E2 +type pG2AffineC14 [400]G2Affine +type ppG2AffineC14 [400]*G2Affine +type qG2AffineC14 [400]batchOpG2Affine + +// batch size 500 when c = 15 +type cG2AffineC15 [500]fptower.E2 +type pG2AffineC15 [500]G2Affine +type ppG2AffineC15 [500]*G2Affine +type qG2AffineC15 [500]batchOpG2Affine + +// batch size 640 when c = 16 +type cG2AffineC16 [640]fptower.E2 +type pG2AffineC16 [640]G2Affine +type ppG2AffineC16 [640]*G2Affine +type qG2AffineC16 [640]batchOpG2Affine + +type bitSetC3 [4]bool +type bitSetC4 [8]bool +type bitSetC5 [16]bool +type bitSetC6 [32]bool +type bitSetC7 [64]bool +type bitSetC8 [128]bool +type bitSetC9 [256]bool +type bitSetC10 [512]bool +type bitSetC11 [1024]bool +type bitSetC12 [2048]bool +type bitSetC13 [4096]bool +type bitSetC14 [8192]bool +type bitSetC15 [16384]bool +type bitSetC16 [32768]bool + +type bitSet interface { + bitSetC3 | + bitSetC4 | + bitSetC5 | + bitSetC6 | + bitSetC7 | + bitSetC8 | + bitSetC9 | + bitSetC10 | + bitSetC11 | + bitSetC12 | + bitSetC13 | + bitSetC14 | + bitSetC15 | + bitSetC16 +} diff --git a/ecc/bls12-381/multiexp_jacobian.go b/ecc/bls12-381/multiexp_jacobian.go new file mode 100644 index 000000000..2a2f8caa8 --- /dev/null +++ b/ecc/bls12-381/multiexp_jacobian.go @@ -0,0 +1,171 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package bls12381 + +func processChunkG1Jacobian[B ibg1JacExtended](chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16) { + + var buckets B + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to substract + if digit&1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit >> 1)].subMixed(&points[i]) + } + } + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total g1JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].ZZ.IsZero() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketg1JacExtendedC3 [4]g1JacExtended +type bucketg1JacExtendedC4 [8]g1JacExtended +type bucketg1JacExtendedC5 [16]g1JacExtended +type bucketg1JacExtendedC6 [32]g1JacExtended +type bucketg1JacExtendedC7 [64]g1JacExtended +type bucketg1JacExtendedC8 [128]g1JacExtended +type bucketg1JacExtendedC9 [256]g1JacExtended +type bucketg1JacExtendedC10 [512]g1JacExtended +type bucketg1JacExtendedC11 [1024]g1JacExtended +type bucketg1JacExtendedC12 [2048]g1JacExtended +type bucketg1JacExtendedC13 [4096]g1JacExtended +type bucketg1JacExtendedC14 [8192]g1JacExtended +type bucketg1JacExtendedC15 [16384]g1JacExtended +type bucketg1JacExtendedC16 [32768]g1JacExtended + +type ibg1JacExtended interface { + bucketg1JacExtendedC3 | + bucketg1JacExtendedC4 | + bucketg1JacExtendedC5 | + bucketg1JacExtendedC6 | + bucketg1JacExtendedC7 | + bucketg1JacExtendedC8 | + bucketg1JacExtendedC9 | + bucketg1JacExtendedC10 | + bucketg1JacExtendedC11 | + bucketg1JacExtendedC12 | + bucketg1JacExtendedC13 | + bucketg1JacExtendedC14 | + bucketg1JacExtendedC15 | + bucketg1JacExtendedC16 +} + +func processChunkG2Jacobian[B ibg2JacExtended](chunk uint64, + chRes chan<- g2JacExtended, + c uint64, + points []G2Affine, + digits []uint16) { + + var buckets B + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to substract + if digit&1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit >> 1)].subMixed(&points[i]) + } + } + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total g2JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].ZZ.IsZero() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketg2JacExtendedC3 [4]g2JacExtended +type bucketg2JacExtendedC4 [8]g2JacExtended +type bucketg2JacExtendedC5 [16]g2JacExtended +type bucketg2JacExtendedC6 [32]g2JacExtended +type bucketg2JacExtendedC7 [64]g2JacExtended +type bucketg2JacExtendedC8 [128]g2JacExtended +type bucketg2JacExtendedC9 [256]g2JacExtended +type bucketg2JacExtendedC10 [512]g2JacExtended +type bucketg2JacExtendedC11 [1024]g2JacExtended +type bucketg2JacExtendedC12 [2048]g2JacExtended +type bucketg2JacExtendedC13 [4096]g2JacExtended +type bucketg2JacExtendedC14 [8192]g2JacExtended +type bucketg2JacExtendedC15 [16384]g2JacExtended +type bucketg2JacExtendedC16 [32768]g2JacExtended + +type ibg2JacExtended interface { + bucketg2JacExtendedC3 | + bucketg2JacExtendedC4 | + bucketg2JacExtendedC5 | + bucketg2JacExtendedC6 | + bucketg2JacExtendedC7 | + bucketg2JacExtendedC8 | + bucketg2JacExtendedC9 | + bucketg2JacExtendedC10 | + bucketg2JacExtendedC11 | + bucketg2JacExtendedC12 | + bucketg2JacExtendedC13 | + bucketg2JacExtendedC14 | + bucketg2JacExtendedC15 | + bucketg2JacExtendedC16 +} diff --git a/ecc/bls12-381/multiexp_test.go b/ecc/bls12-381/multiexp_test.go index 2c051ce12..f5a52ce42 100644 --- a/ecc/bls12-381/multiexp_test.go +++ b/ecc/bls12-381/multiexp_test.go @@ -20,9 +20,11 @@ import ( "fmt" "math/big" "math/bits" + "math/rand" "runtime" "sync" "testing" + "time" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls12-381/fr" @@ -34,9 +36,9 @@ func TestMultiExpG1(t *testing.T) { parameters := gopter.DefaultTestParameters() if testing.Short() { - parameters.MinSuccessfulTests = 2 + parameters.MinSuccessfulTests = 3 } else { - parameters.MinSuccessfulTests = nbFuzzShort + parameters.MinSuccessfulTests = nbFuzzShort * 2 } properties := gopter.NewProperties(parameters) @@ -55,6 +57,13 @@ func TestMultiExpG1(t *testing.T) { g.AddAssign(&g1Gen) } + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + // final scalar to use in double and add method (without mixer factor) // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) var scalar big.Int @@ -78,13 +87,10 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - scalars16, _ := partitionScalars(sampleScalars[:], 16, false, runtime.NumCPU()) - r16.msmC16(samplePoints[:], scalars16, true) - + r16.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) return r16.Equal(&splitted1) && r16.Equal(&splitted2) @@ -93,7 +99,7 @@ func TestMultiExpG1(t *testing.T) { )) // cRange is generated from template and contains the available parameters for the multiexp window size - cRange := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21} + cRange := []uint64{3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} if testing.Short() { // test only "odd" and "even" (ie windows size divide word size vs not) cRange = []uint64{5, 16} @@ -114,21 +120,72 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - results := make([]G1Jac, len(cRange)+1) + results := make([]G1Jac, len(cRange)) for i, c := range cRange { - scalars, _ := partitionScalars(sampleScalars[:], c, false, runtime.NumCPU()) - msmInnerG1Jac(&results[i], int(c), samplePoints[:], scalars, false) - if c == 16 { - // split the first chunk - msmInnerG1Jac(&results[len(results)-1], 16, samplePoints[:], scalars, true) - } + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) } for i := 1; i < len(results); i++ { if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1], cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]G1Affine + + var expected G1Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) + expected.ScalarMultiplication(&g1Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].setInfinity() + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) return false } } @@ -151,8 +208,7 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= 30; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) samplePoints[i-1].FromJacobian(&g) g.AddAssign(&g1Gen) } @@ -175,6 +231,87 @@ func TestMultiExpG1(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestCrossMultiExpG1(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]G1Affine + var g G1Jac + g.Set(&g1Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g1Gen) + } + + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i := 10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + cRange := []uint64{3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + if testing.Short() { + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange = []uint64{5, 16} + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r G1Jac + _innerMsmG1Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got G1Affine + expected.FromJacobian(&r) + + for i := 0; i < len(results); i++ { + got.FromJacobian(&results[i]) + if !expected.Equal(&got) { + t.Fatalf("cross msm failed with c=%d", cRange[i]) + } + } + +} + +// _innerMsmG1Reference always do ext jacobian with c == 16 +func _innerMsmG1Reference(p *G1Jac, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { + // partition the scalars + digits, _ := partitionScalars(scalars, 16, config.NbTasks) + + nbChunks := computeNbChunks(16) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := processChunkG1Jacobian[bucketg1JacExtendedC16] + go processChunk(uint64(j), chChunks[j], 16, points, digits[j*n:(j+1)*n]) + } + + return msmReduceChunkG1Affine(p, int(16), chChunks[:]) +} + func BenchmarkMultiExpG1(b *testing.B) { const ( @@ -183,11 +320,33 @@ func BenchmarkMultiExpG1(b *testing.B) { ) var ( - samplePoints [nbSamples]G1Affine - sampleScalars [nbSamples]fr.Element + samplePoints [nbSamples]G1Affine + sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element ) fillBenchScalars(sampleScalars[:]) + copy(sampleScalarsSmallValues[:], sampleScalars[:]) + copy(sampleScalarsRedundant[:], sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i := 0; i < len(sampleScalarsSmallValues); i++ { + if i%5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i := 0; i < len(sampleScalarsRedundant); i += 100 { + for j := i + 1; j < i+100 && j < len(sampleScalarsRedundant); j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + fillBenchBasesG1(samplePoints[:]) var testPoint G1Affine @@ -201,6 +360,20 @@ func BenchmarkMultiExpG1(b *testing.B) { testPoint.MultiExp(samplePoints[:using], sampleScalars[:using], ecc.MultiExpConfig{}) } }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using], ecc.MultiExpConfig{}) + } + }) } } @@ -279,9 +452,9 @@ func TestMultiExpG2(t *testing.T) { parameters := gopter.DefaultTestParameters() if testing.Short() { - parameters.MinSuccessfulTests = 2 + parameters.MinSuccessfulTests = 3 } else { - parameters.MinSuccessfulTests = nbFuzzShort + parameters.MinSuccessfulTests = nbFuzzShort * 2 } properties := gopter.NewProperties(parameters) @@ -300,6 +473,13 @@ func TestMultiExpG2(t *testing.T) { g.AddAssign(&g2Gen) } + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + // final scalar to use in double and add method (without mixer factor) // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) var scalar big.Int @@ -323,13 +503,10 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - scalars16, _ := partitionScalars(sampleScalars[:], 16, false, runtime.NumCPU()) - r16.msmC16(samplePoints[:], scalars16, true) - + r16.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) return r16.Equal(&splitted1) && r16.Equal(&splitted2) @@ -357,21 +534,72 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - results := make([]G2Jac, len(cRange)+1) + results := make([]G2Jac, len(cRange)) for i, c := range cRange { - scalars, _ := partitionScalars(sampleScalars[:], c, false, runtime.NumCPU()) - msmInnerG2Jac(&results[i], int(c), samplePoints[:], scalars, false) - if c == 16 { - // split the first chunk - msmInnerG2Jac(&results[len(results)-1], 16, samplePoints[:], scalars, true) - } + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) } for i := 1; i < len(results); i++ { if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1], cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G2] Multi exponentation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]G2Affine + + var expected G2Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) + expected.ScalarMultiplication(&g2Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].setInfinity() + } + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G2] Multi exponentation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) return false } } @@ -394,8 +622,7 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= 30; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) samplePoints[i-1].FromJacobian(&g) g.AddAssign(&g2Gen) } @@ -418,6 +645,85 @@ func TestMultiExpG2(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestCrossMultiExpG2(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]G2Affine + var g G2Jac + g.Set(&g2Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g2Gen) + } + + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i := 10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + // for g2, CI suffers with large c size since it needs to allocate a lot of memory for the buckets. + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange := []uint64{5, 16} + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r G2Jac + _innerMsmG2Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got G2Affine + expected.FromJacobian(&r) + + for i := 0; i < len(results); i++ { + got.FromJacobian(&results[i]) + if !expected.Equal(&got) { + t.Fatalf("cross msm failed with c=%d", cRange[i]) + } + } + +} + +// _innerMsmG2Reference always do ext jacobian with c == 16 +func _innerMsmG2Reference(p *G2Jac, points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G2Jac { + // partition the scalars + digits, _ := partitionScalars(scalars, 16, config.NbTasks) + + nbChunks := computeNbChunks(16) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g2JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g2JacExtended, 1) + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := processChunkG2Jacobian[bucketg2JacExtendedC16] + go processChunk(uint64(j), chChunks[j], 16, points, digits[j*n:(j+1)*n]) + } + + return msmReduceChunkG2Affine(p, int(16), chChunks[:]) +} + func BenchmarkMultiExpG2(b *testing.B) { const ( @@ -426,11 +732,33 @@ func BenchmarkMultiExpG2(b *testing.B) { ) var ( - samplePoints [nbSamples]G2Affine - sampleScalars [nbSamples]fr.Element + samplePoints [nbSamples]G2Affine + sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element ) fillBenchScalars(sampleScalars[:]) + copy(sampleScalarsSmallValues[:], sampleScalars[:]) + copy(sampleScalarsRedundant[:], sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i := 0; i < len(sampleScalarsSmallValues); i++ { + if i%5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i := 0; i < len(sampleScalarsRedundant); i += 100 { + for j := i + 1; j < i+100 && j < len(sampleScalarsRedundant); j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + fillBenchBasesG2(samplePoints[:]) var testPoint G2Affine @@ -444,6 +772,20 @@ func BenchmarkMultiExpG2(b *testing.B) { testPoint.MultiExp(samplePoints[:using], sampleScalars[:using], ecc.MultiExpConfig{}) } }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using], ecc.MultiExpConfig{}) + } + }) } } @@ -520,11 +862,7 @@ func fillBenchBasesG2(samplePoints []G2Affine) { func fillBenchScalars(sampleScalars []fr.Element) { // ensure every words of the scalars are filled - var mixer fr.Element - mixer.SetString("7716837800905789770901243404444209691916730933998574719964609384059111546487") - for i := 1; i <= len(sampleScalars); i++ { - sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + for i := 0; i < len(sampleScalars); i++ { + sampleScalars[i].SetRandom() } } diff --git a/ecc/bls12-381/twistededwards/eddsa/doc.go b/ecc/bls12-381/twistededwards/eddsa/doc.go index fb4712633..af3fe1f93 100644 --- a/ecc/bls12-381/twistededwards/eddsa/doc.go +++ b/ecc/bls12-381/twistededwards/eddsa/doc.go @@ -16,7 +16,7 @@ // Package eddsa provides EdDSA signature scheme on bls12-381's twisted edwards curve. // -// See also +// # See also // // https://en.wikipedia.org/wiki/EdDSA package eddsa diff --git a/ecc/bls12-381/twistededwards/eddsa/eddsa_test.go b/ecc/bls12-381/twistededwards/eddsa/eddsa_test.go index 967aac1db..85bd27ca0 100644 --- a/ecc/bls12-381/twistededwards/eddsa/eddsa_test.go +++ b/ecc/bls12-381/twistededwards/eddsa/eddsa_test.go @@ -37,8 +37,10 @@ func Example() { privateKey, _ := GenerateKey(crand.Reader) publicKey := privateKey.PublicKey - // note that the message is on 4 bytes - msg := []byte{0xde, 0xad, 0xf0, 0x0d} + // generate a message (the size must be a multiple of the size of Fr) + var _msg fr.Element + _msg.SetRandom() + msg := _msg.Marshal() // sign the message signature, _ := privateKey.Sign(msg, hFunc) diff --git a/ecc/bls12-381/twistededwards/eddsa/marshal.go b/ecc/bls12-381/twistededwards/eddsa/marshal.go index c68129087..485e9b710 100644 --- a/ecc/bls12-381/twistededwards/eddsa/marshal.go +++ b/ecc/bls12-381/twistededwards/eddsa/marshal.go @@ -94,11 +94,11 @@ func (privKey *PrivateKey) SetBytes(buf []byte) (int, error) { // Bytes returns the binary representation of sig // as a byte array of size 3*sizeFr x||y||s where -// * x, y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x, y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) func (sig *Signature) Bytes() []byte { var res [sizeSignature]byte sigRBin := sig.R.Bytes() @@ -109,11 +109,12 @@ func (sig *Signature) Bytes() []byte { // SetBytes sets sig from a buffer in binary. // buf is read interpreted as x||y||s where -// * x,y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x,y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) +// // It returns the number of bytes read from buf. func (sig *Signature) SetBytes(buf []byte) (int, error) { n := 0 diff --git a/ecc/bls12-381/twistededwards/point.go b/ecc/bls12-381/twistededwards/point.go index aa28fa719..ec42a001a 100644 --- a/ecc/bls12-381/twistededwards/point.go +++ b/ecc/bls12-381/twistededwards/point.go @@ -49,7 +49,7 @@ const ( mUnmask = 0x7f // size in byte of a compressed point (point.Y --> fr.Element) - sizePointCompressed = fr.Limbs * 8 + sizePointCompressed = fr.Bytes ) // Bytes returns the compressed point as a byte array diff --git a/ecc/bls24-315/bls24-315.go b/ecc/bls24-315/bls24-315.go index c5270c672..ade22cdb5 100644 --- a/ecc/bls24-315/bls24-315.go +++ b/ecc/bls24-315/bls24-315.go @@ -1,24 +1,30 @@ // Package bls24315 efficient elliptic curve, pairing and hash to curve implementation for bls24-315. // // bls24-315: A Barreto--Lynn--Scott curve -// embedding degree k=24 -// seed x₀=-3218079743 -// 𝔽r: r=0x196deac24a9da12b25fc7ec9cf927a98c8c480ece644e36419d0c5fd00c00001 (x₀^8-x₀^4+2) -// 𝔽p: p=0x4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001 ((x₀-1)² ⋅ r(x₀)/3+x₀) -// (E/𝔽p): Y²=X³+1 -// (Eₜ/𝔽p⁴): Y² = X³+1/v (D-type twist) -// r ∣ #E(Fp) and r ∣ #Eₜ(𝔽p⁴) +// +// embedding degree k=24 +// seed x₀=-3218079743 +// 𝔽r: r=0x196deac24a9da12b25fc7ec9cf927a98c8c480ece644e36419d0c5fd00c00001 (x₀^8-x₀^4+2) +// 𝔽p: p=0x4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001 ((x₀-1)² ⋅ r(x₀)/3+x₀) +// (E/𝔽p): Y²=X³+1 +// (Eₜ/𝔽p⁴): Y² = X³+1/v (D-type twist) +// r ∣ #E(Fp) and r ∣ #Eₜ(𝔽p⁴) +// // Extension fields tower: -// 𝔽p²[u] = 𝔽p/u²-13 -// 𝔽p⁴[v] = 𝔽p²/v²-u -// 𝔽p¹²[w] = 𝔽p⁴/w³-v -// 𝔽p²⁴[i] = 𝔽p¹²/i²-w +// +// 𝔽p²[u] = 𝔽p/u²-13 +// 𝔽p⁴[v] = 𝔽p²/v²-u +// 𝔽p¹²[w] = 𝔽p⁴/w³-v +// 𝔽p²⁴[i] = 𝔽p¹²/i²-w +// // optimal Ate loop size: -// x₀ +// +// x₀ +// // Security: estimated 160-bit level following [https://eprint.iacr.org/2019/885.pdf] // (r is 253 bits and p²⁴ is 7543 bits) // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package bls24315 diff --git a/ecc/bls24-315/fp/doc.go b/ecc/bls24-315/fp/doc.go index 172938489..10acfff5d 100644 --- a/ecc/bls24-315/fp/doc.go +++ b/ecc/bls24-315/fp/doc.go @@ -21,30 +21,33 @@ // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [5]uint64 // -// Usage +// type Element [5]uint64 +// +// # Usage // // Example API signature: -// // Mul z = x * y (mod q) -// func (z *Element) Mul(x, y *Element) *Element +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) // // Modulus q = // -// q[base10] = 39705142709513438335025689890408969744933502416914749335064285505637884093126342347073617133569 -// q[base16] = 0x4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001 +// q[base10] = 39705142709513438335025689890408969744933502416914749335064285505637884093126342347073617133569 +// q[base16] = 0x4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package fp diff --git a/ecc/bls24-315/fp/element.go b/ecc/bls24-315/fp/element.go index a48ffb0b5..2acd64a01 100644 --- a/ecc/bls24-315/fp/element.go +++ b/ecc/bls24-315/fp/element.go @@ -20,13 +20,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "github.com/consensys/gnark-crypto/field" "io" "math/big" "math/bits" "reflect" "strconv" "strings" - "sync" ) // Element represents a field element stored on 5 words (uint64) @@ -35,18 +35,18 @@ import ( // // Modulus q = // -// q[base10] = 39705142709513438335025689890408969744933502416914749335064285505637884093126342347073617133569 -// q[base16] = 0x4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001 +// q[base10] = 39705142709513438335025689890408969744933502416914749335064285505637884093126342347073617133569 +// q[base16] = 0x4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [5]uint64 const ( - Limbs = 5 // number of 64 bits words needed to represent a Element - Bits = 315 // number of bits needed to represent a Element - Bytes = Limbs * 8 // number of bytes needed to represent a Element + Limbs = 5 // number of 64 bits words needed to represent a Element + Bits = 315 // number of bits needed to represent a Element + Bytes = 40 // number of bytes needed to represent a Element ) // Field modulus q @@ -70,8 +70,8 @@ var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int // -// q[base10] = 39705142709513438335025689890408969744933502416914749335064285505637884093126342347073617133569 -// q[base16] = 0x4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001 +// q[base10] = 39705142709513438335025689890408969744933502416914749335064285505637884093126342347073617133569 +// q[base16] = 0x4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001 func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } @@ -80,12 +80,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 8083954730842193919 -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001", 16) } @@ -93,8 +87,9 @@ func init() { // NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v Element -// v.SetUint64(...) +// +// var v Element +// v.SetUint64(...) func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) @@ -105,7 +100,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -136,14 +131,15 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported // supported types: -// Element -// *Element -// uint64 -// int -// string (see SetString for valid formats) -// *big.Int -// big.Int -// []byte +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { if i1 == nil { return nil, errors.New("can't set fp.Element with ") @@ -255,15 +251,13 @@ func (z *Element) IsOne() bool { // IsUint64 reports whether z can be represented as an uint64. func (z *Element) IsUint64() bool { zz := *z - zz.FromMont() + zz.fromMont() return zz.FitsOnOneWord() } // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -275,15 +269,12 @@ func (z *Element) FitsOnOneWord() bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[4] > _x[4] { return 1 } else if _z[4] < _x[4] { @@ -319,8 +310,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 4031849214061838337, 0) @@ -415,67 +405,9 @@ func (z *Element) Halve() { } -// Mul z = x * y (mod q) -// -// x and y must be strictly inferior to q -func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be strictly inferior to q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } @@ -490,7 +422,7 @@ func (z *Element) Add(x, y *Element) *Element { z[3], carry = bits.Add64(x[3], y[3], carry) z[4], _ = bits.Add64(x[4], y[4], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -512,7 +444,7 @@ func (z *Element) Double(x *Element) *Element { z[3], carry = bits.Add64(x[3], x[3], carry) z[4], _ = bits.Add64(x[4], x[4], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -570,88 +502,181 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // see Mul for algorithm documentation - var t [5]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd1(v, y[4], c[1]) - t[4], t[3] = madd3(m, q4, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - t[4], t[3] = madd3(m, q4, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - t[4], t[3] = madd3(m, q4, c[0], c[2], c[1]) - } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - t[4], t[3] = madd3(m, q4, c[0], c[2], c[1]) - } - { - // round 4 - v := x[4] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], z[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - z[4], z[3] = madd3(m, q4, c[0], c[2], c[1]) - } - - // if z >= q → z -= q + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [6]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + C, t[4] = madd1(y[0], x[4], C) + + t[5], D = bits.Add64(t[5], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + + t[4], C = bits.Add64(t[5], C, 0) + t[5], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + C, t[4] = madd2(y[1], x[4], t[4], C) + + t[5], D = bits.Add64(t[5], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + + t[4], C = bits.Add64(t[5], C, 0) + t[5], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + C, t[4] = madd2(y[2], x[4], t[4], C) + + t[5], D = bits.Add64(t[5], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + + t[4], C = bits.Add64(t[5], C, 0) + t[5], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + C, t[4] = madd2(y[3], x[4], t[4], C) + + t[5], D = bits.Add64(t[5], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + + t[4], C = bits.Add64(t[5], C, 0) + t[5], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[4], x[0], t[0]) + C, t[1] = madd2(y[4], x[1], t[1], C) + C, t[2] = madd2(y[4], x[2], t[2], C) + C, t[3] = madd2(y[4], x[3], t[3], C) + C, t[4] = madd2(y[4], x[4], t[4], C) + + t[5], D = bits.Add64(t[5], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + + t[4], C = bits.Add64(t[5], C, 0) + t[5], _ = bits.Add64(0, D, C) + + if t[5] != 0 { + // we need to reduce, we have a result on 6 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], b = bits.Sub64(t[3], q3, b) + z[4], _ = bits.Sub64(t[4], q4, b) + return + } + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + z[4] = t[4] + + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -660,7 +685,6 @@ func _mulGeneric(z, x, y *Element) { z[3], b = bits.Sub64(z[3], q3, b) z[4], _ = bits.Sub64(z[4], q4, b) } - } func _fromMontGeneric(z *Element) { @@ -718,7 +742,7 @@ func _fromMontGeneric(z *Element) { z[4] = C } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -731,7 +755,7 @@ func _fromMontGeneric(z *Element) { func _reduceGeneric(z *Element) { - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -799,6 +823,35 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + // Exp z = xᵏ (mod q) func (z *Element) Exp(x Element, k *big.Int) *Element { if k.IsUint64() && k.Uint64() == 0 { @@ -813,8 +866,8 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } @@ -841,23 +894,30 @@ var rSquare = Element{ 150264569250089173, } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *Element) ToMont() *Element { +func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *Element) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[32:40], z[0]) + binary.BigEndian.PutUint64(b[24:32], z[1]) + binary.BigEndian.PutUint64(b[16:24], z[2]) + binary.BigEndian.PutUint64(b[8:16], z[3]) + binary.BigEndian.PutUint64(b[0:8], z[4]) + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -876,49 +936,49 @@ func (z *Element) Text(base int) string { if base == 10 { var zzNeg Element zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } zz := *z - zz.FromMont() + zz.fromMont() if zz.FitsOnOneWord() { return strconv.FormatUint(zz[0], base) } - vv := bigIntPool.Get().(*big.Int) - r := zz.ToBigInt(vv).Text(base) - bigIntPool.Put(vv) + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) return r } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[32:40], z[0]) - binary.BigEndian.PutUint64(b[24:32], z[1]) - binary.BigEndian.PutUint64(b[16:24], z[2]) - binary.BigEndian.PutUint64(b[8:16], z[3]) - binary.BigEndian.PutUint64(b[0:8], z[4]) - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[32:40], _z[0]) - binary.BigEndian.PutUint64(res[24:32], _z[1]) - binary.BigEndian.PutUint64(res[16:24], _z[2]) - binary.BigEndian.PutUint64(res[8:16], _z[3]) - binary.BigEndian.PutUint64(res[0:8], _z[4]) +// Bits provides access to z by returning its value as a little-endian [5]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [5]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -931,19 +991,44 @@ func (z *Element) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } +// SetBytesCanonical interprets e as the bytes of a big-endian 40-byte integer. +// If e is not a 40-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fp.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -961,17 +1046,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -993,20 +1077,20 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z // // The number prefix determines the actual base: A prefix of -// ''0b'' or ''0B'' selects base 2, ''0'', ''0o'' or ''0O'' selects base 8, -// and ''0x'' or ''0X'' selects base 16. Otherwise, the selected base is 10 +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For base 16, lower and upper case letters are considered the same: // The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. // -// An underscore character ''_'' may appear between a base +// An underscore character ”_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as a panic if there @@ -1015,7 +1099,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -1024,7 +1108,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -1064,7 +1148,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -1073,10 +1157,83 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 40-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[32:40]) + z[1] = binary.BigEndian.Uint64((*b)[24:32]) + z[2] = binary.BigEndian.Uint64((*b)[16:24]) + z[3] = binary.BigEndian.Uint64((*b)[8:16]) + z[4] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[32:40], e[0]) + binary.BigEndian.PutUint64((*b)[24:32], e[1]) + binary.BigEndian.PutUint64((*b)[16:24], e[2]) + binary.BigEndian.PutUint64((*b)[8:16], e[3]) + binary.BigEndian.PutUint64((*b)[0:8], e[4]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + z[4] = binary.LittleEndian.Uint64((*b)[32:40]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) + binary.LittleEndian.PutUint64((*b)[32:40], e[4]) +} + +func (littleEndian) String() string { return "LittleEndian" } + // Legendre returns the Legendre symbol of z (either +1, -1, or 0.) func (z *Element) Legendre() int { var l Element @@ -1109,7 +1266,7 @@ func (z *Element) Sqrt(x *Element) *Element { // y = x^((s+1)/2)) = w * x y.Mul(x, &w) - // b = x^s = w * w * x = y * x + // b = xˢ = w * w * x = y * x b.Mul(&w, &y) // g = nonResidue ^ s @@ -1123,7 +1280,7 @@ func (z *Element) Sqrt(x *Element) *Element { r := uint64(20) // compute legendre symbol - // t = x^((q-1)/2) = r-1 squaring of x^s + // t = x^((q-1)/2) = r-1 squaring of xˢ t = b for i := uint64(0); i < r-1; i++ { t.Square(&t) @@ -1317,7 +1474,7 @@ func (z *Element) Inverse(x *Element) *Element { // we would multiply by pSq up to 13times; // on x86, the assembly routine outperforms generic code for mul by word // on arm64, we may loose up to ~5% for 6 limbs - mul(&v, &v, &a) + v.Mul(&v, &a) } u.Set(x) // for correctness check @@ -1333,17 +1490,28 @@ func (z *Element) Inverse(x *Element) *Element { // correctness check v.Mul(&u, z) if !v.IsOne() && !u.IsZero() { - return z.inverseExp(&u) + return z.inverseExp(u) } return z } // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) -func (z *Element) inverseExp(x *Element) *Element { - qMinusTwo := Modulus() - qMinusTwo.Sub(qMinusTwo, big.NewInt(2)) - return z.Exp(*x, qMinusTwo) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z } // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits @@ -1458,7 +1626,7 @@ func (z *Element) montReduceSigned(x *Element, xHi uint64) { z[4], z[3] = madd2(m, q4, t[i+4], C) } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) diff --git a/ecc/bls24-315/fp/element_mul_adx_amd64.s b/ecc/bls24-315/fp/element_mul_adx_amd64.s deleted file mode 100644 index c02648d3a..000000000 --- a/ecc/bls24-315/fp/element_mul_adx_amd64.s +++ /dev/null @@ -1,634 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x6fe802ff40300001 -DATA q<>+8(SB)/8, $0x421ee5da52bde502 -DATA q<>+16(SB)/8, $0xdec1d01aa27a1ae0 -DATA q<>+24(SB)/8, $0xd3f7498be97c5eaf -DATA q<>+32(SB)/8, $0x04c23a02b586d650 -GLOBL q<>(SB), (RODATA+NOPTR), $40 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x702ff9ff402fffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, rb0, rb1, rb2, rb3, rb4) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), NOSPLIT, $0-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - MOVQ x+8(FP), DI - - // x[0] -> R9 - // x[1] -> R10 - // x[2] -> R11 - MOVQ 0(DI), R9 - MOVQ 8(DI), R10 - MOVQ 16(DI), R11 - MOVQ y+16(FP), R12 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // clear the flags - XORQ AX, AX - MOVQ 0(R12), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R9, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R10, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R11, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(DI), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 8(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 16(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 24(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 32(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // reduce element(R14,R13,CX,BX,SI) using temp registers (R8,DI,R12,R9,R10) - REDUCE(R14,R13,CX,BX,SI,R8,DI,R12,R9,R10) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - RET - -TEXT ·fromMont(SB), NOSPLIT, $0-8 - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - - // reduce element(R14,R13,CX,BX,SI) using temp registers (DI,R8,R9,R10,R11) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - RET diff --git a/ecc/bls24-315/fp/element_mul_amd64.s b/ecc/bls24-315/fp/element_mul_amd64.s index 94089b607..51165684d 100644 --- a/ecc/bls24-315/fp/element_mul_amd64.s +++ b/ecc/bls24-315/fp/element_mul_amd64.s @@ -1,4 +1,4 @@ -// +build !amd64_adx +// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-315/fp/element_ops_amd64.go b/ecc/bls24-315/fp/element_ops_amd64.go index a3c830471..83bba45ae 100644 --- a/ecc/bls24-315/fp/element_ops_amd64.go +++ b/ecc/bls24-315/fp/element_ops_amd64.go @@ -1,3 +1,6 @@ +//go:build !purego +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,7 +38,70 @@ func fromMont(res *Element) func reduce(res *Element) // Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) +// +// a = a + b (mod q) +// b = a - b (mod q) +// //go:noescape func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bls24-315/fp/element_ops_amd64.s b/ecc/bls24-315/fp/element_ops_amd64.s index c70e0a5ce..9528ab595 100644 --- a/ecc/bls24-315/fp/element_ops_amd64.s +++ b/ecc/bls24-315/fp/element_ops_amd64.s @@ -1,3 +1,5 @@ +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bls24-315/fp/element_ops_noasm.go b/ecc/bls24-315/fp/element_ops_noasm.go deleted file mode 100644 index 906080833..000000000 --- a/ecc/bls24-315/fp/element_ops_noasm.go +++ /dev/null @@ -1,62 +0,0 @@ -//go:build !amd64 -// +build !amd64 - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fp - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 8178485296672800069, - 8476448362227282520, - 14180928431697993131, - 4308307642551989706, - 120359802761433421, - } - x.Mul(x, &y) -} - -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} diff --git a/ecc/bls24-315/fp/element_ops_purego.go b/ecc/bls24-315/fp/element_ops_purego.go new file mode 100644 index 000000000..9a557a358 --- /dev/null +++ b/ecc/bls24-315/fp/element_ops_purego.go @@ -0,0 +1,582 @@ +//go:build !amd64 || purego +// +build !amd64 purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 8178485296672800069, + 8476448362227282520, + 14180928431697993131, + 4308307642551989706, + 120359802761433421, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + var t0, t1, t2, t3, t4 uint64 + var u0, u1, u2, u3, u4 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + u4, t4 = bits.Mul64(v, y[4]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], _ = bits.Sub64(z[4], q4, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3, t4 uint64 + var u0, u1, u2, u3, u4 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + u4, t4 = bits.Mul64(v, x[4]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], _ = bits.Sub64(z[4], q4, b) + } + return z +} diff --git a/ecc/bls24-315/fp/element_test.go b/ecc/bls24-315/fp/element_test.go index 5a9a9e08f..3276dfec3 100644 --- a/ecc/bls24-315/fp/element_test.go +++ b/ecc/bls24-315/fp/element_test.go @@ -23,7 +23,7 @@ import ( "math/big" "math/bits" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field" mrand "math/rand" "testing" @@ -182,17 +182,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -640,7 +633,7 @@ func TestElementBitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPairElement) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -755,7 +748,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -768,13 +761,13 @@ func TestElementAdd(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -803,17 +796,17 @@ func TestElementAdd(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -864,7 +857,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -877,13 +870,13 @@ func TestElementSub(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -912,17 +905,17 @@ func TestElementSub(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -973,7 +966,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -986,7 +979,7 @@ func TestElementMul(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -1000,7 +993,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1040,11 +1033,11 @@ func TestElementMul(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -1057,7 +1050,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1109,7 +1102,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1122,14 +1115,14 @@ func TestElementDiv(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1158,18 +1151,18 @@ func TestElementDiv(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1220,7 +1213,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1233,13 +1226,13 @@ func TestElementExp(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1268,17 +1261,17 @@ func TestElementExp(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1323,7 +1316,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1344,14 +1337,14 @@ func TestElementSquare(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1395,7 +1388,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1416,14 +1409,14 @@ func TestElementInverse(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1467,7 +1460,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1488,14 +1481,14 @@ func TestElementSqrt(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1539,7 +1532,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1560,14 +1553,14 @@ func TestElementDouble(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1611,7 +1604,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1632,14 +1625,14 @@ func TestElementNeg(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -2027,7 +2020,7 @@ func TestElementNegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -2160,17 +2153,17 @@ func TestElementFromMont(t *testing.T) { func(a testPairElement) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPairElement) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -2251,7 +2244,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } @@ -2316,7 +2309,7 @@ func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { return field.BigIntMatchUint64Slice(&aIntMod, slice) } -//TODO: Phase out in favor of property based testing +// TODO: Phase out in favor of property based testing func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { if err := z.matchVeryBigInt(aHi, aInt); err != nil { @@ -2359,7 +2352,7 @@ func TestElementInversionCorrectionFactorFormula(t *testing.T) { inversionCorrectionFactorWord3, inversionCorrectionFactorWord4, } - inversionCorrectionFactor.ToBigInt(&refFactorInt) + inversionCorrectionFactor.toBigInt(&refFactorInt) if refFactorInt.Cmp(factorInt) != 0 { t.Error("mismatch") @@ -2401,7 +2394,7 @@ func TestElementInversionCorrectionFactor(t *testing.T) { if !oneInv.Equal(&one) { var i big.Int - oneInv.ToBigIntRegular(&i) // no montgomery + oneInv.BigInt(&i) // no montgomery i.ModInverse(&i, Modulus()) var fac Element fac.setBigInt(&i) // back to montgomery @@ -2548,7 +2541,7 @@ func TestElement0Inverse(t *testing.T) { } } -//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen func TestUpdateFactorSubtraction(t *testing.T) { for i := 0; i < 1000; i++ { @@ -2761,11 +2754,11 @@ func randomizeUpdateFactors() (int64, int64) { func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { var p1 big.Int - x.ToBigInt(&p1) + x.toBigInt(&p1) p1.Mul(&p1, big.NewInt(xC)) var p2 big.Int - y.ToBigInt(&p2) + y.toBigInt(&p2) p2.Mul(&p2, big.NewInt(yC)) p1.Add(&p1, &p2) @@ -2802,7 +2795,7 @@ func montReduce(res *big.Int, x *big.Int) { } func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { - z.ToBigInt(i) + z.toBigInt(i) var upperWord big.Int upperWord.SetUint64(xHi) upperWord.Lsh(&upperWord, Limbs*64) @@ -2820,7 +2813,7 @@ func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { var xInt big.Int - x.ToBigInt(&xInt) + x.toBigInt(&xInt) xInt.Mul(&xInt, big.NewInt(c)) @@ -2831,7 +2824,7 @@ func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, result func approximateRef(x *Element) uint64 { var asInt big.Int - x.ToBigInt(&asInt) + x.toBigInt(&asInt) n := x.BitLen() if n <= 64 { diff --git a/ecc/bls24-315/fr/doc.go b/ecc/bls24-315/fr/doc.go index 603ee608a..087f41798 100644 --- a/ecc/bls24-315/fr/doc.go +++ b/ecc/bls24-315/fr/doc.go @@ -21,30 +21,33 @@ // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [4]uint64 // -// Usage +// type Element [4]uint64 +// +// # Usage // // Example API signature: -// // Mul z = x * y (mod q) -// func (z *Element) Mul(x, y *Element) *Element +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) // // Modulus q = // -// q[base10] = 11502027791375260645628074404575422495959608200132055716665986169834464870401 -// q[base16] = 0x196deac24a9da12b25fc7ec9cf927a98c8c480ece644e36419d0c5fd00c00001 +// q[base10] = 11502027791375260645628074404575422495959608200132055716665986169834464870401 +// q[base16] = 0x196deac24a9da12b25fc7ec9cf927a98c8c480ece644e36419d0c5fd00c00001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package fr diff --git a/ecc/bls24-315/fr/element.go b/ecc/bls24-315/fr/element.go index 12f3caea4..7f6b4aa34 100644 --- a/ecc/bls24-315/fr/element.go +++ b/ecc/bls24-315/fr/element.go @@ -20,13 +20,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "github.com/consensys/gnark-crypto/field" "io" "math/big" "math/bits" "reflect" "strconv" "strings" - "sync" ) // Element represents a field element stored on 4 words (uint64) @@ -35,18 +35,18 @@ import ( // // Modulus q = // -// q[base10] = 11502027791375260645628074404575422495959608200132055716665986169834464870401 -// q[base16] = 0x196deac24a9da12b25fc7ec9cf927a98c8c480ece644e36419d0c5fd00c00001 +// q[base10] = 11502027791375260645628074404575422495959608200132055716665986169834464870401 +// q[base16] = 0x196deac24a9da12b25fc7ec9cf927a98c8c480ece644e36419d0c5fd00c00001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [4]uint64 const ( - Limbs = 4 // number of 64 bits words needed to represent a Element - Bits = 253 // number of bits needed to represent a Element - Bytes = Limbs * 8 // number of bytes needed to represent a Element + Limbs = 4 // number of 64 bits words needed to represent a Element + Bits = 253 // number of bits needed to represent a Element + Bytes = 32 // number of bytes needed to represent a Element ) // Field modulus q @@ -68,8 +68,8 @@ var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int // -// q[base10] = 11502027791375260645628074404575422495959608200132055716665986169834464870401 -// q[base16] = 0x196deac24a9da12b25fc7ec9cf927a98c8c480ece644e36419d0c5fd00c00001 +// q[base10] = 11502027791375260645628074404575422495959608200132055716665986169834464870401 +// q[base16] = 0x196deac24a9da12b25fc7ec9cf927a98c8c480ece644e36419d0c5fd00c00001 func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } @@ -78,12 +78,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 2184305180030271487 -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("196deac24a9da12b25fc7ec9cf927a98c8c480ece644e36419d0c5fd00c00001", 16) } @@ -91,8 +85,9 @@ func init() { // NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v Element -// v.SetUint64(...) +// +// var v Element +// v.SetUint64(...) func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) @@ -103,7 +98,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -133,14 +128,15 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported // supported types: -// Element -// *Element -// uint64 -// int -// string (see SetString for valid formats) -// *big.Int -// big.Int -// []byte +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { if i1 == nil { return nil, errors.New("can't set fr.Element with ") @@ -250,15 +246,13 @@ func (z *Element) IsOne() bool { // IsUint64 reports whether z can be represented as an uint64. func (z *Element) IsUint64() bool { zz := *z - zz.FromMont() + zz.fromMont() return zz.FitsOnOneWord() } // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -270,15 +264,12 @@ func (z *Element) FitsOnOneWord() bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[3] > _x[3] { return 1 } else if _z[3] < _x[3] { @@ -309,8 +300,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 930102168266997761, 0) @@ -401,67 +391,9 @@ func (z *Element) Halve() { } -// Mul z = x * y (mod q) -// -// x and y must be strictly inferior to q -func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be strictly inferior to q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } @@ -475,7 +407,7 @@ func (z *Element) Add(x, y *Element) *Element { z[2], carry = bits.Add64(x[2], y[2], carry) z[3], _ = bits.Add64(x[3], y[3], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -495,7 +427,7 @@ func (z *Element) Double(x *Element) *Element { z[2], carry = bits.Add64(x[2], x[2], carry) z[3], _ = bits.Add64(x[3], x[3], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -548,65 +480,147 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // see Mul for algorithm documentation - var t [4]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [5]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + + if t[4] != 0 { + // we need to reduce, we have a result on 5 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], _ = bits.Sub64(t[3], q3, b) + return } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - z[3], z[2] = madd3(m, q3, c[0], c[2], c[1]) - } - - // if z >= q → z -= q + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -614,7 +628,6 @@ func _mulGeneric(z, x, y *Element) { z[2], b = bits.Sub64(z[2], q2, b) z[3], _ = bits.Sub64(z[3], q3, b) } - } func _fromMontGeneric(z *Element) { @@ -658,7 +671,7 @@ func _fromMontGeneric(z *Element) { z[3] = C } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -670,7 +683,7 @@ func _fromMontGeneric(z *Element) { func _reduceGeneric(z *Element) { - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -734,6 +747,35 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + // Exp z = xᵏ (mod q) func (z *Element) Exp(x Element, k *big.Int) *Element { if k.IsUint64() && k.Uint64() == 0 { @@ -748,8 +790,8 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } @@ -775,23 +817,29 @@ var rSquare = Element{ 584663452775307866, } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *Element) ToMont() *Element { +func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *Element) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[24:32], z[0]) + binary.BigEndian.PutUint64(b[16:24], z[1]) + binary.BigEndian.PutUint64(b[8:16], z[2]) + binary.BigEndian.PutUint64(b[0:8], z[3]) + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -810,47 +858,49 @@ func (z *Element) Text(base int) string { if base == 10 { var zzNeg Element zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } zz := *z - zz.FromMont() + zz.fromMont() if zz.FitsOnOneWord() { return strconv.FormatUint(zz[0], base) } - vv := bigIntPool.Get().(*big.Int) - r := zz.ToBigInt(vv).Text(base) - bigIntPool.Put(vv) + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) return r } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[24:32], z[0]) - binary.BigEndian.PutUint64(b[16:24], z[1]) - binary.BigEndian.PutUint64(b[8:16], z[2]) - binary.BigEndian.PutUint64(b[0:8], z[3]) - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[24:32], _z[0]) - binary.BigEndian.PutUint64(res[16:24], _z[1]) - binary.BigEndian.PutUint64(res[8:16], _z[2]) - binary.BigEndian.PutUint64(res[0:8], _z[3]) +// Bits provides access to z by returning its value as a little-endian [4]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [4]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -863,19 +913,44 @@ func (z *Element) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } +// SetBytesCanonical interprets e as the bytes of a big-endian 32-byte integer. +// If e is not a 32-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fr.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -893,17 +968,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -925,20 +999,20 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z // // The number prefix determines the actual base: A prefix of -// ''0b'' or ''0B'' selects base 2, ''0'', ''0o'' or ''0O'' selects base 8, -// and ''0x'' or ''0X'' selects base 16. Otherwise, the selected base is 10 +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For base 16, lower and upper case letters are considered the same: // The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. // -// An underscore character ''_'' may appear between a base +// An underscore character ”_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as a panic if there @@ -947,7 +1021,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -956,7 +1030,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -996,7 +1070,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -1005,10 +1079,79 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 32-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[24:32]) + z[1] = binary.BigEndian.Uint64((*b)[16:24]) + z[2] = binary.BigEndian.Uint64((*b)[8:16]) + z[3] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[24:32], e[0]) + binary.BigEndian.PutUint64((*b)[16:24], e[1]) + binary.BigEndian.PutUint64((*b)[8:16], e[2]) + binary.BigEndian.PutUint64((*b)[0:8], e[3]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) +} + +func (littleEndian) String() string { return "LittleEndian" } + // Legendre returns the Legendre symbol of z (either +1, -1, or 0.) func (z *Element) Legendre() int { var l Element @@ -1041,7 +1184,7 @@ func (z *Element) Sqrt(x *Element) *Element { // y = x^((s+1)/2)) = w * x y.Mul(x, &w) - // b = x^s = w * w * x = y * x + // b = xˢ = w * w * x = y * x b.Mul(&w, &y) // g = nonResidue ^ s @@ -1054,7 +1197,7 @@ func (z *Element) Sqrt(x *Element) *Element { r := uint64(22) // compute legendre symbol - // t = x^((q-1)/2) = r-1 squaring of x^s + // t = x^((q-1)/2) = r-1 squaring of xˢ t = b for i := uint64(0); i < r-1; i++ { t.Square(&t) @@ -1244,7 +1387,7 @@ func (z *Element) Inverse(x *Element) *Element { // we would multiply by pSq up to 13times; // on x86, the assembly routine outperforms generic code for mul by word // on arm64, we may loose up to ~5% for 6 limbs - mul(&v, &v, &a) + v.Mul(&v, &a) } u.Set(x) // for correctness check @@ -1259,17 +1402,28 @@ func (z *Element) Inverse(x *Element) *Element { // correctness check v.Mul(&u, z) if !v.IsOne() && !u.IsZero() { - return z.inverseExp(&u) + return z.inverseExp(u) } return z } // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) -func (z *Element) inverseExp(x *Element) *Element { - qMinusTwo := Modulus() - qMinusTwo.Sub(qMinusTwo, big.NewInt(2)) - return z.Exp(*x, qMinusTwo) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z } // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits @@ -1368,7 +1522,7 @@ func (z *Element) montReduceSigned(x *Element, xHi uint64) { z[3], z[2] = madd2(m, q3, t[i+3], C) } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) diff --git a/ecc/bls24-315/fr/element_mul_adx_amd64.s b/ecc/bls24-315/fr/element_mul_adx_amd64.s deleted file mode 100644 index 9333858e9..000000000 --- a/ecc/bls24-315/fr/element_mul_adx_amd64.s +++ /dev/null @@ -1,465 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x19d0c5fd00c00001 -DATA q<>+8(SB)/8, $0xc8c480ece644e364 -DATA q<>+16(SB)/8, $0x25fc7ec9cf927a98 -DATA q<>+24(SB)/8, $0x196deac24a9da12b -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x1e5035fd00bfffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), NOSPLIT, $0-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -TEXT ·fromMont(SB), NOSPLIT, $0-8 - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET diff --git a/ecc/bls24-315/fr/element_mul_amd64.s b/ecc/bls24-315/fr/element_mul_amd64.s index fe961f243..e32f78354 100644 --- a/ecc/bls24-315/fr/element_mul_amd64.s +++ b/ecc/bls24-315/fr/element_mul_amd64.s @@ -1,4 +1,4 @@ -// +build !amd64_adx +// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-315/fr/element_ops_amd64.go b/ecc/bls24-315/fr/element_ops_amd64.go index b37a251b6..e40a9caed 100644 --- a/ecc/bls24-315/fr/element_ops_amd64.go +++ b/ecc/bls24-315/fr/element_ops_amd64.go @@ -1,3 +1,6 @@ +//go:build !purego +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,7 +38,70 @@ func fromMont(res *Element) func reduce(res *Element) // Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) +// +// a = a + b (mod q) +// b = a - b (mod q) +// //go:noescape func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bls24-315/fr/element_ops_amd64.s b/ecc/bls24-315/fr/element_ops_amd64.s index c7e207d85..e09e5ee14 100644 --- a/ecc/bls24-315/fr/element_ops_amd64.s +++ b/ecc/bls24-315/fr/element_ops_amd64.s @@ -1,3 +1,5 @@ +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bls24-315/fr/element_ops_noasm.go b/ecc/bls24-315/fr/element_ops_noasm.go deleted file mode 100644 index 6fdd1b07e..000000000 --- a/ecc/bls24-315/fr/element_ops_noasm.go +++ /dev/null @@ -1,61 +0,0 @@ -//go:build !amd64 -// +build !amd64 - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 16427853282514304894, - 880039980351915818, - 13098611234035318378, - 1598436289436461078, - } - x.Mul(x, &y) -} - -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} diff --git a/ecc/bls24-315/fr/element_ops_purego.go b/ecc/bls24-315/fr/element_ops_purego.go new file mode 100644 index 000000000..8dcb67f76 --- /dev/null +++ b/ecc/bls24-315/fr/element_ops_purego.go @@ -0,0 +1,443 @@ +//go:build !amd64 || purego +// +build !amd64 purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 16427853282514304894, + 880039980351915818, + 13098611234035318378, + 1598436289436461078, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} diff --git a/ecc/bls24-315/fr/element_test.go b/ecc/bls24-315/fr/element_test.go index 349dbb518..7aa0c5ffc 100644 --- a/ecc/bls24-315/fr/element_test.go +++ b/ecc/bls24-315/fr/element_test.go @@ -23,7 +23,7 @@ import ( "math/big" "math/bits" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field" mrand "math/rand" "testing" @@ -182,17 +182,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -638,7 +631,7 @@ func TestElementBitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPairElement) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -753,7 +746,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -766,13 +759,13 @@ func TestElementAdd(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -801,17 +794,17 @@ func TestElementAdd(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -862,7 +855,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -875,13 +868,13 @@ func TestElementSub(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -910,17 +903,17 @@ func TestElementSub(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -971,7 +964,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -984,7 +977,7 @@ func TestElementMul(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -998,7 +991,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1038,11 +1031,11 @@ func TestElementMul(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -1055,7 +1048,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1107,7 +1100,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1120,14 +1113,14 @@ func TestElementDiv(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1156,18 +1149,18 @@ func TestElementDiv(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1218,7 +1211,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1231,13 +1224,13 @@ func TestElementExp(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1266,17 +1259,17 @@ func TestElementExp(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1321,7 +1314,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1342,14 +1335,14 @@ func TestElementSquare(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1393,7 +1386,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1414,14 +1407,14 @@ func TestElementInverse(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1465,7 +1458,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1486,14 +1479,14 @@ func TestElementSqrt(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1537,7 +1530,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1558,14 +1551,14 @@ func TestElementDouble(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1609,7 +1602,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1630,14 +1623,14 @@ func TestElementNeg(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -2025,7 +2018,7 @@ func TestElementNegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -2158,17 +2151,17 @@ func TestElementFromMont(t *testing.T) { func(a testPairElement) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPairElement) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -2247,7 +2240,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } @@ -2309,7 +2302,7 @@ func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { return field.BigIntMatchUint64Slice(&aIntMod, slice) } -//TODO: Phase out in favor of property based testing +// TODO: Phase out in favor of property based testing func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { if err := z.matchVeryBigInt(aHi, aInt); err != nil { @@ -2351,7 +2344,7 @@ func TestElementInversionCorrectionFactorFormula(t *testing.T) { inversionCorrectionFactorWord2, inversionCorrectionFactorWord3, } - inversionCorrectionFactor.ToBigInt(&refFactorInt) + inversionCorrectionFactor.toBigInt(&refFactorInt) if refFactorInt.Cmp(factorInt) != 0 { t.Error("mismatch") @@ -2393,7 +2386,7 @@ func TestElementInversionCorrectionFactor(t *testing.T) { if !oneInv.Equal(&one) { var i big.Int - oneInv.ToBigIntRegular(&i) // no montgomery + oneInv.BigInt(&i) // no montgomery i.ModInverse(&i, Modulus()) var fac Element fac.setBigInt(&i) // back to montgomery @@ -2538,7 +2531,7 @@ func TestElement0Inverse(t *testing.T) { } } -//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen func TestUpdateFactorSubtraction(t *testing.T) { for i := 0; i < 1000; i++ { @@ -2751,11 +2744,11 @@ func randomizeUpdateFactors() (int64, int64) { func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { var p1 big.Int - x.ToBigInt(&p1) + x.toBigInt(&p1) p1.Mul(&p1, big.NewInt(xC)) var p2 big.Int - y.ToBigInt(&p2) + y.toBigInt(&p2) p2.Mul(&p2, big.NewInt(yC)) p1.Add(&p1, &p2) @@ -2792,7 +2785,7 @@ func montReduce(res *big.Int, x *big.Int) { } func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { - z.ToBigInt(i) + z.toBigInt(i) var upperWord big.Int upperWord.SetUint64(xHi) upperWord.Lsh(&upperWord, Limbs*64) @@ -2810,7 +2803,7 @@ func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { var xInt big.Int - x.ToBigInt(&xInt) + x.toBigInt(&xInt) xInt.Mul(&xInt, big.NewInt(c)) @@ -2821,7 +2814,7 @@ func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, result func approximateRef(x *Element) uint64 { var asInt big.Int - x.ToBigInt(&asInt) + x.toBigInt(&asInt) n := x.BitLen() if n <= 64 { diff --git a/ecc/bls24-315/fr/fri/fri.go b/ecc/bls24-315/fr/fri/fri.go index 20eee4fcb..7db8229f1 100644 --- a/ecc/bls24-315/fr/fri/fri.go +++ b/ecc/bls24-315/fr/fri/fri.go @@ -121,7 +121,6 @@ type Round struct { // a function is d-close to a low degree polynomial. // // It is composed of a series of Interactions, emulated with Fiat Shamir, -// type ProofOfProximity struct { // ID unique ID attached to the proof of proximity. It's needed for diff --git a/ecc/bls24-315/fr/gkr/gkr.go b/ecc/bls24-315/fr/gkr/gkr.go new file mode 100644 index 000000000..d8570d13a --- /dev/null +++ b/ecc/bls24-315/fr/gkr/gkr.go @@ -0,0 +1,774 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/sumcheck" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// Gate must be a low-degree polynomial +type Gate interface { + Evaluate(...fr.Element) fr.Element + Degree() int +} + +type Wire struct { + Gate Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]fr.Element + claimedEvaluations []fr.Element + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]fr.Element) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]fr.Element // x in the paper + claimedEvaluations []fr.Element // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + newEq.Eq(c.evaluationPoints[k]) + + eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// computeValAndStep returns val : i ↦ m(1, i...) and step : i ↦ m(1, i...) - m(0, i...) +func computeValAndStep(m polynomial.MultiLin, p *polynomial.Pool) (val polynomial.MultiLin, step polynomial.MultiLin) { + val = p.Clone(m[len(m)/2:]) + step = p.Clone(m[:len(m)/2]) + + valAsPoly, stepAsPoly := polynomial.Polynomial(val), polynomial.Polynomial(step) + + stepAsPoly.Sub(valAsPoly, stepAsPoly) + return +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = g_{j-1}(r_{j-1}). By convention, g_0 is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() (gJ polynomial.Polynomial) { + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + EVal, EStep := computeValAndStep(c.eq, c.manager.memPool) + + puVal := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO: Make a two-dimensional array struct, and index it i-first rather than inputI first: would result in scanning memory access in the "d" loop and obviate the gateInput variable + puStep := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO, ctd: the greater degGJ, the more this would matter + + for i, puI := range c.inputPreprocessors { + puVal[i], puStep[i] = computeValAndStep(puI, c.manager.memPool) + } + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + gJ = make([]fr.Element, degGJ) + + parallel := len(EVal) >= 1024 //TODO: Experiment with threshold + + var gateInput [][]fr.Element + + if parallel { + gateInput = [][]fr.Element{c.manager.memPool.Make(len(c.inputPreprocessors)), + c.manager.memPool.Make(len(c.inputPreprocessors))} + } else { + gateInput = [][]fr.Element{c.manager.memPool.Make(len(c.inputPreprocessors))} + } + + var wg sync.WaitGroup + + for d := 0; d < degGJ; d++ { + + notLastIteration := d+1 < degGJ + + sumOverI := func(res *fr.Element, gateInput []fr.Element, start, end int) { + for i := start; i < end; i++ { + + for inputI := range puVal { + gateInput[inputI].Set(&puVal[inputI][i]) + if notLastIteration { + puVal[inputI][i].Add(&puVal[inputI][i], &puStep[inputI][i]) + } + } + + // gJAtDI = gJ(d, i...) + gJAtDI := c.wire.Gate.Evaluate(gateInput...) + gJAtDI.Mul(&gJAtDI, &EVal[i]) + + res.Add(res, &gJAtDI) + + if notLastIteration { + EVal[i].Add(&EVal[i], &EStep[i]) + } + } + wg.Done() + } + + if parallel { + var firstHalf, secondHalf fr.Element + wg.Add(2) + go sumOverI(&secondHalf, gateInput[1], len(EVal)/2, len(EVal)) + go sumOverI(&firstHalf, gateInput[0], 0, len(EVal)/2) + wg.Wait() + gJ[d].Add(&firstHalf, &secondHalf) + } else { + wg.Add(1) // formalities + sumOverI(&gJ[d], gateInput[0], 0, len(EVal)) + } + } + + c.manager.memPool.Dump(gateInput...) + c.manager.memPool.Dump(EVal, EStep) + + for inputI := range puVal { + c.manager.memPool.Dump(puVal[inputI], puStep[inputI]) + } + + return +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { + c.eq.Fold(element) + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, pool *polynomial.Pool) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = pool + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< b { + return a + } + return b +} + +func ChallengeNames(sorted []*Wire, logNbInstances int, prefix string) []string { + + // Pre-compute the size TODO: Consider not doing this and just grow the list by appending + size := logNbInstances // first challenge + + for _, w := range sorted { + if w.noProof() { // no proof, no challenge + continue + } + if w.NbClaims() > 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + + claims := newClaimsManager(c, assignment, o.pool) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + + claims := newClaimsManager(c, assignment, o.pool) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]fr.Element) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return fmt.Errorf("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return fmt.Errorf("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +type IdentityGate struct{} + +func (IdentityGate) Evaluate(input ...fr.Element) fr.Element { + return input[0] +} + +func (IdentityGate) Degree() int { + return 1 +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = IdentityGate{} + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + + numEvaluations := 0 + + for _, w := range sortedWires { + if !w.IsInput() { + if numEvaluations == 0 { + numEvaluations = len(a[w.Inputs[0]]) + } + evals := make([]fr.Element, numEvaluations) + ins := make([]fr.Element, len(w.Inputs)) + for k := 0; k < numEvaluations; k++ { + for inI, in := range w.Inputs { + ins[inI] = a[in][k] + } + evals[k] = w.Gate.Evaluate(ins...) + } + a[w] = evals + } + } + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} diff --git a/ecc/bls24-315/fr/gkr/gkr_test.go b/ecc/bls24-315/fr/gkr/gkr_test.go new file mode 100644 index 000000000..bc0801503 --- /dev/null +++ b/ecc/bls24-315/fr/gkr/gkr_test.go @@ -0,0 +1,722 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/sumcheck" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + pool := polynomial.NewPool(256, 1<<11) + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, &pool) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandom(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/ConsenSys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandom(slice []fr.Element) { + for i := range slice { + slice[i].SetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, testCase.transcriptSetting()) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, testCase.transcriptSetting()) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(&test_vector_utils.MapHash{Map: testCase.Hash}, []byte{1})) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "../../../../internal/generator/gkr/test_vectors" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func BenchmarkGkrMimc(b *testing.B) { + const N = 1 << 19 + fmt.Println("creating circuit structure") + c := mimcCircuit(91) + + in0 := make([]fr.Element, N) + in1 := make([]fr.Element, N) + setRandom(in0) + setRandom(in1) + + fmt.Println("evaluating circuit") + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + + //b.ResetTimer() + fmt.Println("constructing proof") + Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate string `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = gates[c[i].Gate] + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +var gates map[string]Gate + +func init() { + gates = make(map[string]Gate) + gates["identity"] = IdentityGate{} + gates["mul"] = mulGate{} + gates["mimc"] = mimcCipherGate{} //TODO: Add ark + gates["select-input-3"] = _select(2) +} + +type mimcCipherGate struct { + ark fr.Element +} + +func (m mimcCipherGate) Evaluate(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]). + Add(&sum, &m.ark) + + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +func (m mimcCipherGate) Degree() int { + return 7 +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash *test_vector_utils.ElementMap + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash string `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash *test_vector_utils.ElementMap + if _hash, err = test_vector_utils.ElementMapFromFile(filepath.Join(dir, info.Hash)); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func (c *TestCase) transcriptSetting(initialChallenge ...[]byte) fiatshamir.Settings { + return fiatshamir.WithHash(&test_vector_utils.MapHash{Map: c.Hash}, initialChallenge...) +} + +type mulGate struct{} + +func (g mulGate) Evaluate(element ...fr.Element) (result fr.Element) { + result.Mul(&element[0], &element[1]) + return +} + +func (g mulGate) Degree() int { + return 2 +} + +type _select int + +func (g _select) Evaluate(in ...fr.Element) fr.Element { + return in[g] +} + +func (g _select) Degree() int { + return 1 +} diff --git a/ecc/bls24-315/fr/kzg/kzg.go b/ecc/bls24-315/fr/kzg/kzg.go index 505083c87..a844089a5 100644 --- a/ecc/bls24-315/fr/kzg/kzg.go +++ b/ecc/bls24-315/fr/kzg/kzg.go @@ -84,9 +84,6 @@ func NewSRS(size uint64, bAlpha *big.Int) (*SRS, error) { for i := 1; i < len(alphas); i++ { alphas[i].Mul(&alphas[i-1], &alpha) } - for i := 0; i < len(alphas); i++ { - alphas[i].FromMont() - } g1s := bls24315.BatchScalarMultiplicationG1(&gen1Aff, alphas) copy(srs.G1[1:], g1s) @@ -125,7 +122,7 @@ func Commit(p []fr.Element, srs *SRS, nbTasks ...int) (Digest, error) { var res bls24315.G1Affine - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} if len(nbTasks) > 0 { config.NbTasks = nbTasks[0] } @@ -395,7 +392,7 @@ func BatchVerifyMultiPoints(digests []Digest, proofs []OpeningProof, points []fr for i := 0; i < len(randomNumbers); i++ { quotients[i].Set(&proofs[i].H) } - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} _, err := foldedQuotients.MultiExp(quotients, randomNumbers, config) if err != nil { return nil @@ -478,7 +475,7 @@ func fold(di []Digest, fai []fr.Element, ci []fr.Element) (Digest, fr.Element, e // fold the digests ∑ᵢ[cᵢ]([fᵢ(α)]G₁) var foldedDigests Digest - _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{ScalarsMont: true}) + _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{}) if err != nil { return foldedDigests, foldedEvaluations, err } diff --git a/ecc/bls24-315/fr/mimc/decompose.go b/ecc/bls24-315/fr/mimc/decompose.go new file mode 100644 index 000000000..4a962631f --- /dev/null +++ b/ecc/bls24-315/fr/mimc/decompose.go @@ -0,0 +1,46 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" +) + +// Decompose interpret rawBytes as a bigInt x in big endian, +// and returns the digits of x (from LSB to MSB) when x is written +// in basis modulo. +func Decompose(rawBytes []byte) []fr.Element { + + rawBigInt := big.NewInt(0).SetBytes(rawBytes) + modulo := fr.Modulus() + + // maximum number of chunks that a function + maxNbChunks := len(rawBytes) / fr.Bytes + + res := make([]fr.Element, 0, maxNbChunks) + var tmp fr.Element + t := new(big.Int) + for rawBigInt.Sign() != 0 { + rawBigInt.DivMod(rawBigInt, modulo, t) + tmp.SetBigInt(t) + res = append(res, tmp) + } + + return res +} diff --git a/ecc/bls24-315/fr/mimc/decompose_test.go b/ecc/bls24-315/fr/mimc/decompose_test.go new file mode 100644 index 000000000..817a588f2 --- /dev/null +++ b/ecc/bls24-315/fr/mimc/decompose_test.go @@ -0,0 +1,53 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" +) + +func TestDecompose(t *testing.T) { + + // create 10 random digits in basis r + nbDigits := 10 + a := make([]fr.Element, nbDigits) + for i := 0; i < nbDigits; i++ { + a[i].SetRandom() + } + + // create a big int whose digits in basis r are a + m := fr.Modulus() + var b, tmp big.Int + for i := nbDigits - 1; i >= 0; i-- { + b.Mul(&b, m) + a[i].ToBigIntRegular(&tmp) + b.Add(&b, &tmp) + } + + // query the decomposition and compare to a + bb := b.Bytes() + d := Decompose(bb) + for i := 0; i < nbDigits; i++ { + if !d[i].Equal(&a[i]) { + t.Fatal("error decomposition") + } + } + +} diff --git a/ecc/bls24-315/fr/mimc/mimc.go b/ecc/bls24-315/fr/mimc/mimc.go index 63e8f5e1f..f9971a900 100644 --- a/ecc/bls24-315/fr/mimc/mimc.go +++ b/ecc/bls24-315/fr/mimc/mimc.go @@ -17,6 +17,7 @@ package mimc import ( + "errors" "hash" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" @@ -91,44 +92,44 @@ func (d *digest) BlockSize() int { } // Write (via the embedded io.Writer interface) adds more data to the running hash. -// It never returns an error. +// +// Each []byte block of size BlockSize represents a big endian fr.Element. +// +// If len(p) is not a multiple of BlockSize and any of the []byte in p represent an integer +// larger than fr.Modulus, this function returns an error. +// +// To hash arbitrary data ([]byte not representing canonical field elements) use Decompose +// function in this package. func (d *digest) Write(p []byte) (n int, err error) { n = len(p) + if n%BlockSize != 0 { + return 0, errors.New("invalid input length: must represent a list of field elements, expects a []byte of len m*BlockSize") + } + + // ensure each block represents a field element in canonical reduced form + for i := 0; i < n; i += BlockSize { + if _, err = fr.BigEndian.Element((*[BlockSize]byte)(p[i : i+BlockSize])); err != nil { + return 0, err + } + } + d.data = append(d.data, p...) return } -// Hash hash using Miyaguchi–Preneel: +// Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form func (d *digest) checksum() fr.Element { + // Write guarantees len(data) % BlockSize == 0 - var buffer [BlockSize]byte - var x fr.Element - - // if data size is not multiple of BlockSizes we padd: - // .. || 0xaf8 -> .. || 0x0000...0af8 - if len(d.data)%BlockSize != 0 { - q := len(d.data) / BlockSize - r := len(d.data) % BlockSize - sliceq := make([]byte, q*BlockSize) - copy(sliceq, d.data) - slicer := make([]byte, r) - copy(slicer, d.data[q*BlockSize:]) - sliceremainder := make([]byte, BlockSize-r) - d.data = append(sliceq, sliceremainder...) - d.data = append(d.data, slicer...) - } - + // TODO @ThomasPiellard shouldn't Sum() returns an error if there is no data? if len(d.data) == 0 { - d.data = make([]byte, 32) + d.data = make([]byte, BlockSize) } - nbChunks := len(d.data) / BlockSize - - for i := 0; i < nbChunks; i++ { - copy(buffer[:], d.data[i*BlockSize:(i+1)*BlockSize]) - x.SetBytes(buffer[:]) + for i := 0; i < len(d.data); i += BlockSize { + x, _ := fr.BigEndian.Element((*[BlockSize]byte)(d.data[i : i+BlockSize])) r := d.encrypt(x) d.h.Add(&r, &d.h).Add(&d.h, &x) } diff --git a/ecc/bls24-315/fr/pedersen/pedersen.go b/ecc/bls24-315/fr/pedersen/pedersen.go new file mode 100644 index 000000000..09e19bea1 --- /dev/null +++ b/ecc/bls24-315/fr/pedersen/pedersen.go @@ -0,0 +1,113 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package pedersen + +import ( + "crypto/rand" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls24-315" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "math/big" +) + +// Key for proof and verification +type Key struct { + g bls24315.G2Affine // TODO @tabaie: does this really have to be randomized? + gRootSigmaNeg bls24315.G2Affine //gRootSigmaNeg = g^{-1/σ} + basis []bls24315.G1Affine + basisExpSigma []bls24315.G1Affine +} + +func randomOnG2() (bls24315.G2Affine, error) { // TODO: Add to G2.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return bls24315.G2Affine{}, err + } + return bls24315.HashToG2(gBytes, []byte("random on g2")) +} + +func Setup(basis []bls24315.G1Affine) (Key, error) { + var ( + k Key + err error + ) + + if k.g, err = randomOnG2(); err != nil { + return k, err + } + + var modMinusOne big.Int + modMinusOne.Sub(fr.Modulus(), big.NewInt(1)) + var sigma *big.Int + if sigma, err = rand.Int(rand.Reader, &modMinusOne); err != nil { + return k, err + } + sigma.Add(sigma, big.NewInt(1)) + + var sigmaInvNeg big.Int + sigmaInvNeg.ModInverse(sigma, fr.Modulus()) + sigmaInvNeg.Sub(fr.Modulus(), &sigmaInvNeg) + k.gRootSigmaNeg.ScalarMultiplication(&k.g, &sigmaInvNeg) + + k.basisExpSigma = make([]bls24315.G1Affine, len(basis)) + for i := range basis { + k.basisExpSigma[i].ScalarMultiplication(&basis[i], sigma) + } + + k.basis = basis + return k, err +} + +func (k *Key) Commit(values []fr.Element) (commitment bls24315.G1Affine, knowledgeProof bls24315.G1Affine, err error) { + + if len(values) != len(k.basis) { + err = fmt.Errorf("unexpected number of values") + return + } + + // TODO @gbotrel this will spawn more than one task, see + // https://github.com/ConsenSys/gnark-crypto/issues/269 + config := ecc.MultiExpConfig{ + NbTasks: 1, // TODO Experiment + } + + if _, err = commitment.MultiExp(k.basis, values, config); err != nil { + return + } + + _, err = knowledgeProof.MultiExp(k.basisExpSigma, values, config) + + return +} + +// VerifyKnowledgeProof checks if the proof of knowledge is valid +func (k *Key) VerifyKnowledgeProof(commitment bls24315.G1Affine, knowledgeProof bls24315.G1Affine) error { + + if !commitment.IsInSubGroup() || !knowledgeProof.IsInSubGroup() { + return fmt.Errorf("subgroup check failed") + } + + product, err := bls24315.Pair([]bls24315.G1Affine{commitment, knowledgeProof}, []bls24315.G2Affine{k.g, k.gRootSigmaNeg}) + if err != nil { + return err + } + if product.IsOne() { + return nil + } + return fmt.Errorf("proof rejected") +} diff --git a/ecc/bls24-315/fr/pedersen/pedersen_test.go b/ecc/bls24-315/fr/pedersen/pedersen_test.go new file mode 100644 index 000000000..841896894 --- /dev/null +++ b/ecc/bls24-315/fr/pedersen/pedersen_test.go @@ -0,0 +1,90 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package pedersen + +import ( + "github.com/consensys/gnark-crypto/ecc/bls24-315" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/stretchr/testify/assert" + "math/rand" + "testing" +) + +func interfaceSliceToFrSlice(t *testing.T, values ...interface{}) []fr.Element { + res := make([]fr.Element, len(values)) + for i, v := range values { + _, err := res[i].SetInterface(v) + assert.NoError(t, err) + } + return res +} + +func randomFrSlice(t *testing.T, size int) []interface{} { + res := make([]interface{}, size) + var err error + for i := range res { + var v fr.Element + res[i], err = v.SetRandom() + assert.NoError(t, err) + } + return res +} + +func randomOnG1() (bls24315.G1Affine, error) { // TODO: Add to G1.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return bls24315.G1Affine{}, err + } + return bls24315.HashToG1(gBytes, []byte("random on g2")) +} + +func testCommit(t *testing.T, values ...interface{}) { + + basis := make([]bls24315.G1Affine, len(values)) + for i := range basis { + var err error + basis[i], err = randomOnG1() + assert.NoError(t, err) + } + + var ( + key Key + err error + commitment, pok bls24315.G1Affine + ) + + key, err = Setup(basis) + assert.NoError(t, err) + commitment, pok, err = key.Commit(interfaceSliceToFrSlice(t, values...)) + assert.NoError(t, err) + assert.NoError(t, key.VerifyKnowledgeProof(commitment, pok)) + + pok.Neg(&pok) + assert.NotNil(t, key.VerifyKnowledgeProof(commitment, pok)) +} + +func TestCommitToOne(t *testing.T) { + testCommit(t, 1) +} + +func TestCommitSingle(t *testing.T) { + testCommit(t, randomFrSlice(t, 1)...) +} + +func TestCommitFiveElements(t *testing.T) { + testCommit(t, randomFrSlice(t, 5)...) +} diff --git a/ecc/bls24-315/fr/plookup/vector.go b/ecc/bls24-315/fr/plookup/vector.go index 43752ab7f..ca0e5c78d 100644 --- a/ecc/bls24-315/fr/plookup/vector.go +++ b/ecc/bls24-315/fr/plookup/vector.go @@ -125,7 +125,8 @@ func evaluateAccumulationPolynomial(lf, lt, lh1, lh2 []fr.Element, beta, gamma f // evaluateNumBitReversed computes the evaluation (shifted, bit reversed) of h where // h = (x-1)*z*(1+\beta)*(\gamma+f)*(\gamma(1+\beta) + t+ \beta*t(gX)) - -// (x-1)*z(gX)*(\gamma(1+\beta) + h_{1} + \beta*h_{1}(gX))*(\gamma(1+\beta) + h_{2} + \beta*h_{2}(gX) ) +// +// (x-1)*z(gX)*(\gamma(1+\beta) + h_{1} + \beta*h_{1}(gX))*(\gamma(1+\beta) + h_{2} + \beta*h_{2}(gX) ) // // * cz, ch1, ch2, ct, cf are the polynomials z, h1, h2, t, f in canonical basis // * _lz, _lh1, _lh2, _lt, _lf are the polynomials z, h1, h2, t, f in shifted Lagrange basis (domainBig) @@ -370,7 +371,6 @@ func computeQuotientCanonical(alpha fr.Element, lh, lh0, lhn, lh1h2 []fr.Element // before generating a lookup proof), the commitment needs to be done on the // table sorted. Otherwise the commitment in proof.t will not be the same as // the public commitment: it will contain the same values, but permuted. -// func ProveLookupVector(srs *kzg.SRS, f, t Table) (ProofLookupVector, error) { // res diff --git a/ecc/bls24-315/fr/polynomial/multilin.go b/ecc/bls24-315/fr/polynomial/multilin.go index 17bbf1ed2..708104eaa 100644 --- a/ecc/bls24-315/fr/polynomial/multilin.go +++ b/ecc/bls24-315/fr/polynomial/multilin.go @@ -18,6 +18,7 @@ package polynomial import ( "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "math/bits" ) // MultiLin tracks the values of a (dense i.e. not sparse) multilinear polynomial @@ -46,44 +47,62 @@ func (m *MultiLin) Fold(r fr.Element) { *m = (*m)[:mid] } +func (m MultiLin) Sum() fr.Element { + s := m[0] + for i := 1; i < len(m); i++ { + s.Add(&s, &m[i]) + } + return s +} + +func _clone(m MultiLin, p *Pool) MultiLin { + if p == nil { + return m.Clone() + } else { + return p.Clone(m) + } +} + +func _dump(m MultiLin, p *Pool) { + if p != nil { + p.Dump(m) + } +} + // Evaluate extrapolate the value of the multilinear polynomial corresponding to m // on the given coordinates -func (m MultiLin) Evaluate(coordinates []fr.Element) fr.Element { +func (m MultiLin) Evaluate(coordinates []fr.Element, p *Pool) fr.Element { // Folding is a mutating operation - bkCopy := m.Clone() + bkCopy := _clone(m, p) // Evaluate step by step through repeated folding (i.e. evaluation at the first remaining variable) for _, r := range coordinates { bkCopy.Fold(r) } - return bkCopy[0] + result := bkCopy[0] + + _dump(bkCopy, p) + return result } -// Clone creates a deep copy of a book-keeping table. +// Clone creates a deep copy of a bookkeeping table. // Both multilinear interpolation and sumcheck require folding an underlying // array, but folding changes the array. To do both one requires a deep copy -// of the book-keeping table. +// of the bookkeeping table. func (m MultiLin) Clone() MultiLin { - tableDeepCopy := Make(len(m)) - copy(tableDeepCopy, m) - return tableDeepCopy + res := make(MultiLin, len(m)) + copy(res, m) + return res } // Add two bookKeepingTables func (m *MultiLin) Add(left, right MultiLin) { size := len(left) // Check that left and right have the same size - if len(right) != size { - panic("Left and right do not have the right size") + if len(right) != size || len(*m) != size { + panic("left, right and destination must have the right size") } - // Reallocate the table if necessary - if cap(*m) < size { - *m = make([]fr.Element, size) - } - - // Resize the destination table - *m = (*m)[:size] // Add elementwise for i := 0; i < size; i++ { @@ -93,15 +112,17 @@ func (m *MultiLin) Add(left, right MultiLin) { // EvalEq computes Eq(q₁, ... , qₙ, h₁, ... , hₙ) = Π₁ⁿ Eq(qᵢ, hᵢ) // where Eq(x,y) = xy + (1-x)(1-y) = 1 - x - y + xy + xy interpolates -// _________________ -// | | | -// | 0 | 1 | -// |_______|_______| -// y | | | -// | 1 | 0 | -// |_______|_______| // -// x +// _________________ +// | | | +// | 0 | 1 | +// |_______|_______| +// y | | | +// | 1 | 0 | +// |_______|_______| +// +// x +// // In other words the polynomial evaluated here is the multilinear extrapolation of // one that evaluates to q' == h' for vectors q', h' of binary values func EvalEq(q, h []fr.Element) fr.Element { @@ -128,10 +149,7 @@ func (m *MultiLin) Eq(q []fr.Element) { n := len(q) if len(*m) != 1< 0 { - i.Sub(fr.Modulus(), &i) - i.Neg(&i) +func (p Polynomial) SetZero() { + for i := 0; i < len(p); i++ { + p[i].SetZero() } - return i } func (p Polynomial) Text(base int) string { @@ -171,12 +175,13 @@ func (p Polynomial) Text(base int) string { continue } - i := signedBigInt(&p[d]) + pD := p[d] + pDText := pD.Text(base) initialLen := builder.Len() - if i.Sign() < 1 { - i.Neg(&i) + if pDText[0] == '-' { + pDText = pDText[1:] if first { builder.WriteString("-") } else { @@ -188,13 +193,8 @@ func (p Polynomial) Text(base int) string { first = false - asInt64 := int64(0) - if i.IsInt64() { - asInt64 = i.Int64() - } - - if asInt64 != 1 || d == 0 { - builder.WriteString(i.Text(base)) + if !pD.IsOne() || d == 0 { + builder.WriteString(pDText) } if builder.Len()-initialLen > 10 { diff --git a/ecc/bls24-315/fr/polynomial/polynomial_test.go b/ecc/bls24-315/fr/polynomial/polynomial_test.go index c23d86197..ad1f480f0 100644 --- a/ecc/bls24-315/fr/polynomial/polynomial_test.go +++ b/ecc/bls24-315/fr/polynomial/polynomial_test.go @@ -17,10 +17,10 @@ package polynomial import ( + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/stretchr/testify/assert" "math/big" "testing" - - "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" ) func TestPolynomialEval(t *testing.T) { @@ -206,3 +206,13 @@ func TestPolynomialAdd(t *testing.T) { t.Fatal("side effect, _f2 should not have been modified") } } + +func TestPolynomialText(t *testing.T) { + var one, negTwo fr.Element + one.SetOne() + negTwo.SetInt64(-2) + + p := Polynomial{one, negTwo, one} + + assert.Equal(t, "X² - 2X + 1", p.Text(10)) +} diff --git a/ecc/bls24-315/fr/polynomial/pool.go b/ecc/bls24-315/fr/polynomial/pool.go index 045ba8eb1..d8d4e570e 100644 --- a/ecc/bls24-315/fr/polynomial/pool.go +++ b/ecc/bls24-315/fr/polynomial/pool.go @@ -17,114 +17,187 @@ package polynomial import ( + "encoding/json" "fmt" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "reflect" + "runtime" + "sort" "sync" "unsafe" ) // Memory management for polynomials -// Copied verbatim from gkr repo +// WARNING: This is not thread safe TODO: Make sure that is not a problem +// TODO: There is a lot of "unsafe" memory management here and needs to be vetted thoroughly -// Sets a maximum for the array size we keep in pool -const maxNForLargePool int = 1 << 24 -const maxNForSmallPool int = 256 +type sizedPool struct { + maxN int + pool sync.Pool + stats poolStats +} + +type inUseData struct { + allocatedFor []uintptr + pool *sizedPool +} -// Aliases because it is annoying to use arrays in all the places -type largeArr = [maxNForLargePool]fr.Element -type smallArr = [maxNForSmallPool]fr.Element +type Pool struct { + //lock sync.Mutex + inUse map[*fr.Element]inUseData + subPools []sizedPool +} -var rC = sync.Map{} +func (p *sizedPool) get(n int) *fr.Element { + p.stats.maake(n) + return p.pool.Get().(*fr.Element) +} -var ( - largePool = sync.Pool{ - New: func() interface{} { - var res largeArr - return &res - }, - } - smallPool = sync.Pool{ - New: func() interface{} { - var res smallArr - return &res - }, +func (p *sizedPool) put(ptr *fr.Element) { + p.stats.dump() + p.pool.Put(ptr) +} + +func NewPool(maxN ...int) (pool Pool) { + + sort.Ints(maxN) + pool = Pool{ + inUse: make(map[*fr.Element]inUseData), + subPools: make([]sizedPool, len(maxN)), } -) -// ClearPool Clears the pool completely, shields against memory leaks -// Eg: if we forgot to dump a polynomial at some point, this will ensure the value get dumped eventually -// Returns how many polynomials were cleared that way -func ClearPool() int { - res := 0 - rC.Range(func(k, _ interface{}) bool { - switch ptr := k.(type) { - case *largeArr: - largePool.Put(ptr) - case *smallArr: - smallPool.Put(ptr) - default: - panic(fmt.Sprintf("tried to clear %v", reflect.TypeOf(ptr))) + for i := range pool.subPools { + subPool := &pool.subPools[i] + subPool.maxN = maxN[i] + subPool.pool = sync.Pool{ + New: func() interface{} { + subPool.stats.Allocated++ + return getDataPointer(make([]fr.Element, 0, subPool.maxN)) + }, } - res++ - return true - }) - return res + } + return } -// CountPool Returns the number of elements in the pool without mutating it -func CountPool() int { - res := 0 - rC.Range(func(_, _ interface{}) bool { - res++ - return true - }) - return res +func (p *Pool) findCorrespondingPool(n int) *sizedPool { + poolI := 0 + for poolI < len(p.subPools) && n > p.subPools[poolI].maxN { + poolI++ + } + return &p.subPools[poolI] // out of bounds error here would mean that n is too large } -// Make tries to find a reusable polynomial or allocates a new one -func Make(n int) []fr.Element { - if n > maxNForLargePool { - panic(fmt.Sprintf("been provided with size of %v but the maximum is %v", n, maxNForLargePool)) +func (p *Pool) Make(n int) []fr.Element { + pool := p.findCorrespondingPool(n) + ptr := pool.get(n) + p.addInUse(ptr, pool) + return unsafe.Slice(ptr, n) +} + +// Dump dumps a set of polynomials into the pool +func (p *Pool) Dump(slices ...[]fr.Element) { + for _, slice := range slices { + ptr := getDataPointer(slice) + if metadata, ok := p.inUse[ptr]; ok { + delete(p.inUse, ptr) + metadata.pool.put(ptr) + } else { + panic("attempting to dump a slice not created by the pool") + } } +} + +func (p *Pool) addInUse(ptr *fr.Element, pool *sizedPool) { + pcs := make([]uintptr, 2) + n := runtime.Callers(3, pcs) - if n <= maxNForSmallPool { - ptr := smallPool.Get().(*smallArr) - rC.Store(ptr, struct{}{}) // registers the pointer being used - return (*ptr)[:n] + if prevPcs, ok := p.inUse[ptr]; ok { // TODO: remove if unnecessary for security + panic(fmt.Errorf("re-allocated non-dumped slice, previously allocated at %v", runtime.CallersFrames(prevPcs.allocatedFor))) } + p.inUse[ptr] = inUseData{ + allocatedFor: pcs[:n], + pool: pool, + } +} - ptr := largePool.Get().(*largeArr) - rC.Store(ptr, struct{}{}) // remember we allocated the pointer is being used - return (*ptr)[:n] +func printFrame(frame runtime.Frame) { + fmt.Printf("\t%s line %d, function %s\n", frame.File, frame.Line, frame.Function) } -// Dump dumps a set of polynomials into the pool -// Returns the number of deallocated polys -func Dump(arrs ...[]fr.Element) int { - cnt := 0 - for _, arr := range arrs { - ptr := ptr(arr) - pool := &smallPool - if len(arr) > maxNForSmallPool { - pool = &largePool - } - // If the rC did not register, then - // either the array was allocated somewhere else which can be ignored - // otherwise a double put which MUST be ignored - if _, ok := rC.Load(ptr); ok { - pool.Put(ptr) - // And deregisters the ptr - rC.Delete(ptr) - cnt++ +func (p *Pool) printInUse() { + fmt.Println("slices never dumped allocated at:") + for _, pcs := range p.inUse { + fmt.Println("-------------------------") + + var frame runtime.Frame + frames := runtime.CallersFrames(pcs.allocatedFor) + more := true + for more { + frame, more = frames.Next() + printFrame(frame) } } - return cnt } -func ptr(m []fr.Element) unsafe.Pointer { - if cap(m) != maxNForSmallPool && cap(m) != maxNForLargePool { - panic(fmt.Sprintf("can't cast to large or small array, the put array's is %v it should have capacity %v or %v", cap(m), maxNForLargePool, maxNForSmallPool)) +type poolStats struct { + Used int + Allocated int + ReuseRate float64 + InUse int + GreatestNUsed int + SmallestNUsed int +} + +type poolsStats struct { + SubPools []poolStats + InUse int +} + +func (s *poolStats) maake(n int) { + s.Used++ + s.InUse++ + if n > s.GreatestNUsed { + s.GreatestNUsed = n + } + if s.SmallestNUsed == 0 || s.SmallestNUsed > n { + s.SmallestNUsed = n + } +} + +func (s *poolStats) dump() { + s.InUse-- +} + +func (s *poolStats) finalize() { + s.ReuseRate = float64(s.Used) / float64(s.Allocated) +} + +func getDataPointer(slice []fr.Element) *fr.Element { + header := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) + return (*fr.Element)(unsafe.Pointer(header.Data)) +} + +func (p *Pool) PrintPoolStats() { + InUse := 0 + subStats := make([]poolStats, len(p.subPools)) + for i := range p.subPools { + subPool := &p.subPools[i] + subPool.stats.finalize() + subStats[i] = subPool.stats + InUse += subPool.stats.InUse } - return unsafe.Pointer(&m[0]) + + poolsStats := poolsStats{ + SubPools: subStats, + InUse: InUse, + } + serialized, _ := json.MarshalIndent(poolsStats, "", " ") + fmt.Println(string(serialized)) + p.printInUse() +} + +func (p *Pool) Clone(slice []fr.Element) []fr.Element { + res := p.Make(len(slice)) + copy(res, slice) + return res } diff --git a/ecc/bls24-315/fr/sumcheck/sumcheck.go b/ecc/bls24-315/fr/sumcheck/sumcheck.go new file mode 100644 index 000000000..54b1dd785 --- /dev/null +++ b/ecc/bls24-315/fr/sumcheck/sumcheck.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = &transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return fmt.Errorf("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/ecc/bls24-315/fr/sumcheck/sumcheck_test.go b/ecc/bls24-315/fr/sumcheck/sumcheck_test.go new file mode 100644 index 000000000..cfcaee5e9 --- /dev/null +++ b/ecc/bls24-315/fr/sumcheck/sumcheck_test.go @@ -0,0 +1,161 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/ecc/bls24-315/fr/test_vector_utils/test_vector_utils.go b/ecc/bls24-315/fr/test_vector_utils/test_vector_utils.go new file mode 100644 index 000000000..17f7ca0ce --- /dev/null +++ b/ecc/bls24-315/fr/test_vector_utils/test_vector_utils.go @@ -0,0 +1,429 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package test_vector_utils + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/polynomial" + "hash" + + "os" + "path/filepath" + "reflect" + "sort" + "strconv" + "strings" +) + +type ElementTriplet struct { + key1 fr.Element + key2 fr.Element + key2Present bool + value fr.Element + used bool +} + +func (t *ElementTriplet) CmpKey(o *ElementTriplet) int { + if cmp1 := t.key1.Cmp(&o.key1); cmp1 != 0 { + return cmp1 + } + + if t.key2Present { + if o.key2Present { + return t.key2.Cmp(&o.key2) + } + return 1 + } else { + if o.key2Present { + return -1 + } + return 0 + } +} + +var MapCache = make(map[string]*ElementMap) + +func ElementMapFromFile(path string) (*ElementMap, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if h, ok := MapCache[path]; ok { + return h, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var asMap map[string]interface{} + if err = json.Unmarshal(bytes, &asMap); err != nil { + return nil, err + } + + var h ElementMap + if h, err = CreateElementMap(asMap); err == nil { + MapCache[path] = &h + } + + return &h, err + + } else { + return nil, err + } +} + +func CreateElementMap(rawMap map[string]interface{}) (ElementMap, error) { + res := make(ElementMap, 0, len(rawMap)) + + for k, v := range rawMap { + var entry ElementTriplet + if _, err := SetElement(&entry.value, v); err != nil { + return nil, err + } + + key := strings.Split(k, ",") + switch len(key) { + case 1: + entry.key2Present = false + case 2: + entry.key2Present = true + if _, err := SetElement(&entry.key2, key[1]); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("cannot parse %T as one or two field elements", v) + } + if _, err := SetElement(&entry.key1, key[0]); err != nil { + return nil, err + } + + res = append(res, &entry) + } + + res.sort() + return res, nil +} + +type ElementMap []*ElementTriplet + +type MapHash struct { + Map *ElementMap + state fr.Element + stateValid bool +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func (m *MapHash) Write(p []byte) (n int, err error) { + var x fr.Element + for i := 0; i < len(p); i += fr.Bytes { + x.SetBytes(p[i:min(len(p), i+fr.Bytes)]) + if err = m.write(x); err != nil { + return + } + } + n = len(p) + return +} + +func (m *MapHash) Sum(b []byte) []byte { + mP := *m + if _, err := mP.Write(b); err != nil { + panic(err) + } + bytes := mP.state.Bytes() + return bytes[:] +} + +func (m *MapHash) Reset() { + m.stateValid = false +} + +func (m *MapHash) Size() int { + return fr.Bytes +} + +func (m *MapHash) BlockSize() int { + return fr.Bytes +} + +func (m *MapHash) write(x fr.Element) error { + X := &x + Y := &m.state + if !m.stateValid { + Y = nil + } + var err error + if m.state, err = m.Map.FindPair(X, Y); err == nil { + m.stateValid = true + } + return err +} + +func (t *ElementTriplet) writeKey(sb *strings.Builder) { + sb.WriteRune('"') + sb.WriteString(t.key1.String()) + if t.key2Present { + sb.WriteRune(',') + sb.WriteString(t.key2.String()) + } + sb.WriteRune('"') +} +func (m *ElementMap) UnusedEntries() []interface{} { + unused := make([]interface{}, 0) + for _, v := range *m { + if !v.used { + var vInterface interface{} + if v.key2Present { + vInterface = []interface{}{ElementToInterface(&v.key1), ElementToInterface(&v.key2)} + } else { + vInterface = ElementToInterface(&v.key1) + } + unused = append(unused, vInterface) + } + } + return unused +} + +func (m *ElementMap) sort() { + sort.Slice(*m, func(i, j int) bool { + return (*m)[i].CmpKey((*m)[j]) <= 0 + }) +} + +func (m *ElementMap) find(toFind *ElementTriplet) (fr.Element, error) { + i := sort.Search(len(*m), func(i int) bool { return (*m)[i].CmpKey(toFind) >= 0 }) + + if i < len(*m) && (*m)[i].CmpKey(toFind) == 0 { + (*m)[i].used = true + return (*m)[i].value, nil + } + var sb strings.Builder + sb.WriteString("no value available for input ") + toFind.writeKey(&sb) + return fr.Element{}, fmt.Errorf(sb.String()) +} + +func (m *ElementMap) FindPair(x *fr.Element, y *fr.Element) (fr.Element, error) { + + toFind := ElementTriplet{ + key1: *x, + key2Present: y != nil, + } + + if y != nil { + toFind.key2 = *y + } + + return m.find(&toFind) +} + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type MessageCounter struct { + startState uint64 + state uint64 + step uint64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += uint64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + uint64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: uint64(startState), state: uint64(startState), step: uint64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + text := x.Text(10) + if len(text) < 10 && !strings.Contains(text, "/") { + if i, err := strconv.Atoi(text); err != nil { + panic(err.Error()) + } else { + return i + } + } + return text +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/ecc/bls24-315/g1.go b/ecc/bls24-315/g1.go index c97b14149..d9c125045 100644 --- a/ecc/bls24-315/g1.go +++ b/ecc/bls24-315/g1.go @@ -17,13 +17,12 @@ package bls24315 import ( - "math/big" - "runtime" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls24-315/fp" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" + "runtime" ) // G1Affine point in affine coordinates @@ -36,7 +35,7 @@ type G1Jac struct { X, Y, Z fp.Element } -// g1JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +// g1JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) type g1JacExtended struct { X, Y, ZZ, ZZZ fp.Element } @@ -50,6 +49,13 @@ func (p *G1Affine) Set(a *G1Affine) *G1Affine { return p } +// setInfinity sets p to O +func (p *G1Affine) setInfinity() *G1Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + // ScalarMultiplication computes and returns p = a ⋅ s func (p *G1Affine) ScalarMultiplication(a *G1Affine, s *big.Int) *G1Affine { var _p G1Jac @@ -373,6 +379,7 @@ func (p *G1Jac) IsOnCurve() bool { } // IsInSubGroup returns true if p is on the r-torsion, false otherwise. + // Z[r,0]+Z[-lambdaG1Affine, 1] is the kernel // of (u,v)->u+lambdaG1Affinev mod r. Expressing r, lambdaG1Affine as // polynomials in x, a short vector of this Zmodule is @@ -474,8 +481,8 @@ func (p *G1Jac) mulGLV(a *G1Jac, s *big.Int) *G1Jac { // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // we don't target constant-timeness so we check first if we increase the bounds or not maxBit := k1.BitLen() @@ -585,15 +592,15 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { return p } - var A, B, X1ZZ2, X2ZZ1, Y1ZZZ2, Y2ZZZ1 fp.Element + var A, B, U1, U2, S1, S2 fp.Element // p2: q, p1: p - X2ZZ1.Mul(&q.X, &p.ZZ) - X1ZZ2.Mul(&p.X, &q.ZZ) - A.Sub(&X2ZZ1, &X1ZZ2) - Y2ZZZ1.Mul(&q.Y, &p.ZZZ) - Y1ZZZ2.Mul(&p.Y, &q.ZZZ) - B.Sub(&Y2ZZZ1, &Y1ZZZ2) + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) if A.IsZero() { if B.IsZero() { @@ -605,11 +612,7 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { return p } - var U1, U2, S1, S2, P, R, PP, PPP, Q, V fp.Element - U1.Mul(&p.X, &q.ZZ) - U2.Mul(&q.X, &p.ZZ) - S1.Mul(&p.Y, &q.ZZZ) - S2.Mul(&q.Y, &p.ZZZ) + var P, R, PP, PPP, Q, V fp.Element P.Sub(&U2, &U1) R.Sub(&S2, &S1) PP.Square(&P) @@ -634,6 +637,8 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { // double point in Jacobian extended coordinates // http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well func (p *g1JacExtended) double(q *g1JacExtended) *g1JacExtended { var U, V, W, S, XX, M fp.Element @@ -876,95 +881,72 @@ func BatchJacobianToAffineG1(points []G1Jac) []G1Affine { // and return resulting points in affine coordinates // uses a simple windowed-NAF like exponentiation algorithm func BatchScalarMultiplicationG1(base *G1Affine, scalars []fr.Element) []G1Affine { - // approximate cost in group ops is // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) nbPoints := uint64(len(scalars)) min := ^uint64(0) bestC := 0 - for c := 2; c < 18; c++ { - cost := uint64(1 << (c - 1)) - nbChunks := uint64(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - cost += nbPoints * ((fr.Limbs * 64) + nbChunks) + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c - 1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add if cost < min { min = cost bestC = c } } c := uint64(bestC) // window size - nbChunks := int(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c } - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // precompute all powers of base for our window // note here that if performance is critical, we can implement as in the msmX methods // this allocation to be on the stack - baseTable := make([]G1Jac, (1 << (c - 1))) - baseTable[0].Set(&g1Infinity) - baseTable[0].AddMixed(base) + baseTable := make([]G1Jac, (1 << (maxC - 1))) + baseTable[0].FromAffine(base) for i := 1; i < len(baseTable); i++ { baseTable[i] = baseTable[i-1] baseTable[i].AddMixed(base) } - - pScalars, _ := partitionScalars(scalars, c, false, runtime.NumCPU()) - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := 0; chunk < nbChunks; chunk++ { - jc := uint64(uint64(chunk) * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = (64%c) != 0 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } // convert our base exp table into affine to use AddMixed baseTableAff := BatchJacobianToAffineG1(baseTable) toReturn := make([]G1Jac, len(scalars)) + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + // for each digit, take value in the base table, double it c time, voilà. - parallel.Execute(len(pScalars), func(start, end int) { + parallel.Execute(len(scalars), func(start, end int) { var p G1Jac for i := start; i < end; i++ { p.Set(&g1Infinity) for chunk := nbChunks - 1; chunk >= 0; chunk-- { - s := selectors[chunk] if chunk != nbChunks-1 { for j := uint64(0); j < c; j++ { p.DoubleAssign() } } + offset := chunk * len(scalars) + digit := digits[i+offset] - bits := (pScalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (pScalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { + if digit == 0 { continue } // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { + if digit&1 == 0 { // add - p.AddMixed(&baseTableAff[bits-1]) + p.AddMixed(&baseTableAff[(digit>>1)-1]) } else { // sub - t := baseTableAff[bits & ^msbWindow] + t := baseTableAff[digit>>1] t.Neg(&t) p.AddMixed(&t) } @@ -978,3 +960,54 @@ func BatchScalarMultiplicationG1(base *G1Affine, scalars []fr.Element) []G1Affin toReturnAff := BatchJacobianToAffineG1(toReturn) return toReturnAff } + +// batch add affine coordinates +// using batch inversion +// special cases (doubling, infinity) must be filtered out before this call +func batchAddG1Affine[TP pG1Affine, TPP ppG1Affine, TC cG1Affine](R *TPP, P *TP, batchSize int) { + var lambda, lambdain TC + + // add part + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // invert denominator using montgomery batch invert technique + { + var accumulator fp.Element + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var d fp.Element + var rr G1Affine + + // add part + for j := 0; j < batchSize; j++ { + // computa lambda + d.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &d) + + // compute X, Y + rr.X.Square(&lambda[j]) + rr.X.Sub(&rr.X, &(*R)[j].X) + rr.X.Sub(&rr.X, &(*P)[j].X) + d.Sub(&(*R)[j].X, &rr.X) + rr.Y.Mul(&lambda[j], &d) + rr.Y.Sub(&rr.Y, &(*R)[j].Y) + (*R)[j].Set(&rr) + } +} diff --git a/ecc/bls24-315/g1_test.go b/ecc/bls24-315/g1_test.go index 5eba73ee9..6f553b866 100644 --- a/ecc/bls24-315/g1_test.go +++ b/ecc/bls24-315/g1_test.go @@ -19,6 +19,7 @@ package bls24315 import ( "fmt" "math/big" + "math/rand" "testing" "github.com/consensys/gnark-crypto/ecc/bls24-315/fp" @@ -338,7 +339,7 @@ func TestG1AffineOps(t *testing.T) { r := fr.Modulus() var g G1Jac - g.mulGLV(&g1Gen, r) + g.ScalarMultiplication(&g1Gen, r) var scalar, blindedScalar, rminusone big.Int var op1, op2, op3, gneg G1Jac @@ -458,8 +459,7 @@ func TestG1AffineBatchScalarMultiplication(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } result := BatchScalarMultiplicationG1(&g1GenAff, sampleScalars[:]) @@ -472,7 +472,7 @@ func TestG1AffineBatchScalarMultiplication(t *testing.T) { var expectedJac G1Jac var expected G1Affine var b big.Int - expectedJac.mulGLV(&g1Gen, sampleScalars[i].ToBigInt(&b)) + expectedJac.ScalarMultiplication(&g1Gen, sampleScalars[i].ToBigIntRegular(&b)) expected.FromJacobian(&expectedJac) if !result[i].Equal(&expected) { return false @@ -499,6 +499,33 @@ func BenchmarkG1JacIsInSubGroup(b *testing.B) { } +func BenchmarkBatchAddG1Affine(b *testing.B) { + + var P, R pG1AffineC16 + var RR ppG1AffineC16 + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBasesG1(P[:]) + fillBenchBasesG1(R[:]) + + for i := 0; i < len(ridx); i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAddG1Affine[pG1AffineC16, ppG1AffineC16, cG1AffineC16](&RR, &P, len(P)) + } +} + func BenchmarkG1AffineBatchScalarMultiplication(b *testing.B) { // ensure every words of the scalars are filled var mixer fr.Element @@ -511,8 +538,7 @@ func BenchmarkG1AffineBatchScalarMultiplication(b *testing.B) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } for i := 5; i <= pow; i++ { diff --git a/ecc/bls24-315/g2.go b/ecc/bls24-315/g2.go index 81fb7f642..266eaab42 100644 --- a/ecc/bls24-315/g2.go +++ b/ecc/bls24-315/g2.go @@ -17,13 +17,12 @@ package bls24315 import ( - "math/big" - "runtime" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" "github.com/consensys/gnark-crypto/ecc/bls24-315/internal/fptower" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" + "runtime" ) // G2Affine point in affine coordinates @@ -36,7 +35,7 @@ type G2Jac struct { X, Y, Z fptower.E4 } -// g2JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +// g2JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) type g2JacExtended struct { X, Y, ZZ, ZZZ fptower.E4 } @@ -55,6 +54,13 @@ func (p *G2Affine) Set(a *G2Affine) *G2Affine { return p } +// setInfinity sets p to O +func (p *G2Affine) setInfinity() *G2Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + // ScalarMultiplication computes and returns p = a ⋅ s func (p *G2Affine) ScalarMultiplication(a *G2Affine, s *big.Int) *G2Affine { var _p G2Jac @@ -371,7 +377,8 @@ func (p *G2Jac) IsOnCurve() bool { // IsInSubGroup returns true if p is on the r-torsion, false otherwise. // https://eprint.iacr.org/2021/1130.pdf, sec.4 -// ψ(p) = x₀ P +// and https://eprint.iacr.org/2022/352.pdf, sec. 4.2 +// ψ(p) = [x₀]P func (p *G2Jac) IsInSubGroup() bool { var res, tmp G2Jac tmp.psi(p) @@ -473,8 +480,8 @@ func (p *G2Jac) mulGLV(a *G2Jac, s *big.Int) *G2Jac { // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // we don't target constant-timeness so we check first if we increase the bounds or not maxBit := k1.BitLen() @@ -616,15 +623,15 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { return p } - var A, B, X1ZZ2, X2ZZ1, Y1ZZZ2, Y2ZZZ1 fptower.E4 + var A, B, U1, U2, S1, S2 fptower.E4 // p2: q, p1: p - X2ZZ1.Mul(&q.X, &p.ZZ) - X1ZZ2.Mul(&p.X, &q.ZZ) - A.Sub(&X2ZZ1, &X1ZZ2) - Y2ZZZ1.Mul(&q.Y, &p.ZZZ) - Y1ZZZ2.Mul(&p.Y, &q.ZZZ) - B.Sub(&Y2ZZZ1, &Y1ZZZ2) + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) if A.IsZero() { if B.IsZero() { @@ -636,11 +643,7 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { return p } - var U1, U2, S1, S2, P, R, PP, PPP, Q, V fptower.E4 - U1.Mul(&p.X, &q.ZZ) - U2.Mul(&q.X, &p.ZZ) - S1.Mul(&p.Y, &q.ZZZ) - S2.Mul(&q.Y, &p.ZZZ) + var P, R, PP, PPP, Q, V fptower.E4 P.Sub(&U2, &U1) R.Sub(&S2, &S1) PP.Square(&P) @@ -665,6 +668,8 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { // double point in Jacobian extended coordinates // http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well func (p *g2JacExtended) double(q *g2JacExtended) *g2JacExtended { var U, V, W, S, XX, M fptower.E4 @@ -888,93 +893,70 @@ func (p *g2Proj) FromAffine(Q *G2Affine) *g2Proj { // and return resulting points in affine coordinates // uses a simple windowed-NAF like exponentiation algorithm func BatchScalarMultiplicationG2(base *G2Affine, scalars []fr.Element) []G2Affine { - // approximate cost in group ops is // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) nbPoints := uint64(len(scalars)) min := ^uint64(0) bestC := 0 - for c := 2; c < 18; c++ { - cost := uint64(1 << (c - 1)) - nbChunks := uint64(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - cost += nbPoints * ((fr.Limbs * 64) + nbChunks) + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c - 1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add if cost < min { min = cost bestC = c } } c := uint64(bestC) // window size - nbChunks := int(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c } - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // precompute all powers of base for our window // note here that if performance is critical, we can implement as in the msmX methods // this allocation to be on the stack - baseTable := make([]G2Jac, (1 << (c - 1))) - baseTable[0].Set(&g2Infinity) - baseTable[0].AddMixed(base) + baseTable := make([]G2Jac, (1 << (maxC - 1))) + baseTable[0].FromAffine(base) for i := 1; i < len(baseTable); i++ { baseTable[i] = baseTable[i-1] baseTable[i].AddMixed(base) } - - pScalars, _ := partitionScalars(scalars, c, false, runtime.NumCPU()) - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := 0; chunk < nbChunks; chunk++ { - jc := uint64(uint64(chunk) * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = (64%c) != 0 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } toReturn := make([]G2Affine, len(scalars)) + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + // for each digit, take value in the base table, double it c time, voilà. - parallel.Execute(len(pScalars), func(start, end int) { + parallel.Execute(len(scalars), func(start, end int) { var p G2Jac for i := start; i < end; i++ { p.Set(&g2Infinity) for chunk := nbChunks - 1; chunk >= 0; chunk-- { - s := selectors[chunk] if chunk != nbChunks-1 { for j := uint64(0); j < c; j++ { p.DoubleAssign() } } + offset := chunk * len(scalars) + digit := digits[i+offset] - bits := (pScalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (pScalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { + if digit == 0 { continue } // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { + if digit&1 == 0 { // add - p.AddAssign(&baseTable[bits-1]) + p.AddAssign(&baseTable[(digit>>1)-1]) } else { // sub - t := baseTable[bits & ^msbWindow] + t := baseTable[digit>>1] t.Neg(&t) p.AddAssign(&t) } @@ -987,3 +969,54 @@ func BatchScalarMultiplicationG2(base *G2Affine, scalars []fr.Element) []G2Affin }) return toReturn } + +// batch add affine coordinates +// using batch inversion +// special cases (doubling, infinity) must be filtered out before this call +func batchAddG2Affine[TP pG2Affine, TPP ppG2Affine, TC cG2Affine](R *TPP, P *TP, batchSize int) { + var lambda, lambdain TC + + // add part + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // invert denominator using montgomery batch invert technique + { + var accumulator fptower.E4 + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var d fptower.E4 + var rr G2Affine + + // add part + for j := 0; j < batchSize; j++ { + // computa lambda + d.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &d) + + // compute X, Y + rr.X.Square(&lambda[j]) + rr.X.Sub(&rr.X, &(*R)[j].X) + rr.X.Sub(&rr.X, &(*P)[j].X) + d.Sub(&(*R)[j].X, &rr.X) + rr.Y.Mul(&lambda[j], &d) + rr.Y.Sub(&rr.Y, &(*R)[j].Y) + (*R)[j].Set(&rr) + } +} diff --git a/ecc/bls24-315/g2_test.go b/ecc/bls24-315/g2_test.go index bab8fbad1..f89f79150 100644 --- a/ecc/bls24-315/g2_test.go +++ b/ecc/bls24-315/g2_test.go @@ -19,6 +19,7 @@ package bls24315 import ( "fmt" "math/big" + "math/rand" "testing" "github.com/consensys/gnark-crypto/ecc/bls24-315/internal/fptower" @@ -339,7 +340,7 @@ func TestG2AffineOps(t *testing.T) { r := fr.Modulus() var g G2Jac - g.mulGLV(&g2Gen, r) + g.ScalarMultiplication(&g2Gen, r) var scalar, blindedScalar, rminusone big.Int var op1, op2, op3, gneg G2Jac @@ -464,8 +465,7 @@ func TestG2AffineBatchScalarMultiplication(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } result := BatchScalarMultiplicationG2(&g2GenAff, sampleScalars[:]) @@ -478,7 +478,7 @@ func TestG2AffineBatchScalarMultiplication(t *testing.T) { var expectedJac G2Jac var expected G2Affine var b big.Int - expectedJac.mulGLV(&g2Gen, sampleScalars[i].ToBigInt(&b)) + expectedJac.ScalarMultiplication(&g2Gen, sampleScalars[i].ToBigIntRegular(&b)) expected.FromJacobian(&expectedJac) if !result[i].Equal(&expected) { return false @@ -505,6 +505,33 @@ func BenchmarkG2JacIsInSubGroup(b *testing.B) { } +func BenchmarkBatchAddG2Affine(b *testing.B) { + + var P, R pG2AffineC16 + var RR ppG2AffineC16 + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBasesG2(P[:]) + fillBenchBasesG2(R[:]) + + for i := 0; i < len(ridx); i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAddG2Affine[pG2AffineC16, ppG2AffineC16, cG2AffineC16](&RR, &P, len(P)) + } +} + func BenchmarkG2AffineBatchScalarMultiplication(b *testing.B) { // ensure every words of the scalars are filled var mixer fr.Element @@ -517,8 +544,7 @@ func BenchmarkG2AffineBatchScalarMultiplication(b *testing.B) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } for i := 5; i <= pow; i++ { diff --git a/ecc/bls24-315/hash_to_g1.go b/ecc/bls24-315/hash_to_g1.go index faae02c7a..ecaa0185c 100644 --- a/ecc/bls24-315/hash_to_g1.go +++ b/ecc/bls24-315/hash_to_g1.go @@ -17,7 +17,6 @@ package bls24315 import ( - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls24-315/fp" "math/big" @@ -256,35 +255,14 @@ func g1EvalPolynomial(z *fp.Element, monic bool, coefficients []fp.Element, x *f z.Set(&dst) } -// hashToFp hashes msg to count prime field elements. -// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 -func hashToFp(msg, dst []byte, count int) ([]fp.Element, error) { - // 128 bits of security - // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 - const Bytes = 1 + (fp.Bits-1)/8 - const L = 16 + Bytes - - lenInBytes := count * L - pseudoRandomBytes, err := ecc.ExpandMsgXmd(msg, dst, lenInBytes) - if err != nil { - return nil, err - } - - res := make([]fp.Element, count) - for i := 0; i < count; i++ { - res[i].SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) - } - return res, nil -} - // g1Sgn0 is an algebraic substitute for the notion of sign in ordered fields // Namely, every non-zero quadratic residue in a finite field of characteristic =/= 2 has exactly two square roots, one of each sign // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#name-the-sgn0-function // The sign of an element is not obviously related to that of its Montgomery form func g1Sgn0(z *fp.Element) uint64 { - nonMont := *z - nonMont.FromMont() + nonMont := z.Bits() + // m == 1 return nonMont[0] % 2 @@ -302,11 +280,11 @@ func MapToG1(u fp.Element) G1Affine { // EncodeToG1 hashes a message to a point on the G1 curve using the SSWU map. // It is faster than HashToG1, but the result is not uniformly distributed. Unsuitable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func EncodeToG1(msg, dst []byte) (G1Affine, error) { var res G1Affine - u, err := hashToFp(msg, dst, 1) + u, err := fp.Hash(msg, dst, 1) if err != nil { return res, err } @@ -322,9 +300,9 @@ func EncodeToG1(msg, dst []byte) (G1Affine, error) { // HashToG1 hashes a message to a point on the G1 curve using the SSWU map. // Slower than EncodeToG1, but usable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func HashToG1(msg, dst []byte) (G1Affine, error) { - u, err := hashToFp(msg, dst, 2*1) + u, err := fp.Hash(msg, dst, 2*1) if err != nil { return G1Affine{}, err } diff --git a/ecc/bls24-315/hash_to_g1_test.go b/ecc/bls24-315/hash_to_g1_test.go index 9680c97e1..d536b1a23 100644 --- a/ecc/bls24-315/hash_to_g1_test.go +++ b/ecc/bls24-315/hash_to_g1_test.go @@ -62,7 +62,7 @@ func TestG1SqrtRatio(t *testing.T) { func TestHashToFpG1(t *testing.T) { for _, c := range encodeToG1Vector.cases { - elems, err := hashToFp([]byte(c.msg), encodeToG1Vector.dst, 1) + elems, err := fp.Hash([]byte(c.msg), encodeToG1Vector.dst, 1) if err != nil { t.Error(err) } @@ -70,7 +70,7 @@ func TestHashToFpG1(t *testing.T) { } for _, c := range hashToG1Vector.cases { - elems, err := hashToFp([]byte(c.msg), hashToG1Vector.dst, 2*1) + elems, err := fp.Hash([]byte(c.msg), hashToG1Vector.dst, 2*1) if err != nil { t.Error(err) } @@ -220,7 +220,7 @@ func BenchmarkHashToG1(b *testing.B) { } } -//TODO: Crude. Do something clever in Jacobian +// TODO: Crude. Do something clever in Jacobian func isOnE1Prime(p G1Affine) bool { var A, B fp.Element @@ -247,7 +247,7 @@ func isOnE1Prime(p G1Affine) bool { return LHS.Equal(&RHS) } -//Only works on simple extensions (two-story towers) +// Only works on simple extensions (two-story towers) func g1CoordSetString(z *fp.Element, s string) { z.SetString(s) } diff --git a/ecc/bls24-315/hash_to_g2.go b/ecc/bls24-315/hash_to_g2.go index 3686bba0e..2a4b37a32 100644 --- a/ecc/bls24-315/hash_to_g2.go +++ b/ecc/bls24-315/hash_to_g2.go @@ -107,7 +107,7 @@ func MapToG2(t fptower.E4) G2Affine { // https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-2.2.2 func EncodeToG2(msg, dst []byte) (G2Affine, error) { var res G2Affine - _t, err := hashToFp(msg, dst, 2) + _t, err := fp.Hash(msg, dst, 2) if err != nil { return res, err } @@ -122,7 +122,7 @@ func EncodeToG2(msg, dst []byte) (G2Affine, error) { // https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-3 func HashToG2(msg, dst []byte) (G2Affine, error) { var res G2Affine - u, err := hashToFp(msg, dst, 4) + u, err := fp.Hash(msg, dst, 4) if err != nil { return res, err } diff --git a/ecc/bls24-315/internal/fptower/e12.go b/ecc/bls24-315/internal/fptower/e12.go index 9ddc554ea..0bc4d0b28 100644 --- a/ecc/bls24-315/internal/fptower/e12.go +++ b/ecc/bls24-315/internal/fptower/e12.go @@ -75,20 +75,8 @@ func (z *E12) IsZero() bool { return z.C0.IsZero() && z.C1.IsZero() && z.C2.IsZero() } -// ToMont converts to Mont form -func (z *E12) ToMont() *E12 { - z.C0.ToMont() - z.C1.ToMont() - z.C2.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E12) FromMont() *E12 { - z.C0.FromMont() - z.C1.FromMont() - z.C2.FromMont() - return z +func (z *E12) IsOne() bool { + return z.C0.IsOne() && z.C1.IsZero() && z.C2.IsZero() } // Add adds two elements of E12 diff --git a/ecc/bls24-315/internal/fptower/e2.go b/ecc/bls24-315/internal/fptower/e2.go index de6253587..73cf36264 100644 --- a/ecc/bls24-315/internal/fptower/e2.go +++ b/ecc/bls24-315/internal/fptower/e2.go @@ -32,10 +32,9 @@ func (z *E2) Equal(x *E2) bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *E2) Cmp(x *E2) int { if a1 := z.A1.Cmp(&x.A1); a1 != 0 { return a1 @@ -97,6 +96,10 @@ func (z *E2) IsZero() bool { return z.A0.IsZero() && z.A1.IsZero() } +func (z *E2) IsOne() bool { + return z.A0.IsOne() && z.A1.IsZero() +} + // Add adds two elements of E2 func (z *E2) Add(x, y *E2) *E2 { addE2(z, x, y) @@ -126,20 +129,6 @@ func (z *E2) String() string { return (z.A0.String() + "+" + z.A1.String() + "*u") } -// ToMont converts to mont form -func (z *E2) ToMont() *E2 { - z.A0.ToMont() - z.A1.ToMont() - return z -} - -// FromMont converts from mont form -func (z *E2) FromMont() *E2 { - z.A0.FromMont() - z.A1.FromMont() - return z -} - // MulByElement multiplies an element in E2 by an element in fp func (z *E2) MulByElement(x *E2, y *fp.Element) *E2 { var yCopy fp.Element diff --git a/ecc/bls24-315/internal/fptower/e24.go b/ecc/bls24-315/internal/fptower/e24.go index 68a8dc3a7..888123993 100644 --- a/ecc/bls24-315/internal/fptower/e24.go +++ b/ecc/bls24-315/internal/fptower/e24.go @@ -66,20 +66,6 @@ func (z *E24) SetOne() *E24 { return z } -// ToMont converts to Mont form -func (z *E24) ToMont() *E24 { - z.D0.ToMont() - z.D1.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E24) FromMont() *E24 { - z.D0.FromMont() - z.D1.FromMont() - return z -} - // Add set z=x+y in E24 and return z func (z *E24) Add(x, y *E24) *E24 { z.D0.Add(&x.D0, &y.D0) @@ -117,6 +103,10 @@ func (z *E24) IsZero() bool { return z.D0.IsZero() && z.D1.IsZero() } +func (z *E24) IsOne() bool { + return z.D0.IsOne() && z.D1.IsZero() +} + // Mul set z=x*y in E24 and return z func (z *E24) Mul(x, y *E24) *E24 { var a, b, c E12 @@ -224,9 +214,12 @@ func (z *E24) CyclotomicSquareCompressed(x *E24) *E24 { // DecompressKarabina Karabina's cyclotomic square result // if g3 != 0 -// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// +// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// // if g3 == 0 -// g4 = 2g1g5/g2 +// +// g4 = 2g1g5/g2 // // if g3=g2=0 then g4=g5=g1=0 and g0=1 (x=1) // Theorem 3.1 is well-defined for all x in Gϕₙ\{1} @@ -251,7 +244,7 @@ func (z *E24) DecompressKarabina(x *E24) *E24 { t[1].Sub(&t[0], &x.D0.C2). Double(&t[1]). Add(&t[1], &t[0]) - // t0 = E * g5^2 + t1 + // t0 = E * g5^2 + t1 t[2].Square(&x.D1.C2) t[0].MulByNonResidue(&t[2]). Add(&t[0], &t[1]) @@ -287,9 +280,12 @@ func (z *E24) DecompressKarabina(x *E24) *E24 { // BatchDecompressKarabina multiple Karabina's cyclotomic square results // if g3 != 0 -// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// +// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// // if g3 == 0 -// g4 = 2g1g5/g2 +// +// g4 = 2g1g5/g2 // // if g3=g2=0 then g4=g5=g1=0 and g0=1 (x=1) // Theorem 3.1 is well-defined for all x in Gϕₙ\{1} @@ -325,7 +321,7 @@ func BatchDecompressKarabina(x []E24) []E24 { t1[i].Sub(&t0[i], &x[i].D0.C2). Double(&t1[i]). Add(&t1[i], &t0[i]) - // t0 = E * g5^2 + t1 + // t0 = E * g5^2 + t1 t2[i].Square(&x[i].D1.C2) t0[i].MulByNonResidue(&t2[i]). Add(&t0[i], &t1[i]) @@ -600,8 +596,8 @@ func (z *E24) ExpGLV(x E24, k *big.Int) *E24 { table[14].Mul(&table[11], &table[2]) // bounds on the lattice base vectors guarantee that s1, s2 are len(r)/2 bits long max - s1.SetBigInt(&s[0]).FromMont() - s2.SetBigInt(&s[1]).FromMont() + s1 = s1.SetBigInt(&s[0]).Bits() + s2 = s2.SetBigInt(&s[1]).Bits() // loop starts from len(s1)/2 due to the bounds for i := len(s1)/2 + 1; i >= 0; i-- { diff --git a/ecc/bls24-315/internal/fptower/e2_adx_amd64.s b/ecc/bls24-315/internal/fptower/e2_adx_amd64.s deleted file mode 100644 index abb8f3c04..000000000 --- a/ecc/bls24-315/internal/fptower/e2_adx_amd64.s +++ /dev/null @@ -1,1634 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x6fe802ff40300001 -DATA q<>+8(SB)/8, $0x421ee5da52bde502 -DATA q<>+16(SB)/8, $0xdec1d01aa27a1ae0 -DATA q<>+24(SB)/8, $0xd3f7498be97c5eaf -DATA q<>+32(SB)/8, $0x04c23a02b586d650 -GLOBL q<>(SB), (RODATA+NOPTR), $40 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x702ff9ff402fffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, rb0, rb1, rb2, rb3, rb4) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - -TEXT ·addE2(SB), NOSPLIT, $0-24 - MOVQ x+8(FP), AX - MOVQ 0(AX), BX - MOVQ 8(AX), SI - MOVQ 16(AX), DI - MOVQ 24(AX), R8 - MOVQ 32(AX), R9 - MOVQ y+16(FP), DX - ADDQ 0(DX), BX - ADCQ 8(DX), SI - ADCQ 16(DX), DI - ADCQ 24(DX), R8 - ADCQ 32(DX), R9 - - // reduce element(BX,SI,DI,R8,R9) using temp registers (R10,R11,R12,R13,R14) - REDUCE(BX,SI,DI,R8,R9,R10,R11,R12,R13,R14) - - MOVQ res+0(FP), CX - MOVQ BX, 0(CX) - MOVQ SI, 8(CX) - MOVQ DI, 16(CX) - MOVQ R8, 24(CX) - MOVQ R9, 32(CX) - MOVQ 40(AX), BX - MOVQ 48(AX), SI - MOVQ 56(AX), DI - MOVQ 64(AX), R8 - MOVQ 72(AX), R9 - ADDQ 40(DX), BX - ADCQ 48(DX), SI - ADCQ 56(DX), DI - ADCQ 64(DX), R8 - ADCQ 72(DX), R9 - - // reduce element(BX,SI,DI,R8,R9) using temp registers (R15,R10,R11,R12,R13) - REDUCE(BX,SI,DI,R8,R9,R15,R10,R11,R12,R13) - - MOVQ BX, 40(CX) - MOVQ SI, 48(CX) - MOVQ DI, 56(CX) - MOVQ R8, 64(CX) - MOVQ R9, 72(CX) - RET - -TEXT ·doubleE2(SB), NOSPLIT, $0-16 - MOVQ res+0(FP), DX - MOVQ x+8(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - MOVQ 32(AX), R8 - ADDQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(CX,BX,SI,DI,R8) using temp registers (R9,R10,R11,R12,R13) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ CX, 0(DX) - MOVQ BX, 8(DX) - MOVQ SI, 16(DX) - MOVQ DI, 24(DX) - MOVQ R8, 32(DX) - MOVQ 40(AX), CX - MOVQ 48(AX), BX - MOVQ 56(AX), SI - MOVQ 64(AX), DI - MOVQ 72(AX), R8 - ADDQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - - // reduce element(CX,BX,SI,DI,R8) using temp registers (R14,R15,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R14,R15,R9,R10,R11) - - MOVQ CX, 40(DX) - MOVQ BX, 48(DX) - MOVQ SI, 56(DX) - MOVQ DI, 64(DX) - MOVQ R8, 72(DX) - RET - -TEXT ·subE2(SB), NOSPLIT, $0-24 - XORQ R8, R8 - MOVQ x+8(FP), DI - MOVQ 0(DI), AX - MOVQ 8(DI), DX - MOVQ 16(DI), CX - MOVQ 24(DI), BX - MOVQ 32(DI), SI - MOVQ y+16(FP), DI - SUBQ 0(DI), AX - SBBQ 8(DI), DX - SBBQ 16(DI), CX - SBBQ 24(DI), BX - SBBQ 32(DI), SI - MOVQ x+8(FP), DI - MOVQ $0x6fe802ff40300001, R9 - MOVQ $0x421ee5da52bde502, R10 - MOVQ $0xdec1d01aa27a1ae0, R11 - MOVQ $0xd3f7498be97c5eaf, R12 - MOVQ $0x04c23a02b586d650, R13 - CMOVQCC R8, R9 - CMOVQCC R8, R10 - CMOVQCC R8, R11 - CMOVQCC R8, R12 - CMOVQCC R8, R13 - ADDQ R9, AX - ADCQ R10, DX - ADCQ R11, CX - ADCQ R12, BX - ADCQ R13, SI - MOVQ res+0(FP), R14 - MOVQ AX, 0(R14) - MOVQ DX, 8(R14) - MOVQ CX, 16(R14) - MOVQ BX, 24(R14) - MOVQ SI, 32(R14) - MOVQ 40(DI), AX - MOVQ 48(DI), DX - MOVQ 56(DI), CX - MOVQ 64(DI), BX - MOVQ 72(DI), SI - MOVQ y+16(FP), DI - SUBQ 40(DI), AX - SBBQ 48(DI), DX - SBBQ 56(DI), CX - SBBQ 64(DI), BX - SBBQ 72(DI), SI - MOVQ $0x6fe802ff40300001, R15 - MOVQ $0x421ee5da52bde502, R9 - MOVQ $0xdec1d01aa27a1ae0, R10 - MOVQ $0xd3f7498be97c5eaf, R11 - MOVQ $0x04c23a02b586d650, R12 - CMOVQCC R8, R15 - CMOVQCC R8, R9 - CMOVQCC R8, R10 - CMOVQCC R8, R11 - CMOVQCC R8, R12 - ADDQ R15, AX - ADCQ R9, DX - ADCQ R10, CX - ADCQ R11, BX - ADCQ R12, SI - MOVQ res+0(FP), DI - MOVQ AX, 40(DI) - MOVQ DX, 48(DI) - MOVQ CX, 56(DI) - MOVQ BX, 64(DI) - MOVQ SI, 72(DI) - RET - -TEXT ·negE2(SB), NOSPLIT, $0-16 - MOVQ res+0(FP), DX - MOVQ x+8(FP), AX - MOVQ 0(AX), BX - MOVQ 8(AX), SI - MOVQ 16(AX), DI - MOVQ 24(AX), R8 - MOVQ 32(AX), R9 - MOVQ BX, AX - ORQ SI, AX - ORQ DI, AX - ORQ R8, AX - ORQ R9, AX - TESTQ AX, AX - JNE l1 - MOVQ AX, 0(DX) - MOVQ AX, 8(DX) - MOVQ AX, 16(DX) - MOVQ AX, 24(DX) - MOVQ AX, 32(DX) - JMP l3 - -l1: - MOVQ $0x6fe802ff40300001, CX - SUBQ BX, CX - MOVQ CX, 0(DX) - MOVQ $0x421ee5da52bde502, CX - SBBQ SI, CX - MOVQ CX, 8(DX) - MOVQ $0xdec1d01aa27a1ae0, CX - SBBQ DI, CX - MOVQ CX, 16(DX) - MOVQ $0xd3f7498be97c5eaf, CX - SBBQ R8, CX - MOVQ CX, 24(DX) - MOVQ $0x04c23a02b586d650, CX - SBBQ R9, CX - MOVQ CX, 32(DX) - -l3: - MOVQ x+8(FP), AX - MOVQ 40(AX), BX - MOVQ 48(AX), SI - MOVQ 56(AX), DI - MOVQ 64(AX), R8 - MOVQ 72(AX), R9 - MOVQ BX, AX - ORQ SI, AX - ORQ DI, AX - ORQ R8, AX - ORQ R9, AX - TESTQ AX, AX - JNE l2 - MOVQ AX, 40(DX) - MOVQ AX, 48(DX) - MOVQ AX, 56(DX) - MOVQ AX, 64(DX) - MOVQ AX, 72(DX) - RET - -l2: - MOVQ $0x6fe802ff40300001, CX - SUBQ BX, CX - MOVQ CX, 40(DX) - MOVQ $0x421ee5da52bde502, CX - SBBQ SI, CX - MOVQ CX, 48(DX) - MOVQ $0xdec1d01aa27a1ae0, CX - SBBQ DI, CX - MOVQ CX, 56(DX) - MOVQ $0xd3f7498be97c5eaf, CX - SBBQ R8, CX - MOVQ CX, 64(DX) - MOVQ $0x04c23a02b586d650, CX - SBBQ R9, CX - MOVQ CX, 72(DX) - RET - -TEXT ·mulNonResE2(SB), NOSPLIT, $0-16 - MOVQ x+8(FP), AX - MOVQ 40(AX), DX - MOVQ 48(AX), CX - MOVQ 56(AX), BX - MOVQ 64(AX), SI - MOVQ 72(AX), DI - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R13,R14,R15,s0-8(SP),s1-16(SP)) - REDUCE(DX,CX,BX,SI,DI,R13,R14,R15,s0-8(SP),s1-16(SP)) - - MOVQ DX, R13 - MOVQ CX, R14 - MOVQ BX, R15 - MOVQ SI, s0-8(SP) - MOVQ DI, s1-16(SP) - ADDQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ R13, DX - ADCQ R14, CX - ADCQ R15, BX - ADCQ s0-8(SP), SI - ADCQ s1-16(SP), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - ADDQ 40(AX), DX - ADCQ 48(AX), CX - ADCQ 56(AX), BX - ADCQ 64(AX), SI - ADCQ 72(AX), DI - - // reduce element(DX,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12) - REDUCE(DX,CX,BX,SI,DI,R8,R9,R10,R11,R12) - - MOVQ res+0(FP), R13 - MOVQ 0(AX), R8 - MOVQ 8(AX), R9 - MOVQ 16(AX), R10 - MOVQ 24(AX), R11 - MOVQ 32(AX), R12 - MOVQ R8, 40(R13) - MOVQ R9, 48(R13) - MOVQ R10, 56(R13) - MOVQ R11, 64(R13) - MOVQ R12, 72(R13) - MOVQ DX, 0(R13) - MOVQ CX, 8(R13) - MOVQ BX, 16(R13) - MOVQ SI, 24(R13) - MOVQ DI, 32(R13) - RET - -TEXT ·mulAdxE2(SB), $80-24 - NO_LOCAL_POINTERS - - // var a, b, c fp.Element - // a.Add(&x.A0, &x.A1) - // b.Add(&y.A0, &y.A1) - // a.Mul(&a, &b) - // b.Mul(&x.A0, &y.A0) - // c.Mul(&x.A1, &y.A1) - // z.A1.Sub(&a, &b).Sub(&z.A1, &c) - // fp.MulBy13(&c) - // z.A0.Add(&c, &b) - - MOVQ x+8(FP), AX - MOVQ 40(AX), R14 - MOVQ 48(AX), R15 - MOVQ 56(AX), CX - MOVQ 64(AX), BX - MOVQ 72(AX), SI - - // A -> BP - // t[0] -> DI - // t[1] -> R8 - // t[2] -> R9 - // t[3] -> R10 - // t[4] -> R11 - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 40(DX), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R14, DI, R8 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R15, AX, R9 - ADOXQ AX, R8 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ CX, AX, R10 - ADOXQ AX, R9 - - // (A,t[3]) := x[3]*y[0] + A - MULXQ BX, AX, R11 - ADOXQ AX, R10 - - // (A,t[4]) := x[4]*y[0] + A - MULXQ SI, AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ DI, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ DI, AX - MOVQ R12, DI - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R8, DI - MULXQ q<>+8(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R9, R8 - MULXQ q<>+16(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R10, R9 - MULXQ q<>+24(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R11, R10 - MULXQ q<>+32(SB), AX, R11 - ADOXQ AX, R10 - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 48(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R14, AX, BP - ADOXQ AX, DI - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R8 - MULXQ R15, AX, BP - ADOXQ AX, R8 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, R9 - MULXQ CX, AX, BP - ADOXQ AX, R9 - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, R10 - MULXQ BX, AX, BP - ADOXQ AX, R10 - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, R11 - MULXQ SI, AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ DI, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ DI, AX - MOVQ R12, DI - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R8, DI - MULXQ q<>+8(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R9, R8 - MULXQ q<>+16(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R10, R9 - MULXQ q<>+24(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R11, R10 - MULXQ q<>+32(SB), AX, R11 - ADOXQ AX, R10 - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 56(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R14, AX, BP - ADOXQ AX, DI - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R8 - MULXQ R15, AX, BP - ADOXQ AX, R8 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, R9 - MULXQ CX, AX, BP - ADOXQ AX, R9 - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, R10 - MULXQ BX, AX, BP - ADOXQ AX, R10 - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, R11 - MULXQ SI, AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ DI, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ DI, AX - MOVQ R12, DI - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R8, DI - MULXQ q<>+8(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R9, R8 - MULXQ q<>+16(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R10, R9 - MULXQ q<>+24(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R11, R10 - MULXQ q<>+32(SB), AX, R11 - ADOXQ AX, R10 - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 64(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R14, AX, BP - ADOXQ AX, DI - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R8 - MULXQ R15, AX, BP - ADOXQ AX, R8 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, R9 - MULXQ CX, AX, BP - ADOXQ AX, R9 - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, R10 - MULXQ BX, AX, BP - ADOXQ AX, R10 - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, R11 - MULXQ SI, AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ DI, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ DI, AX - MOVQ R12, DI - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R8, DI - MULXQ q<>+8(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R9, R8 - MULXQ q<>+16(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R10, R9 - MULXQ q<>+24(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R11, R10 - MULXQ q<>+32(SB), AX, R11 - ADOXQ AX, R10 - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 72(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R14, AX, BP - ADOXQ AX, DI - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R8 - MULXQ R15, AX, BP - ADOXQ AX, R8 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, R9 - MULXQ CX, AX, BP - ADOXQ AX, R9 - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, R10 - MULXQ BX, AX, BP - ADOXQ AX, R10 - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, R11 - MULXQ SI, AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ DI, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ DI, AX - MOVQ R12, DI - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R8, DI - MULXQ q<>+8(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R9, R8 - MULXQ q<>+16(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R10, R9 - MULXQ q<>+24(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R11, R10 - MULXQ q<>+32(SB), AX, R11 - ADOXQ AX, R10 - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // reduce element(DI,R8,R9,R10,R11) using temp registers (R14,R15,CX,BX,SI) - REDUCE(DI,R8,R9,R10,R11,R14,R15,CX,BX,SI) - - MOVQ DI, s5-48(SP) - MOVQ R8, s6-56(SP) - MOVQ R9, s7-64(SP) - MOVQ R10, s8-72(SP) - MOVQ R11, s9-80(SP) - MOVQ x+8(FP), AX - MOVQ y+16(FP), DX - MOVQ 40(AX), R14 - MOVQ 48(AX), R15 - MOVQ 56(AX), CX - MOVQ 64(AX), BX - MOVQ 72(AX), SI - ADDQ 0(AX), R14 - ADCQ 8(AX), R15 - ADCQ 16(AX), CX - ADCQ 24(AX), BX - ADCQ 32(AX), SI - MOVQ R14, s0-8(SP) - MOVQ R15, s1-16(SP) - MOVQ CX, s2-24(SP) - MOVQ BX, s3-32(SP) - MOVQ SI, s4-40(SP) - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - ADDQ 40(DX), R14 - ADCQ 48(DX), R15 - ADCQ 56(DX), CX - ADCQ 64(DX), BX - ADCQ 72(DX), SI - - // A -> BP - // t[0] -> DI - // t[1] -> R8 - // t[2] -> R9 - // t[3] -> R10 - // t[4] -> R11 - // clear the flags - XORQ AX, AX - MOVQ s0-8(SP), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R14, DI, R8 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R15, AX, R9 - ADOXQ AX, R8 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ CX, AX, R10 - ADOXQ AX, R9 - - // (A,t[3]) := x[3]*y[0] + A - MULXQ BX, AX, R11 - ADOXQ AX, R10 - - // (A,t[4]) := x[4]*y[0] + A - MULXQ SI, AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ DI, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ DI, AX - MOVQ R13, DI - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R8, DI - MULXQ q<>+8(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R9, R8 - MULXQ q<>+16(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R10, R9 - MULXQ q<>+24(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R11, R10 - MULXQ q<>+32(SB), AX, R11 - ADOXQ AX, R10 - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ s1-16(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R14, AX, BP - ADOXQ AX, DI - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R8 - MULXQ R15, AX, BP - ADOXQ AX, R8 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, R9 - MULXQ CX, AX, BP - ADOXQ AX, R9 - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, R10 - MULXQ BX, AX, BP - ADOXQ AX, R10 - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, R11 - MULXQ SI, AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ DI, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ DI, AX - MOVQ R13, DI - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R8, DI - MULXQ q<>+8(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R9, R8 - MULXQ q<>+16(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R10, R9 - MULXQ q<>+24(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R11, R10 - MULXQ q<>+32(SB), AX, R11 - ADOXQ AX, R10 - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ s2-24(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R14, AX, BP - ADOXQ AX, DI - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R8 - MULXQ R15, AX, BP - ADOXQ AX, R8 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, R9 - MULXQ CX, AX, BP - ADOXQ AX, R9 - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, R10 - MULXQ BX, AX, BP - ADOXQ AX, R10 - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, R11 - MULXQ SI, AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ DI, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ DI, AX - MOVQ R13, DI - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R8, DI - MULXQ q<>+8(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R9, R8 - MULXQ q<>+16(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R10, R9 - MULXQ q<>+24(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R11, R10 - MULXQ q<>+32(SB), AX, R11 - ADOXQ AX, R10 - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ s3-32(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R14, AX, BP - ADOXQ AX, DI - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R8 - MULXQ R15, AX, BP - ADOXQ AX, R8 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, R9 - MULXQ CX, AX, BP - ADOXQ AX, R9 - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, R10 - MULXQ BX, AX, BP - ADOXQ AX, R10 - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, R11 - MULXQ SI, AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ DI, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ DI, AX - MOVQ R13, DI - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R8, DI - MULXQ q<>+8(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R9, R8 - MULXQ q<>+16(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R10, R9 - MULXQ q<>+24(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R11, R10 - MULXQ q<>+32(SB), AX, R11 - ADOXQ AX, R10 - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ s4-40(SP), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R14, AX, BP - ADOXQ AX, DI - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R8 - MULXQ R15, AX, BP - ADOXQ AX, R8 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, R9 - MULXQ CX, AX, BP - ADOXQ AX, R9 - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, R10 - MULXQ BX, AX, BP - ADOXQ AX, R10 - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, R11 - MULXQ SI, AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ DI, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R13 - ADCXQ DI, AX - MOVQ R13, DI - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R8, DI - MULXQ q<>+8(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R9, R8 - MULXQ q<>+16(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R10, R9 - MULXQ q<>+24(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R11, R10 - MULXQ q<>+32(SB), AX, R11 - ADOXQ AX, R10 - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // reduce element(DI,R8,R9,R10,R11) using temp registers (R14,R15,CX,BX,SI) - REDUCE(DI,R8,R9,R10,R11,R14,R15,CX,BX,SI) - - MOVQ DI, s0-8(SP) - MOVQ R8, s1-16(SP) - MOVQ R9, s2-24(SP) - MOVQ R10, s3-32(SP) - MOVQ R11, s4-40(SP) - MOVQ x+8(FP), AX - MOVQ 0(AX), R14 - MOVQ 8(AX), R15 - MOVQ 16(AX), CX - MOVQ 24(AX), BX - MOVQ 32(AX), SI - - // A -> BP - // t[0] -> DI - // t[1] -> R8 - // t[2] -> R9 - // t[3] -> R10 - // t[4] -> R11 - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 0(DX), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R14, DI, R8 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R15, AX, R9 - ADOXQ AX, R8 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ CX, AX, R10 - ADOXQ AX, R9 - - // (A,t[3]) := x[3]*y[0] + A - MULXQ BX, AX, R11 - ADOXQ AX, R10 - - // (A,t[4]) := x[4]*y[0] + A - MULXQ SI, AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ DI, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ DI, AX - MOVQ R12, DI - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R8, DI - MULXQ q<>+8(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R9, R8 - MULXQ q<>+16(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R10, R9 - MULXQ q<>+24(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R11, R10 - MULXQ q<>+32(SB), AX, R11 - ADOXQ AX, R10 - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 8(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R14, AX, BP - ADOXQ AX, DI - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R8 - MULXQ R15, AX, BP - ADOXQ AX, R8 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, R9 - MULXQ CX, AX, BP - ADOXQ AX, R9 - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, R10 - MULXQ BX, AX, BP - ADOXQ AX, R10 - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, R11 - MULXQ SI, AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ DI, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ DI, AX - MOVQ R12, DI - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R8, DI - MULXQ q<>+8(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R9, R8 - MULXQ q<>+16(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R10, R9 - MULXQ q<>+24(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R11, R10 - MULXQ q<>+32(SB), AX, R11 - ADOXQ AX, R10 - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 16(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R14, AX, BP - ADOXQ AX, DI - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R8 - MULXQ R15, AX, BP - ADOXQ AX, R8 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, R9 - MULXQ CX, AX, BP - ADOXQ AX, R9 - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, R10 - MULXQ BX, AX, BP - ADOXQ AX, R10 - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, R11 - MULXQ SI, AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ DI, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ DI, AX - MOVQ R12, DI - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R8, DI - MULXQ q<>+8(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R9, R8 - MULXQ q<>+16(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R10, R9 - MULXQ q<>+24(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R11, R10 - MULXQ q<>+32(SB), AX, R11 - ADOXQ AX, R10 - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 24(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R14, AX, BP - ADOXQ AX, DI - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R8 - MULXQ R15, AX, BP - ADOXQ AX, R8 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, R9 - MULXQ CX, AX, BP - ADOXQ AX, R9 - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, R10 - MULXQ BX, AX, BP - ADOXQ AX, R10 - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, R11 - MULXQ SI, AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ DI, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ DI, AX - MOVQ R12, DI - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R8, DI - MULXQ q<>+8(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R9, R8 - MULXQ q<>+16(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R10, R9 - MULXQ q<>+24(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R11, R10 - MULXQ q<>+32(SB), AX, R11 - ADOXQ AX, R10 - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), DX - MOVQ 32(DX), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R14, AX, BP - ADOXQ AX, DI - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R8 - MULXQ R15, AX, BP - ADOXQ AX, R8 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, R9 - MULXQ CX, AX, BP - ADOXQ AX, R9 - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, R10 - MULXQ BX, AX, BP - ADOXQ AX, R10 - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, R11 - MULXQ SI, AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ DI, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ DI, AX - MOVQ R12, DI - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R8, DI - MULXQ q<>+8(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ R9, R8 - MULXQ q<>+16(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ R10, R9 - MULXQ q<>+24(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ R11, R10 - MULXQ q<>+32(SB), AX, R11 - ADOXQ AX, R10 - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // reduce element(DI,R8,R9,R10,R11) using temp registers (R14,R15,CX,BX,SI) - REDUCE(DI,R8,R9,R10,R11,R14,R15,CX,BX,SI) - - XORQ DX, DX - MOVQ s0-8(SP), R14 - MOVQ s1-16(SP), R15 - MOVQ s2-24(SP), CX - MOVQ s3-32(SP), BX - MOVQ s4-40(SP), SI - SUBQ DI, R14 - SBBQ R8, R15 - SBBQ R9, CX - SBBQ R10, BX - SBBQ R11, SI - MOVQ DI, s0-8(SP) - MOVQ R8, s1-16(SP) - MOVQ R9, s2-24(SP) - MOVQ R10, s3-32(SP) - MOVQ R11, s4-40(SP) - MOVQ $0x6fe802ff40300001, DI - MOVQ $0x421ee5da52bde502, R8 - MOVQ $0xdec1d01aa27a1ae0, R9 - MOVQ $0xd3f7498be97c5eaf, R10 - MOVQ $0x04c23a02b586d650, R11 - CMOVQCC DX, DI - CMOVQCC DX, R8 - CMOVQCC DX, R9 - CMOVQCC DX, R10 - CMOVQCC DX, R11 - ADDQ DI, R14 - ADCQ R8, R15 - ADCQ R9, CX - ADCQ R10, BX - ADCQ R11, SI - SUBQ s5-48(SP), R14 - SBBQ s6-56(SP), R15 - SBBQ s7-64(SP), CX - SBBQ s8-72(SP), BX - SBBQ s9-80(SP), SI - MOVQ $0x6fe802ff40300001, DI - MOVQ $0x421ee5da52bde502, R8 - MOVQ $0xdec1d01aa27a1ae0, R9 - MOVQ $0xd3f7498be97c5eaf, R10 - MOVQ $0x04c23a02b586d650, R11 - CMOVQCC DX, DI - CMOVQCC DX, R8 - CMOVQCC DX, R9 - CMOVQCC DX, R10 - CMOVQCC DX, R11 - ADDQ DI, R14 - ADCQ R8, R15 - ADCQ R9, CX - ADCQ R10, BX - ADCQ R11, SI - MOVQ res+0(FP), AX - MOVQ R14, 40(AX) - MOVQ R15, 48(AX) - MOVQ CX, 56(AX) - MOVQ BX, 64(AX) - MOVQ SI, 72(AX) - MOVQ s5-48(SP), DI - MOVQ s6-56(SP), R8 - MOVQ s7-64(SP), R9 - MOVQ s8-72(SP), R10 - MOVQ s9-80(SP), R11 - MOVQ s0-8(SP), R14 - MOVQ s1-16(SP), R15 - MOVQ s2-24(SP), CX - MOVQ s3-32(SP), BX - MOVQ s4-40(SP), SI - ADDQ DI, R14 - ADCQ R8, R15 - ADCQ R9, CX - ADCQ R10, BX - ADCQ R11, SI - - // reduce element(R14,R15,CX,BX,SI) using temp registers (DI,R8,R9,R10,R11) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ s5-48(SP), DI - MOVQ s6-56(SP), R8 - MOVQ s7-64(SP), R9 - MOVQ s8-72(SP), R10 - MOVQ s9-80(SP), R11 - MOVQ R14, s5-48(SP) - MOVQ R15, s6-56(SP) - MOVQ CX, s7-64(SP) - MOVQ BX, s8-72(SP) - MOVQ SI, s9-80(SP) - ADDQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - - // reduce element(DI,R8,R9,R10,R11) using temp registers (R14,R15,CX,BX,SI) - REDUCE(DI,R8,R9,R10,R11,R14,R15,CX,BX,SI) - - ADDQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - - // reduce element(DI,R8,R9,R10,R11) using temp registers (R14,R15,CX,BX,SI) - REDUCE(DI,R8,R9,R10,R11,R14,R15,CX,BX,SI) - - MOVQ DI, s0-8(SP) - MOVQ R8, s1-16(SP) - MOVQ R9, s2-24(SP) - MOVQ R10, s3-32(SP) - MOVQ R11, s4-40(SP) - ADDQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - ADCQ R10, R10 - ADCQ R11, R11 - - // reduce element(DI,R8,R9,R10,R11) using temp registers (R14,R15,CX,BX,SI) - REDUCE(DI,R8,R9,R10,R11,R14,R15,CX,BX,SI) - - ADDQ s0-8(SP), DI - ADCQ s1-16(SP), R8 - ADCQ s2-24(SP), R9 - ADCQ s3-32(SP), R10 - ADCQ s4-40(SP), R11 - - // reduce element(DI,R8,R9,R10,R11) using temp registers (R14,R15,CX,BX,SI) - REDUCE(DI,R8,R9,R10,R11,R14,R15,CX,BX,SI) - - ADDQ s5-48(SP), DI - ADCQ s6-56(SP), R8 - ADCQ s7-64(SP), R9 - ADCQ s8-72(SP), R10 - ADCQ s9-80(SP), R11 - - // reduce element(DI,R8,R9,R10,R11) using temp registers (R14,R15,CX,BX,SI) - REDUCE(DI,R8,R9,R10,R11,R14,R15,CX,BX,SI) - - MOVQ DI, 0(AX) - MOVQ R8, 8(AX) - MOVQ R9, 16(AX) - MOVQ R10, 24(AX) - MOVQ R11, 32(AX) - RET diff --git a/ecc/bls24-315/internal/fptower/e2_amd64.s b/ecc/bls24-315/internal/fptower/e2_amd64.s index 4b5acc809..6f9615547 100644 --- a/ecc/bls24-315/internal/fptower/e2_amd64.s +++ b/ecc/bls24-315/internal/fptower/e2_amd64.s @@ -1,5 +1,3 @@ -// +build !amd64_adx - // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bls24-315/internal/fptower/e4.go b/ecc/bls24-315/internal/fptower/e4.go index 06b537415..a77125da6 100644 --- a/ecc/bls24-315/internal/fptower/e4.go +++ b/ecc/bls24-315/internal/fptower/e4.go @@ -32,10 +32,9 @@ func (z *E4) Equal(x *E4) bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *E4) Cmp(x *E4) int { if a1 := z.B1.Cmp(&x.B1); a1 != 0 { return a1 @@ -86,20 +85,6 @@ func (z *E4) SetOne() *E4 { return z } -// ToMont converts to Mont form -func (z *E4) ToMont() *E4 { - z.B0.ToMont() - z.B1.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E4) FromMont() *E4 { - z.B0.FromMont() - z.B1.FromMont() - return z -} - // MulByElement multiplies an element in E4 by an element in fp func (z *E4) MulByElement(x *E4, y *fp.Element) *E4 { var yCopy fp.Element @@ -153,6 +138,10 @@ func (z *E4) IsZero() bool { return z.B0.IsZero() && z.B1.IsZero() } +func (z *E4) IsOne() bool { + return z.B0.IsOne() && z.B1.IsZero() +} + // MulByNonResidue mul x by (0,1) func (z *E4) MulByNonResidue(x *E4) *E4 { z.B1, z.B0 = x.B0, x.B1 diff --git a/ecc/bls24-315/marshal.go b/ecc/bls24-315/marshal.go index aa0e92d75..944af6694 100644 --- a/ecc/bls24-315/marshal.go +++ b/ecc/bls24-315/marshal.go @@ -100,7 +100,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fr.Bytes]) + err = t.SetBytesCanonical(buf[:fr.Bytes]) return case *fp.Element: read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) @@ -108,7 +108,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fp.Bytes]) + err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: var sliceLen uint32 @@ -126,7 +126,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fr.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { + return + } } return case *[]fp.Element: @@ -145,7 +147,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fp.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return + } } return case *G1Affine: @@ -221,7 +225,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -276,7 +284,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -643,9 +655,6 @@ func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - msbMask := mCompressedSmallest // compressed, we need to know if Y is lexicographically bigger than -Y // if p.Y ">" -p.Y @@ -654,13 +663,7 @@ func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { } // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[32:40], tmp[0]) - binary.BigEndian.PutUint64(res[24:32], tmp[1]) - binary.BigEndian.PutUint64(res[16:24], tmp[2]) - binary.BigEndian.PutUint64(res[8:16], tmp[3]) - binary.BigEndian.PutUint64(res[0:8], tmp[4]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= msbMask @@ -679,27 +682,12 @@ func (p *G1Affine) RawBytes() (res [SizeOfG1AffineUncompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - // not compressed // we store the Y coordinate - tmp = p.Y - tmp.FromMont() - binary.BigEndian.PutUint64(res[72:80], tmp[0]) - binary.BigEndian.PutUint64(res[64:72], tmp[1]) - binary.BigEndian.PutUint64(res[56:64], tmp[2]) - binary.BigEndian.PutUint64(res[48:56], tmp[3]) - binary.BigEndian.PutUint64(res[40:48], tmp[4]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[40:40+fp.Bytes]), p.Y) // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[32:40], tmp[0]) - binary.BigEndian.PutUint64(res[24:32], tmp[1]) - binary.BigEndian.PutUint64(res[16:24], tmp[2]) - binary.BigEndian.PutUint64(res[8:16], tmp[3]) - binary.BigEndian.PutUint64(res[0:8], tmp[4]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= mUncompressed @@ -750,8 +738,12 @@ func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { // uncompressed point if mData == mUncompressed { // read X and Y coordinates - p.X.SetBytes(buf[:fp.Bytes]) - p.Y.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.Y.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } // subgroup check if subGroupCheck && !p.IsInSubGroup() { @@ -771,7 +763,9 @@ func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return 0, err + } var YSquared, Y fp.Element @@ -845,7 +839,7 @@ func (p *G1Affine) unsafeComputeY(subGroupCheck bool) error { // assumes buf[:8] mask is set to compressed // returns true if point is infinity and need no further processing // it sets X coordinate and uses Y for scratch space to store decompression metadata -func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { +func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { // read the most significant byte mData := buf[0] & mMask @@ -854,7 +848,7 @@ func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { p.X.SetZero() p.Y.SetZero() isInfinity = true - return + return isInfinity, nil } // we need to copy the input buffer (to keep this method thread safe) @@ -863,12 +857,14 @@ func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return false, err + } // store mData in p.Y[0] p.Y[0] = uint64(mData) // recomputing Y will be done asynchronously - return + return isInfinity, nil } // SizeOfG2AffineCompressed represents the size in bytes that a G2Affine need in binary form, compressed @@ -906,9 +902,6 @@ func (p *G2Affine) Bytes() (res [SizeOfG2AffineCompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - msbMask := mCompressedSmallest // compressed, we need to know if Y is lexicographically bigger than -Y // if p.Y ">" -p.Y @@ -918,37 +911,10 @@ func (p *G2Affine) Bytes() (res [SizeOfG2AffineCompressed]byte) { // we store X and mask the most significant word with our metadata mask // p.X.B1.A1 | p.X.B1.A0 | p.X.B0.A1 | p.X.B0.A0 - tmp = p.X.B1.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[32:40], tmp[0]) - binary.BigEndian.PutUint64(res[24:32], tmp[1]) - binary.BigEndian.PutUint64(res[16:24], tmp[2]) - binary.BigEndian.PutUint64(res[8:16], tmp[3]) - binary.BigEndian.PutUint64(res[0:8], tmp[4]) - - tmp = p.X.B1.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[72:80], tmp[0]) - binary.BigEndian.PutUint64(res[64:72], tmp[1]) - binary.BigEndian.PutUint64(res[56:64], tmp[2]) - binary.BigEndian.PutUint64(res[48:56], tmp[3]) - binary.BigEndian.PutUint64(res[40:48], tmp[4]) - - tmp = p.X.B0.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[112:120], tmp[0]) - binary.BigEndian.PutUint64(res[104:112], tmp[1]) - binary.BigEndian.PutUint64(res[96:104], tmp[2]) - binary.BigEndian.PutUint64(res[88:96], tmp[3]) - binary.BigEndian.PutUint64(res[80:88], tmp[4]) - - tmp = p.X.B0.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[152:160], tmp[0]) - binary.BigEndian.PutUint64(res[144:152], tmp[1]) - binary.BigEndian.PutUint64(res[136:144], tmp[2]) - binary.BigEndian.PutUint64(res[128:136], tmp[3]) - binary.BigEndian.PutUint64(res[120:128], tmp[4]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X.B1.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[40:40+fp.Bytes]), p.X.B1.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[80:80+fp.Bytes]), p.X.B0.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[120:120+fp.Bytes]), p.X.B0.A0) res[0] |= msbMask @@ -967,77 +933,20 @@ func (p *G2Affine) RawBytes() (res [SizeOfG2AffineUncompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - // not compressed // we store the Y coordinate // p.Y.B1.A1 | p.Y.B1.A0 | p.Y.B0.A1 | p.Y.B0.A0 - tmp = p.Y.B1.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[192:200], tmp[0]) - binary.BigEndian.PutUint64(res[184:192], tmp[1]) - binary.BigEndian.PutUint64(res[176:184], tmp[2]) - binary.BigEndian.PutUint64(res[168:176], tmp[3]) - binary.BigEndian.PutUint64(res[160:168], tmp[4]) - - tmp = p.Y.B1.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[232:240], tmp[0]) - binary.BigEndian.PutUint64(res[224:232], tmp[1]) - binary.BigEndian.PutUint64(res[216:224], tmp[2]) - binary.BigEndian.PutUint64(res[208:216], tmp[3]) - binary.BigEndian.PutUint64(res[200:208], tmp[4]) - - tmp = p.Y.B0.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[272:280], tmp[0]) - binary.BigEndian.PutUint64(res[264:272], tmp[1]) - binary.BigEndian.PutUint64(res[256:264], tmp[2]) - binary.BigEndian.PutUint64(res[248:256], tmp[3]) - binary.BigEndian.PutUint64(res[240:248], tmp[4]) - - tmp = p.Y.B0.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[312:320], tmp[0]) - binary.BigEndian.PutUint64(res[304:312], tmp[1]) - binary.BigEndian.PutUint64(res[296:304], tmp[2]) - binary.BigEndian.PutUint64(res[288:296], tmp[3]) - binary.BigEndian.PutUint64(res[280:288], tmp[4]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[160:160+fp.Bytes]), p.Y.B1.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[200:200+fp.Bytes]), p.Y.B1.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[240:240+fp.Bytes]), p.Y.B0.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[280:280+fp.Bytes]), p.Y.B0.A0) // we store X and mask the most significant word with our metadata mask // p.X.B1.A1 | p.X.B1.A0 | p.X.B0.A1 | p.X.B0.A0 - tmp = p.X.B1.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[32:40], tmp[0]) - binary.BigEndian.PutUint64(res[24:32], tmp[1]) - binary.BigEndian.PutUint64(res[16:24], tmp[2]) - binary.BigEndian.PutUint64(res[8:16], tmp[3]) - binary.BigEndian.PutUint64(res[0:8], tmp[4]) - - tmp = p.X.B1.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[72:80], tmp[0]) - binary.BigEndian.PutUint64(res[64:72], tmp[1]) - binary.BigEndian.PutUint64(res[56:64], tmp[2]) - binary.BigEndian.PutUint64(res[48:56], tmp[3]) - binary.BigEndian.PutUint64(res[40:48], tmp[4]) - - tmp = p.X.B0.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[112:120], tmp[0]) - binary.BigEndian.PutUint64(res[104:112], tmp[1]) - binary.BigEndian.PutUint64(res[96:104], tmp[2]) - binary.BigEndian.PutUint64(res[88:96], tmp[3]) - binary.BigEndian.PutUint64(res[80:88], tmp[4]) - - tmp = p.X.B0.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[152:160], tmp[0]) - binary.BigEndian.PutUint64(res[144:152], tmp[1]) - binary.BigEndian.PutUint64(res[136:144], tmp[2]) - binary.BigEndian.PutUint64(res[128:136], tmp[3]) - binary.BigEndian.PutUint64(res[120:128], tmp[4]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X.B1.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[40:40+fp.Bytes]), p.X.B1.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[80:80+fp.Bytes]), p.X.B0.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[120:120+fp.Bytes]), p.X.B0.A0) res[0] |= mUncompressed @@ -1089,15 +998,31 @@ func (p *G2Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { if mData == mUncompressed { // read X and Y coordinates // p.X.B1.A1 | p.X.B1.A0 | p.X.B0.A1 | p.X.B0.A0 - p.X.B1.A1.SetBytes(buf[fp.Bytes*0 : fp.Bytes*1]) - p.X.B1.A0.SetBytes(buf[fp.Bytes*1 : fp.Bytes*2]) - p.X.B0.A1.SetBytes(buf[fp.Bytes*2 : fp.Bytes*3]) - p.X.B0.A0.SetBytes(buf[fp.Bytes*3 : fp.Bytes*4]) + if err := p.X.B1.A1.SetBytesCanonical(buf[fp.Bytes*0 : fp.Bytes*1]); err != nil { + return 0, err + } + if err := p.X.B1.A0.SetBytesCanonical(buf[fp.Bytes*1 : fp.Bytes*2]); err != nil { + return 0, err + } + if err := p.X.B0.A1.SetBytesCanonical(buf[fp.Bytes*2 : fp.Bytes*3]); err != nil { + return 0, err + } + if err := p.X.B0.A0.SetBytesCanonical(buf[fp.Bytes*3 : fp.Bytes*4]); err != nil { + return 0, err + } // p.Y.B1.A1 | p.Y.B1.A0 | p.Y.B0.A1 | p.Y.B0.A0 - p.Y.B1.A1.SetBytes(buf[fp.Bytes*4 : fp.Bytes*5]) - p.Y.B1.A0.SetBytes(buf[fp.Bytes*5 : fp.Bytes*6]) - p.Y.B0.A1.SetBytes(buf[fp.Bytes*6 : fp.Bytes*7]) - p.Y.B0.A0.SetBytes(buf[fp.Bytes*7 : fp.Bytes*8]) + if err := p.Y.B1.A1.SetBytesCanonical(buf[fp.Bytes*4 : fp.Bytes*5]); err != nil { + return 0, err + } + if err := p.Y.B1.A0.SetBytesCanonical(buf[fp.Bytes*5 : fp.Bytes*6]); err != nil { + return 0, err + } + if err := p.Y.B0.A1.SetBytesCanonical(buf[fp.Bytes*6 : fp.Bytes*7]); err != nil { + return 0, err + } + if err := p.Y.B0.A0.SetBytesCanonical(buf[fp.Bytes*7 : fp.Bytes*8]); err != nil { + return 0, err + } // subgroup check if subGroupCheck && !p.IsInSubGroup() { @@ -1118,10 +1043,18 @@ func (p *G2Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { // read X coordinate // p.X.B1.A1 | p.X.B1.A0 | p.X.B0.A1 | p.X.B0.A0 - p.X.B1.A1.SetBytes(bufX[fp.Bytes*0 : fp.Bytes*1]) - p.X.B1.A0.SetBytes(buf[fp.Bytes*1 : fp.Bytes*2]) - p.X.B0.A1.SetBytes(buf[fp.Bytes*2 : fp.Bytes*3]) - p.X.B0.A0.SetBytes(buf[fp.Bytes*3 : fp.Bytes*4]) + if err := p.X.B1.A1.SetBytesCanonical(bufX[fp.Bytes*0 : fp.Bytes*1]); err != nil { + return 0, err + } + if err := p.X.B1.A0.SetBytesCanonical(buf[fp.Bytes*1 : fp.Bytes*2]); err != nil { + return 0, err + } + if err := p.X.B0.A1.SetBytesCanonical(buf[fp.Bytes*2 : fp.Bytes*3]); err != nil { + return 0, err + } + if err := p.X.B0.A0.SetBytesCanonical(buf[fp.Bytes*3 : fp.Bytes*4]); err != nil { + return 0, err + } var YSquared, Y fptower.E4 @@ -1197,7 +1130,7 @@ func (p *G2Affine) unsafeComputeY(subGroupCheck bool) error { // assumes buf[:8] mask is set to compressed // returns true if point is infinity and need no further processing // it sets X coordinate and uses Y for scratch space to store decompression metadata -func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { +func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { // read the most significant byte mData := buf[0] & mMask @@ -1206,7 +1139,7 @@ func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { p.X.SetZero() p.Y.SetZero() isInfinity = true - return + return isInfinity, nil } // we need to copy the input buffer (to keep this method thread safe) @@ -1216,14 +1149,22 @@ func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { // read X coordinate // p.X.B1.A1 | p.X.B1.A0 | p.X.B0.A1 | p.X.B0.A0 - p.X.B1.A1.SetBytes(bufX[fp.Bytes*0 : fp.Bytes*1]) - p.X.B1.A0.SetBytes(buf[fp.Bytes*1 : fp.Bytes*2]) - p.X.B0.A1.SetBytes(buf[fp.Bytes*2 : fp.Bytes*3]) - p.X.B0.A0.SetBytes(buf[fp.Bytes*3 : fp.Bytes*4]) + if err := p.X.B1.A1.SetBytesCanonical(bufX[fp.Bytes*0 : fp.Bytes*1]); err != nil { + return false, err + } + if err := p.X.B1.A0.SetBytesCanonical(buf[fp.Bytes*1 : fp.Bytes*2]); err != nil { + return false, err + } + if err := p.X.B0.A1.SetBytesCanonical(buf[fp.Bytes*2 : fp.Bytes*3]); err != nil { + return false, err + } + if err := p.X.B0.A0.SetBytesCanonical(buf[fp.Bytes*3 : fp.Bytes*4]); err != nil { + return false, err + } // store mData in p.Y.B0.A0[0] p.Y.B0.A0[0] = uint64(mData) // recomputing Y will be done asynchronously - return + return isInfinity, nil } diff --git a/ecc/bls24-315/multiexp.go b/ecc/bls24-315/multiexp.go index 0c3d0039b..52abe34e0 100644 --- a/ecc/bls24-315/multiexp.go +++ b/ecc/bls24-315/multiexp.go @@ -25,143 +25,6 @@ import ( "runtime" ) -// selector stores the index, mask and shifts needed to select bits from a scalar -// it is used during the multiExp algorithm or the batch scalar multiplication -type selector struct { - index uint64 // index in the multi-word scalar to select bits from - mask uint64 // mask (c-bit wide) - shift uint64 // shift needed to get our bits on low positions - - multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) - maskHigh uint64 // same than mask, for index+1 - shiftHigh uint64 // same than shift, for index+1 -} - -// partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits -// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract -// 2^{c} to the current digit, making it negative. -// negative digits can be processed in a later step as adding -G into the bucket instead of G -// (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) -// scalarsMont indicates wheter the provided scalars are in montgomery form -// returns smallValues, which represent the number of scalars which meets the following condition -// 0 < scalar < 2^c (in other words, scalars where only the c-least significant bits are non zero) -func partitionScalars(scalars []fr.Element, c uint64, scalarsMont bool, nbTasks int) ([]fr.Element, int) { - toReturn := make([]fr.Element, len(scalars)) - - // number of c-bit radixes in a scalar - nbChunks := fr.Limbs * 64 / c - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // msb of the c-bit window - max := int(1 << (c - 1)) // max value we want for our digits - cDivides64 := (64 % c) == 0 // if c doesn't divide 64, we may need to select over multiple words - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := uint64(0); chunk < nbChunks; chunk++ { - jc := uint64(chunk * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } - - // for each chunk, we could track the number of non-zeros points we will need to process - // this way, if a chunk has more work to do than others, we can spawn off more go routines - // (at the cost of more buckets allocated) - // a simplified approach is to track the small values where only the first word is set - // if this number represent a significant number of points, then we will split first chunk - // processing in the msm in 2, to ensure all go routines finish at ~same time - // /!\ nbTasks is enough as parallel.Execute is not going to spawn more than nbTasks go routine - // if it does, though, this will deadlocK. - chSmallValues := make(chan int, nbTasks) - - parallel.Execute(len(scalars), func(start, end int) { - smallValues := 0 - for i := start; i < end; i++ { - var carry int - - scalar := scalars[i] - if scalarsMont { - scalar.FromMont() - } - if scalar.FitsOnOneWord() { - // everything is 0, no need to process this scalar - if scalar[0] == 0 { - continue - } - // low c-bits are 1 in mask - if scalar[0]&mask == scalar[0] { - smallValues++ - } - } - - // for each chunk in the scalar, compute the current digit, and an eventual carry - for chunk := uint64(0); chunk < nbChunks; chunk++ { - s := selectors[chunk] - - // init with carry if any - digit := carry - carry = 0 - - // digit = value of the c-bit window - digit += int((scalar[s.index] & s.mask) >> s.shift) - - if s.multiWordSelect { - // we are selecting bits over 2 words - digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh - } - - // if digit is zero, no impact on result - if digit == 0 { - continue - } - - // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract - // 2^{c} to the current digit, making it negative. - if digit >= max { - digit -= (1 << c) - carry = 1 - } - - var bits uint64 - if digit >= 0 { - bits = uint64(digit) - } else { - bits = uint64(-digit-1) | msbWindow - } - - toReturn[i][s.index] |= (bits << s.shift) - if s.multiWordSelect { - toReturn[i][s.index+1] |= (bits >> s.shiftHigh) - } - - } - } - - chSmallValues <- smallValues - - }, nbTasks) - - // aggregate small values - close(chSmallValues) - smallValues := 0 - for o := range chSmallValues { - smallValues += o - } - return toReturn, smallValues -} - // MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf // // This call return an error if len(scalars) != len(points) or if provided config is invalid. @@ -221,7 +84,7 @@ func (p *G1Jac) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.Mul // we split recursively until nbChunks(c) >= nbTasks, bestC := func(nbPoints int) uint64 { // implemented msmC methods (the c we use must be in this slice) - implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21} + implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} var C uint64 // approximate cost (in group operations) // cost = bits/c * (nbPoints + 2^{c}) @@ -229,118 +92,177 @@ func (p *G1Jac) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.Mul // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results min := math.MaxFloat64 for _, c := range implementedCs { - cc := fr.Limbs * 64 * (nbPoints + (1 << (c))) + cc := (fr.Bits + 1) * (nbPoints + (1 << c)) cost := float64(cc) / float64(c) if cost < min { min = cost C = c } } - // empirical, needs to be tuned. - // if C > 16 && nbPoints < 1 << 23 { - // C = 16 - // } return C } - var C uint64 - nbSplits := 1 - nbChunks := 0 - for nbChunks < config.NbTasks { - C = bestC(nbPoints) - nbChunks = int(fr.Limbs * 64 / C) // number of c-bit radixes in a scalar - if (fr.Limbs*64)%C != 0 { - nbChunks++ - } - nbChunks *= nbSplits - if nbChunks < config.NbTasks { - nbSplits <<= 1 - nbPoints >>= 1 + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // if we don't utilise all the tasks (CPU in the default case) that we could, let's see if it's worth it to split + if config.NbTasks > 1 && nbChunks < config.NbTasks { + // before spliting, let's see if we endup with more tasks than thread; + cSplit := bestC(nbPoints / 2) + nbChunksPostSplit := int(computeNbChunks(cSplit)) + nbTasksPostSplit := nbChunksPostSplit * 2 + if (nbTasksPostSplit <= config.NbTasks/2) || (nbTasksPostSplit-config.NbTasks/2) <= (config.NbTasks-nbChunks) { + // if postSplit we still have less tasks than available CPU + // or if we have more tasks BUT the difference of CPU usage is in our favor, we split. + config.NbTasks /= 2 + var _p G1Jac + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil } } + _innerMsmG1(p, C, points, scalars, config) + + return p, nil +} + +func _innerMsmG1(p *G1Jac, c uint64, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { // partition the scalars - // note: we do that before the actual chunk processing, as for each c-bit window (starting from LSW) - // if it's larger than 2^{c-1}, we have a carry we need to propagate up to the higher window - var smallValues int - scalars, smallValues = partitionScalars(scalars, C, config.ScalarsMont, config.NbTasks) - - // if we have more than 10% of small values, we split the processing of the first chunk in 2 - // we may want to do that in msmInnerG1Jac , but that would incur a cost of looping through all scalars one more time - splitFirstChunk := (float64(smallValues) / float64(len(scalars))) >= 0.1 - - // we have nbSplits intermediate results that we must sum together. - _p := make([]G1Jac, nbSplits-1) - chDone := make(chan int, nbSplits-1) - for i := 0; i < nbSplits-1; i++ { - start := i * nbPoints - end := start + nbPoints - go func(start, end, i int) { - msmInnerG1Jac(&_p[i], int(C), points[start:end], scalars[start:end], splitFirstChunk) - chDone <- i - }(start, end, i) + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) } - msmInnerG1Jac(p, int(C), points[(nbSplits-1)*nbPoints:], scalars[(nbSplits-1)*nbPoints:], splitFirstChunk) - for i := 0; i < nbSplits-1; i++ { - done := <-chDone - p.AddAssign(&_p[done]) + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessorG1(c, chunkStats[j]) + if j == int(nbChunks-1) { + processChunk = getChunkProcessorG1(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan g1JacExtended, 2) + split := n / 2 + go processChunk(uint64(j), chSplit, c, points[:split], digits[j*n:(j*n)+split]) + go processChunk(uint64(j), chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n]) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n]) } - close(chDone) - return p, nil -} -func msmInnerG1Jac(p *G1Jac, c int, points []G1Affine, scalars []fr.Element, splitFirstChunk bool) { + return msmReduceChunkG1Affine(p, int(c), chChunks[:]) +} +// getChunkProcessorG1 decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessorG1(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- g1JacExtended, c uint64, points []G1Affine, digits []uint16) { switch c { + case 2: + return processChunkG1Jacobian[bucketg1JacExtendedC2] case 4: - p.msmC4(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC4] case 5: - p.msmC5(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC5] case 6: - p.msmC6(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC6] case 7: - p.msmC7(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC7] case 8: - p.msmC8(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC8] case 9: - p.msmC9(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC9] case 10: - p.msmC10(points, scalars, splitFirstChunk) - + const batchSize = 80 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC10] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC10, bucketG1AffineC10, bitSetC10, pG1AffineC10, ppG1AffineC10, qG1AffineC10, cG1AffineC10] case 11: - p.msmC11(points, scalars, splitFirstChunk) - + const batchSize = 150 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC11] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC11, bucketG1AffineC11, bitSetC11, pG1AffineC11, ppG1AffineC11, qG1AffineC11, cG1AffineC11] case 12: - p.msmC12(points, scalars, splitFirstChunk) - + const batchSize = 200 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC12] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC12, bucketG1AffineC12, bitSetC12, pG1AffineC12, ppG1AffineC12, qG1AffineC12, cG1AffineC12] case 13: - p.msmC13(points, scalars, splitFirstChunk) - + const batchSize = 350 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC13] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC13, bucketG1AffineC13, bitSetC13, pG1AffineC13, ppG1AffineC13, qG1AffineC13, cG1AffineC13] case 14: - p.msmC14(points, scalars, splitFirstChunk) - + const batchSize = 400 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC14] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC14, bucketG1AffineC14, bitSetC14, pG1AffineC14, ppG1AffineC14, qG1AffineC14, cG1AffineC14] case 15: - p.msmC15(points, scalars, splitFirstChunk) - + const batchSize = 500 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC15] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC15, bucketG1AffineC15, bitSetC15, pG1AffineC15, ppG1AffineC15, qG1AffineC15, cG1AffineC15] case 16: - p.msmC16(points, scalars, splitFirstChunk) - - case 20: - p.msmC20(points, scalars, splitFirstChunk) - - case 21: - p.msmC21(points, scalars, splitFirstChunk) - + const batchSize = 640 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC16] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC16, bucketG1AffineC16, bitSetC16, pG1AffineC16, ppG1AffineC16, qG1AffineC16, cG1AffineC16] default: - panic("not implemented") + // panic("will not happen c != previous values is not generated by templates") + return processChunkG1Jacobian[bucketg1JacExtendedC16] } } @@ -360,1846 +282,445 @@ func msmReduceChunkG1Affine(p *G1Jac, c int, chChunks []chan g1JacExtended) *G1J return p.unsafeFromJacExtended(&_p) } -func msmProcessChunkG1Affine(chunk uint64, - chRes chan<- g1JacExtended, - buckets []g1JacExtended, - c uint64, - points []G1Affine, - scalars []fr.Element) { +// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf +// +// This call return an error if len(scalars) != len(points) or if provided config is invalid. +func (p *G2Affine) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Affine, error) { + var _p G2Jac + if _, err := _p.MultiExp(points, scalars, config); err != nil { + return nil, err + } + p.FromJacobian(&_p) + return p, nil +} - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) +// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf +// +// This call return an error if len(scalars) != len(points) or if provided config is invalid. +func (p *G2Jac) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Jac, error) { + // note: + // each of the msmCX method is the same, except for the c constant it declares + // duplicating (through template generation) these methods allows to declare the buckets on the stack + // the choice of c needs to be improved: + // there is a theoritical value that gives optimal asymptotics + // but in practice, other factors come into play, including: + // * if c doesn't divide 64, the word size, then we're bound to select bits over 2 words of our scalars, instead of 1 + // * number of CPUs + // * cache friendliness (which depends on the host, G1 or G2... ) + // --> for example, on BN254, a G1 point fits into one cache line of 64bytes, but a G2 point don't. - for i := 0; i < len(buckets); i++ { - buckets[i].setInfinity() - } + // for each msmCX + // step 1 + // we compute, for each scalars over c-bit wide windows, nbChunk digits + // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract + // 2^{c} to the current digit, making it negative. + // negative digits will be processed in the next step as adding -G into the bucket instead of G + // (computing -G is cheap, and this saves us half of the buckets) + // step 2 + // buckets are declared on the stack + // notice that we have 2^{c-1} buckets instead of 2^{c} (see step1) + // we use jacobian extended formulas here as they are faster than mixed addition + // msmProcessChunk places points into buckets base on their selector and return the weighted bucket sum in given channel + // step 3 + // reduce the buckets weigthed sums into our result (msmReduceChunk) - jc := uint64(chunk * c) - s := selector{} - s.index = jc / 64 - s.shift = jc - (s.index * 64) - s.mask = mask << s.shift - s.multiWordSelect = (64%c) != 0 && s.shift > (64-c) && s.index < (fr.Limbs-1) - if s.multiWordSelect { - nbBitsHigh := s.shift - uint64(64-c) - s.maskHigh = (1 << nbBitsHigh) - 1 - s.shiftHigh = (c - nbBitsHigh) + // ensure len(points) == len(scalars) + nbPoints := len(points) + if nbPoints != len(scalars) { + return nil, errors.New("len(points) != len(scalars)") } - // for each scalars, get the digit corresponding to the chunk we're processing. - for i := 0; i < len(scalars); i++ { - bits := (scalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (scalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { - continue - } + // if nbTasks is not set, use all available CPUs + if config.NbTasks <= 0 { + config.NbTasks = runtime.NumCPU() + } else if config.NbTasks > 1024 { + return nil, errors.New("invalid config: config.NbTasks > 1024") + } - // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { - // add - buckets[bits-1].addMixed(&points[i]) - } else { - // sub - buckets[bits & ^msbWindow].subMixed(&points[i]) + // here, we compute the best C for nbPoints + // we split recursively until nbChunks(c) >= nbTasks, + bestC := func(nbPoints int) uint64 { + // implemented msmC methods (the c we use must be in this slice) + implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + var C uint64 + // approximate cost (in group operations) + // cost = bits/c * (nbPoints + 2^{c}) + // this needs to be verified empirically. + // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results + min := math.MaxFloat64 + for _, c := range implementedCs { + cc := (fr.Bits + 1) * (nbPoints + (1 << c)) + cost := float64(cc) / float64(c) + if cost < min { + min = cost + C = c + } } + return C } - // reduce buckets into total - // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] - - var runningSum, total g1JacExtended - runningSum.setInfinity() - total.setInfinity() - for k := len(buckets) - 1; k >= 0; k-- { - if !buckets[k].ZZ.IsZero() { - runningSum.add(&buckets[k]) + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // if we don't utilise all the tasks (CPU in the default case) that we could, let's see if it's worth it to split + if config.NbTasks > 1 && nbChunks < config.NbTasks { + // before spliting, let's see if we endup with more tasks than thread; + cSplit := bestC(nbPoints / 2) + nbChunksPostSplit := int(computeNbChunks(cSplit)) + nbTasksPostSplit := nbChunksPostSplit * 2 + if (nbTasksPostSplit <= config.NbTasks/2) || (nbTasksPostSplit-config.NbTasks/2) <= (config.NbTasks-nbChunks) { + // if postSplit we still have less tasks than available CPU + // or if we have more tasks BUT the difference of CPU usage is in our favor, we split. + config.NbTasks /= 2 + var _p G2Jac + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil } - total.add(&runningSum) } - chRes <- total + _innerMsmG2(p, C, points, scalars, config) + return p, nil } -func (p *G1Jac) msmC4(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 4 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) +func _innerMsmG2(p *G2Jac, c uint64, points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G2Jac { + // partition the scalars + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) // for each chunk, spawn one go routine that'll loop through all the scalars in the // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + // note that buckets is an array allocated on the stack and this is critical for performance // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended + chChunks := make([]chan g2JacExtended, nbChunks) for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) + chChunks[i] = make(chan g2JacExtended, 1) } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessorG2(c, chunkStats[j]) + if j == int(nbChunks-1) { + processChunk = getChunkProcessorG2(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan g2JacExtended, 2) + split := n / 2 + go processChunk(uint64(j), chSplit, c, points[:split], digits[j*n:(j*n)+split]) + go processChunk(uint64(j), chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n]) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n]) } - return msmReduceChunkG1Affine(p, c, chChunks[:]) + return msmReduceChunkG2Affine(p, int(c), chChunks[:]) } -func (p *G1Jac) msmC5(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 5 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } +// getChunkProcessorG2 decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessorG2(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- g2JacExtended, c uint64, points []G2Affine, digits []uint16) { + switch c { - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + case 2: + return processChunkG2Jacobian[bucketg2JacExtendedC2] + case 4: + return processChunkG2Jacobian[bucketg2JacExtendedC4] + case 5: + return processChunkG2Jacobian[bucketg2JacExtendedC5] + case 6: + return processChunkG2Jacobian[bucketg2JacExtendedC6] + case 7: + return processChunkG2Jacobian[bucketg2JacExtendedC7] + case 8: + return processChunkG2Jacobian[bucketg2JacExtendedC8] + case 9: + return processChunkG2Jacobian[bucketg2JacExtendedC9] + case 10: + const batchSize = 80 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC10] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC10, bucketG2AffineC10, bitSetC10, pG2AffineC10, ppG2AffineC10, qG2AffineC10, cG2AffineC10] + case 11: + const batchSize = 150 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC11] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC11, bucketG2AffineC11, bitSetC11, pG2AffineC11, ppG2AffineC11, qG2AffineC11, cG2AffineC11] + case 12: + const batchSize = 200 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC12] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC12, bucketG2AffineC12, bitSetC12, pG2AffineC12, ppG2AffineC12, qG2AffineC12, cG2AffineC12] + case 13: + const batchSize = 350 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC13] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC13, bucketG2AffineC13, bitSetC13, pG2AffineC13, ppG2AffineC13, qG2AffineC13, cG2AffineC13] + case 14: + const batchSize = 400 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC14] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC14, bucketG2AffineC14, bitSetC14, pG2AffineC14, ppG2AffineC14, qG2AffineC14, cG2AffineC14] + case 15: + const batchSize = 500 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC15] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC15, bucketG2AffineC15, bitSetC15, pG2AffineC15, ppG2AffineC15, qG2AffineC15, cG2AffineC15] + case 16: + const batchSize = 640 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC16] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC16, bucketG2AffineC16, bitSetC16, pG2AffineC16, ppG2AffineC16, qG2AffineC16, cG2AffineC16] + default: + // panic("will not happen c != previous values is not generated by templates") + return processChunkG2Jacobian[bucketg2JacExtendedC16] } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) } -func (p *G1Jac) msmC6(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 6 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC7(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 7 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC8(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 8 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC9(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 9 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC10(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 10 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC11(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 11 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC12(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 12 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC13(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 13 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC14(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 14 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC15(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 15 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC16(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 16 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC20(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 20 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC21(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 21 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf -// -// This call return an error if len(scalars) != len(points) or if provided config is invalid. -func (p *G2Affine) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Affine, error) { - var _p G2Jac - if _, err := _p.MultiExp(points, scalars, config); err != nil { - return nil, err - } - p.FromJacobian(&_p) - return p, nil -} - -// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf -// -// This call return an error if len(scalars) != len(points) or if provided config is invalid. -func (p *G2Jac) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Jac, error) { - // note: - // each of the msmCX method is the same, except for the c constant it declares - // duplicating (through template generation) these methods allows to declare the buckets on the stack - // the choice of c needs to be improved: - // there is a theoritical value that gives optimal asymptotics - // but in practice, other factors come into play, including: - // * if c doesn't divide 64, the word size, then we're bound to select bits over 2 words of our scalars, instead of 1 - // * number of CPUs - // * cache friendliness (which depends on the host, G1 or G2... ) - // --> for example, on BN254, a G1 point fits into one cache line of 64bytes, but a G2 point don't. - - // for each msmCX - // step 1 - // we compute, for each scalars over c-bit wide windows, nbChunk digits - // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract - // 2^{c} to the current digit, making it negative. - // negative digits will be processed in the next step as adding -G into the bucket instead of G - // (computing -G is cheap, and this saves us half of the buckets) - // step 2 - // buckets are declared on the stack - // notice that we have 2^{c-1} buckets instead of 2^{c} (see step1) - // we use jacobian extended formulas here as they are faster than mixed addition - // msmProcessChunk places points into buckets base on their selector and return the weighted bucket sum in given channel - // step 3 - // reduce the buckets weigthed sums into our result (msmReduceChunk) - - // ensure len(points) == len(scalars) - nbPoints := len(points) - if nbPoints != len(scalars) { - return nil, errors.New("len(points) != len(scalars)") - } - - // if nbTasks is not set, use all available CPUs - if config.NbTasks <= 0 { - config.NbTasks = runtime.NumCPU() - } else if config.NbTasks > 1024 { - return nil, errors.New("invalid config: config.NbTasks > 1024") - } - - // here, we compute the best C for nbPoints - // we split recursively until nbChunks(c) >= nbTasks, - bestC := func(nbPoints int) uint64 { - // implemented msmC methods (the c we use must be in this slice) - implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21} - var C uint64 - // approximate cost (in group operations) - // cost = bits/c * (nbPoints + 2^{c}) - // this needs to be verified empirically. - // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results - min := math.MaxFloat64 - for _, c := range implementedCs { - cc := fr.Limbs * 64 * (nbPoints + (1 << (c))) - cost := float64(cc) / float64(c) - if cost < min { - min = cost - C = c - } - } - // empirical, needs to be tuned. - // if C > 16 && nbPoints < 1 << 23 { - // C = 16 - // } - return C - } - - var C uint64 - nbSplits := 1 - nbChunks := 0 - for nbChunks < config.NbTasks { - C = bestC(nbPoints) - nbChunks = int(fr.Limbs * 64 / C) // number of c-bit radixes in a scalar - if (fr.Limbs*64)%C != 0 { - nbChunks++ - } - nbChunks *= nbSplits - if nbChunks < config.NbTasks { - nbSplits <<= 1 - nbPoints >>= 1 - } - } - - // partition the scalars - // note: we do that before the actual chunk processing, as for each c-bit window (starting from LSW) - // if it's larger than 2^{c-1}, we have a carry we need to propagate up to the higher window - var smallValues int - scalars, smallValues = partitionScalars(scalars, C, config.ScalarsMont, config.NbTasks) - - // if we have more than 10% of small values, we split the processing of the first chunk in 2 - // we may want to do that in msmInnerG2Jac , but that would incur a cost of looping through all scalars one more time - splitFirstChunk := (float64(smallValues) / float64(len(scalars))) >= 0.1 - - // we have nbSplits intermediate results that we must sum together. - _p := make([]G2Jac, nbSplits-1) - chDone := make(chan int, nbSplits-1) - for i := 0; i < nbSplits-1; i++ { - start := i * nbPoints - end := start + nbPoints - go func(start, end, i int) { - msmInnerG2Jac(&_p[i], int(C), points[start:end], scalars[start:end], splitFirstChunk) - chDone <- i - }(start, end, i) - } - - msmInnerG2Jac(p, int(C), points[(nbSplits-1)*nbPoints:], scalars[(nbSplits-1)*nbPoints:], splitFirstChunk) - for i := 0; i < nbSplits-1; i++ { - done := <-chDone - p.AddAssign(&_p[done]) - } - close(chDone) - return p, nil -} - -func msmInnerG2Jac(p *G2Jac, c int, points []G2Affine, scalars []fr.Element, splitFirstChunk bool) { - - switch c { - - case 4: - p.msmC4(points, scalars, splitFirstChunk) - - case 5: - p.msmC5(points, scalars, splitFirstChunk) - - case 6: - p.msmC6(points, scalars, splitFirstChunk) - - case 7: - p.msmC7(points, scalars, splitFirstChunk) - - case 8: - p.msmC8(points, scalars, splitFirstChunk) - - case 9: - p.msmC9(points, scalars, splitFirstChunk) - - case 10: - p.msmC10(points, scalars, splitFirstChunk) - - case 11: - p.msmC11(points, scalars, splitFirstChunk) - - case 12: - p.msmC12(points, scalars, splitFirstChunk) - - case 13: - p.msmC13(points, scalars, splitFirstChunk) - - case 14: - p.msmC14(points, scalars, splitFirstChunk) - - case 15: - p.msmC15(points, scalars, splitFirstChunk) - - case 16: - p.msmC16(points, scalars, splitFirstChunk) - - case 20: - p.msmC20(points, scalars, splitFirstChunk) - - case 21: - p.msmC21(points, scalars, splitFirstChunk) - - default: - panic("not implemented") - } -} - -// msmReduceChunkG2Affine reduces the weighted sum of the buckets into the result of the multiExp -func msmReduceChunkG2Affine(p *G2Jac, c int, chChunks []chan g2JacExtended) *G2Jac { - var _p g2JacExtended - totalj := <-chChunks[len(chChunks)-1] - _p.Set(&totalj) - for j := len(chChunks) - 2; j >= 0; j-- { - for l := 0; l < c; l++ { - _p.double(&_p) - } - totalj := <-chChunks[j] - _p.add(&totalj) - } - - return p.unsafeFromJacExtended(&_p) -} - -func msmProcessChunkG2Affine(chunk uint64, - chRes chan<- g2JacExtended, - buckets []g2JacExtended, - c uint64, - points []G2Affine, - scalars []fr.Element) { - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) - - for i := 0; i < len(buckets); i++ { - buckets[i].setInfinity() - } - - jc := uint64(chunk * c) - s := selector{} - s.index = jc / 64 - s.shift = jc - (s.index * 64) - s.mask = mask << s.shift - s.multiWordSelect = (64%c) != 0 && s.shift > (64-c) && s.index < (fr.Limbs-1) - if s.multiWordSelect { - nbBitsHigh := s.shift - uint64(64-c) - s.maskHigh = (1 << nbBitsHigh) - 1 - s.shiftHigh = (c - nbBitsHigh) - } - - // for each scalars, get the digit corresponding to the chunk we're processing. - for i := 0; i < len(scalars); i++ { - bits := (scalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (scalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { - continue - } - - // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { - // add - buckets[bits-1].addMixed(&points[i]) - } else { - // sub - buckets[bits & ^msbWindow].subMixed(&points[i]) - } - } - - // reduce buckets into total - // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] - - var runningSum, total g2JacExtended - runningSum.setInfinity() - total.setInfinity() - for k := len(buckets) - 1; k >= 0; k-- { - if !buckets[k].ZZ.IsZero() { - runningSum.add(&buckets[k]) - } - total.add(&runningSum) - } - - chRes <- total - -} - -func (p *G2Jac) msmC4(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 4 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC5(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 5 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC6(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 6 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC7(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 7 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC8(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 8 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC9(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 9 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC10(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 10 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC11(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 11 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() +// msmReduceChunkG2Affine reduces the weighted sum of the buckets into the result of the multiExp +func msmReduceChunkG2Affine(p *G2Jac, c int, chChunks []chan g2JacExtended) *G2Jac { + var _p g2JacExtended + totalj := <-chChunks[len(chChunks)-1] + _p.Set(&totalj) + for j := len(chChunks) - 2; j >= 0; j-- { + for l := 0; l < c; l++ { + _p.double(&_p) + } + totalj := <-chChunks[j] + _p.add(&totalj) } - return msmReduceChunkG2Affine(p, c, chChunks[:]) + return p.unsafeFromJacExtended(&_p) } -func (p *G2Jac) msmC12(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 12 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } +// selector stores the index, mask and shifts needed to select bits from a scalar +// it is used during the multiExp algorithm or the batch scalar multiplication +type selector struct { + index uint64 // index in the multi-word scalar to select bits from + mask uint64 // mask (c-bit wide) + shift uint64 // shift needed to get our bits on low positions - return msmReduceChunkG2Affine(p, c, chChunks[:]) + multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) + maskHigh uint64 // same than mask, for index+1 + shiftHigh uint64 // same than shift, for index+1 } -func (p *G2Jac) msmC13(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 13 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) +// return number of chunks for a given window size c +// the last chunk may be bigger to accomodate a potential carry from the NAF decomposition +func computeNbChunks(c uint64) uint64 { + return (fr.Bits + c - 1) / c } -func (p *G2Jac) msmC14(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 14 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) +// return the last window size for a scalar; +// this last window should accomodate a carry (from the NAF decomposition) +// it can be == c if we have 1 available bit +// it can be > c if we have 0 available bit +// it can be < c if we have 2+ available bits +func lastC(c uint64) uint64 { + nbAvailableBits := (computeNbChunks(c) * c) - fr.Bits + return c + 1 - nbAvailableBits } -func (p *G2Jac) msmC15(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 15 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } +type chunkStat struct { + // relative weight of work compared to other chunks. 100.0 -> nominal weight. + weight float32 - return msmReduceChunkG2Affine(p, c, chChunks[:]) + // percentage of bucket filled in the window; + ppBucketFilled float32 + nbBucketFilled int } -func (p *G2Jac) msmC16(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 16 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } +// partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits +// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract +// 2^{c} to the current digit, making it negative. +// negative digits can be processed in a later step as adding -G into the bucket instead of G +// (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) +func partitionScalars(scalars []fr.Element, c uint64, nbTasks int) ([]uint16, []chunkStat) { + // number of c-bit radixes in a scalar + nbChunks := computeNbChunks(c) - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + digits := make([]uint16, len(scalars)*int(nbChunks)) - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } + mask := uint64((1 << c) - 1) // low c bits are 1 + max := int(1<<(c-1)) - 1 // max value (inclusive) we want for our digits + cDivides64 := (64 % c) == 0 // if c doesn't divide 64, we may need to select over multiple words - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + // compute offset and word selector / shift to select the right bits of our windows + selectors := make([]selector, nbChunks) + for chunk := uint64(0); chunk < nbChunks; chunk++ { + jc := uint64(chunk * c) + d := selector{} + d.index = jc / 64 + d.shift = jc - (d.index * 64) + d.mask = mask << d.shift + d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs-1) + if d.multiWordSelect { + nbBitsHigh := d.shift - uint64(64-c) + d.maskHigh = (1 << nbBitsHigh) - 1 + d.shiftHigh = (c - nbBitsHigh) + } + selectors[chunk] = d } - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} + parallel.Execute(len(scalars), func(start, end int) { + for i := start; i < end; i++ { + if scalars[i].IsZero() { + // everything is 0, no need to process this scalar + continue + } + scalar := scalars[i].Bits() -func (p *G2Jac) msmC20(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 20 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) + var carry int - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + // for each chunk in the scalar, compute the current digit, and an eventual carry + for chunk := uint64(0); chunk < nbChunks-1; chunk++ { + s := selectors[chunk] - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } + // init with carry if any + digit := carry + carry = 0 - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } + // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract + // 2^{c} to the current digit, making it negative. + if digit > max { + digit -= (1 << c) + carry = 1 + } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } + // if digit is zero, no impact on result + if digit == 0 { + continue + } - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} + var bits uint16 + if digit > 0 { + bits = uint16(digit) << 1 + } else { + bits = (uint16(-digit-1) << 1) + 1 + } + digits[int(chunk)*len(scalars)+i] = bits + } -func (p *G2Jac) msmC21(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 21 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) + // for the last chunk, we don't want to borrow from a next window + // (but may have a larger max value) + chunk := nbChunks - 1 + s := selectors[chunk] + // init with carry if any + digit := carry + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } + digits[int(chunk)*len(scalars)+i] = uint16(digit) << 1 + } - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + }, nbTasks) - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) + // aggregate chunk stats + chunkStats := make([]chunkStat, nbChunks) + if c <= 9 { + // no need to compute stats for small window sizes + return digits, chunkStats } + parallel.Execute(len(chunkStats), func(start, end int) { + // for each chunk compute the statistics + for chunkID := start; chunkID < end; chunkID++ { + // indicates if a bucket is hit. + var b bitSetC16 - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) + // digits for the chunk + chunkDigits := digits[chunkID*len(scalars) : (chunkID+1)*len(scalars)] - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + totalOps := 0 + nz := 0 // non zero buckets count + for _, digit := range chunkDigits { + if digit == 0 { + continue + } + totalOps++ + bucketID := digit >> 1 + if digit&1 == 0 { + bucketID -= 1 + } + if !b[bucketID] { + nz++ + b[bucketID] = true + } + } + chunkStats[chunkID].weight = float32(totalOps) // count number of ops for now, we will compute the weight after + chunkStats[chunkID].ppBucketFilled = (float32(nz) * 100.0) / float32(int(1<<(c-1))) + chunkStats[chunkID].nbBucketFilled = nz + } + }, nbTasks) - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) + totalOps := float32(0.0) + for _, stat := range chunkStats { + totalOps += stat.weight } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + target := totalOps / float32(nbChunks) + if target != 0.0 { + // if target == 0, it means all the scalars are 0 everywhere, there is no work to be done. + for i := 0; i < len(chunkStats); i++ { + chunkStats[i].weight = (chunkStats[i].weight * 100.0) / target + } } - return msmReduceChunkG2Affine(p, c, chChunks[:]) + return digits, chunkStats } diff --git a/ecc/bls24-315/multiexp_affine.go b/ecc/bls24-315/multiexp_affine.go new file mode 100644 index 000000000..40a45408f --- /dev/null +++ b/ecc/bls24-315/multiexp_affine.go @@ -0,0 +1,686 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package bls24315 + +import ( + "github.com/consensys/gnark-crypto/ecc/bls24-315/fp" + "github.com/consensys/gnark-crypto/ecc/bls24-315/internal/fptower" +) + +type batchOpG1Affine struct { + bucketID uint16 + point G1Affine +} + +// processChunkG1BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/ConsenSys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunkG1BatchAffine[BJE ibg1JacExtended, B ibG1Affine, BS bitSet, TP pG1Affine, TPP ppG1Affine, TQ qOpsG1Affine, TC cG1Affine]( + chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16) { + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // g1JacExtended coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediatly + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in G1Affine used with the batch affine additions + // 1 in g1JacExtended used in case the queue of conflicting points + var buckets B + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + bucketsJE[i].setInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize } + + executeAndReset := func() { + batchAddG1Affine[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOpG1Affine) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.setInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *G1Affine, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.setInfinity() + } + return + } + if isAdd { + BK.setInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func() { + for i := 0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func() { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit >> 1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID -= 1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue)-1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total g1JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].ZZ.IsZero() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketG1AffineC10 [512]G1Affine +type bucketG1AffineC11 [1024]G1Affine +type bucketG1AffineC12 [2048]G1Affine +type bucketG1AffineC13 [4096]G1Affine +type bucketG1AffineC14 [8192]G1Affine +type bucketG1AffineC15 [16384]G1Affine +type bucketG1AffineC16 [32768]G1Affine + +// buckets: array of G1Affine points of size 1 << (c-1) +type ibG1Affine interface { + bucketG1AffineC10 | + bucketG1AffineC11 | + bucketG1AffineC12 | + bucketG1AffineC13 | + bucketG1AffineC14 | + bucketG1AffineC15 | + bucketG1AffineC16 +} + +// array of coordinates fp.Element +type cG1Affine interface { + cG1AffineC10 | + cG1AffineC11 | + cG1AffineC12 | + cG1AffineC13 | + cG1AffineC14 | + cG1AffineC15 | + cG1AffineC16 +} + +// buckets: array of G1Affine points (for the batch addition) +type pG1Affine interface { + pG1AffineC10 | + pG1AffineC11 | + pG1AffineC12 | + pG1AffineC13 | + pG1AffineC14 | + pG1AffineC15 | + pG1AffineC16 +} + +// buckets: array of *G1Affine points (for the batch addition) +type ppG1Affine interface { + ppG1AffineC10 | + ppG1AffineC11 | + ppG1AffineC12 | + ppG1AffineC13 | + ppG1AffineC14 | + ppG1AffineC15 | + ppG1AffineC16 +} + +// buckets: array of G1Affine queue operations (for the batch addition) +type qOpsG1Affine interface { + qG1AffineC10 | + qG1AffineC11 | + qG1AffineC12 | + qG1AffineC13 | + qG1AffineC14 | + qG1AffineC15 | + qG1AffineC16 +} + +// batch size 80 when c = 10 +type cG1AffineC10 [80]fp.Element +type pG1AffineC10 [80]G1Affine +type ppG1AffineC10 [80]*G1Affine +type qG1AffineC10 [80]batchOpG1Affine + +// batch size 150 when c = 11 +type cG1AffineC11 [150]fp.Element +type pG1AffineC11 [150]G1Affine +type ppG1AffineC11 [150]*G1Affine +type qG1AffineC11 [150]batchOpG1Affine + +// batch size 200 when c = 12 +type cG1AffineC12 [200]fp.Element +type pG1AffineC12 [200]G1Affine +type ppG1AffineC12 [200]*G1Affine +type qG1AffineC12 [200]batchOpG1Affine + +// batch size 350 when c = 13 +type cG1AffineC13 [350]fp.Element +type pG1AffineC13 [350]G1Affine +type ppG1AffineC13 [350]*G1Affine +type qG1AffineC13 [350]batchOpG1Affine + +// batch size 400 when c = 14 +type cG1AffineC14 [400]fp.Element +type pG1AffineC14 [400]G1Affine +type ppG1AffineC14 [400]*G1Affine +type qG1AffineC14 [400]batchOpG1Affine + +// batch size 500 when c = 15 +type cG1AffineC15 [500]fp.Element +type pG1AffineC15 [500]G1Affine +type ppG1AffineC15 [500]*G1Affine +type qG1AffineC15 [500]batchOpG1Affine + +// batch size 640 when c = 16 +type cG1AffineC16 [640]fp.Element +type pG1AffineC16 [640]G1Affine +type ppG1AffineC16 [640]*G1Affine +type qG1AffineC16 [640]batchOpG1Affine + +type batchOpG2Affine struct { + bucketID uint16 + point G2Affine +} + +// processChunkG2BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/ConsenSys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunkG2BatchAffine[BJE ibg2JacExtended, B ibG2Affine, BS bitSet, TP pG2Affine, TPP ppG2Affine, TQ qOpsG2Affine, TC cG2Affine]( + chunk uint64, + chRes chan<- g2JacExtended, + c uint64, + points []G2Affine, + digits []uint16) { + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // g2JacExtended coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediatly + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in G2Affine used with the batch affine additions + // 1 in g2JacExtended used in case the queue of conflicting points + var buckets B + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + bucketsJE[i].setInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize } + + executeAndReset := func() { + batchAddG2Affine[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOpG2Affine) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.setInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *G2Affine, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.setInfinity() + } + return + } + if isAdd { + BK.setInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func() { + for i := 0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func() { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit >> 1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID -= 1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue)-1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total g2JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].ZZ.IsZero() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketG2AffineC10 [512]G2Affine +type bucketG2AffineC11 [1024]G2Affine +type bucketG2AffineC12 [2048]G2Affine +type bucketG2AffineC13 [4096]G2Affine +type bucketG2AffineC14 [8192]G2Affine +type bucketG2AffineC15 [16384]G2Affine +type bucketG2AffineC16 [32768]G2Affine + +// buckets: array of G2Affine points of size 1 << (c-1) +type ibG2Affine interface { + bucketG2AffineC10 | + bucketG2AffineC11 | + bucketG2AffineC12 | + bucketG2AffineC13 | + bucketG2AffineC14 | + bucketG2AffineC15 | + bucketG2AffineC16 +} + +// array of coordinates fptower.E4 +type cG2Affine interface { + cG2AffineC10 | + cG2AffineC11 | + cG2AffineC12 | + cG2AffineC13 | + cG2AffineC14 | + cG2AffineC15 | + cG2AffineC16 +} + +// buckets: array of G2Affine points (for the batch addition) +type pG2Affine interface { + pG2AffineC10 | + pG2AffineC11 | + pG2AffineC12 | + pG2AffineC13 | + pG2AffineC14 | + pG2AffineC15 | + pG2AffineC16 +} + +// buckets: array of *G2Affine points (for the batch addition) +type ppG2Affine interface { + ppG2AffineC10 | + ppG2AffineC11 | + ppG2AffineC12 | + ppG2AffineC13 | + ppG2AffineC14 | + ppG2AffineC15 | + ppG2AffineC16 +} + +// buckets: array of G2Affine queue operations (for the batch addition) +type qOpsG2Affine interface { + qG2AffineC10 | + qG2AffineC11 | + qG2AffineC12 | + qG2AffineC13 | + qG2AffineC14 | + qG2AffineC15 | + qG2AffineC16 +} + +// batch size 80 when c = 10 +type cG2AffineC10 [80]fptower.E4 +type pG2AffineC10 [80]G2Affine +type ppG2AffineC10 [80]*G2Affine +type qG2AffineC10 [80]batchOpG2Affine + +// batch size 150 when c = 11 +type cG2AffineC11 [150]fptower.E4 +type pG2AffineC11 [150]G2Affine +type ppG2AffineC11 [150]*G2Affine +type qG2AffineC11 [150]batchOpG2Affine + +// batch size 200 when c = 12 +type cG2AffineC12 [200]fptower.E4 +type pG2AffineC12 [200]G2Affine +type ppG2AffineC12 [200]*G2Affine +type qG2AffineC12 [200]batchOpG2Affine + +// batch size 350 when c = 13 +type cG2AffineC13 [350]fptower.E4 +type pG2AffineC13 [350]G2Affine +type ppG2AffineC13 [350]*G2Affine +type qG2AffineC13 [350]batchOpG2Affine + +// batch size 400 when c = 14 +type cG2AffineC14 [400]fptower.E4 +type pG2AffineC14 [400]G2Affine +type ppG2AffineC14 [400]*G2Affine +type qG2AffineC14 [400]batchOpG2Affine + +// batch size 500 when c = 15 +type cG2AffineC15 [500]fptower.E4 +type pG2AffineC15 [500]G2Affine +type ppG2AffineC15 [500]*G2Affine +type qG2AffineC15 [500]batchOpG2Affine + +// batch size 640 when c = 16 +type cG2AffineC16 [640]fptower.E4 +type pG2AffineC16 [640]G2Affine +type ppG2AffineC16 [640]*G2Affine +type qG2AffineC16 [640]batchOpG2Affine + +type bitSetC2 [2]bool +type bitSetC4 [8]bool +type bitSetC5 [16]bool +type bitSetC6 [32]bool +type bitSetC7 [64]bool +type bitSetC8 [128]bool +type bitSetC9 [256]bool +type bitSetC10 [512]bool +type bitSetC11 [1024]bool +type bitSetC12 [2048]bool +type bitSetC13 [4096]bool +type bitSetC14 [8192]bool +type bitSetC15 [16384]bool +type bitSetC16 [32768]bool + +type bitSet interface { + bitSetC2 | + bitSetC4 | + bitSetC5 | + bitSetC6 | + bitSetC7 | + bitSetC8 | + bitSetC9 | + bitSetC10 | + bitSetC11 | + bitSetC12 | + bitSetC13 | + bitSetC14 | + bitSetC15 | + bitSetC16 +} diff --git a/ecc/bls24-315/multiexp_jacobian.go b/ecc/bls24-315/multiexp_jacobian.go new file mode 100644 index 000000000..be0bb121b --- /dev/null +++ b/ecc/bls24-315/multiexp_jacobian.go @@ -0,0 +1,171 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package bls24315 + +func processChunkG1Jacobian[B ibg1JacExtended](chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16) { + + var buckets B + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to substract + if digit&1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit >> 1)].subMixed(&points[i]) + } + } + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total g1JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].ZZ.IsZero() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketg1JacExtendedC2 [2]g1JacExtended +type bucketg1JacExtendedC4 [8]g1JacExtended +type bucketg1JacExtendedC5 [16]g1JacExtended +type bucketg1JacExtendedC6 [32]g1JacExtended +type bucketg1JacExtendedC7 [64]g1JacExtended +type bucketg1JacExtendedC8 [128]g1JacExtended +type bucketg1JacExtendedC9 [256]g1JacExtended +type bucketg1JacExtendedC10 [512]g1JacExtended +type bucketg1JacExtendedC11 [1024]g1JacExtended +type bucketg1JacExtendedC12 [2048]g1JacExtended +type bucketg1JacExtendedC13 [4096]g1JacExtended +type bucketg1JacExtendedC14 [8192]g1JacExtended +type bucketg1JacExtendedC15 [16384]g1JacExtended +type bucketg1JacExtendedC16 [32768]g1JacExtended + +type ibg1JacExtended interface { + bucketg1JacExtendedC2 | + bucketg1JacExtendedC4 | + bucketg1JacExtendedC5 | + bucketg1JacExtendedC6 | + bucketg1JacExtendedC7 | + bucketg1JacExtendedC8 | + bucketg1JacExtendedC9 | + bucketg1JacExtendedC10 | + bucketg1JacExtendedC11 | + bucketg1JacExtendedC12 | + bucketg1JacExtendedC13 | + bucketg1JacExtendedC14 | + bucketg1JacExtendedC15 | + bucketg1JacExtendedC16 +} + +func processChunkG2Jacobian[B ibg2JacExtended](chunk uint64, + chRes chan<- g2JacExtended, + c uint64, + points []G2Affine, + digits []uint16) { + + var buckets B + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to substract + if digit&1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit >> 1)].subMixed(&points[i]) + } + } + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total g2JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].ZZ.IsZero() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketg2JacExtendedC2 [2]g2JacExtended +type bucketg2JacExtendedC4 [8]g2JacExtended +type bucketg2JacExtendedC5 [16]g2JacExtended +type bucketg2JacExtendedC6 [32]g2JacExtended +type bucketg2JacExtendedC7 [64]g2JacExtended +type bucketg2JacExtendedC8 [128]g2JacExtended +type bucketg2JacExtendedC9 [256]g2JacExtended +type bucketg2JacExtendedC10 [512]g2JacExtended +type bucketg2JacExtendedC11 [1024]g2JacExtended +type bucketg2JacExtendedC12 [2048]g2JacExtended +type bucketg2JacExtendedC13 [4096]g2JacExtended +type bucketg2JacExtendedC14 [8192]g2JacExtended +type bucketg2JacExtendedC15 [16384]g2JacExtended +type bucketg2JacExtendedC16 [32768]g2JacExtended + +type ibg2JacExtended interface { + bucketg2JacExtendedC2 | + bucketg2JacExtendedC4 | + bucketg2JacExtendedC5 | + bucketg2JacExtendedC6 | + bucketg2JacExtendedC7 | + bucketg2JacExtendedC8 | + bucketg2JacExtendedC9 | + bucketg2JacExtendedC10 | + bucketg2JacExtendedC11 | + bucketg2JacExtendedC12 | + bucketg2JacExtendedC13 | + bucketg2JacExtendedC14 | + bucketg2JacExtendedC15 | + bucketg2JacExtendedC16 +} diff --git a/ecc/bls24-315/multiexp_test.go b/ecc/bls24-315/multiexp_test.go index 713bedc00..945da773a 100644 --- a/ecc/bls24-315/multiexp_test.go +++ b/ecc/bls24-315/multiexp_test.go @@ -20,9 +20,11 @@ import ( "fmt" "math/big" "math/bits" + "math/rand" "runtime" "sync" "testing" + "time" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" @@ -34,9 +36,9 @@ func TestMultiExpG1(t *testing.T) { parameters := gopter.DefaultTestParameters() if testing.Short() { - parameters.MinSuccessfulTests = 2 + parameters.MinSuccessfulTests = 3 } else { - parameters.MinSuccessfulTests = nbFuzzShort + parameters.MinSuccessfulTests = nbFuzzShort * 2 } properties := gopter.NewProperties(parameters) @@ -55,6 +57,13 @@ func TestMultiExpG1(t *testing.T) { g.AddAssign(&g1Gen) } + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + // final scalar to use in double and add method (without mixer factor) // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) var scalar big.Int @@ -78,13 +87,10 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - scalars16, _ := partitionScalars(sampleScalars[:], 16, false, runtime.NumCPU()) - r16.msmC16(samplePoints[:], scalars16, true) - + r16.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) return r16.Equal(&splitted1) && r16.Equal(&splitted2) @@ -93,7 +99,7 @@ func TestMultiExpG1(t *testing.T) { )) // cRange is generated from template and contains the available parameters for the multiexp window size - cRange := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21} + cRange := []uint64{2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} if testing.Short() { // test only "odd" and "even" (ie windows size divide word size vs not) cRange = []uint64{5, 16} @@ -114,21 +120,72 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - results := make([]G1Jac, len(cRange)+1) + results := make([]G1Jac, len(cRange)) for i, c := range cRange { - scalars, _ := partitionScalars(sampleScalars[:], c, false, runtime.NumCPU()) - msmInnerG1Jac(&results[i], int(c), samplePoints[:], scalars, false) - if c == 16 { - // split the first chunk - msmInnerG1Jac(&results[len(results)-1], 16, samplePoints[:], scalars, true) - } + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) } for i := 1; i < len(results); i++ { if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1], cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]G1Affine + + var expected G1Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) + expected.ScalarMultiplication(&g1Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].setInfinity() + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) return false } } @@ -151,8 +208,7 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= 30; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) samplePoints[i-1].FromJacobian(&g) g.AddAssign(&g1Gen) } @@ -175,6 +231,87 @@ func TestMultiExpG1(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestCrossMultiExpG1(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]G1Affine + var g G1Jac + g.Set(&g1Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g1Gen) + } + + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i := 10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + cRange := []uint64{2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + if testing.Short() { + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange = []uint64{5, 16} + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r G1Jac + _innerMsmG1Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got G1Affine + expected.FromJacobian(&r) + + for i := 0; i < len(results); i++ { + got.FromJacobian(&results[i]) + if !expected.Equal(&got) { + t.Fatalf("cross msm failed with c=%d", cRange[i]) + } + } + +} + +// _innerMsmG1Reference always do ext jacobian with c == 16 +func _innerMsmG1Reference(p *G1Jac, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { + // partition the scalars + digits, _ := partitionScalars(scalars, 16, config.NbTasks) + + nbChunks := computeNbChunks(16) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := processChunkG1Jacobian[bucketg1JacExtendedC16] + go processChunk(uint64(j), chChunks[j], 16, points, digits[j*n:(j+1)*n]) + } + + return msmReduceChunkG1Affine(p, int(16), chChunks[:]) +} + func BenchmarkMultiExpG1(b *testing.B) { const ( @@ -183,11 +320,33 @@ func BenchmarkMultiExpG1(b *testing.B) { ) var ( - samplePoints [nbSamples]G1Affine - sampleScalars [nbSamples]fr.Element + samplePoints [nbSamples]G1Affine + sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element ) fillBenchScalars(sampleScalars[:]) + copy(sampleScalarsSmallValues[:], sampleScalars[:]) + copy(sampleScalarsRedundant[:], sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i := 0; i < len(sampleScalarsSmallValues); i++ { + if i%5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i := 0; i < len(sampleScalarsRedundant); i += 100 { + for j := i + 1; j < i+100 && j < len(sampleScalarsRedundant); j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + fillBenchBasesG1(samplePoints[:]) var testPoint G1Affine @@ -201,6 +360,20 @@ func BenchmarkMultiExpG1(b *testing.B) { testPoint.MultiExp(samplePoints[:using], sampleScalars[:using], ecc.MultiExpConfig{}) } }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using], ecc.MultiExpConfig{}) + } + }) } } @@ -279,9 +452,9 @@ func TestMultiExpG2(t *testing.T) { parameters := gopter.DefaultTestParameters() if testing.Short() { - parameters.MinSuccessfulTests = 2 + parameters.MinSuccessfulTests = 3 } else { - parameters.MinSuccessfulTests = nbFuzzShort + parameters.MinSuccessfulTests = nbFuzzShort * 2 } properties := gopter.NewProperties(parameters) @@ -300,6 +473,13 @@ func TestMultiExpG2(t *testing.T) { g.AddAssign(&g2Gen) } + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + // final scalar to use in double and add method (without mixer factor) // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) var scalar big.Int @@ -323,13 +503,10 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - scalars16, _ := partitionScalars(sampleScalars[:], 16, false, runtime.NumCPU()) - r16.msmC16(samplePoints[:], scalars16, true) - + r16.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) return r16.Equal(&splitted1) && r16.Equal(&splitted2) @@ -357,21 +534,72 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - results := make([]G2Jac, len(cRange)+1) + results := make([]G2Jac, len(cRange)) for i, c := range cRange { - scalars, _ := partitionScalars(sampleScalars[:], c, false, runtime.NumCPU()) - msmInnerG2Jac(&results[i], int(c), samplePoints[:], scalars, false) - if c == 16 { - // split the first chunk - msmInnerG2Jac(&results[len(results)-1], 16, samplePoints[:], scalars, true) - } + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) } for i := 1; i < len(results); i++ { if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1], cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G2] Multi exponentation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]G2Affine + + var expected G2Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) + expected.ScalarMultiplication(&g2Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].setInfinity() + } + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G2] Multi exponentation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) return false } } @@ -394,8 +622,7 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= 30; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) samplePoints[i-1].FromJacobian(&g) g.AddAssign(&g2Gen) } @@ -418,6 +645,85 @@ func TestMultiExpG2(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestCrossMultiExpG2(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]G2Affine + var g G2Jac + g.Set(&g2Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g2Gen) + } + + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i := 10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + // for g2, CI suffers with large c size since it needs to allocate a lot of memory for the buckets. + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange := []uint64{5, 16} + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r G2Jac + _innerMsmG2Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got G2Affine + expected.FromJacobian(&r) + + for i := 0; i < len(results); i++ { + got.FromJacobian(&results[i]) + if !expected.Equal(&got) { + t.Fatalf("cross msm failed with c=%d", cRange[i]) + } + } + +} + +// _innerMsmG2Reference always do ext jacobian with c == 16 +func _innerMsmG2Reference(p *G2Jac, points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G2Jac { + // partition the scalars + digits, _ := partitionScalars(scalars, 16, config.NbTasks) + + nbChunks := computeNbChunks(16) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g2JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g2JacExtended, 1) + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := processChunkG2Jacobian[bucketg2JacExtendedC16] + go processChunk(uint64(j), chChunks[j], 16, points, digits[j*n:(j+1)*n]) + } + + return msmReduceChunkG2Affine(p, int(16), chChunks[:]) +} + func BenchmarkMultiExpG2(b *testing.B) { const ( @@ -426,11 +732,33 @@ func BenchmarkMultiExpG2(b *testing.B) { ) var ( - samplePoints [nbSamples]G2Affine - sampleScalars [nbSamples]fr.Element + samplePoints [nbSamples]G2Affine + sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element ) fillBenchScalars(sampleScalars[:]) + copy(sampleScalarsSmallValues[:], sampleScalars[:]) + copy(sampleScalarsRedundant[:], sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i := 0; i < len(sampleScalarsSmallValues); i++ { + if i%5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i := 0; i < len(sampleScalarsRedundant); i += 100 { + for j := i + 1; j < i+100 && j < len(sampleScalarsRedundant); j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + fillBenchBasesG2(samplePoints[:]) var testPoint G2Affine @@ -444,6 +772,20 @@ func BenchmarkMultiExpG2(b *testing.B) { testPoint.MultiExp(samplePoints[:using], sampleScalars[:using], ecc.MultiExpConfig{}) } }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using], ecc.MultiExpConfig{}) + } + }) } } @@ -520,11 +862,7 @@ func fillBenchBasesG2(samplePoints []G2Affine) { func fillBenchScalars(sampleScalars []fr.Element) { // ensure every words of the scalars are filled - var mixer fr.Element - mixer.SetString("7716837800905789770901243404444209691916730933998574719964609384059111546487") - for i := 1; i <= len(sampleScalars); i++ { - sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + for i := 0; i < len(sampleScalars); i++ { + sampleScalars[i].SetRandom() } } diff --git a/ecc/bls24-315/twistededwards/eddsa/doc.go b/ecc/bls24-315/twistededwards/eddsa/doc.go index b48e0daeb..1c5a95464 100644 --- a/ecc/bls24-315/twistededwards/eddsa/doc.go +++ b/ecc/bls24-315/twistededwards/eddsa/doc.go @@ -16,7 +16,7 @@ // Package eddsa provides EdDSA signature scheme on bls24-315's twisted edwards curve. // -// See also +// # See also // // https://en.wikipedia.org/wiki/EdDSA package eddsa diff --git a/ecc/bls24-315/twistededwards/eddsa/eddsa_test.go b/ecc/bls24-315/twistededwards/eddsa/eddsa_test.go index 334b1cc0e..669bce5e6 100644 --- a/ecc/bls24-315/twistededwards/eddsa/eddsa_test.go +++ b/ecc/bls24-315/twistededwards/eddsa/eddsa_test.go @@ -37,8 +37,10 @@ func Example() { privateKey, _ := GenerateKey(crand.Reader) publicKey := privateKey.PublicKey - // note that the message is on 4 bytes - msg := []byte{0xde, 0xad, 0xf0, 0x0d} + // generate a message (the size must be a multiple of the size of Fr) + var _msg fr.Element + _msg.SetRandom() + msg := _msg.Marshal() // sign the message signature, _ := privateKey.Sign(msg, hFunc) diff --git a/ecc/bls24-315/twistededwards/eddsa/marshal.go b/ecc/bls24-315/twistededwards/eddsa/marshal.go index c68129087..485e9b710 100644 --- a/ecc/bls24-315/twistededwards/eddsa/marshal.go +++ b/ecc/bls24-315/twistededwards/eddsa/marshal.go @@ -94,11 +94,11 @@ func (privKey *PrivateKey) SetBytes(buf []byte) (int, error) { // Bytes returns the binary representation of sig // as a byte array of size 3*sizeFr x||y||s where -// * x, y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x, y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) func (sig *Signature) Bytes() []byte { var res [sizeSignature]byte sigRBin := sig.R.Bytes() @@ -109,11 +109,12 @@ func (sig *Signature) Bytes() []byte { // SetBytes sets sig from a buffer in binary. // buf is read interpreted as x||y||s where -// * x,y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x,y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) +// // It returns the number of bytes read from buf. func (sig *Signature) SetBytes(buf []byte) (int, error) { n := 0 diff --git a/ecc/bls24-315/twistededwards/point.go b/ecc/bls24-315/twistededwards/point.go index f717b9016..a460738f3 100644 --- a/ecc/bls24-315/twistededwards/point.go +++ b/ecc/bls24-315/twistededwards/point.go @@ -49,7 +49,7 @@ const ( mUnmask = 0x7f // size in byte of a compressed point (point.Y --> fr.Element) - sizePointCompressed = fr.Limbs * 8 + sizePointCompressed = fr.Bytes ) // Bytes returns the compressed point as a byte array diff --git a/ecc/bls24-317/bls24-317.go b/ecc/bls24-317/bls24-317.go index ee1cb5325..9d2eb9c53 100644 --- a/ecc/bls24-317/bls24-317.go +++ b/ecc/bls24-317/bls24-317.go @@ -1,24 +1,30 @@ // Package bls24317 efficient elliptic curve, pairing and hash to curve implementation for bls24-317. // // bls24-317: A Barreto--Lynn--Scott curve -// embedding degree k=24 -// seed x₀=3640754176 -// 𝔽r: r=30869589236456844204538189757527902584594726589286811523515204428962673459201 (x₀^8-x₀^4+2) -// 𝔽p: p=136393071104295911515099765908274057061945112121419593977210139303905973197232025618026156731051 ((x₀-1)² ⋅ r(x₀)/3+x₀) -// (E/𝔽p): Y²=X³+4 -// (Eₜ/𝔽p⁴): Y² = X³+4v (M-type twist) -// r ∣ #E(Fp) and r ∣ #Eₜ(𝔽p⁴) +// +// embedding degree k=24 +// seed x₀=3640754176 +// 𝔽r: r=30869589236456844204538189757527902584594726589286811523515204428962673459201 (x₀^8-x₀^4+2) +// 𝔽p: p=136393071104295911515099765908274057061945112121419593977210139303905973197232025618026156731051 ((x₀-1)² ⋅ r(x₀)/3+x₀) +// (E/𝔽p): Y²=X³+4 +// (Eₜ/𝔽p⁴): Y² = X³+4v (M-type twist) +// r ∣ #E(Fp) and r ∣ #Eₜ(𝔽p⁴) +// // Extension fields tower: -// 𝔽p²[u] = 𝔽p/u²+1 -// 𝔽p⁴[v] = 𝔽p²/v²-u-1 -// 𝔽p¹²[w] = 𝔽p⁴/w³-v -// 𝔽p²⁴[i] = 𝔽p¹²/i²-w +// +// 𝔽p²[u] = 𝔽p/u²+1 +// 𝔽p⁴[v] = 𝔽p²/v²-u-1 +// 𝔽p¹²[w] = 𝔽p⁴/w³-v +// 𝔽p²⁴[i] = 𝔽p¹²/i²-w +// // optimal Ate loop size: -// x₀ +// +// x₀ +// // Security: estimated 160-bit level following [https://eprint.iacr.org/2019/885.pdf] // (r is 255 bits and p²⁴ is 7599 bits) // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package bls24317 diff --git a/ecc/bls24-317/fp/doc.go b/ecc/bls24-317/fp/doc.go index 7940cd814..59f6de342 100644 --- a/ecc/bls24-317/fp/doc.go +++ b/ecc/bls24-317/fp/doc.go @@ -21,30 +21,33 @@ // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [5]uint64 // -// Usage +// type Element [5]uint64 +// +// # Usage // // Example API signature: -// // Mul z = x * y (mod q) -// func (z *Element) Mul(x, y *Element) *Element +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) // // Modulus q = // -// q[base10] = 136393071104295911515099765908274057061945112121419593977210139303905973197232025618026156731051 -// q[base16] = 0x1058ca226f60892cf28fc5a0b7f9d039169a61e684c73446d6f339e43424bf7e8d512e565dab2aab +// q[base10] = 136393071104295911515099765908274057061945112121419593977210139303905973197232025618026156731051 +// q[base16] = 0x1058ca226f60892cf28fc5a0b7f9d039169a61e684c73446d6f339e43424bf7e8d512e565dab2aab // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package fp diff --git a/ecc/bls24-317/fp/element.go b/ecc/bls24-317/fp/element.go index a71f578b9..8770432dd 100644 --- a/ecc/bls24-317/fp/element.go +++ b/ecc/bls24-317/fp/element.go @@ -20,13 +20,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "github.com/consensys/gnark-crypto/field" "io" "math/big" "math/bits" "reflect" "strconv" "strings" - "sync" ) // Element represents a field element stored on 5 words (uint64) @@ -35,18 +35,18 @@ import ( // // Modulus q = // -// q[base10] = 136393071104295911515099765908274057061945112121419593977210139303905973197232025618026156731051 -// q[base16] = 0x1058ca226f60892cf28fc5a0b7f9d039169a61e684c73446d6f339e43424bf7e8d512e565dab2aab +// q[base10] = 136393071104295911515099765908274057061945112121419593977210139303905973197232025618026156731051 +// q[base16] = 0x1058ca226f60892cf28fc5a0b7f9d039169a61e684c73446d6f339e43424bf7e8d512e565dab2aab // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [5]uint64 const ( - Limbs = 5 // number of 64 bits words needed to represent a Element - Bits = 317 // number of bits needed to represent a Element - Bytes = Limbs * 8 // number of bytes needed to represent a Element + Limbs = 5 // number of 64 bits words needed to represent a Element + Bits = 317 // number of bits needed to represent a Element + Bytes = 40 // number of bytes needed to represent a Element ) // Field modulus q @@ -70,8 +70,8 @@ var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int // -// q[base10] = 136393071104295911515099765908274057061945112121419593977210139303905973197232025618026156731051 -// q[base16] = 0x1058ca226f60892cf28fc5a0b7f9d039169a61e684c73446d6f339e43424bf7e8d512e565dab2aab +// q[base10] = 136393071104295911515099765908274057061945112121419593977210139303905973197232025618026156731051 +// q[base16] = 0x1058ca226f60892cf28fc5a0b7f9d039169a61e684c73446d6f339e43424bf7e8d512e565dab2aab func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } @@ -80,12 +80,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 6176088765535387645 -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("1058ca226f60892cf28fc5a0b7f9d039169a61e684c73446d6f339e43424bf7e8d512e565dab2aab", 16) } @@ -93,8 +87,9 @@ func init() { // NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v Element -// v.SetUint64(...) +// +// var v Element +// v.SetUint64(...) func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) @@ -105,7 +100,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -136,14 +131,15 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported // supported types: -// Element -// *Element -// uint64 -// int -// string (see SetString for valid formats) -// *big.Int -// big.Int -// []byte +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { if i1 == nil { return nil, errors.New("can't set fp.Element with ") @@ -255,15 +251,13 @@ func (z *Element) IsOne() bool { // IsUint64 reports whether z can be represented as an uint64. func (z *Element) IsUint64() bool { zz := *z - zz.FromMont() + zz.fromMont() return zz.FitsOnOneWord() } // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -275,15 +269,12 @@ func (z *Element) FitsOnOneWord() bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[4] > _x[4] { return 1 } else if _z[4] < _x[4] { @@ -319,8 +310,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 5091485590467482966, 0) @@ -415,67 +405,9 @@ func (z *Element) Halve() { } -// Mul z = x * y (mod q) -// -// x and y must be strictly inferior to q -func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be strictly inferior to q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } @@ -490,7 +422,7 @@ func (z *Element) Add(x, y *Element) *Element { z[3], carry = bits.Add64(x[3], y[3], carry) z[4], _ = bits.Add64(x[4], y[4], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -512,7 +444,7 @@ func (z *Element) Double(x *Element) *Element { z[3], carry = bits.Add64(x[3], x[3], carry) z[4], _ = bits.Add64(x[4], x[4], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -570,88 +502,181 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // see Mul for algorithm documentation - var t [5]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd1(v, y[4], c[1]) - t[4], t[3] = madd3(m, q4, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - t[4], t[3] = madd3(m, q4, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - t[4], t[3] = madd3(m, q4, c[0], c[2], c[1]) - } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - t[4], t[3] = madd3(m, q4, c[0], c[2], c[1]) - } - { - // round 4 - v := x[4] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], z[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - z[4], z[3] = madd3(m, q4, c[0], c[2], c[1]) - } - - // if z >= q → z -= q + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [6]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + C, t[4] = madd1(y[0], x[4], C) + + t[5], D = bits.Add64(t[5], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + + t[4], C = bits.Add64(t[5], C, 0) + t[5], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + C, t[4] = madd2(y[1], x[4], t[4], C) + + t[5], D = bits.Add64(t[5], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + + t[4], C = bits.Add64(t[5], C, 0) + t[5], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + C, t[4] = madd2(y[2], x[4], t[4], C) + + t[5], D = bits.Add64(t[5], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + + t[4], C = bits.Add64(t[5], C, 0) + t[5], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + C, t[4] = madd2(y[3], x[4], t[4], C) + + t[5], D = bits.Add64(t[5], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + + t[4], C = bits.Add64(t[5], C, 0) + t[5], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[4], x[0], t[0]) + C, t[1] = madd2(y[4], x[1], t[1], C) + C, t[2] = madd2(y[4], x[2], t[2], C) + C, t[3] = madd2(y[4], x[3], t[3], C) + C, t[4] = madd2(y[4], x[4], t[4], C) + + t[5], D = bits.Add64(t[5], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + + t[4], C = bits.Add64(t[5], C, 0) + t[5], _ = bits.Add64(0, D, C) + + if t[5] != 0 { + // we need to reduce, we have a result on 6 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], b = bits.Sub64(t[3], q3, b) + z[4], _ = bits.Sub64(t[4], q4, b) + return + } + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + z[4] = t[4] + + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -660,7 +685,6 @@ func _mulGeneric(z, x, y *Element) { z[3], b = bits.Sub64(z[3], q3, b) z[4], _ = bits.Sub64(z[4], q4, b) } - } func _fromMontGeneric(z *Element) { @@ -718,7 +742,7 @@ func _fromMontGeneric(z *Element) { z[4] = C } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -731,7 +755,7 @@ func _fromMontGeneric(z *Element) { func _reduceGeneric(z *Element) { - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -799,6 +823,35 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + // Exp z = xᵏ (mod q) func (z *Element) Exp(x Element, k *big.Int) *Element { if k.IsUint64() && k.Uint64() == 0 { @@ -813,8 +866,8 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } @@ -841,23 +894,30 @@ var rSquare = Element{ 1146553493836047074, } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *Element) ToMont() *Element { +func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *Element) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[32:40], z[0]) + binary.BigEndian.PutUint64(b[24:32], z[1]) + binary.BigEndian.PutUint64(b[16:24], z[2]) + binary.BigEndian.PutUint64(b[8:16], z[3]) + binary.BigEndian.PutUint64(b[0:8], z[4]) + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -876,49 +936,49 @@ func (z *Element) Text(base int) string { if base == 10 { var zzNeg Element zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } zz := *z - zz.FromMont() + zz.fromMont() if zz.FitsOnOneWord() { return strconv.FormatUint(zz[0], base) } - vv := bigIntPool.Get().(*big.Int) - r := zz.ToBigInt(vv).Text(base) - bigIntPool.Put(vv) + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) return r } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[32:40], z[0]) - binary.BigEndian.PutUint64(b[24:32], z[1]) - binary.BigEndian.PutUint64(b[16:24], z[2]) - binary.BigEndian.PutUint64(b[8:16], z[3]) - binary.BigEndian.PutUint64(b[0:8], z[4]) - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[32:40], _z[0]) - binary.BigEndian.PutUint64(res[24:32], _z[1]) - binary.BigEndian.PutUint64(res[16:24], _z[2]) - binary.BigEndian.PutUint64(res[8:16], _z[3]) - binary.BigEndian.PutUint64(res[0:8], _z[4]) +// Bits provides access to z by returning its value as a little-endian [5]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [5]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -931,19 +991,44 @@ func (z *Element) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } +// SetBytesCanonical interprets e as the bytes of a big-endian 40-byte integer. +// If e is not a 40-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fp.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -961,17 +1046,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -993,20 +1077,20 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z // // The number prefix determines the actual base: A prefix of -// ''0b'' or ''0B'' selects base 2, ''0'', ''0o'' or ''0O'' selects base 8, -// and ''0x'' or ''0X'' selects base 16. Otherwise, the selected base is 10 +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For base 16, lower and upper case letters are considered the same: // The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. // -// An underscore character ''_'' may appear between a base +// An underscore character ”_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as a panic if there @@ -1015,7 +1099,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -1024,7 +1108,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -1064,7 +1148,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -1073,10 +1157,83 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 40-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[32:40]) + z[1] = binary.BigEndian.Uint64((*b)[24:32]) + z[2] = binary.BigEndian.Uint64((*b)[16:24]) + z[3] = binary.BigEndian.Uint64((*b)[8:16]) + z[4] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[32:40], e[0]) + binary.BigEndian.PutUint64((*b)[24:32], e[1]) + binary.BigEndian.PutUint64((*b)[16:24], e[2]) + binary.BigEndian.PutUint64((*b)[8:16], e[3]) + binary.BigEndian.PutUint64((*b)[0:8], e[4]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + z[4] = binary.LittleEndian.Uint64((*b)[32:40]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) + binary.LittleEndian.PutUint64((*b)[32:40], e[4]) +} + +func (littleEndian) String() string { return "LittleEndian" } + // Legendre returns the Legendre symbol of z (either +1, -1, or 0.) func (z *Element) Legendre() int { var l Element @@ -1264,7 +1421,7 @@ func (z *Element) Inverse(x *Element) *Element { // we would multiply by pSq up to 13times; // on x86, the assembly routine outperforms generic code for mul by word // on arm64, we may loose up to ~5% for 6 limbs - mul(&v, &v, &a) + v.Mul(&v, &a) } u.Set(x) // for correctness check @@ -1280,17 +1437,28 @@ func (z *Element) Inverse(x *Element) *Element { // correctness check v.Mul(&u, z) if !v.IsOne() && !u.IsZero() { - return z.inverseExp(&u) + return z.inverseExp(u) } return z } // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) -func (z *Element) inverseExp(x *Element) *Element { - qMinusTwo := Modulus() - qMinusTwo.Sub(qMinusTwo, big.NewInt(2)) - return z.Exp(*x, qMinusTwo) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z } // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits @@ -1405,7 +1573,7 @@ func (z *Element) montReduceSigned(x *Element, xHi uint64) { z[4], z[3] = madd2(m, q4, t[i+4], C) } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) diff --git a/ecc/bls24-317/fp/element_mul_adx_amd64.s b/ecc/bls24-317/fp/element_mul_adx_amd64.s deleted file mode 100644 index 6c2bb89b4..000000000 --- a/ecc/bls24-317/fp/element_mul_adx_amd64.s +++ /dev/null @@ -1,634 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x8d512e565dab2aab -DATA q<>+8(SB)/8, $0xd6f339e43424bf7e -DATA q<>+16(SB)/8, $0x169a61e684c73446 -DATA q<>+24(SB)/8, $0xf28fc5a0b7f9d039 -DATA q<>+32(SB)/8, $0x1058ca226f60892c -GLOBL q<>(SB), (RODATA+NOPTR), $40 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x55b5e0028b047ffd -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, rb0, rb1, rb2, rb3, rb4) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), NOSPLIT, $0-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - MOVQ x+8(FP), DI - - // x[0] -> R9 - // x[1] -> R10 - // x[2] -> R11 - MOVQ 0(DI), R9 - MOVQ 8(DI), R10 - MOVQ 16(DI), R11 - MOVQ y+16(FP), R12 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // clear the flags - XORQ AX, AX - MOVQ 0(R12), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R9, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R10, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R11, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(DI), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 8(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 16(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 24(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 32(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // reduce element(R14,R13,CX,BX,SI) using temp registers (R8,DI,R12,R9,R10) - REDUCE(R14,R13,CX,BX,SI,R8,DI,R12,R9,R10) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - RET - -TEXT ·fromMont(SB), NOSPLIT, $0-8 - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - - // reduce element(R14,R13,CX,BX,SI) using temp registers (DI,R8,R9,R10,R11) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - RET diff --git a/ecc/bls24-317/fp/element_mul_amd64.s b/ecc/bls24-317/fp/element_mul_amd64.s index 88c1b2e2d..56bfe818a 100644 --- a/ecc/bls24-317/fp/element_mul_amd64.s +++ b/ecc/bls24-317/fp/element_mul_amd64.s @@ -1,4 +1,4 @@ -// +build !amd64_adx +// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-317/fp/element_ops_amd64.go b/ecc/bls24-317/fp/element_ops_amd64.go index a3c830471..83bba45ae 100644 --- a/ecc/bls24-317/fp/element_ops_amd64.go +++ b/ecc/bls24-317/fp/element_ops_amd64.go @@ -1,3 +1,6 @@ +//go:build !purego +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,7 +38,70 @@ func fromMont(res *Element) func reduce(res *Element) // Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) +// +// a = a + b (mod q) +// b = a - b (mod q) +// //go:noescape func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bls24-317/fp/element_ops_amd64.s b/ecc/bls24-317/fp/element_ops_amd64.s index 328d16f8b..cb68645b3 100644 --- a/ecc/bls24-317/fp/element_ops_amd64.s +++ b/ecc/bls24-317/fp/element_ops_amd64.s @@ -1,3 +1,5 @@ +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bls24-317/fp/element_ops_noasm.go b/ecc/bls24-317/fp/element_ops_noasm.go deleted file mode 100644 index 760b7aaac..000000000 --- a/ecc/bls24-317/fp/element_ops_noasm.go +++ /dev/null @@ -1,62 +0,0 @@ -//go:build !amd64 -// +build !amd64 - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fp - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 17338930599381248615, - 10169435867607475877, - 1410856163759197139, - 12105193723137614523, - 691221942076914011, - } - x.Mul(x, &y) -} - -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} diff --git a/ecc/bls24-317/fp/element_ops_purego.go b/ecc/bls24-317/fp/element_ops_purego.go new file mode 100644 index 000000000..aed04e01f --- /dev/null +++ b/ecc/bls24-317/fp/element_ops_purego.go @@ -0,0 +1,582 @@ +//go:build !amd64 || purego +// +build !amd64 purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 17338930599381248615, + 10169435867607475877, + 1410856163759197139, + 12105193723137614523, + 691221942076914011, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + var t0, t1, t2, t3, t4 uint64 + var u0, u1, u2, u3, u4 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + u4, t4 = bits.Mul64(v, y[4]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], _ = bits.Sub64(z[4], q4, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3, t4 uint64 + var u0, u1, u2, u3, u4 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + u4, t4 = bits.Mul64(v, x[4]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], _ = bits.Sub64(z[4], q4, b) + } + return z +} diff --git a/ecc/bls24-317/fp/element_test.go b/ecc/bls24-317/fp/element_test.go index 79452bcd8..e71554c17 100644 --- a/ecc/bls24-317/fp/element_test.go +++ b/ecc/bls24-317/fp/element_test.go @@ -23,7 +23,7 @@ import ( "math/big" "math/bits" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field" mrand "math/rand" "testing" @@ -182,17 +182,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -640,7 +633,7 @@ func TestElementBitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPairElement) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -755,7 +748,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -768,13 +761,13 @@ func TestElementAdd(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -803,17 +796,17 @@ func TestElementAdd(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -864,7 +857,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -877,13 +870,13 @@ func TestElementSub(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -912,17 +905,17 @@ func TestElementSub(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -973,7 +966,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -986,7 +979,7 @@ func TestElementMul(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -1000,7 +993,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1040,11 +1033,11 @@ func TestElementMul(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -1057,7 +1050,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1109,7 +1102,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1122,14 +1115,14 @@ func TestElementDiv(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1158,18 +1151,18 @@ func TestElementDiv(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1220,7 +1213,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1233,13 +1226,13 @@ func TestElementExp(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1268,17 +1261,17 @@ func TestElementExp(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1323,7 +1316,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1344,14 +1337,14 @@ func TestElementSquare(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1395,7 +1388,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1416,14 +1409,14 @@ func TestElementInverse(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1467,7 +1460,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1488,14 +1481,14 @@ func TestElementSqrt(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1539,7 +1532,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1560,14 +1553,14 @@ func TestElementDouble(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1611,7 +1604,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1632,14 +1625,14 @@ func TestElementNeg(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -2027,7 +2020,7 @@ func TestElementNegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -2160,17 +2153,17 @@ func TestElementFromMont(t *testing.T) { func(a testPairElement) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPairElement) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -2251,7 +2244,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } @@ -2316,7 +2309,7 @@ func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { return field.BigIntMatchUint64Slice(&aIntMod, slice) } -//TODO: Phase out in favor of property based testing +// TODO: Phase out in favor of property based testing func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { if err := z.matchVeryBigInt(aHi, aInt); err != nil { @@ -2359,7 +2352,7 @@ func TestElementInversionCorrectionFactorFormula(t *testing.T) { inversionCorrectionFactorWord3, inversionCorrectionFactorWord4, } - inversionCorrectionFactor.ToBigInt(&refFactorInt) + inversionCorrectionFactor.toBigInt(&refFactorInt) if refFactorInt.Cmp(factorInt) != 0 { t.Error("mismatch") @@ -2401,7 +2394,7 @@ func TestElementInversionCorrectionFactor(t *testing.T) { if !oneInv.Equal(&one) { var i big.Int - oneInv.ToBigIntRegular(&i) // no montgomery + oneInv.BigInt(&i) // no montgomery i.ModInverse(&i, Modulus()) var fac Element fac.setBigInt(&i) // back to montgomery @@ -2548,7 +2541,7 @@ func TestElement0Inverse(t *testing.T) { } } -//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen func TestUpdateFactorSubtraction(t *testing.T) { for i := 0; i < 1000; i++ { @@ -2761,11 +2754,11 @@ func randomizeUpdateFactors() (int64, int64) { func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { var p1 big.Int - x.ToBigInt(&p1) + x.toBigInt(&p1) p1.Mul(&p1, big.NewInt(xC)) var p2 big.Int - y.ToBigInt(&p2) + y.toBigInt(&p2) p2.Mul(&p2, big.NewInt(yC)) p1.Add(&p1, &p2) @@ -2802,7 +2795,7 @@ func montReduce(res *big.Int, x *big.Int) { } func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { - z.ToBigInt(i) + z.toBigInt(i) var upperWord big.Int upperWord.SetUint64(xHi) upperWord.Lsh(&upperWord, Limbs*64) @@ -2820,7 +2813,7 @@ func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { var xInt big.Int - x.ToBigInt(&xInt) + x.toBigInt(&xInt) xInt.Mul(&xInt, big.NewInt(c)) @@ -2831,7 +2824,7 @@ func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, result func approximateRef(x *Element) uint64 { var asInt big.Int - x.ToBigInt(&asInt) + x.toBigInt(&asInt) n := x.BitLen() if n <= 64 { diff --git a/ecc/bls24-317/fr/doc.go b/ecc/bls24-317/fr/doc.go index a6431d9f1..860ebf253 100644 --- a/ecc/bls24-317/fr/doc.go +++ b/ecc/bls24-317/fr/doc.go @@ -21,30 +21,33 @@ // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [4]uint64 // -// Usage +// type Element [4]uint64 +// +// # Usage // // Example API signature: -// // Mul z = x * y (mod q) -// func (z *Element) Mul(x, y *Element) *Element +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) // // Modulus q = // -// q[base10] = 30869589236456844204538189757527902584594726589286811523515204428962673459201 -// q[base16] = 0x443f917ea68dafc2d0b097f28d83cd491cd1e79196bf0e7af000000000000001 +// q[base10] = 30869589236456844204538189757527902584594726589286811523515204428962673459201 +// q[base16] = 0x443f917ea68dafc2d0b097f28d83cd491cd1e79196bf0e7af000000000000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package fr diff --git a/ecc/bls24-317/fr/element.go b/ecc/bls24-317/fr/element.go index 3200e0e54..0fe1f2287 100644 --- a/ecc/bls24-317/fr/element.go +++ b/ecc/bls24-317/fr/element.go @@ -20,13 +20,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "github.com/consensys/gnark-crypto/field" "io" "math/big" "math/bits" "reflect" "strconv" "strings" - "sync" ) // Element represents a field element stored on 4 words (uint64) @@ -35,18 +35,18 @@ import ( // // Modulus q = // -// q[base10] = 30869589236456844204538189757527902584594726589286811523515204428962673459201 -// q[base16] = 0x443f917ea68dafc2d0b097f28d83cd491cd1e79196bf0e7af000000000000001 +// q[base10] = 30869589236456844204538189757527902584594726589286811523515204428962673459201 +// q[base16] = 0x443f917ea68dafc2d0b097f28d83cd491cd1e79196bf0e7af000000000000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [4]uint64 const ( - Limbs = 4 // number of 64 bits words needed to represent a Element - Bits = 255 // number of bits needed to represent a Element - Bytes = Limbs * 8 // number of bytes needed to represent a Element + Limbs = 4 // number of 64 bits words needed to represent a Element + Bits = 255 // number of bits needed to represent a Element + Bytes = 32 // number of bytes needed to represent a Element ) // Field modulus q @@ -68,8 +68,8 @@ var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int // -// q[base10] = 30869589236456844204538189757527902584594726589286811523515204428962673459201 -// q[base16] = 0x443f917ea68dafc2d0b097f28d83cd491cd1e79196bf0e7af000000000000001 +// q[base10] = 30869589236456844204538189757527902584594726589286811523515204428962673459201 +// q[base16] = 0x443f917ea68dafc2d0b097f28d83cd491cd1e79196bf0e7af000000000000001 func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } @@ -78,12 +78,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 17293822569102704639 -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("443f917ea68dafc2d0b097f28d83cd491cd1e79196bf0e7af000000000000001", 16) } @@ -91,8 +85,9 @@ func init() { // NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v Element -// v.SetUint64(...) +// +// var v Element +// v.SetUint64(...) func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) @@ -103,7 +98,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -133,14 +128,15 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported // supported types: -// Element -// *Element -// uint64 -// int -// string (see SetString for valid formats) -// *big.Int -// big.Int -// []byte +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { if i1 == nil { return nil, errors.New("can't set fr.Element with ") @@ -250,15 +246,13 @@ func (z *Element) IsOne() bool { // IsUint64 reports whether z can be represented as an uint64. func (z *Element) IsUint64() bool { zz := *z - zz.FromMont() + zz.fromMont() return zz.FitsOnOneWord() } // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -270,15 +264,12 @@ func (z *Element) FitsOnOneWord() bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[3] > _x[3] { return 1 } else if _z[3] < _x[3] { @@ -309,8 +300,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 8646911284551352321, 0) @@ -401,67 +391,9 @@ func (z *Element) Halve() { } -// Mul z = x * y (mod q) -// -// x and y must be strictly inferior to q -func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be strictly inferior to q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } @@ -475,7 +407,7 @@ func (z *Element) Add(x, y *Element) *Element { z[2], carry = bits.Add64(x[2], y[2], carry) z[3], _ = bits.Add64(x[3], y[3], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -495,7 +427,7 @@ func (z *Element) Double(x *Element) *Element { z[2], carry = bits.Add64(x[2], x[2], carry) z[3], _ = bits.Add64(x[3], x[3], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -548,65 +480,147 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // see Mul for algorithm documentation - var t [4]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [5]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + + if t[4] != 0 { + // we need to reduce, we have a result on 5 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], _ = bits.Sub64(t[3], q3, b) + return } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - z[3], z[2] = madd3(m, q3, c[0], c[2], c[1]) - } - - // if z >= q → z -= q + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -614,7 +628,6 @@ func _mulGeneric(z, x, y *Element) { z[2], b = bits.Sub64(z[2], q2, b) z[3], _ = bits.Sub64(z[3], q3, b) } - } func _fromMontGeneric(z *Element) { @@ -658,7 +671,7 @@ func _fromMontGeneric(z *Element) { z[3] = C } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -670,7 +683,7 @@ func _fromMontGeneric(z *Element) { func _reduceGeneric(z *Element) { - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -734,6 +747,35 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + // Exp z = xᵏ (mod q) func (z *Element) Exp(x Element, k *big.Int) *Element { if k.IsUint64() && k.Uint64() == 0 { @@ -748,8 +790,8 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } @@ -775,23 +817,29 @@ var rSquare = Element{ 4216292045776253362, } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *Element) ToMont() *Element { +func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *Element) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[24:32], z[0]) + binary.BigEndian.PutUint64(b[16:24], z[1]) + binary.BigEndian.PutUint64(b[8:16], z[2]) + binary.BigEndian.PutUint64(b[0:8], z[3]) + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -810,47 +858,49 @@ func (z *Element) Text(base int) string { if base == 10 { var zzNeg Element zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } zz := *z - zz.FromMont() + zz.fromMont() if zz.FitsOnOneWord() { return strconv.FormatUint(zz[0], base) } - vv := bigIntPool.Get().(*big.Int) - r := zz.ToBigInt(vv).Text(base) - bigIntPool.Put(vv) + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) return r } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[24:32], z[0]) - binary.BigEndian.PutUint64(b[16:24], z[1]) - binary.BigEndian.PutUint64(b[8:16], z[2]) - binary.BigEndian.PutUint64(b[0:8], z[3]) - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[24:32], _z[0]) - binary.BigEndian.PutUint64(res[16:24], _z[1]) - binary.BigEndian.PutUint64(res[8:16], _z[2]) - binary.BigEndian.PutUint64(res[0:8], _z[3]) +// Bits provides access to z by returning its value as a little-endian [4]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [4]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -863,19 +913,44 @@ func (z *Element) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } +// SetBytesCanonical interprets e as the bytes of a big-endian 32-byte integer. +// If e is not a 32-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fr.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -893,17 +968,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -925,20 +999,20 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z // // The number prefix determines the actual base: A prefix of -// ''0b'' or ''0B'' selects base 2, ''0'', ''0o'' or ''0O'' selects base 8, -// and ''0x'' or ''0X'' selects base 16. Otherwise, the selected base is 10 +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For base 16, lower and upper case letters are considered the same: // The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. // -// An underscore character ''_'' may appear between a base +// An underscore character ”_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as a panic if there @@ -947,7 +1021,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -956,7 +1030,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -996,7 +1070,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -1005,10 +1079,79 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 32-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[24:32]) + z[1] = binary.BigEndian.Uint64((*b)[16:24]) + z[2] = binary.BigEndian.Uint64((*b)[8:16]) + z[3] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[24:32], e[0]) + binary.BigEndian.PutUint64((*b)[16:24], e[1]) + binary.BigEndian.PutUint64((*b)[8:16], e[2]) + binary.BigEndian.PutUint64((*b)[0:8], e[3]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) +} + +func (littleEndian) String() string { return "LittleEndian" } + // Legendre returns the Legendre symbol of z (either +1, -1, or 0.) func (z *Element) Legendre() int { var l Element @@ -1041,7 +1184,7 @@ func (z *Element) Sqrt(x *Element) *Element { // y = x^((s+1)/2)) = w * x y.Mul(x, &w) - // b = x^s = w * w * x = y * x + // b = xˢ = w * w * x = y * x b.Mul(&w, &y) // g = nonResidue ^ s @@ -1054,7 +1197,7 @@ func (z *Element) Sqrt(x *Element) *Element { r := uint64(60) // compute legendre symbol - // t = x^((q-1)/2) = r-1 squaring of x^s + // t = x^((q-1)/2) = r-1 squaring of xˢ t = b for i := uint64(0); i < r-1; i++ { t.Square(&t) @@ -1244,7 +1387,7 @@ func (z *Element) Inverse(x *Element) *Element { // we would multiply by pSq up to 13times; // on x86, the assembly routine outperforms generic code for mul by word // on arm64, we may loose up to ~5% for 6 limbs - mul(&v, &v, &a) + v.Mul(&v, &a) } u.Set(x) // for correctness check @@ -1259,17 +1402,28 @@ func (z *Element) Inverse(x *Element) *Element { // correctness check v.Mul(&u, z) if !v.IsOne() && !u.IsZero() { - return z.inverseExp(&u) + return z.inverseExp(u) } return z } // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) -func (z *Element) inverseExp(x *Element) *Element { - qMinusTwo := Modulus() - qMinusTwo.Sub(qMinusTwo, big.NewInt(2)) - return z.Exp(*x, qMinusTwo) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z } // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits @@ -1368,7 +1522,7 @@ func (z *Element) montReduceSigned(x *Element, xHi uint64) { z[3], z[2] = madd2(m, q3, t[i+3], C) } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) diff --git a/ecc/bls24-317/fr/element_mul_adx_amd64.s b/ecc/bls24-317/fr/element_mul_adx_amd64.s deleted file mode 100644 index 0a9fb44d0..000000000 --- a/ecc/bls24-317/fr/element_mul_adx_amd64.s +++ /dev/null @@ -1,465 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xf000000000000001 -DATA q<>+8(SB)/8, $0x1cd1e79196bf0e7a -DATA q<>+16(SB)/8, $0xd0b097f28d83cd49 -DATA q<>+24(SB)/8, $0x443f917ea68dafc2 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xefffffffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), NOSPLIT, $0-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -TEXT ·fromMont(SB), NOSPLIT, $0-8 - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET diff --git a/ecc/bls24-317/fr/element_mul_amd64.s b/ecc/bls24-317/fr/element_mul_amd64.s index 9cc3e52ca..150801a1a 100644 --- a/ecc/bls24-317/fr/element_mul_amd64.s +++ b/ecc/bls24-317/fr/element_mul_amd64.s @@ -1,4 +1,4 @@ -// +build !amd64_adx +// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bls24-317/fr/element_ops_amd64.go b/ecc/bls24-317/fr/element_ops_amd64.go index b37a251b6..e40a9caed 100644 --- a/ecc/bls24-317/fr/element_ops_amd64.go +++ b/ecc/bls24-317/fr/element_ops_amd64.go @@ -1,3 +1,6 @@ +//go:build !purego +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,7 +38,70 @@ func fromMont(res *Element) func reduce(res *Element) // Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) +// +// a = a + b (mod q) +// b = a - b (mod q) +// //go:noescape func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bls24-317/fr/element_ops_amd64.s b/ecc/bls24-317/fr/element_ops_amd64.s index 21463ba99..a62d5598f 100644 --- a/ecc/bls24-317/fr/element_ops_amd64.s +++ b/ecc/bls24-317/fr/element_ops_amd64.s @@ -1,3 +1,5 @@ +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bls24-317/fr/element_ops_noasm.go b/ecc/bls24-317/fr/element_ops_noasm.go deleted file mode 100644 index 1b9b96d8c..000000000 --- a/ecc/bls24-317/fr/element_ops_noasm.go +++ /dev/null @@ -1,61 +0,0 @@ -//go:build !amd64 -// +build !amd64 - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 18446744073709551568, - 10999079689622735090, - 16060824205876888138, - 3752826977836272504, - } - x.Mul(x, &y) -} - -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} diff --git a/ecc/bls24-317/fr/element_ops_purego.go b/ecc/bls24-317/fr/element_ops_purego.go new file mode 100644 index 000000000..7b7f9352d --- /dev/null +++ b/ecc/bls24-317/fr/element_ops_purego.go @@ -0,0 +1,443 @@ +//go:build !amd64 || purego +// +build !amd64 purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 18446744073709551568, + 10999079689622735090, + 16060824205876888138, + 3752826977836272504, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} diff --git a/ecc/bls24-317/fr/element_test.go b/ecc/bls24-317/fr/element_test.go index 08d1a9278..df76095cc 100644 --- a/ecc/bls24-317/fr/element_test.go +++ b/ecc/bls24-317/fr/element_test.go @@ -23,7 +23,7 @@ import ( "math/big" "math/bits" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field" mrand "math/rand" "testing" @@ -182,17 +182,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -638,7 +631,7 @@ func TestElementBitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPairElement) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -753,7 +746,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -766,13 +759,13 @@ func TestElementAdd(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -801,17 +794,17 @@ func TestElementAdd(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -862,7 +855,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -875,13 +868,13 @@ func TestElementSub(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -910,17 +903,17 @@ func TestElementSub(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -971,7 +964,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -984,7 +977,7 @@ func TestElementMul(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -998,7 +991,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1038,11 +1031,11 @@ func TestElementMul(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -1055,7 +1048,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1107,7 +1100,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1120,14 +1113,14 @@ func TestElementDiv(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1156,18 +1149,18 @@ func TestElementDiv(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1218,7 +1211,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1231,13 +1224,13 @@ func TestElementExp(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1266,17 +1259,17 @@ func TestElementExp(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1321,7 +1314,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1342,14 +1335,14 @@ func TestElementSquare(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1393,7 +1386,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1414,14 +1407,14 @@ func TestElementInverse(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1465,7 +1458,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1486,14 +1479,14 @@ func TestElementSqrt(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1537,7 +1530,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1558,14 +1551,14 @@ func TestElementDouble(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1609,7 +1602,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1630,14 +1623,14 @@ func TestElementNeg(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -2025,7 +2018,7 @@ func TestElementNegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -2158,17 +2151,17 @@ func TestElementFromMont(t *testing.T) { func(a testPairElement) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPairElement) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -2247,7 +2240,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } @@ -2309,7 +2302,7 @@ func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { return field.BigIntMatchUint64Slice(&aIntMod, slice) } -//TODO: Phase out in favor of property based testing +// TODO: Phase out in favor of property based testing func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { if err := z.matchVeryBigInt(aHi, aInt); err != nil { @@ -2351,7 +2344,7 @@ func TestElementInversionCorrectionFactorFormula(t *testing.T) { inversionCorrectionFactorWord2, inversionCorrectionFactorWord3, } - inversionCorrectionFactor.ToBigInt(&refFactorInt) + inversionCorrectionFactor.toBigInt(&refFactorInt) if refFactorInt.Cmp(factorInt) != 0 { t.Error("mismatch") @@ -2393,7 +2386,7 @@ func TestElementInversionCorrectionFactor(t *testing.T) { if !oneInv.Equal(&one) { var i big.Int - oneInv.ToBigIntRegular(&i) // no montgomery + oneInv.BigInt(&i) // no montgomery i.ModInverse(&i, Modulus()) var fac Element fac.setBigInt(&i) // back to montgomery @@ -2538,7 +2531,7 @@ func TestElement0Inverse(t *testing.T) { } } -//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen func TestUpdateFactorSubtraction(t *testing.T) { for i := 0; i < 1000; i++ { @@ -2751,11 +2744,11 @@ func randomizeUpdateFactors() (int64, int64) { func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { var p1 big.Int - x.ToBigInt(&p1) + x.toBigInt(&p1) p1.Mul(&p1, big.NewInt(xC)) var p2 big.Int - y.ToBigInt(&p2) + y.toBigInt(&p2) p2.Mul(&p2, big.NewInt(yC)) p1.Add(&p1, &p2) @@ -2792,7 +2785,7 @@ func montReduce(res *big.Int, x *big.Int) { } func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { - z.ToBigInt(i) + z.toBigInt(i) var upperWord big.Int upperWord.SetUint64(xHi) upperWord.Lsh(&upperWord, Limbs*64) @@ -2810,7 +2803,7 @@ func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { var xInt big.Int - x.ToBigInt(&xInt) + x.toBigInt(&xInt) xInt.Mul(&xInt, big.NewInt(c)) @@ -2821,7 +2814,7 @@ func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, result func approximateRef(x *Element) uint64 { var asInt big.Int - x.ToBigInt(&asInt) + x.toBigInt(&asInt) n := x.BitLen() if n <= 64 { diff --git a/ecc/bls24-317/fr/fri/fri.go b/ecc/bls24-317/fr/fri/fri.go index 66c68ea06..0ea7a0d9d 100644 --- a/ecc/bls24-317/fr/fri/fri.go +++ b/ecc/bls24-317/fr/fri/fri.go @@ -121,7 +121,6 @@ type Round struct { // a function is d-close to a low degree polynomial. // // It is composed of a series of Interactions, emulated with Fiat Shamir, -// type ProofOfProximity struct { // ID unique ID attached to the proof of proximity. It's needed for diff --git a/ecc/bls24-317/fr/gkr/gkr.go b/ecc/bls24-317/fr/gkr/gkr.go new file mode 100644 index 000000000..2de626540 --- /dev/null +++ b/ecc/bls24-317/fr/gkr/gkr.go @@ -0,0 +1,774 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/sumcheck" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// Gate must be a low-degree polynomial +type Gate interface { + Evaluate(...fr.Element) fr.Element + Degree() int +} + +type Wire struct { + Gate Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]fr.Element + claimedEvaluations []fr.Element + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]fr.Element) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]fr.Element // x in the paper + claimedEvaluations []fr.Element // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + newEq.Eq(c.evaluationPoints[k]) + + eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// computeValAndStep returns val : i ↦ m(1, i...) and step : i ↦ m(1, i...) - m(0, i...) +func computeValAndStep(m polynomial.MultiLin, p *polynomial.Pool) (val polynomial.MultiLin, step polynomial.MultiLin) { + val = p.Clone(m[len(m)/2:]) + step = p.Clone(m[:len(m)/2]) + + valAsPoly, stepAsPoly := polynomial.Polynomial(val), polynomial.Polynomial(step) + + stepAsPoly.Sub(valAsPoly, stepAsPoly) + return +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = g_{j-1}(r_{j-1}). By convention, g_0 is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() (gJ polynomial.Polynomial) { + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + EVal, EStep := computeValAndStep(c.eq, c.manager.memPool) + + puVal := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO: Make a two-dimensional array struct, and index it i-first rather than inputI first: would result in scanning memory access in the "d" loop and obviate the gateInput variable + puStep := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO, ctd: the greater degGJ, the more this would matter + + for i, puI := range c.inputPreprocessors { + puVal[i], puStep[i] = computeValAndStep(puI, c.manager.memPool) + } + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + gJ = make([]fr.Element, degGJ) + + parallel := len(EVal) >= 1024 //TODO: Experiment with threshold + + var gateInput [][]fr.Element + + if parallel { + gateInput = [][]fr.Element{c.manager.memPool.Make(len(c.inputPreprocessors)), + c.manager.memPool.Make(len(c.inputPreprocessors))} + } else { + gateInput = [][]fr.Element{c.manager.memPool.Make(len(c.inputPreprocessors))} + } + + var wg sync.WaitGroup + + for d := 0; d < degGJ; d++ { + + notLastIteration := d+1 < degGJ + + sumOverI := func(res *fr.Element, gateInput []fr.Element, start, end int) { + for i := start; i < end; i++ { + + for inputI := range puVal { + gateInput[inputI].Set(&puVal[inputI][i]) + if notLastIteration { + puVal[inputI][i].Add(&puVal[inputI][i], &puStep[inputI][i]) + } + } + + // gJAtDI = gJ(d, i...) + gJAtDI := c.wire.Gate.Evaluate(gateInput...) + gJAtDI.Mul(&gJAtDI, &EVal[i]) + + res.Add(res, &gJAtDI) + + if notLastIteration { + EVal[i].Add(&EVal[i], &EStep[i]) + } + } + wg.Done() + } + + if parallel { + var firstHalf, secondHalf fr.Element + wg.Add(2) + go sumOverI(&secondHalf, gateInput[1], len(EVal)/2, len(EVal)) + go sumOverI(&firstHalf, gateInput[0], 0, len(EVal)/2) + wg.Wait() + gJ[d].Add(&firstHalf, &secondHalf) + } else { + wg.Add(1) // formalities + sumOverI(&gJ[d], gateInput[0], 0, len(EVal)) + } + } + + c.manager.memPool.Dump(gateInput...) + c.manager.memPool.Dump(EVal, EStep) + + for inputI := range puVal { + c.manager.memPool.Dump(puVal[inputI], puStep[inputI]) + } + + return +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { + c.eq.Fold(element) + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, pool *polynomial.Pool) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = pool + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< b { + return a + } + return b +} + +func ChallengeNames(sorted []*Wire, logNbInstances int, prefix string) []string { + + // Pre-compute the size TODO: Consider not doing this and just grow the list by appending + size := logNbInstances // first challenge + + for _, w := range sorted { + if w.noProof() { // no proof, no challenge + continue + } + if w.NbClaims() > 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + + claims := newClaimsManager(c, assignment, o.pool) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + + claims := newClaimsManager(c, assignment, o.pool) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]fr.Element) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return fmt.Errorf("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return fmt.Errorf("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +type IdentityGate struct{} + +func (IdentityGate) Evaluate(input ...fr.Element) fr.Element { + return input[0] +} + +func (IdentityGate) Degree() int { + return 1 +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = IdentityGate{} + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + + numEvaluations := 0 + + for _, w := range sortedWires { + if !w.IsInput() { + if numEvaluations == 0 { + numEvaluations = len(a[w.Inputs[0]]) + } + evals := make([]fr.Element, numEvaluations) + ins := make([]fr.Element, len(w.Inputs)) + for k := 0; k < numEvaluations; k++ { + for inI, in := range w.Inputs { + ins[inI] = a[in][k] + } + evals[k] = w.Gate.Evaluate(ins...) + } + a[w] = evals + } + } + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} diff --git a/ecc/bls24-317/fr/gkr/gkr_test.go b/ecc/bls24-317/fr/gkr/gkr_test.go new file mode 100644 index 000000000..4b9aec71e --- /dev/null +++ b/ecc/bls24-317/fr/gkr/gkr_test.go @@ -0,0 +1,722 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/sumcheck" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + pool := polynomial.NewPool(256, 1<<11) + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, &pool) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandom(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/ConsenSys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandom(slice []fr.Element) { + for i := range slice { + slice[i].SetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, testCase.transcriptSetting()) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, testCase.transcriptSetting()) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(&test_vector_utils.MapHash{Map: testCase.Hash}, []byte{1})) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "../../../../internal/generator/gkr/test_vectors" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func BenchmarkGkrMimc(b *testing.B) { + const N = 1 << 19 + fmt.Println("creating circuit structure") + c := mimcCircuit(91) + + in0 := make([]fr.Element, N) + in1 := make([]fr.Element, N) + setRandom(in0) + setRandom(in1) + + fmt.Println("evaluating circuit") + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + + //b.ResetTimer() + fmt.Println("constructing proof") + Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate string `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = gates[c[i].Gate] + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +var gates map[string]Gate + +func init() { + gates = make(map[string]Gate) + gates["identity"] = IdentityGate{} + gates["mul"] = mulGate{} + gates["mimc"] = mimcCipherGate{} //TODO: Add ark + gates["select-input-3"] = _select(2) +} + +type mimcCipherGate struct { + ark fr.Element +} + +func (m mimcCipherGate) Evaluate(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]). + Add(&sum, &m.ark) + + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +func (m mimcCipherGate) Degree() int { + return 7 +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash *test_vector_utils.ElementMap + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash string `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash *test_vector_utils.ElementMap + if _hash, err = test_vector_utils.ElementMapFromFile(filepath.Join(dir, info.Hash)); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func (c *TestCase) transcriptSetting(initialChallenge ...[]byte) fiatshamir.Settings { + return fiatshamir.WithHash(&test_vector_utils.MapHash{Map: c.Hash}, initialChallenge...) +} + +type mulGate struct{} + +func (g mulGate) Evaluate(element ...fr.Element) (result fr.Element) { + result.Mul(&element[0], &element[1]) + return +} + +func (g mulGate) Degree() int { + return 2 +} + +type _select int + +func (g _select) Evaluate(in ...fr.Element) fr.Element { + return in[g] +} + +func (g _select) Degree() int { + return 1 +} diff --git a/ecc/bls24-317/fr/kzg/kzg.go b/ecc/bls24-317/fr/kzg/kzg.go index 38e7015e8..5afa077e4 100644 --- a/ecc/bls24-317/fr/kzg/kzg.go +++ b/ecc/bls24-317/fr/kzg/kzg.go @@ -84,9 +84,6 @@ func NewSRS(size uint64, bAlpha *big.Int) (*SRS, error) { for i := 1; i < len(alphas); i++ { alphas[i].Mul(&alphas[i-1], &alpha) } - for i := 0; i < len(alphas); i++ { - alphas[i].FromMont() - } g1s := bls24317.BatchScalarMultiplicationG1(&gen1Aff, alphas) copy(srs.G1[1:], g1s) @@ -125,7 +122,7 @@ func Commit(p []fr.Element, srs *SRS, nbTasks ...int) (Digest, error) { var res bls24317.G1Affine - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} if len(nbTasks) > 0 { config.NbTasks = nbTasks[0] } @@ -395,7 +392,7 @@ func BatchVerifyMultiPoints(digests []Digest, proofs []OpeningProof, points []fr for i := 0; i < len(randomNumbers); i++ { quotients[i].Set(&proofs[i].H) } - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} _, err := foldedQuotients.MultiExp(quotients, randomNumbers, config) if err != nil { return nil @@ -478,7 +475,7 @@ func fold(di []Digest, fai []fr.Element, ci []fr.Element) (Digest, fr.Element, e // fold the digests ∑ᵢ[cᵢ]([fᵢ(α)]G₁) var foldedDigests Digest - _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{ScalarsMont: true}) + _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{}) if err != nil { return foldedDigests, foldedEvaluations, err } diff --git a/ecc/bls24-317/fr/mimc/decompose.go b/ecc/bls24-317/fr/mimc/decompose.go new file mode 100644 index 000000000..b027d8bbc --- /dev/null +++ b/ecc/bls24-317/fr/mimc/decompose.go @@ -0,0 +1,46 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" +) + +// Decompose interpret rawBytes as a bigInt x in big endian, +// and returns the digits of x (from LSB to MSB) when x is written +// in basis modulo. +func Decompose(rawBytes []byte) []fr.Element { + + rawBigInt := big.NewInt(0).SetBytes(rawBytes) + modulo := fr.Modulus() + + // maximum number of chunks that a function + maxNbChunks := len(rawBytes) / fr.Bytes + + res := make([]fr.Element, 0, maxNbChunks) + var tmp fr.Element + t := new(big.Int) + for rawBigInt.Sign() != 0 { + rawBigInt.DivMod(rawBigInt, modulo, t) + tmp.SetBigInt(t) + res = append(res, tmp) + } + + return res +} diff --git a/ecc/bls24-317/fr/mimc/decompose_test.go b/ecc/bls24-317/fr/mimc/decompose_test.go new file mode 100644 index 000000000..26a476920 --- /dev/null +++ b/ecc/bls24-317/fr/mimc/decompose_test.go @@ -0,0 +1,53 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" +) + +func TestDecompose(t *testing.T) { + + // create 10 random digits in basis r + nbDigits := 10 + a := make([]fr.Element, nbDigits) + for i := 0; i < nbDigits; i++ { + a[i].SetRandom() + } + + // create a big int whose digits in basis r are a + m := fr.Modulus() + var b, tmp big.Int + for i := nbDigits - 1; i >= 0; i-- { + b.Mul(&b, m) + a[i].ToBigIntRegular(&tmp) + b.Add(&b, &tmp) + } + + // query the decomposition and compare to a + bb := b.Bytes() + d := Decompose(bb) + for i := 0; i < nbDigits; i++ { + if !d[i].Equal(&a[i]) { + t.Fatal("error decomposition") + } + } + +} diff --git a/ecc/bls24-317/fr/mimc/mimc.go b/ecc/bls24-317/fr/mimc/mimc.go index 45a926276..c950b0575 100644 --- a/ecc/bls24-317/fr/mimc/mimc.go +++ b/ecc/bls24-317/fr/mimc/mimc.go @@ -17,6 +17,7 @@ package mimc import ( + "errors" "hash" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" @@ -91,44 +92,44 @@ func (d *digest) BlockSize() int { } // Write (via the embedded io.Writer interface) adds more data to the running hash. -// It never returns an error. +// +// Each []byte block of size BlockSize represents a big endian fr.Element. +// +// If len(p) is not a multiple of BlockSize and any of the []byte in p represent an integer +// larger than fr.Modulus, this function returns an error. +// +// To hash arbitrary data ([]byte not representing canonical field elements) use Decompose +// function in this package. func (d *digest) Write(p []byte) (n int, err error) { n = len(p) + if n%BlockSize != 0 { + return 0, errors.New("invalid input length: must represent a list of field elements, expects a []byte of len m*BlockSize") + } + + // ensure each block represents a field element in canonical reduced form + for i := 0; i < n; i += BlockSize { + if _, err = fr.BigEndian.Element((*[BlockSize]byte)(p[i : i+BlockSize])); err != nil { + return 0, err + } + } + d.data = append(d.data, p...) return } -// Hash hash using Miyaguchi–Preneel: +// Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form func (d *digest) checksum() fr.Element { + // Write guarantees len(data) % BlockSize == 0 - var buffer [BlockSize]byte - var x fr.Element - - // if data size is not multiple of BlockSizes we padd: - // .. || 0xaf8 -> .. || 0x0000...0af8 - if len(d.data)%BlockSize != 0 { - q := len(d.data) / BlockSize - r := len(d.data) % BlockSize - sliceq := make([]byte, q*BlockSize) - copy(sliceq, d.data) - slicer := make([]byte, r) - copy(slicer, d.data[q*BlockSize:]) - sliceremainder := make([]byte, BlockSize-r) - d.data = append(sliceq, sliceremainder...) - d.data = append(d.data, slicer...) - } - + // TODO @ThomasPiellard shouldn't Sum() returns an error if there is no data? if len(d.data) == 0 { - d.data = make([]byte, 32) + d.data = make([]byte, BlockSize) } - nbChunks := len(d.data) / BlockSize - - for i := 0; i < nbChunks; i++ { - copy(buffer[:], d.data[i*BlockSize:(i+1)*BlockSize]) - x.SetBytes(buffer[:]) + for i := 0; i < len(d.data); i += BlockSize { + x, _ := fr.BigEndian.Element((*[BlockSize]byte)(d.data[i : i+BlockSize])) r := d.encrypt(x) d.h.Add(&r, &d.h).Add(&d.h, &x) } diff --git a/ecc/bls24-317/fr/pedersen/pedersen.go b/ecc/bls24-317/fr/pedersen/pedersen.go new file mode 100644 index 000000000..9b6777fa8 --- /dev/null +++ b/ecc/bls24-317/fr/pedersen/pedersen.go @@ -0,0 +1,113 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package pedersen + +import ( + "crypto/rand" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bls24-317" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "math/big" +) + +// Key for proof and verification +type Key struct { + g bls24317.G2Affine // TODO @tabaie: does this really have to be randomized? + gRootSigmaNeg bls24317.G2Affine //gRootSigmaNeg = g^{-1/σ} + basis []bls24317.G1Affine + basisExpSigma []bls24317.G1Affine +} + +func randomOnG2() (bls24317.G2Affine, error) { // TODO: Add to G2.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return bls24317.G2Affine{}, err + } + return bls24317.HashToG2(gBytes, []byte("random on g2")) +} + +func Setup(basis []bls24317.G1Affine) (Key, error) { + var ( + k Key + err error + ) + + if k.g, err = randomOnG2(); err != nil { + return k, err + } + + var modMinusOne big.Int + modMinusOne.Sub(fr.Modulus(), big.NewInt(1)) + var sigma *big.Int + if sigma, err = rand.Int(rand.Reader, &modMinusOne); err != nil { + return k, err + } + sigma.Add(sigma, big.NewInt(1)) + + var sigmaInvNeg big.Int + sigmaInvNeg.ModInverse(sigma, fr.Modulus()) + sigmaInvNeg.Sub(fr.Modulus(), &sigmaInvNeg) + k.gRootSigmaNeg.ScalarMultiplication(&k.g, &sigmaInvNeg) + + k.basisExpSigma = make([]bls24317.G1Affine, len(basis)) + for i := range basis { + k.basisExpSigma[i].ScalarMultiplication(&basis[i], sigma) + } + + k.basis = basis + return k, err +} + +func (k *Key) Commit(values []fr.Element) (commitment bls24317.G1Affine, knowledgeProof bls24317.G1Affine, err error) { + + if len(values) != len(k.basis) { + err = fmt.Errorf("unexpected number of values") + return + } + + // TODO @gbotrel this will spawn more than one task, see + // https://github.com/ConsenSys/gnark-crypto/issues/269 + config := ecc.MultiExpConfig{ + NbTasks: 1, // TODO Experiment + } + + if _, err = commitment.MultiExp(k.basis, values, config); err != nil { + return + } + + _, err = knowledgeProof.MultiExp(k.basisExpSigma, values, config) + + return +} + +// VerifyKnowledgeProof checks if the proof of knowledge is valid +func (k *Key) VerifyKnowledgeProof(commitment bls24317.G1Affine, knowledgeProof bls24317.G1Affine) error { + + if !commitment.IsInSubGroup() || !knowledgeProof.IsInSubGroup() { + return fmt.Errorf("subgroup check failed") + } + + product, err := bls24317.Pair([]bls24317.G1Affine{commitment, knowledgeProof}, []bls24317.G2Affine{k.g, k.gRootSigmaNeg}) + if err != nil { + return err + } + if product.IsOne() { + return nil + } + return fmt.Errorf("proof rejected") +} diff --git a/ecc/bls24-317/fr/pedersen/pedersen_test.go b/ecc/bls24-317/fr/pedersen/pedersen_test.go new file mode 100644 index 000000000..fe0f0b48f --- /dev/null +++ b/ecc/bls24-317/fr/pedersen/pedersen_test.go @@ -0,0 +1,90 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package pedersen + +import ( + "github.com/consensys/gnark-crypto/ecc/bls24-317" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/stretchr/testify/assert" + "math/rand" + "testing" +) + +func interfaceSliceToFrSlice(t *testing.T, values ...interface{}) []fr.Element { + res := make([]fr.Element, len(values)) + for i, v := range values { + _, err := res[i].SetInterface(v) + assert.NoError(t, err) + } + return res +} + +func randomFrSlice(t *testing.T, size int) []interface{} { + res := make([]interface{}, size) + var err error + for i := range res { + var v fr.Element + res[i], err = v.SetRandom() + assert.NoError(t, err) + } + return res +} + +func randomOnG1() (bls24317.G1Affine, error) { // TODO: Add to G1.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return bls24317.G1Affine{}, err + } + return bls24317.HashToG1(gBytes, []byte("random on g2")) +} + +func testCommit(t *testing.T, values ...interface{}) { + + basis := make([]bls24317.G1Affine, len(values)) + for i := range basis { + var err error + basis[i], err = randomOnG1() + assert.NoError(t, err) + } + + var ( + key Key + err error + commitment, pok bls24317.G1Affine + ) + + key, err = Setup(basis) + assert.NoError(t, err) + commitment, pok, err = key.Commit(interfaceSliceToFrSlice(t, values...)) + assert.NoError(t, err) + assert.NoError(t, key.VerifyKnowledgeProof(commitment, pok)) + + pok.Neg(&pok) + assert.NotNil(t, key.VerifyKnowledgeProof(commitment, pok)) +} + +func TestCommitToOne(t *testing.T) { + testCommit(t, 1) +} + +func TestCommitSingle(t *testing.T) { + testCommit(t, randomFrSlice(t, 1)...) +} + +func TestCommitFiveElements(t *testing.T) { + testCommit(t, randomFrSlice(t, 5)...) +} diff --git a/ecc/bls24-317/fr/plookup/vector.go b/ecc/bls24-317/fr/plookup/vector.go index b8f112bdd..c698941b2 100644 --- a/ecc/bls24-317/fr/plookup/vector.go +++ b/ecc/bls24-317/fr/plookup/vector.go @@ -125,7 +125,8 @@ func evaluateAccumulationPolynomial(lf, lt, lh1, lh2 []fr.Element, beta, gamma f // evaluateNumBitReversed computes the evaluation (shifted, bit reversed) of h where // h = (x-1)*z*(1+\beta)*(\gamma+f)*(\gamma(1+\beta) + t+ \beta*t(gX)) - -// (x-1)*z(gX)*(\gamma(1+\beta) + h_{1} + \beta*h_{1}(gX))*(\gamma(1+\beta) + h_{2} + \beta*h_{2}(gX) ) +// +// (x-1)*z(gX)*(\gamma(1+\beta) + h_{1} + \beta*h_{1}(gX))*(\gamma(1+\beta) + h_{2} + \beta*h_{2}(gX) ) // // * cz, ch1, ch2, ct, cf are the polynomials z, h1, h2, t, f in canonical basis // * _lz, _lh1, _lh2, _lt, _lf are the polynomials z, h1, h2, t, f in shifted Lagrange basis (domainBig) @@ -370,7 +371,6 @@ func computeQuotientCanonical(alpha fr.Element, lh, lh0, lhn, lh1h2 []fr.Element // before generating a lookup proof), the commitment needs to be done on the // table sorted. Otherwise the commitment in proof.t will not be the same as // the public commitment: it will contain the same values, but permuted. -// func ProveLookupVector(srs *kzg.SRS, f, t Table) (ProofLookupVector, error) { // res diff --git a/ecc/bls24-317/fr/polynomial/multilin.go b/ecc/bls24-317/fr/polynomial/multilin.go index 891137fea..f116a7563 100644 --- a/ecc/bls24-317/fr/polynomial/multilin.go +++ b/ecc/bls24-317/fr/polynomial/multilin.go @@ -18,6 +18,7 @@ package polynomial import ( "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "math/bits" ) // MultiLin tracks the values of a (dense i.e. not sparse) multilinear polynomial @@ -46,44 +47,62 @@ func (m *MultiLin) Fold(r fr.Element) { *m = (*m)[:mid] } +func (m MultiLin) Sum() fr.Element { + s := m[0] + for i := 1; i < len(m); i++ { + s.Add(&s, &m[i]) + } + return s +} + +func _clone(m MultiLin, p *Pool) MultiLin { + if p == nil { + return m.Clone() + } else { + return p.Clone(m) + } +} + +func _dump(m MultiLin, p *Pool) { + if p != nil { + p.Dump(m) + } +} + // Evaluate extrapolate the value of the multilinear polynomial corresponding to m // on the given coordinates -func (m MultiLin) Evaluate(coordinates []fr.Element) fr.Element { +func (m MultiLin) Evaluate(coordinates []fr.Element, p *Pool) fr.Element { // Folding is a mutating operation - bkCopy := m.Clone() + bkCopy := _clone(m, p) // Evaluate step by step through repeated folding (i.e. evaluation at the first remaining variable) for _, r := range coordinates { bkCopy.Fold(r) } - return bkCopy[0] + result := bkCopy[0] + + _dump(bkCopy, p) + return result } -// Clone creates a deep copy of a book-keeping table. +// Clone creates a deep copy of a bookkeeping table. // Both multilinear interpolation and sumcheck require folding an underlying // array, but folding changes the array. To do both one requires a deep copy -// of the book-keeping table. +// of the bookkeeping table. func (m MultiLin) Clone() MultiLin { - tableDeepCopy := Make(len(m)) - copy(tableDeepCopy, m) - return tableDeepCopy + res := make(MultiLin, len(m)) + copy(res, m) + return res } // Add two bookKeepingTables func (m *MultiLin) Add(left, right MultiLin) { size := len(left) // Check that left and right have the same size - if len(right) != size { - panic("Left and right do not have the right size") + if len(right) != size || len(*m) != size { + panic("left, right and destination must have the right size") } - // Reallocate the table if necessary - if cap(*m) < size { - *m = make([]fr.Element, size) - } - - // Resize the destination table - *m = (*m)[:size] // Add elementwise for i := 0; i < size; i++ { @@ -93,15 +112,17 @@ func (m *MultiLin) Add(left, right MultiLin) { // EvalEq computes Eq(q₁, ... , qₙ, h₁, ... , hₙ) = Π₁ⁿ Eq(qᵢ, hᵢ) // where Eq(x,y) = xy + (1-x)(1-y) = 1 - x - y + xy + xy interpolates -// _________________ -// | | | -// | 0 | 1 | -// |_______|_______| -// y | | | -// | 1 | 0 | -// |_______|_______| // -// x +// _________________ +// | | | +// | 0 | 1 | +// |_______|_______| +// y | | | +// | 1 | 0 | +// |_______|_______| +// +// x +// // In other words the polynomial evaluated here is the multilinear extrapolation of // one that evaluates to q' == h' for vectors q', h' of binary values func EvalEq(q, h []fr.Element) fr.Element { @@ -128,10 +149,7 @@ func (m *MultiLin) Eq(q []fr.Element) { n := len(q) if len(*m) != 1< 0 { - i.Sub(fr.Modulus(), &i) - i.Neg(&i) +func (p Polynomial) SetZero() { + for i := 0; i < len(p); i++ { + p[i].SetZero() } - return i } func (p Polynomial) Text(base int) string { @@ -171,12 +175,13 @@ func (p Polynomial) Text(base int) string { continue } - i := signedBigInt(&p[d]) + pD := p[d] + pDText := pD.Text(base) initialLen := builder.Len() - if i.Sign() < 1 { - i.Neg(&i) + if pDText[0] == '-' { + pDText = pDText[1:] if first { builder.WriteString("-") } else { @@ -188,13 +193,8 @@ func (p Polynomial) Text(base int) string { first = false - asInt64 := int64(0) - if i.IsInt64() { - asInt64 = i.Int64() - } - - if asInt64 != 1 || d == 0 { - builder.WriteString(i.Text(base)) + if !pD.IsOne() || d == 0 { + builder.WriteString(pDText) } if builder.Len()-initialLen > 10 { diff --git a/ecc/bls24-317/fr/polynomial/polynomial_test.go b/ecc/bls24-317/fr/polynomial/polynomial_test.go index 8998b04aa..25d87b841 100644 --- a/ecc/bls24-317/fr/polynomial/polynomial_test.go +++ b/ecc/bls24-317/fr/polynomial/polynomial_test.go @@ -17,10 +17,10 @@ package polynomial import ( + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/stretchr/testify/assert" "math/big" "testing" - - "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" ) func TestPolynomialEval(t *testing.T) { @@ -206,3 +206,13 @@ func TestPolynomialAdd(t *testing.T) { t.Fatal("side effect, _f2 should not have been modified") } } + +func TestPolynomialText(t *testing.T) { + var one, negTwo fr.Element + one.SetOne() + negTwo.SetInt64(-2) + + p := Polynomial{one, negTwo, one} + + assert.Equal(t, "X² - 2X + 1", p.Text(10)) +} diff --git a/ecc/bls24-317/fr/polynomial/pool.go b/ecc/bls24-317/fr/polynomial/pool.go index 8e260ebad..409fbee82 100644 --- a/ecc/bls24-317/fr/polynomial/pool.go +++ b/ecc/bls24-317/fr/polynomial/pool.go @@ -17,114 +17,187 @@ package polynomial import ( + "encoding/json" "fmt" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" "reflect" + "runtime" + "sort" "sync" "unsafe" ) // Memory management for polynomials -// Copied verbatim from gkr repo +// WARNING: This is not thread safe TODO: Make sure that is not a problem +// TODO: There is a lot of "unsafe" memory management here and needs to be vetted thoroughly -// Sets a maximum for the array size we keep in pool -const maxNForLargePool int = 1 << 24 -const maxNForSmallPool int = 256 +type sizedPool struct { + maxN int + pool sync.Pool + stats poolStats +} + +type inUseData struct { + allocatedFor []uintptr + pool *sizedPool +} -// Aliases because it is annoying to use arrays in all the places -type largeArr = [maxNForLargePool]fr.Element -type smallArr = [maxNForSmallPool]fr.Element +type Pool struct { + //lock sync.Mutex + inUse map[*fr.Element]inUseData + subPools []sizedPool +} -var rC = sync.Map{} +func (p *sizedPool) get(n int) *fr.Element { + p.stats.maake(n) + return p.pool.Get().(*fr.Element) +} -var ( - largePool = sync.Pool{ - New: func() interface{} { - var res largeArr - return &res - }, - } - smallPool = sync.Pool{ - New: func() interface{} { - var res smallArr - return &res - }, +func (p *sizedPool) put(ptr *fr.Element) { + p.stats.dump() + p.pool.Put(ptr) +} + +func NewPool(maxN ...int) (pool Pool) { + + sort.Ints(maxN) + pool = Pool{ + inUse: make(map[*fr.Element]inUseData), + subPools: make([]sizedPool, len(maxN)), } -) -// ClearPool Clears the pool completely, shields against memory leaks -// Eg: if we forgot to dump a polynomial at some point, this will ensure the value get dumped eventually -// Returns how many polynomials were cleared that way -func ClearPool() int { - res := 0 - rC.Range(func(k, _ interface{}) bool { - switch ptr := k.(type) { - case *largeArr: - largePool.Put(ptr) - case *smallArr: - smallPool.Put(ptr) - default: - panic(fmt.Sprintf("tried to clear %v", reflect.TypeOf(ptr))) + for i := range pool.subPools { + subPool := &pool.subPools[i] + subPool.maxN = maxN[i] + subPool.pool = sync.Pool{ + New: func() interface{} { + subPool.stats.Allocated++ + return getDataPointer(make([]fr.Element, 0, subPool.maxN)) + }, } - res++ - return true - }) - return res + } + return } -// CountPool Returns the number of elements in the pool without mutating it -func CountPool() int { - res := 0 - rC.Range(func(_, _ interface{}) bool { - res++ - return true - }) - return res +func (p *Pool) findCorrespondingPool(n int) *sizedPool { + poolI := 0 + for poolI < len(p.subPools) && n > p.subPools[poolI].maxN { + poolI++ + } + return &p.subPools[poolI] // out of bounds error here would mean that n is too large } -// Make tries to find a reusable polynomial or allocates a new one -func Make(n int) []fr.Element { - if n > maxNForLargePool { - panic(fmt.Sprintf("been provided with size of %v but the maximum is %v", n, maxNForLargePool)) +func (p *Pool) Make(n int) []fr.Element { + pool := p.findCorrespondingPool(n) + ptr := pool.get(n) + p.addInUse(ptr, pool) + return unsafe.Slice(ptr, n) +} + +// Dump dumps a set of polynomials into the pool +func (p *Pool) Dump(slices ...[]fr.Element) { + for _, slice := range slices { + ptr := getDataPointer(slice) + if metadata, ok := p.inUse[ptr]; ok { + delete(p.inUse, ptr) + metadata.pool.put(ptr) + } else { + panic("attempting to dump a slice not created by the pool") + } } +} + +func (p *Pool) addInUse(ptr *fr.Element, pool *sizedPool) { + pcs := make([]uintptr, 2) + n := runtime.Callers(3, pcs) - if n <= maxNForSmallPool { - ptr := smallPool.Get().(*smallArr) - rC.Store(ptr, struct{}{}) // registers the pointer being used - return (*ptr)[:n] + if prevPcs, ok := p.inUse[ptr]; ok { // TODO: remove if unnecessary for security + panic(fmt.Errorf("re-allocated non-dumped slice, previously allocated at %v", runtime.CallersFrames(prevPcs.allocatedFor))) } + p.inUse[ptr] = inUseData{ + allocatedFor: pcs[:n], + pool: pool, + } +} - ptr := largePool.Get().(*largeArr) - rC.Store(ptr, struct{}{}) // remember we allocated the pointer is being used - return (*ptr)[:n] +func printFrame(frame runtime.Frame) { + fmt.Printf("\t%s line %d, function %s\n", frame.File, frame.Line, frame.Function) } -// Dump dumps a set of polynomials into the pool -// Returns the number of deallocated polys -func Dump(arrs ...[]fr.Element) int { - cnt := 0 - for _, arr := range arrs { - ptr := ptr(arr) - pool := &smallPool - if len(arr) > maxNForSmallPool { - pool = &largePool - } - // If the rC did not register, then - // either the array was allocated somewhere else which can be ignored - // otherwise a double put which MUST be ignored - if _, ok := rC.Load(ptr); ok { - pool.Put(ptr) - // And deregisters the ptr - rC.Delete(ptr) - cnt++ +func (p *Pool) printInUse() { + fmt.Println("slices never dumped allocated at:") + for _, pcs := range p.inUse { + fmt.Println("-------------------------") + + var frame runtime.Frame + frames := runtime.CallersFrames(pcs.allocatedFor) + more := true + for more { + frame, more = frames.Next() + printFrame(frame) } } - return cnt } -func ptr(m []fr.Element) unsafe.Pointer { - if cap(m) != maxNForSmallPool && cap(m) != maxNForLargePool { - panic(fmt.Sprintf("can't cast to large or small array, the put array's is %v it should have capacity %v or %v", cap(m), maxNForLargePool, maxNForSmallPool)) +type poolStats struct { + Used int + Allocated int + ReuseRate float64 + InUse int + GreatestNUsed int + SmallestNUsed int +} + +type poolsStats struct { + SubPools []poolStats + InUse int +} + +func (s *poolStats) maake(n int) { + s.Used++ + s.InUse++ + if n > s.GreatestNUsed { + s.GreatestNUsed = n + } + if s.SmallestNUsed == 0 || s.SmallestNUsed > n { + s.SmallestNUsed = n + } +} + +func (s *poolStats) dump() { + s.InUse-- +} + +func (s *poolStats) finalize() { + s.ReuseRate = float64(s.Used) / float64(s.Allocated) +} + +func getDataPointer(slice []fr.Element) *fr.Element { + header := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) + return (*fr.Element)(unsafe.Pointer(header.Data)) +} + +func (p *Pool) PrintPoolStats() { + InUse := 0 + subStats := make([]poolStats, len(p.subPools)) + for i := range p.subPools { + subPool := &p.subPools[i] + subPool.stats.finalize() + subStats[i] = subPool.stats + InUse += subPool.stats.InUse } - return unsafe.Pointer(&m[0]) + + poolsStats := poolsStats{ + SubPools: subStats, + InUse: InUse, + } + serialized, _ := json.MarshalIndent(poolsStats, "", " ") + fmt.Println(string(serialized)) + p.printInUse() +} + +func (p *Pool) Clone(slice []fr.Element) []fr.Element { + res := p.Make(len(slice)) + copy(res, slice) + return res } diff --git a/ecc/bls24-317/fr/sumcheck/sumcheck.go b/ecc/bls24-317/fr/sumcheck/sumcheck.go new file mode 100644 index 000000000..81ff55743 --- /dev/null +++ b/ecc/bls24-317/fr/sumcheck/sumcheck.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = &transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return fmt.Errorf("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/ecc/bls24-317/fr/sumcheck/sumcheck_test.go b/ecc/bls24-317/fr/sumcheck/sumcheck_test.go new file mode 100644 index 000000000..f1f458056 --- /dev/null +++ b/ecc/bls24-317/fr/sumcheck/sumcheck_test.go @@ -0,0 +1,161 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/ecc/bls24-317/fr/test_vector_utils/test_vector_utils.go b/ecc/bls24-317/fr/test_vector_utils/test_vector_utils.go new file mode 100644 index 000000000..a488ebc5a --- /dev/null +++ b/ecc/bls24-317/fr/test_vector_utils/test_vector_utils.go @@ -0,0 +1,429 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package test_vector_utils + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" + "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/polynomial" + "hash" + + "os" + "path/filepath" + "reflect" + "sort" + "strconv" + "strings" +) + +type ElementTriplet struct { + key1 fr.Element + key2 fr.Element + key2Present bool + value fr.Element + used bool +} + +func (t *ElementTriplet) CmpKey(o *ElementTriplet) int { + if cmp1 := t.key1.Cmp(&o.key1); cmp1 != 0 { + return cmp1 + } + + if t.key2Present { + if o.key2Present { + return t.key2.Cmp(&o.key2) + } + return 1 + } else { + if o.key2Present { + return -1 + } + return 0 + } +} + +var MapCache = make(map[string]*ElementMap) + +func ElementMapFromFile(path string) (*ElementMap, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if h, ok := MapCache[path]; ok { + return h, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var asMap map[string]interface{} + if err = json.Unmarshal(bytes, &asMap); err != nil { + return nil, err + } + + var h ElementMap + if h, err = CreateElementMap(asMap); err == nil { + MapCache[path] = &h + } + + return &h, err + + } else { + return nil, err + } +} + +func CreateElementMap(rawMap map[string]interface{}) (ElementMap, error) { + res := make(ElementMap, 0, len(rawMap)) + + for k, v := range rawMap { + var entry ElementTriplet + if _, err := SetElement(&entry.value, v); err != nil { + return nil, err + } + + key := strings.Split(k, ",") + switch len(key) { + case 1: + entry.key2Present = false + case 2: + entry.key2Present = true + if _, err := SetElement(&entry.key2, key[1]); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("cannot parse %T as one or two field elements", v) + } + if _, err := SetElement(&entry.key1, key[0]); err != nil { + return nil, err + } + + res = append(res, &entry) + } + + res.sort() + return res, nil +} + +type ElementMap []*ElementTriplet + +type MapHash struct { + Map *ElementMap + state fr.Element + stateValid bool +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func (m *MapHash) Write(p []byte) (n int, err error) { + var x fr.Element + for i := 0; i < len(p); i += fr.Bytes { + x.SetBytes(p[i:min(len(p), i+fr.Bytes)]) + if err = m.write(x); err != nil { + return + } + } + n = len(p) + return +} + +func (m *MapHash) Sum(b []byte) []byte { + mP := *m + if _, err := mP.Write(b); err != nil { + panic(err) + } + bytes := mP.state.Bytes() + return bytes[:] +} + +func (m *MapHash) Reset() { + m.stateValid = false +} + +func (m *MapHash) Size() int { + return fr.Bytes +} + +func (m *MapHash) BlockSize() int { + return fr.Bytes +} + +func (m *MapHash) write(x fr.Element) error { + X := &x + Y := &m.state + if !m.stateValid { + Y = nil + } + var err error + if m.state, err = m.Map.FindPair(X, Y); err == nil { + m.stateValid = true + } + return err +} + +func (t *ElementTriplet) writeKey(sb *strings.Builder) { + sb.WriteRune('"') + sb.WriteString(t.key1.String()) + if t.key2Present { + sb.WriteRune(',') + sb.WriteString(t.key2.String()) + } + sb.WriteRune('"') +} +func (m *ElementMap) UnusedEntries() []interface{} { + unused := make([]interface{}, 0) + for _, v := range *m { + if !v.used { + var vInterface interface{} + if v.key2Present { + vInterface = []interface{}{ElementToInterface(&v.key1), ElementToInterface(&v.key2)} + } else { + vInterface = ElementToInterface(&v.key1) + } + unused = append(unused, vInterface) + } + } + return unused +} + +func (m *ElementMap) sort() { + sort.Slice(*m, func(i, j int) bool { + return (*m)[i].CmpKey((*m)[j]) <= 0 + }) +} + +func (m *ElementMap) find(toFind *ElementTriplet) (fr.Element, error) { + i := sort.Search(len(*m), func(i int) bool { return (*m)[i].CmpKey(toFind) >= 0 }) + + if i < len(*m) && (*m)[i].CmpKey(toFind) == 0 { + (*m)[i].used = true + return (*m)[i].value, nil + } + var sb strings.Builder + sb.WriteString("no value available for input ") + toFind.writeKey(&sb) + return fr.Element{}, fmt.Errorf(sb.String()) +} + +func (m *ElementMap) FindPair(x *fr.Element, y *fr.Element) (fr.Element, error) { + + toFind := ElementTriplet{ + key1: *x, + key2Present: y != nil, + } + + if y != nil { + toFind.key2 = *y + } + + return m.find(&toFind) +} + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type MessageCounter struct { + startState uint64 + state uint64 + step uint64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += uint64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + uint64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: uint64(startState), state: uint64(startState), step: uint64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + text := x.Text(10) + if len(text) < 10 && !strings.Contains(text, "/") { + if i, err := strconv.Atoi(text); err != nil { + panic(err.Error()) + } else { + return i + } + } + return text +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/ecc/bls24-317/g1.go b/ecc/bls24-317/g1.go index 24670352e..9a4d35cf7 100644 --- a/ecc/bls24-317/g1.go +++ b/ecc/bls24-317/g1.go @@ -17,13 +17,12 @@ package bls24317 import ( - "math/big" - "runtime" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls24-317/fp" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" + "runtime" ) // G1Affine point in affine coordinates @@ -36,7 +35,7 @@ type G1Jac struct { X, Y, Z fp.Element } -// g1JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +// g1JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) type g1JacExtended struct { X, Y, ZZ, ZZZ fp.Element } @@ -50,6 +49,13 @@ func (p *G1Affine) Set(a *G1Affine) *G1Affine { return p } +// setInfinity sets p to O +func (p *G1Affine) setInfinity() *G1Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + // ScalarMultiplication computes and returns p = a ⋅ s func (p *G1Affine) ScalarMultiplication(a *G1Affine, s *big.Int) *G1Affine { var _p G1Jac @@ -373,6 +379,7 @@ func (p *G1Jac) IsOnCurve() bool { } // IsInSubGroup returns true if p is on the r-torsion, false otherwise. + // Z[r,0]+Z[-lambdaG1Affine, 1] is the kernel // of (u,v)->u+lambdaG1Affinev mod r. Expressing r, lambdaG1Affine as // polynomials in x, a short vector of this Zmodule is @@ -474,8 +481,8 @@ func (p *G1Jac) mulGLV(a *G1Jac, s *big.Int) *G1Jac { // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // we don't target constant-timeness so we check first if we increase the bounds or not maxBit := k1.BitLen() @@ -585,15 +592,15 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { return p } - var A, B, X1ZZ2, X2ZZ1, Y1ZZZ2, Y2ZZZ1 fp.Element + var A, B, U1, U2, S1, S2 fp.Element // p2: q, p1: p - X2ZZ1.Mul(&q.X, &p.ZZ) - X1ZZ2.Mul(&p.X, &q.ZZ) - A.Sub(&X2ZZ1, &X1ZZ2) - Y2ZZZ1.Mul(&q.Y, &p.ZZZ) - Y1ZZZ2.Mul(&p.Y, &q.ZZZ) - B.Sub(&Y2ZZZ1, &Y1ZZZ2) + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) if A.IsZero() { if B.IsZero() { @@ -605,11 +612,7 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { return p } - var U1, U2, S1, S2, P, R, PP, PPP, Q, V fp.Element - U1.Mul(&p.X, &q.ZZ) - U2.Mul(&q.X, &p.ZZ) - S1.Mul(&p.Y, &q.ZZZ) - S2.Mul(&q.Y, &p.ZZZ) + var P, R, PP, PPP, Q, V fp.Element P.Sub(&U2, &U1) R.Sub(&S2, &S1) PP.Square(&P) @@ -634,6 +637,8 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { // double point in Jacobian extended coordinates // http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well func (p *g1JacExtended) double(q *g1JacExtended) *g1JacExtended { var U, V, W, S, XX, M fp.Element @@ -876,95 +881,72 @@ func BatchJacobianToAffineG1(points []G1Jac) []G1Affine { // and return resulting points in affine coordinates // uses a simple windowed-NAF like exponentiation algorithm func BatchScalarMultiplicationG1(base *G1Affine, scalars []fr.Element) []G1Affine { - // approximate cost in group ops is // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) nbPoints := uint64(len(scalars)) min := ^uint64(0) bestC := 0 - for c := 2; c < 18; c++ { - cost := uint64(1 << (c - 1)) - nbChunks := uint64(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - cost += nbPoints * ((fr.Limbs * 64) + nbChunks) + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c - 1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add if cost < min { min = cost bestC = c } } c := uint64(bestC) // window size - nbChunks := int(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c } - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // precompute all powers of base for our window // note here that if performance is critical, we can implement as in the msmX methods // this allocation to be on the stack - baseTable := make([]G1Jac, (1 << (c - 1))) - baseTable[0].Set(&g1Infinity) - baseTable[0].AddMixed(base) + baseTable := make([]G1Jac, (1 << (maxC - 1))) + baseTable[0].FromAffine(base) for i := 1; i < len(baseTable); i++ { baseTable[i] = baseTable[i-1] baseTable[i].AddMixed(base) } - - pScalars, _ := partitionScalars(scalars, c, false, runtime.NumCPU()) - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := 0; chunk < nbChunks; chunk++ { - jc := uint64(uint64(chunk) * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = (64%c) != 0 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } // convert our base exp table into affine to use AddMixed baseTableAff := BatchJacobianToAffineG1(baseTable) toReturn := make([]G1Jac, len(scalars)) + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + // for each digit, take value in the base table, double it c time, voilà. - parallel.Execute(len(pScalars), func(start, end int) { + parallel.Execute(len(scalars), func(start, end int) { var p G1Jac for i := start; i < end; i++ { p.Set(&g1Infinity) for chunk := nbChunks - 1; chunk >= 0; chunk-- { - s := selectors[chunk] if chunk != nbChunks-1 { for j := uint64(0); j < c; j++ { p.DoubleAssign() } } + offset := chunk * len(scalars) + digit := digits[i+offset] - bits := (pScalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (pScalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { + if digit == 0 { continue } // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { + if digit&1 == 0 { // add - p.AddMixed(&baseTableAff[bits-1]) + p.AddMixed(&baseTableAff[(digit>>1)-1]) } else { // sub - t := baseTableAff[bits & ^msbWindow] + t := baseTableAff[digit>>1] t.Neg(&t) p.AddMixed(&t) } @@ -978,3 +960,54 @@ func BatchScalarMultiplicationG1(base *G1Affine, scalars []fr.Element) []G1Affin toReturnAff := BatchJacobianToAffineG1(toReturn) return toReturnAff } + +// batch add affine coordinates +// using batch inversion +// special cases (doubling, infinity) must be filtered out before this call +func batchAddG1Affine[TP pG1Affine, TPP ppG1Affine, TC cG1Affine](R *TPP, P *TP, batchSize int) { + var lambda, lambdain TC + + // add part + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // invert denominator using montgomery batch invert technique + { + var accumulator fp.Element + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var d fp.Element + var rr G1Affine + + // add part + for j := 0; j < batchSize; j++ { + // computa lambda + d.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &d) + + // compute X, Y + rr.X.Square(&lambda[j]) + rr.X.Sub(&rr.X, &(*R)[j].X) + rr.X.Sub(&rr.X, &(*P)[j].X) + d.Sub(&(*R)[j].X, &rr.X) + rr.Y.Mul(&lambda[j], &d) + rr.Y.Sub(&rr.Y, &(*R)[j].Y) + (*R)[j].Set(&rr) + } +} diff --git a/ecc/bls24-317/g1_test.go b/ecc/bls24-317/g1_test.go index 2c08510b1..dc50ffb49 100644 --- a/ecc/bls24-317/g1_test.go +++ b/ecc/bls24-317/g1_test.go @@ -19,6 +19,7 @@ package bls24317 import ( "fmt" "math/big" + "math/rand" "testing" "github.com/consensys/gnark-crypto/ecc/bls24-317/fp" @@ -338,7 +339,7 @@ func TestG1AffineOps(t *testing.T) { r := fr.Modulus() var g G1Jac - g.mulGLV(&g1Gen, r) + g.ScalarMultiplication(&g1Gen, r) var scalar, blindedScalar, rminusone big.Int var op1, op2, op3, gneg G1Jac @@ -458,8 +459,7 @@ func TestG1AffineBatchScalarMultiplication(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } result := BatchScalarMultiplicationG1(&g1GenAff, sampleScalars[:]) @@ -472,7 +472,7 @@ func TestG1AffineBatchScalarMultiplication(t *testing.T) { var expectedJac G1Jac var expected G1Affine var b big.Int - expectedJac.mulGLV(&g1Gen, sampleScalars[i].ToBigInt(&b)) + expectedJac.ScalarMultiplication(&g1Gen, sampleScalars[i].ToBigIntRegular(&b)) expected.FromJacobian(&expectedJac) if !result[i].Equal(&expected) { return false @@ -499,6 +499,33 @@ func BenchmarkG1JacIsInSubGroup(b *testing.B) { } +func BenchmarkBatchAddG1Affine(b *testing.B) { + + var P, R pG1AffineC16 + var RR ppG1AffineC16 + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBasesG1(P[:]) + fillBenchBasesG1(R[:]) + + for i := 0; i < len(ridx); i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAddG1Affine[pG1AffineC16, ppG1AffineC16, cG1AffineC16](&RR, &P, len(P)) + } +} + func BenchmarkG1AffineBatchScalarMultiplication(b *testing.B) { // ensure every words of the scalars are filled var mixer fr.Element @@ -511,8 +538,7 @@ func BenchmarkG1AffineBatchScalarMultiplication(b *testing.B) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } for i := 5; i <= pow; i++ { diff --git a/ecc/bls24-317/g2.go b/ecc/bls24-317/g2.go index bf9e97885..8a0b6c0b6 100644 --- a/ecc/bls24-317/g2.go +++ b/ecc/bls24-317/g2.go @@ -17,13 +17,12 @@ package bls24317 import ( - "math/big" - "runtime" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" "github.com/consensys/gnark-crypto/ecc/bls24-317/internal/fptower" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" + "runtime" ) // G2Affine point in affine coordinates @@ -36,7 +35,7 @@ type G2Jac struct { X, Y, Z fptower.E4 } -// g2JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +// g2JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) type g2JacExtended struct { X, Y, ZZ, ZZZ fptower.E4 } @@ -55,6 +54,13 @@ func (p *G2Affine) Set(a *G2Affine) *G2Affine { return p } +// setInfinity sets p to O +func (p *G2Affine) setInfinity() *G2Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + // ScalarMultiplication computes and returns p = a ⋅ s func (p *G2Affine) ScalarMultiplication(a *G2Affine, s *big.Int) *G2Affine { var _p G2Jac @@ -371,7 +377,8 @@ func (p *G2Jac) IsOnCurve() bool { // IsInSubGroup returns true if p is on the r-torsion, false otherwise. // https://eprint.iacr.org/2021/1130.pdf, sec.4 -// ψ(p) = x₀ P +// and https://eprint.iacr.org/2022/352.pdf, sec. 4.2 +// ψ(p) = [x₀]P func (p *G2Jac) IsInSubGroup() bool { var res, tmp G2Jac tmp.psi(p) @@ -473,8 +480,8 @@ func (p *G2Jac) mulGLV(a *G2Jac, s *big.Int) *G2Jac { // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // we don't target constant-timeness so we check first if we increase the bounds or not maxBit := k1.BitLen() @@ -616,15 +623,15 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { return p } - var A, B, X1ZZ2, X2ZZ1, Y1ZZZ2, Y2ZZZ1 fptower.E4 + var A, B, U1, U2, S1, S2 fptower.E4 // p2: q, p1: p - X2ZZ1.Mul(&q.X, &p.ZZ) - X1ZZ2.Mul(&p.X, &q.ZZ) - A.Sub(&X2ZZ1, &X1ZZ2) - Y2ZZZ1.Mul(&q.Y, &p.ZZZ) - Y1ZZZ2.Mul(&p.Y, &q.ZZZ) - B.Sub(&Y2ZZZ1, &Y1ZZZ2) + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) if A.IsZero() { if B.IsZero() { @@ -636,11 +643,7 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { return p } - var U1, U2, S1, S2, P, R, PP, PPP, Q, V fptower.E4 - U1.Mul(&p.X, &q.ZZ) - U2.Mul(&q.X, &p.ZZ) - S1.Mul(&p.Y, &q.ZZZ) - S2.Mul(&q.Y, &p.ZZZ) + var P, R, PP, PPP, Q, V fptower.E4 P.Sub(&U2, &U1) R.Sub(&S2, &S1) PP.Square(&P) @@ -665,6 +668,8 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { // double point in Jacobian extended coordinates // http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well func (p *g2JacExtended) double(q *g2JacExtended) *g2JacExtended { var U, V, W, S, XX, M fptower.E4 @@ -888,93 +893,70 @@ func (p *g2Proj) FromAffine(Q *G2Affine) *g2Proj { // and return resulting points in affine coordinates // uses a simple windowed-NAF like exponentiation algorithm func BatchScalarMultiplicationG2(base *G2Affine, scalars []fr.Element) []G2Affine { - // approximate cost in group ops is // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) nbPoints := uint64(len(scalars)) min := ^uint64(0) bestC := 0 - for c := 2; c < 18; c++ { - cost := uint64(1 << (c - 1)) - nbChunks := uint64(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - cost += nbPoints * ((fr.Limbs * 64) + nbChunks) + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c - 1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add if cost < min { min = cost bestC = c } } c := uint64(bestC) // window size - nbChunks := int(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c } - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // precompute all powers of base for our window // note here that if performance is critical, we can implement as in the msmX methods // this allocation to be on the stack - baseTable := make([]G2Jac, (1 << (c - 1))) - baseTable[0].Set(&g2Infinity) - baseTable[0].AddMixed(base) + baseTable := make([]G2Jac, (1 << (maxC - 1))) + baseTable[0].FromAffine(base) for i := 1; i < len(baseTable); i++ { baseTable[i] = baseTable[i-1] baseTable[i].AddMixed(base) } - - pScalars, _ := partitionScalars(scalars, c, false, runtime.NumCPU()) - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := 0; chunk < nbChunks; chunk++ { - jc := uint64(uint64(chunk) * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = (64%c) != 0 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } toReturn := make([]G2Affine, len(scalars)) + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + // for each digit, take value in the base table, double it c time, voilà. - parallel.Execute(len(pScalars), func(start, end int) { + parallel.Execute(len(scalars), func(start, end int) { var p G2Jac for i := start; i < end; i++ { p.Set(&g2Infinity) for chunk := nbChunks - 1; chunk >= 0; chunk-- { - s := selectors[chunk] if chunk != nbChunks-1 { for j := uint64(0); j < c; j++ { p.DoubleAssign() } } + offset := chunk * len(scalars) + digit := digits[i+offset] - bits := (pScalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (pScalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { + if digit == 0 { continue } // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { + if digit&1 == 0 { // add - p.AddAssign(&baseTable[bits-1]) + p.AddAssign(&baseTable[(digit>>1)-1]) } else { // sub - t := baseTable[bits & ^msbWindow] + t := baseTable[digit>>1] t.Neg(&t) p.AddAssign(&t) } @@ -987,3 +969,54 @@ func BatchScalarMultiplicationG2(base *G2Affine, scalars []fr.Element) []G2Affin }) return toReturn } + +// batch add affine coordinates +// using batch inversion +// special cases (doubling, infinity) must be filtered out before this call +func batchAddG2Affine[TP pG2Affine, TPP ppG2Affine, TC cG2Affine](R *TPP, P *TP, batchSize int) { + var lambda, lambdain TC + + // add part + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // invert denominator using montgomery batch invert technique + { + var accumulator fptower.E4 + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var d fptower.E4 + var rr G2Affine + + // add part + for j := 0; j < batchSize; j++ { + // computa lambda + d.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &d) + + // compute X, Y + rr.X.Square(&lambda[j]) + rr.X.Sub(&rr.X, &(*R)[j].X) + rr.X.Sub(&rr.X, &(*P)[j].X) + d.Sub(&(*R)[j].X, &rr.X) + rr.Y.Mul(&lambda[j], &d) + rr.Y.Sub(&rr.Y, &(*R)[j].Y) + (*R)[j].Set(&rr) + } +} diff --git a/ecc/bls24-317/g2_test.go b/ecc/bls24-317/g2_test.go index 376b46934..89046a867 100644 --- a/ecc/bls24-317/g2_test.go +++ b/ecc/bls24-317/g2_test.go @@ -19,6 +19,7 @@ package bls24317 import ( "fmt" "math/big" + "math/rand" "testing" "github.com/consensys/gnark-crypto/ecc/bls24-317/internal/fptower" @@ -339,7 +340,7 @@ func TestG2AffineOps(t *testing.T) { r := fr.Modulus() var g G2Jac - g.mulGLV(&g2Gen, r) + g.ScalarMultiplication(&g2Gen, r) var scalar, blindedScalar, rminusone big.Int var op1, op2, op3, gneg G2Jac @@ -464,8 +465,7 @@ func TestG2AffineBatchScalarMultiplication(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } result := BatchScalarMultiplicationG2(&g2GenAff, sampleScalars[:]) @@ -478,7 +478,7 @@ func TestG2AffineBatchScalarMultiplication(t *testing.T) { var expectedJac G2Jac var expected G2Affine var b big.Int - expectedJac.mulGLV(&g2Gen, sampleScalars[i].ToBigInt(&b)) + expectedJac.ScalarMultiplication(&g2Gen, sampleScalars[i].ToBigIntRegular(&b)) expected.FromJacobian(&expectedJac) if !result[i].Equal(&expected) { return false @@ -505,6 +505,33 @@ func BenchmarkG2JacIsInSubGroup(b *testing.B) { } +func BenchmarkBatchAddG2Affine(b *testing.B) { + + var P, R pG2AffineC16 + var RR ppG2AffineC16 + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBasesG2(P[:]) + fillBenchBasesG2(R[:]) + + for i := 0; i < len(ridx); i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAddG2Affine[pG2AffineC16, ppG2AffineC16, cG2AffineC16](&RR, &P, len(P)) + } +} + func BenchmarkG2AffineBatchScalarMultiplication(b *testing.B) { // ensure every words of the scalars are filled var mixer fr.Element @@ -517,8 +544,7 @@ func BenchmarkG2AffineBatchScalarMultiplication(b *testing.B) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } for i := 5; i <= pow; i++ { diff --git a/ecc/bls24-317/hash_to_g1.go b/ecc/bls24-317/hash_to_g1.go index ab839928e..5e5055043 100644 --- a/ecc/bls24-317/hash_to_g1.go +++ b/ecc/bls24-317/hash_to_g1.go @@ -17,7 +17,6 @@ package bls24317 import ( - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls24-317/fp" "math/big" @@ -233,35 +232,14 @@ func g1EvalPolynomial(z *fp.Element, monic bool, coefficients []fp.Element, x *f z.Set(&dst) } -// hashToFp hashes msg to count prime field elements. -// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 -func hashToFp(msg, dst []byte, count int) ([]fp.Element, error) { - // 128 bits of security - // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 - const Bytes = 1 + (fp.Bits-1)/8 - const L = 16 + Bytes - - lenInBytes := count * L - pseudoRandomBytes, err := ecc.ExpandMsgXmd(msg, dst, lenInBytes) - if err != nil { - return nil, err - } - - res := make([]fp.Element, count) - for i := 0; i < count; i++ { - res[i].SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) - } - return res, nil -} - // g1Sgn0 is an algebraic substitute for the notion of sign in ordered fields // Namely, every non-zero quadratic residue in a finite field of characteristic =/= 2 has exactly two square roots, one of each sign // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#name-the-sgn0-function // The sign of an element is not obviously related to that of its Montgomery form func g1Sgn0(z *fp.Element) uint64 { - nonMont := *z - nonMont.FromMont() + nonMont := z.Bits() + // m == 1 return nonMont[0] % 2 @@ -279,11 +257,11 @@ func MapToG1(u fp.Element) G1Affine { // EncodeToG1 hashes a message to a point on the G1 curve using the SSWU map. // It is faster than HashToG1, but the result is not uniformly distributed. Unsuitable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func EncodeToG1(msg, dst []byte) (G1Affine, error) { var res G1Affine - u, err := hashToFp(msg, dst, 1) + u, err := fp.Hash(msg, dst, 1) if err != nil { return res, err } @@ -299,9 +277,9 @@ func EncodeToG1(msg, dst []byte) (G1Affine, error) { // HashToG1 hashes a message to a point on the G1 curve using the SSWU map. // Slower than EncodeToG1, but usable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func HashToG1(msg, dst []byte) (G1Affine, error) { - u, err := hashToFp(msg, dst, 2*1) + u, err := fp.Hash(msg, dst, 2*1) if err != nil { return G1Affine{}, err } diff --git a/ecc/bls24-317/hash_to_g1_test.go b/ecc/bls24-317/hash_to_g1_test.go index 177c64c72..4abccacc1 100644 --- a/ecc/bls24-317/hash_to_g1_test.go +++ b/ecc/bls24-317/hash_to_g1_test.go @@ -62,7 +62,7 @@ func TestG1SqrtRatio(t *testing.T) { func TestHashToFpG1(t *testing.T) { for _, c := range encodeToG1Vector.cases { - elems, err := hashToFp([]byte(c.msg), encodeToG1Vector.dst, 1) + elems, err := fp.Hash([]byte(c.msg), encodeToG1Vector.dst, 1) if err != nil { t.Error(err) } @@ -70,7 +70,7 @@ func TestHashToFpG1(t *testing.T) { } for _, c := range hashToG1Vector.cases { - elems, err := hashToFp([]byte(c.msg), hashToG1Vector.dst, 2*1) + elems, err := fp.Hash([]byte(c.msg), hashToG1Vector.dst, 2*1) if err != nil { t.Error(err) } @@ -220,7 +220,7 @@ func BenchmarkHashToG1(b *testing.B) { } } -//TODO: Crude. Do something clever in Jacobian +// TODO: Crude. Do something clever in Jacobian func isOnE1Prime(p G1Affine) bool { var A, B fp.Element @@ -247,7 +247,7 @@ func isOnE1Prime(p G1Affine) bool { return LHS.Equal(&RHS) } -//Only works on simple extensions (two-story towers) +// Only works on simple extensions (two-story towers) func g1CoordSetString(z *fp.Element, s string) { z.SetString(s) } diff --git a/ecc/bls24-317/hash_to_g2.go b/ecc/bls24-317/hash_to_g2.go index ea3fab244..6e2411b78 100644 --- a/ecc/bls24-317/hash_to_g2.go +++ b/ecc/bls24-317/hash_to_g2.go @@ -106,7 +106,7 @@ func MapToG2(t fptower.E4) G2Affine { // https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-2.2.2 func EncodeToG2(msg, dst []byte) (G2Affine, error) { var res G2Affine - _t, err := hashToFp(msg, dst, 2) + _t, err := fp.Hash(msg, dst, 2) if err != nil { return res, err } @@ -121,7 +121,7 @@ func EncodeToG2(msg, dst []byte) (G2Affine, error) { // https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-3 func HashToG2(msg, dst []byte) (G2Affine, error) { var res G2Affine - u, err := hashToFp(msg, dst, 4) + u, err := fp.Hash(msg, dst, 4) if err != nil { return res, err } diff --git a/ecc/bls24-317/internal/fptower/e12.go b/ecc/bls24-317/internal/fptower/e12.go index fc4785764..f1f296d74 100644 --- a/ecc/bls24-317/internal/fptower/e12.go +++ b/ecc/bls24-317/internal/fptower/e12.go @@ -75,20 +75,8 @@ func (z *E12) IsZero() bool { return z.C0.IsZero() && z.C1.IsZero() && z.C2.IsZero() } -// ToMont converts to Mont form -func (z *E12) ToMont() *E12 { - z.C0.ToMont() - z.C1.ToMont() - z.C2.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E12) FromMont() *E12 { - z.C0.FromMont() - z.C1.FromMont() - z.C2.FromMont() - return z +func (z *E12) IsOne() bool { + return z.C0.IsOne() && z.C1.IsZero() && z.C2.IsZero() } // Add adds two elements of E12 diff --git a/ecc/bls24-317/internal/fptower/e2.go b/ecc/bls24-317/internal/fptower/e2.go index 688d71776..c8091c6f9 100644 --- a/ecc/bls24-317/internal/fptower/e2.go +++ b/ecc/bls24-317/internal/fptower/e2.go @@ -15,8 +15,9 @@ package fptower import ( - "github.com/consensys/gnark-crypto/ecc/bls24-317/fp" "math/big" + + "github.com/consensys/gnark-crypto/ecc/bls24-317/fp" ) // E2 is a degree two finite field extension of fp.Element @@ -31,10 +32,9 @@ func (z *E2) Equal(x *E2) bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *E2) Cmp(x *E2) int { if a1 := z.A1.Cmp(&x.A1); a1 != 0 { return a1 @@ -96,6 +96,10 @@ func (z *E2) IsZero() bool { return z.A0.IsZero() && z.A1.IsZero() } +func (z *E2) IsOne() bool { + return z.A0.IsOne() && z.A1.IsZero() +} + // Add adds two elements of E2 func (z *E2) Add(x, y *E2) *E2 { addE2(z, x, y) @@ -125,20 +129,6 @@ func (z *E2) String() string { return (z.A0.String() + "+" + z.A1.String() + "*u") } -// ToMont converts to mont form -func (z *E2) ToMont() *E2 { - z.A0.ToMont() - z.A1.ToMont() - return z -} - -// FromMont converts from mont form -func (z *E2) FromMont() *E2 { - z.A0.FromMont() - z.A1.FromMont() - return z -} - // MulByElement multiplies an element in E2 by an element in fp func (z *E2) MulByElement(x *E2, y *fp.Element) *E2 { var yCopy fp.Element diff --git a/ecc/bls24-317/internal/fptower/e24.go b/ecc/bls24-317/internal/fptower/e24.go index 1c8432093..a8d403b6b 100644 --- a/ecc/bls24-317/internal/fptower/e24.go +++ b/ecc/bls24-317/internal/fptower/e24.go @@ -66,20 +66,6 @@ func (z *E24) SetOne() *E24 { return z } -// ToMont converts to Mont form -func (z *E24) ToMont() *E24 { - z.D0.ToMont() - z.D1.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E24) FromMont() *E24 { - z.D0.FromMont() - z.D1.FromMont() - return z -} - // Add set z=x+y in E24 and return z func (z *E24) Add(x, y *E24) *E24 { z.D0.Add(&x.D0, &y.D0) @@ -117,6 +103,10 @@ func (z *E24) IsZero() bool { return z.D0.IsZero() && z.D1.IsZero() } +func (z *E24) IsOne() bool { + return z.D0.IsOne() && z.D1.IsZero() +} + // Mul set z=x*y in E24 and return z func (z *E24) Mul(x, y *E24) *E24 { var a, b, c E12 @@ -224,9 +214,12 @@ func (z *E24) CyclotomicSquareCompressed(x *E24) *E24 { // DecompressKarabina Karabina's cyclotomic square result // if g3 != 0 -// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// +// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// // if g3 == 0 -// g4 = 2g1g5/g2 +// +// g4 = 2g1g5/g2 // // if g3=g2=0 then g4=g5=g1=0 and g0=1 (x=1) // Theorem 3.1 is well-defined for all x in Gϕₙ\{1} @@ -251,7 +244,7 @@ func (z *E24) DecompressKarabina(x *E24) *E24 { t[1].Sub(&t[0], &x.D0.C2). Double(&t[1]). Add(&t[1], &t[0]) - // t0 = E * g5^2 + t1 + // t0 = E * g5^2 + t1 t[2].Square(&x.D1.C2) t[0].MulByNonResidue(&t[2]). Add(&t[0], &t[1]) @@ -287,9 +280,12 @@ func (z *E24) DecompressKarabina(x *E24) *E24 { // BatchDecompressKarabina multiple Karabina's cyclotomic square results // if g3 != 0 -// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// +// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// // if g3 == 0 -// g4 = 2g1g5/g2 +// +// g4 = 2g1g5/g2 // // if g3=g2=0 then g4=g5=g1=0 and g0=1 (x=1) // Theorem 3.1 is well-defined for all x in Gϕₙ\{1} @@ -325,7 +321,7 @@ func BatchDecompressKarabina(x []E24) []E24 { t1[i].Sub(&t0[i], &x[i].D0.C2). Double(&t1[i]). Add(&t1[i], &t0[i]) - // t0 = E * g5^2 + t1 + // t0 = E * g5^2 + t1 t2[i].Square(&x[i].D1.C2) t0[i].MulByNonResidue(&t2[i]). Add(&t0[i], &t1[i]) @@ -600,8 +596,8 @@ func (z *E24) ExpGLV(x E24, k *big.Int) *E24 { table[14].Mul(&table[11], &table[2]) // bounds on the lattice base vectors guarantee that s1, s2 are len(r)/2 bits long max - s1.SetBigInt(&s[0]).FromMont() - s2.SetBigInt(&s[1]).FromMont() + s1 = s1.SetBigInt(&s[0]).Bits() + s2 = s2.SetBigInt(&s[1]).Bits() // loop starts from len(s1)/2 due to the bounds for i := len(s1)/2 + 1; i >= 0; i-- { diff --git a/ecc/bls24-317/internal/fptower/e4.go b/ecc/bls24-317/internal/fptower/e4.go index c6319aa9b..71e514882 100644 --- a/ecc/bls24-317/internal/fptower/e4.go +++ b/ecc/bls24-317/internal/fptower/e4.go @@ -32,10 +32,9 @@ func (z *E4) Equal(x *E4) bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *E4) Cmp(x *E4) int { if a1 := z.B1.Cmp(&x.B1); a1 != 0 { return a1 @@ -86,20 +85,6 @@ func (z *E4) SetOne() *E4 { return z } -// ToMont converts to Mont form -func (z *E4) ToMont() *E4 { - z.B0.ToMont() - z.B1.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E4) FromMont() *E4 { - z.B0.FromMont() - z.B1.FromMont() - return z -} - // MulByElement multiplies an element in E4 by an element in fp func (z *E4) MulByElement(x *E4, y *fp.Element) *E4 { var yCopy fp.Element @@ -153,6 +138,10 @@ func (z *E4) IsZero() bool { return z.B0.IsZero() && z.B1.IsZero() } +func (z *E4) IsOne() bool { + return z.B0.IsOne() && z.B1.IsZero() +} + // MulByNonResidue mul x by (0,1) func (z *E4) MulByNonResidue(x *E4) *E4 { z.B1, z.B0 = x.B0, x.B1 diff --git a/ecc/bls24-317/marshal.go b/ecc/bls24-317/marshal.go index 23e749ab8..ddd960555 100644 --- a/ecc/bls24-317/marshal.go +++ b/ecc/bls24-317/marshal.go @@ -100,7 +100,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fr.Bytes]) + err = t.SetBytesCanonical(buf[:fr.Bytes]) return case *fp.Element: read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) @@ -108,7 +108,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fp.Bytes]) + err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: var sliceLen uint32 @@ -126,7 +126,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fr.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { + return + } } return case *[]fp.Element: @@ -145,7 +147,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fp.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return + } } return case *G1Affine: @@ -221,7 +225,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -276,7 +284,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -643,9 +655,6 @@ func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - msbMask := mCompressedSmallest // compressed, we need to know if Y is lexicographically bigger than -Y // if p.Y ">" -p.Y @@ -654,13 +663,7 @@ func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { } // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[32:40], tmp[0]) - binary.BigEndian.PutUint64(res[24:32], tmp[1]) - binary.BigEndian.PutUint64(res[16:24], tmp[2]) - binary.BigEndian.PutUint64(res[8:16], tmp[3]) - binary.BigEndian.PutUint64(res[0:8], tmp[4]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= msbMask @@ -679,27 +682,12 @@ func (p *G1Affine) RawBytes() (res [SizeOfG1AffineUncompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - // not compressed // we store the Y coordinate - tmp = p.Y - tmp.FromMont() - binary.BigEndian.PutUint64(res[72:80], tmp[0]) - binary.BigEndian.PutUint64(res[64:72], tmp[1]) - binary.BigEndian.PutUint64(res[56:64], tmp[2]) - binary.BigEndian.PutUint64(res[48:56], tmp[3]) - binary.BigEndian.PutUint64(res[40:48], tmp[4]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[40:40+fp.Bytes]), p.Y) // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[32:40], tmp[0]) - binary.BigEndian.PutUint64(res[24:32], tmp[1]) - binary.BigEndian.PutUint64(res[16:24], tmp[2]) - binary.BigEndian.PutUint64(res[8:16], tmp[3]) - binary.BigEndian.PutUint64(res[0:8], tmp[4]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= mUncompressed @@ -750,8 +738,12 @@ func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { // uncompressed point if mData == mUncompressed { // read X and Y coordinates - p.X.SetBytes(buf[:fp.Bytes]) - p.Y.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.Y.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } // subgroup check if subGroupCheck && !p.IsInSubGroup() { @@ -771,7 +763,9 @@ func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return 0, err + } var YSquared, Y fp.Element @@ -845,7 +839,7 @@ func (p *G1Affine) unsafeComputeY(subGroupCheck bool) error { // assumes buf[:8] mask is set to compressed // returns true if point is infinity and need no further processing // it sets X coordinate and uses Y for scratch space to store decompression metadata -func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { +func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { // read the most significant byte mData := buf[0] & mMask @@ -854,7 +848,7 @@ func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { p.X.SetZero() p.Y.SetZero() isInfinity = true - return + return isInfinity, nil } // we need to copy the input buffer (to keep this method thread safe) @@ -863,12 +857,14 @@ func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return false, err + } // store mData in p.Y[0] p.Y[0] = uint64(mData) // recomputing Y will be done asynchronously - return + return isInfinity, nil } // SizeOfG2AffineCompressed represents the size in bytes that a G2Affine need in binary form, compressed @@ -906,9 +902,6 @@ func (p *G2Affine) Bytes() (res [SizeOfG2AffineCompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - msbMask := mCompressedSmallest // compressed, we need to know if Y is lexicographically bigger than -Y // if p.Y ">" -p.Y @@ -918,37 +911,10 @@ func (p *G2Affine) Bytes() (res [SizeOfG2AffineCompressed]byte) { // we store X and mask the most significant word with our metadata mask // p.X.B1.A1 | p.X.B1.A0 | p.X.B0.A1 | p.X.B0.A0 - tmp = p.X.B1.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[32:40], tmp[0]) - binary.BigEndian.PutUint64(res[24:32], tmp[1]) - binary.BigEndian.PutUint64(res[16:24], tmp[2]) - binary.BigEndian.PutUint64(res[8:16], tmp[3]) - binary.BigEndian.PutUint64(res[0:8], tmp[4]) - - tmp = p.X.B1.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[72:80], tmp[0]) - binary.BigEndian.PutUint64(res[64:72], tmp[1]) - binary.BigEndian.PutUint64(res[56:64], tmp[2]) - binary.BigEndian.PutUint64(res[48:56], tmp[3]) - binary.BigEndian.PutUint64(res[40:48], tmp[4]) - - tmp = p.X.B0.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[112:120], tmp[0]) - binary.BigEndian.PutUint64(res[104:112], tmp[1]) - binary.BigEndian.PutUint64(res[96:104], tmp[2]) - binary.BigEndian.PutUint64(res[88:96], tmp[3]) - binary.BigEndian.PutUint64(res[80:88], tmp[4]) - - tmp = p.X.B0.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[152:160], tmp[0]) - binary.BigEndian.PutUint64(res[144:152], tmp[1]) - binary.BigEndian.PutUint64(res[136:144], tmp[2]) - binary.BigEndian.PutUint64(res[128:136], tmp[3]) - binary.BigEndian.PutUint64(res[120:128], tmp[4]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X.B1.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[40:40+fp.Bytes]), p.X.B1.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[80:80+fp.Bytes]), p.X.B0.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[120:120+fp.Bytes]), p.X.B0.A0) res[0] |= msbMask @@ -967,77 +933,20 @@ func (p *G2Affine) RawBytes() (res [SizeOfG2AffineUncompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - // not compressed // we store the Y coordinate // p.Y.B1.A1 | p.Y.B1.A0 | p.Y.B0.A1 | p.Y.B0.A0 - tmp = p.Y.B1.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[192:200], tmp[0]) - binary.BigEndian.PutUint64(res[184:192], tmp[1]) - binary.BigEndian.PutUint64(res[176:184], tmp[2]) - binary.BigEndian.PutUint64(res[168:176], tmp[3]) - binary.BigEndian.PutUint64(res[160:168], tmp[4]) - - tmp = p.Y.B1.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[232:240], tmp[0]) - binary.BigEndian.PutUint64(res[224:232], tmp[1]) - binary.BigEndian.PutUint64(res[216:224], tmp[2]) - binary.BigEndian.PutUint64(res[208:216], tmp[3]) - binary.BigEndian.PutUint64(res[200:208], tmp[4]) - - tmp = p.Y.B0.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[272:280], tmp[0]) - binary.BigEndian.PutUint64(res[264:272], tmp[1]) - binary.BigEndian.PutUint64(res[256:264], tmp[2]) - binary.BigEndian.PutUint64(res[248:256], tmp[3]) - binary.BigEndian.PutUint64(res[240:248], tmp[4]) - - tmp = p.Y.B0.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[312:320], tmp[0]) - binary.BigEndian.PutUint64(res[304:312], tmp[1]) - binary.BigEndian.PutUint64(res[296:304], tmp[2]) - binary.BigEndian.PutUint64(res[288:296], tmp[3]) - binary.BigEndian.PutUint64(res[280:288], tmp[4]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[160:160+fp.Bytes]), p.Y.B1.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[200:200+fp.Bytes]), p.Y.B1.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[240:240+fp.Bytes]), p.Y.B0.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[280:280+fp.Bytes]), p.Y.B0.A0) // we store X and mask the most significant word with our metadata mask // p.X.B1.A1 | p.X.B1.A0 | p.X.B0.A1 | p.X.B0.A0 - tmp = p.X.B1.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[32:40], tmp[0]) - binary.BigEndian.PutUint64(res[24:32], tmp[1]) - binary.BigEndian.PutUint64(res[16:24], tmp[2]) - binary.BigEndian.PutUint64(res[8:16], tmp[3]) - binary.BigEndian.PutUint64(res[0:8], tmp[4]) - - tmp = p.X.B1.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[72:80], tmp[0]) - binary.BigEndian.PutUint64(res[64:72], tmp[1]) - binary.BigEndian.PutUint64(res[56:64], tmp[2]) - binary.BigEndian.PutUint64(res[48:56], tmp[3]) - binary.BigEndian.PutUint64(res[40:48], tmp[4]) - - tmp = p.X.B0.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[112:120], tmp[0]) - binary.BigEndian.PutUint64(res[104:112], tmp[1]) - binary.BigEndian.PutUint64(res[96:104], tmp[2]) - binary.BigEndian.PutUint64(res[88:96], tmp[3]) - binary.BigEndian.PutUint64(res[80:88], tmp[4]) - - tmp = p.X.B0.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[152:160], tmp[0]) - binary.BigEndian.PutUint64(res[144:152], tmp[1]) - binary.BigEndian.PutUint64(res[136:144], tmp[2]) - binary.BigEndian.PutUint64(res[128:136], tmp[3]) - binary.BigEndian.PutUint64(res[120:128], tmp[4]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X.B1.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[40:40+fp.Bytes]), p.X.B1.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[80:80+fp.Bytes]), p.X.B0.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[120:120+fp.Bytes]), p.X.B0.A0) res[0] |= mUncompressed @@ -1089,15 +998,31 @@ func (p *G2Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { if mData == mUncompressed { // read X and Y coordinates // p.X.B1.A1 | p.X.B1.A0 | p.X.B0.A1 | p.X.B0.A0 - p.X.B1.A1.SetBytes(buf[fp.Bytes*0 : fp.Bytes*1]) - p.X.B1.A0.SetBytes(buf[fp.Bytes*1 : fp.Bytes*2]) - p.X.B0.A1.SetBytes(buf[fp.Bytes*2 : fp.Bytes*3]) - p.X.B0.A0.SetBytes(buf[fp.Bytes*3 : fp.Bytes*4]) + if err := p.X.B1.A1.SetBytesCanonical(buf[fp.Bytes*0 : fp.Bytes*1]); err != nil { + return 0, err + } + if err := p.X.B1.A0.SetBytesCanonical(buf[fp.Bytes*1 : fp.Bytes*2]); err != nil { + return 0, err + } + if err := p.X.B0.A1.SetBytesCanonical(buf[fp.Bytes*2 : fp.Bytes*3]); err != nil { + return 0, err + } + if err := p.X.B0.A0.SetBytesCanonical(buf[fp.Bytes*3 : fp.Bytes*4]); err != nil { + return 0, err + } // p.Y.B1.A1 | p.Y.B1.A0 | p.Y.B0.A1 | p.Y.B0.A0 - p.Y.B1.A1.SetBytes(buf[fp.Bytes*4 : fp.Bytes*5]) - p.Y.B1.A0.SetBytes(buf[fp.Bytes*5 : fp.Bytes*6]) - p.Y.B0.A1.SetBytes(buf[fp.Bytes*6 : fp.Bytes*7]) - p.Y.B0.A0.SetBytes(buf[fp.Bytes*7 : fp.Bytes*8]) + if err := p.Y.B1.A1.SetBytesCanonical(buf[fp.Bytes*4 : fp.Bytes*5]); err != nil { + return 0, err + } + if err := p.Y.B1.A0.SetBytesCanonical(buf[fp.Bytes*5 : fp.Bytes*6]); err != nil { + return 0, err + } + if err := p.Y.B0.A1.SetBytesCanonical(buf[fp.Bytes*6 : fp.Bytes*7]); err != nil { + return 0, err + } + if err := p.Y.B0.A0.SetBytesCanonical(buf[fp.Bytes*7 : fp.Bytes*8]); err != nil { + return 0, err + } // subgroup check if subGroupCheck && !p.IsInSubGroup() { @@ -1118,10 +1043,18 @@ func (p *G2Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { // read X coordinate // p.X.B1.A1 | p.X.B1.A0 | p.X.B0.A1 | p.X.B0.A0 - p.X.B1.A1.SetBytes(bufX[fp.Bytes*0 : fp.Bytes*1]) - p.X.B1.A0.SetBytes(buf[fp.Bytes*1 : fp.Bytes*2]) - p.X.B0.A1.SetBytes(buf[fp.Bytes*2 : fp.Bytes*3]) - p.X.B0.A0.SetBytes(buf[fp.Bytes*3 : fp.Bytes*4]) + if err := p.X.B1.A1.SetBytesCanonical(bufX[fp.Bytes*0 : fp.Bytes*1]); err != nil { + return 0, err + } + if err := p.X.B1.A0.SetBytesCanonical(buf[fp.Bytes*1 : fp.Bytes*2]); err != nil { + return 0, err + } + if err := p.X.B0.A1.SetBytesCanonical(buf[fp.Bytes*2 : fp.Bytes*3]); err != nil { + return 0, err + } + if err := p.X.B0.A0.SetBytesCanonical(buf[fp.Bytes*3 : fp.Bytes*4]); err != nil { + return 0, err + } var YSquared, Y fptower.E4 @@ -1197,7 +1130,7 @@ func (p *G2Affine) unsafeComputeY(subGroupCheck bool) error { // assumes buf[:8] mask is set to compressed // returns true if point is infinity and need no further processing // it sets X coordinate and uses Y for scratch space to store decompression metadata -func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { +func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { // read the most significant byte mData := buf[0] & mMask @@ -1206,7 +1139,7 @@ func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { p.X.SetZero() p.Y.SetZero() isInfinity = true - return + return isInfinity, nil } // we need to copy the input buffer (to keep this method thread safe) @@ -1216,14 +1149,22 @@ func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { // read X coordinate // p.X.B1.A1 | p.X.B1.A0 | p.X.B0.A1 | p.X.B0.A0 - p.X.B1.A1.SetBytes(bufX[fp.Bytes*0 : fp.Bytes*1]) - p.X.B1.A0.SetBytes(buf[fp.Bytes*1 : fp.Bytes*2]) - p.X.B0.A1.SetBytes(buf[fp.Bytes*2 : fp.Bytes*3]) - p.X.B0.A0.SetBytes(buf[fp.Bytes*3 : fp.Bytes*4]) + if err := p.X.B1.A1.SetBytesCanonical(bufX[fp.Bytes*0 : fp.Bytes*1]); err != nil { + return false, err + } + if err := p.X.B1.A0.SetBytesCanonical(buf[fp.Bytes*1 : fp.Bytes*2]); err != nil { + return false, err + } + if err := p.X.B0.A1.SetBytesCanonical(buf[fp.Bytes*2 : fp.Bytes*3]); err != nil { + return false, err + } + if err := p.X.B0.A0.SetBytesCanonical(buf[fp.Bytes*3 : fp.Bytes*4]); err != nil { + return false, err + } // store mData in p.Y.B0.A0[0] p.Y.B0.A0[0] = uint64(mData) // recomputing Y will be done asynchronously - return + return isInfinity, nil } diff --git a/ecc/bls24-317/multiexp.go b/ecc/bls24-317/multiexp.go index 5a1b6797d..037aa6538 100644 --- a/ecc/bls24-317/multiexp.go +++ b/ecc/bls24-317/multiexp.go @@ -25,143 +25,6 @@ import ( "runtime" ) -// selector stores the index, mask and shifts needed to select bits from a scalar -// it is used during the multiExp algorithm or the batch scalar multiplication -type selector struct { - index uint64 // index in the multi-word scalar to select bits from - mask uint64 // mask (c-bit wide) - shift uint64 // shift needed to get our bits on low positions - - multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) - maskHigh uint64 // same than mask, for index+1 - shiftHigh uint64 // same than shift, for index+1 -} - -// partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits -// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract -// 2^{c} to the current digit, making it negative. -// negative digits can be processed in a later step as adding -G into the bucket instead of G -// (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) -// scalarsMont indicates wheter the provided scalars are in montgomery form -// returns smallValues, which represent the number of scalars which meets the following condition -// 0 < scalar < 2^c (in other words, scalars where only the c-least significant bits are non zero) -func partitionScalars(scalars []fr.Element, c uint64, scalarsMont bool, nbTasks int) ([]fr.Element, int) { - toReturn := make([]fr.Element, len(scalars)) - - // number of c-bit radixes in a scalar - nbChunks := fr.Limbs * 64 / c - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // msb of the c-bit window - max := int(1 << (c - 1)) // max value we want for our digits - cDivides64 := (64 % c) == 0 // if c doesn't divide 64, we may need to select over multiple words - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := uint64(0); chunk < nbChunks; chunk++ { - jc := uint64(chunk * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } - - // for each chunk, we could track the number of non-zeros points we will need to process - // this way, if a chunk has more work to do than others, we can spawn off more go routines - // (at the cost of more buckets allocated) - // a simplified approach is to track the small values where only the first word is set - // if this number represent a significant number of points, then we will split first chunk - // processing in the msm in 2, to ensure all go routines finish at ~same time - // /!\ nbTasks is enough as parallel.Execute is not going to spawn more than nbTasks go routine - // if it does, though, this will deadlocK. - chSmallValues := make(chan int, nbTasks) - - parallel.Execute(len(scalars), func(start, end int) { - smallValues := 0 - for i := start; i < end; i++ { - var carry int - - scalar := scalars[i] - if scalarsMont { - scalar.FromMont() - } - if scalar.FitsOnOneWord() { - // everything is 0, no need to process this scalar - if scalar[0] == 0 { - continue - } - // low c-bits are 1 in mask - if scalar[0]&mask == scalar[0] { - smallValues++ - } - } - - // for each chunk in the scalar, compute the current digit, and an eventual carry - for chunk := uint64(0); chunk < nbChunks; chunk++ { - s := selectors[chunk] - - // init with carry if any - digit := carry - carry = 0 - - // digit = value of the c-bit window - digit += int((scalar[s.index] & s.mask) >> s.shift) - - if s.multiWordSelect { - // we are selecting bits over 2 words - digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh - } - - // if digit is zero, no impact on result - if digit == 0 { - continue - } - - // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract - // 2^{c} to the current digit, making it negative. - if digit >= max { - digit -= (1 << c) - carry = 1 - } - - var bits uint64 - if digit >= 0 { - bits = uint64(digit) - } else { - bits = uint64(-digit-1) | msbWindow - } - - toReturn[i][s.index] |= (bits << s.shift) - if s.multiWordSelect { - toReturn[i][s.index+1] |= (bits >> s.shiftHigh) - } - - } - } - - chSmallValues <- smallValues - - }, nbTasks) - - // aggregate small values - close(chSmallValues) - smallValues := 0 - for o := range chSmallValues { - smallValues += o - } - return toReturn, smallValues -} - // MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf // // This call return an error if len(scalars) != len(points) or if provided config is invalid. @@ -221,7 +84,7 @@ func (p *G1Jac) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.Mul // we split recursively until nbChunks(c) >= nbTasks, bestC := func(nbPoints int) uint64 { // implemented msmC methods (the c we use must be in this slice) - implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21} + implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} var C uint64 // approximate cost (in group operations) // cost = bits/c * (nbPoints + 2^{c}) @@ -229,118 +92,177 @@ func (p *G1Jac) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.Mul // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results min := math.MaxFloat64 for _, c := range implementedCs { - cc := fr.Limbs * 64 * (nbPoints + (1 << (c))) + cc := (fr.Bits + 1) * (nbPoints + (1 << c)) cost := float64(cc) / float64(c) if cost < min { min = cost C = c } } - // empirical, needs to be tuned. - // if C > 16 && nbPoints < 1 << 23 { - // C = 16 - // } return C } - var C uint64 - nbSplits := 1 - nbChunks := 0 - for nbChunks < config.NbTasks { - C = bestC(nbPoints) - nbChunks = int(fr.Limbs * 64 / C) // number of c-bit radixes in a scalar - if (fr.Limbs*64)%C != 0 { - nbChunks++ - } - nbChunks *= nbSplits - if nbChunks < config.NbTasks { - nbSplits <<= 1 - nbPoints >>= 1 + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // if we don't utilise all the tasks (CPU in the default case) that we could, let's see if it's worth it to split + if config.NbTasks > 1 && nbChunks < config.NbTasks { + // before spliting, let's see if we endup with more tasks than thread; + cSplit := bestC(nbPoints / 2) + nbChunksPostSplit := int(computeNbChunks(cSplit)) + nbTasksPostSplit := nbChunksPostSplit * 2 + if (nbTasksPostSplit <= config.NbTasks/2) || (nbTasksPostSplit-config.NbTasks/2) <= (config.NbTasks-nbChunks) { + // if postSplit we still have less tasks than available CPU + // or if we have more tasks BUT the difference of CPU usage is in our favor, we split. + config.NbTasks /= 2 + var _p G1Jac + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil } } + _innerMsmG1(p, C, points, scalars, config) + + return p, nil +} + +func _innerMsmG1(p *G1Jac, c uint64, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { // partition the scalars - // note: we do that before the actual chunk processing, as for each c-bit window (starting from LSW) - // if it's larger than 2^{c-1}, we have a carry we need to propagate up to the higher window - var smallValues int - scalars, smallValues = partitionScalars(scalars, C, config.ScalarsMont, config.NbTasks) - - // if we have more than 10% of small values, we split the processing of the first chunk in 2 - // we may want to do that in msmInnerG1Jac , but that would incur a cost of looping through all scalars one more time - splitFirstChunk := (float64(smallValues) / float64(len(scalars))) >= 0.1 - - // we have nbSplits intermediate results that we must sum together. - _p := make([]G1Jac, nbSplits-1) - chDone := make(chan int, nbSplits-1) - for i := 0; i < nbSplits-1; i++ { - start := i * nbPoints - end := start + nbPoints - go func(start, end, i int) { - msmInnerG1Jac(&_p[i], int(C), points[start:end], scalars[start:end], splitFirstChunk) - chDone <- i - }(start, end, i) + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) } - msmInnerG1Jac(p, int(C), points[(nbSplits-1)*nbPoints:], scalars[(nbSplits-1)*nbPoints:], splitFirstChunk) - for i := 0; i < nbSplits-1; i++ { - done := <-chDone - p.AddAssign(&_p[done]) + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessorG1(c, chunkStats[j]) + if j == int(nbChunks-1) { + processChunk = getChunkProcessorG1(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan g1JacExtended, 2) + split := n / 2 + go processChunk(uint64(j), chSplit, c, points[:split], digits[j*n:(j*n)+split]) + go processChunk(uint64(j), chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n]) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n]) } - close(chDone) - return p, nil -} -func msmInnerG1Jac(p *G1Jac, c int, points []G1Affine, scalars []fr.Element, splitFirstChunk bool) { + return msmReduceChunkG1Affine(p, int(c), chChunks[:]) +} +// getChunkProcessorG1 decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessorG1(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- g1JacExtended, c uint64, points []G1Affine, digits []uint16) { switch c { + case 3: + return processChunkG1Jacobian[bucketg1JacExtendedC3] case 4: - p.msmC4(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC4] case 5: - p.msmC5(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC5] case 6: - p.msmC6(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC6] case 7: - p.msmC7(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC7] case 8: - p.msmC8(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC8] case 9: - p.msmC9(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC9] case 10: - p.msmC10(points, scalars, splitFirstChunk) - + const batchSize = 80 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC10] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC10, bucketG1AffineC10, bitSetC10, pG1AffineC10, ppG1AffineC10, qG1AffineC10, cG1AffineC10] case 11: - p.msmC11(points, scalars, splitFirstChunk) - + const batchSize = 150 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC11] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC11, bucketG1AffineC11, bitSetC11, pG1AffineC11, ppG1AffineC11, qG1AffineC11, cG1AffineC11] case 12: - p.msmC12(points, scalars, splitFirstChunk) - + const batchSize = 200 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC12] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC12, bucketG1AffineC12, bitSetC12, pG1AffineC12, ppG1AffineC12, qG1AffineC12, cG1AffineC12] case 13: - p.msmC13(points, scalars, splitFirstChunk) - + const batchSize = 350 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC13] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC13, bucketG1AffineC13, bitSetC13, pG1AffineC13, ppG1AffineC13, qG1AffineC13, cG1AffineC13] case 14: - p.msmC14(points, scalars, splitFirstChunk) - + const batchSize = 400 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC14] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC14, bucketG1AffineC14, bitSetC14, pG1AffineC14, ppG1AffineC14, qG1AffineC14, cG1AffineC14] case 15: - p.msmC15(points, scalars, splitFirstChunk) - + const batchSize = 500 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC15] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC15, bucketG1AffineC15, bitSetC15, pG1AffineC15, ppG1AffineC15, qG1AffineC15, cG1AffineC15] case 16: - p.msmC16(points, scalars, splitFirstChunk) - - case 20: - p.msmC20(points, scalars, splitFirstChunk) - - case 21: - p.msmC21(points, scalars, splitFirstChunk) - + const batchSize = 640 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC16] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC16, bucketG1AffineC16, bitSetC16, pG1AffineC16, ppG1AffineC16, qG1AffineC16, cG1AffineC16] default: - panic("not implemented") + // panic("will not happen c != previous values is not generated by templates") + return processChunkG1Jacobian[bucketg1JacExtendedC16] } } @@ -360,1846 +282,445 @@ func msmReduceChunkG1Affine(p *G1Jac, c int, chChunks []chan g1JacExtended) *G1J return p.unsafeFromJacExtended(&_p) } -func msmProcessChunkG1Affine(chunk uint64, - chRes chan<- g1JacExtended, - buckets []g1JacExtended, - c uint64, - points []G1Affine, - scalars []fr.Element) { +// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf +// +// This call return an error if len(scalars) != len(points) or if provided config is invalid. +func (p *G2Affine) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Affine, error) { + var _p G2Jac + if _, err := _p.MultiExp(points, scalars, config); err != nil { + return nil, err + } + p.FromJacobian(&_p) + return p, nil +} - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) +// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf +// +// This call return an error if len(scalars) != len(points) or if provided config is invalid. +func (p *G2Jac) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Jac, error) { + // note: + // each of the msmCX method is the same, except for the c constant it declares + // duplicating (through template generation) these methods allows to declare the buckets on the stack + // the choice of c needs to be improved: + // there is a theoritical value that gives optimal asymptotics + // but in practice, other factors come into play, including: + // * if c doesn't divide 64, the word size, then we're bound to select bits over 2 words of our scalars, instead of 1 + // * number of CPUs + // * cache friendliness (which depends on the host, G1 or G2... ) + // --> for example, on BN254, a G1 point fits into one cache line of 64bytes, but a G2 point don't. - for i := 0; i < len(buckets); i++ { - buckets[i].setInfinity() - } + // for each msmCX + // step 1 + // we compute, for each scalars over c-bit wide windows, nbChunk digits + // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract + // 2^{c} to the current digit, making it negative. + // negative digits will be processed in the next step as adding -G into the bucket instead of G + // (computing -G is cheap, and this saves us half of the buckets) + // step 2 + // buckets are declared on the stack + // notice that we have 2^{c-1} buckets instead of 2^{c} (see step1) + // we use jacobian extended formulas here as they are faster than mixed addition + // msmProcessChunk places points into buckets base on their selector and return the weighted bucket sum in given channel + // step 3 + // reduce the buckets weigthed sums into our result (msmReduceChunk) - jc := uint64(chunk * c) - s := selector{} - s.index = jc / 64 - s.shift = jc - (s.index * 64) - s.mask = mask << s.shift - s.multiWordSelect = (64%c) != 0 && s.shift > (64-c) && s.index < (fr.Limbs-1) - if s.multiWordSelect { - nbBitsHigh := s.shift - uint64(64-c) - s.maskHigh = (1 << nbBitsHigh) - 1 - s.shiftHigh = (c - nbBitsHigh) + // ensure len(points) == len(scalars) + nbPoints := len(points) + if nbPoints != len(scalars) { + return nil, errors.New("len(points) != len(scalars)") } - // for each scalars, get the digit corresponding to the chunk we're processing. - for i := 0; i < len(scalars); i++ { - bits := (scalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (scalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { - continue - } + // if nbTasks is not set, use all available CPUs + if config.NbTasks <= 0 { + config.NbTasks = runtime.NumCPU() + } else if config.NbTasks > 1024 { + return nil, errors.New("invalid config: config.NbTasks > 1024") + } - // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { - // add - buckets[bits-1].addMixed(&points[i]) - } else { - // sub - buckets[bits & ^msbWindow].subMixed(&points[i]) + // here, we compute the best C for nbPoints + // we split recursively until nbChunks(c) >= nbTasks, + bestC := func(nbPoints int) uint64 { + // implemented msmC methods (the c we use must be in this slice) + implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + var C uint64 + // approximate cost (in group operations) + // cost = bits/c * (nbPoints + 2^{c}) + // this needs to be verified empirically. + // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results + min := math.MaxFloat64 + for _, c := range implementedCs { + cc := (fr.Bits + 1) * (nbPoints + (1 << c)) + cost := float64(cc) / float64(c) + if cost < min { + min = cost + C = c + } } + return C } - // reduce buckets into total - // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] - - var runningSum, total g1JacExtended - runningSum.setInfinity() - total.setInfinity() - for k := len(buckets) - 1; k >= 0; k-- { - if !buckets[k].ZZ.IsZero() { - runningSum.add(&buckets[k]) + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // if we don't utilise all the tasks (CPU in the default case) that we could, let's see if it's worth it to split + if config.NbTasks > 1 && nbChunks < config.NbTasks { + // before spliting, let's see if we endup with more tasks than thread; + cSplit := bestC(nbPoints / 2) + nbChunksPostSplit := int(computeNbChunks(cSplit)) + nbTasksPostSplit := nbChunksPostSplit * 2 + if (nbTasksPostSplit <= config.NbTasks/2) || (nbTasksPostSplit-config.NbTasks/2) <= (config.NbTasks-nbChunks) { + // if postSplit we still have less tasks than available CPU + // or if we have more tasks BUT the difference of CPU usage is in our favor, we split. + config.NbTasks /= 2 + var _p G2Jac + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil } - total.add(&runningSum) } - chRes <- total + _innerMsmG2(p, C, points, scalars, config) + return p, nil } -func (p *G1Jac) msmC4(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 4 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) +func _innerMsmG2(p *G2Jac, c uint64, points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G2Jac { + // partition the scalars + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) // for each chunk, spawn one go routine that'll loop through all the scalars in the // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + // note that buckets is an array allocated on the stack and this is critical for performance // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended + chChunks := make([]chan g2JacExtended, nbChunks) for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) + chChunks[i] = make(chan g2JacExtended, 1) } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessorG2(c, chunkStats[j]) + if j == int(nbChunks-1) { + processChunk = getChunkProcessorG2(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan g2JacExtended, 2) + split := n / 2 + go processChunk(uint64(j), chSplit, c, points[:split], digits[j*n:(j*n)+split]) + go processChunk(uint64(j), chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n]) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n]) } - return msmReduceChunkG1Affine(p, c, chChunks[:]) + return msmReduceChunkG2Affine(p, int(c), chChunks[:]) } -func (p *G1Jac) msmC5(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 5 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } +// getChunkProcessorG2 decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessorG2(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- g2JacExtended, c uint64, points []G2Affine, digits []uint16) { + switch c { - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + case 3: + return processChunkG2Jacobian[bucketg2JacExtendedC3] + case 4: + return processChunkG2Jacobian[bucketg2JacExtendedC4] + case 5: + return processChunkG2Jacobian[bucketg2JacExtendedC5] + case 6: + return processChunkG2Jacobian[bucketg2JacExtendedC6] + case 7: + return processChunkG2Jacobian[bucketg2JacExtendedC7] + case 8: + return processChunkG2Jacobian[bucketg2JacExtendedC8] + case 9: + return processChunkG2Jacobian[bucketg2JacExtendedC9] + case 10: + const batchSize = 80 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC10] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC10, bucketG2AffineC10, bitSetC10, pG2AffineC10, ppG2AffineC10, qG2AffineC10, cG2AffineC10] + case 11: + const batchSize = 150 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC11] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC11, bucketG2AffineC11, bitSetC11, pG2AffineC11, ppG2AffineC11, qG2AffineC11, cG2AffineC11] + case 12: + const batchSize = 200 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC12] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC12, bucketG2AffineC12, bitSetC12, pG2AffineC12, ppG2AffineC12, qG2AffineC12, cG2AffineC12] + case 13: + const batchSize = 350 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC13] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC13, bucketG2AffineC13, bitSetC13, pG2AffineC13, ppG2AffineC13, qG2AffineC13, cG2AffineC13] + case 14: + const batchSize = 400 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC14] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC14, bucketG2AffineC14, bitSetC14, pG2AffineC14, ppG2AffineC14, qG2AffineC14, cG2AffineC14] + case 15: + const batchSize = 500 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC15] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC15, bucketG2AffineC15, bitSetC15, pG2AffineC15, ppG2AffineC15, qG2AffineC15, cG2AffineC15] + case 16: + const batchSize = 640 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC16] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC16, bucketG2AffineC16, bitSetC16, pG2AffineC16, ppG2AffineC16, qG2AffineC16, cG2AffineC16] + default: + // panic("will not happen c != previous values is not generated by templates") + return processChunkG2Jacobian[bucketg2JacExtendedC16] } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) } -func (p *G1Jac) msmC6(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 6 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC7(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 7 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC8(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 8 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC9(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 9 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC10(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 10 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC11(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 11 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC12(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 12 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC13(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 13 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC14(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 14 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC15(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 15 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC16(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 16 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC20(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 20 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC21(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 21 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf -// -// This call return an error if len(scalars) != len(points) or if provided config is invalid. -func (p *G2Affine) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Affine, error) { - var _p G2Jac - if _, err := _p.MultiExp(points, scalars, config); err != nil { - return nil, err - } - p.FromJacobian(&_p) - return p, nil -} - -// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf -// -// This call return an error if len(scalars) != len(points) or if provided config is invalid. -func (p *G2Jac) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Jac, error) { - // note: - // each of the msmCX method is the same, except for the c constant it declares - // duplicating (through template generation) these methods allows to declare the buckets on the stack - // the choice of c needs to be improved: - // there is a theoritical value that gives optimal asymptotics - // but in practice, other factors come into play, including: - // * if c doesn't divide 64, the word size, then we're bound to select bits over 2 words of our scalars, instead of 1 - // * number of CPUs - // * cache friendliness (which depends on the host, G1 or G2... ) - // --> for example, on BN254, a G1 point fits into one cache line of 64bytes, but a G2 point don't. - - // for each msmCX - // step 1 - // we compute, for each scalars over c-bit wide windows, nbChunk digits - // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract - // 2^{c} to the current digit, making it negative. - // negative digits will be processed in the next step as adding -G into the bucket instead of G - // (computing -G is cheap, and this saves us half of the buckets) - // step 2 - // buckets are declared on the stack - // notice that we have 2^{c-1} buckets instead of 2^{c} (see step1) - // we use jacobian extended formulas here as they are faster than mixed addition - // msmProcessChunk places points into buckets base on their selector and return the weighted bucket sum in given channel - // step 3 - // reduce the buckets weigthed sums into our result (msmReduceChunk) - - // ensure len(points) == len(scalars) - nbPoints := len(points) - if nbPoints != len(scalars) { - return nil, errors.New("len(points) != len(scalars)") - } - - // if nbTasks is not set, use all available CPUs - if config.NbTasks <= 0 { - config.NbTasks = runtime.NumCPU() - } else if config.NbTasks > 1024 { - return nil, errors.New("invalid config: config.NbTasks > 1024") - } - - // here, we compute the best C for nbPoints - // we split recursively until nbChunks(c) >= nbTasks, - bestC := func(nbPoints int) uint64 { - // implemented msmC methods (the c we use must be in this slice) - implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21} - var C uint64 - // approximate cost (in group operations) - // cost = bits/c * (nbPoints + 2^{c}) - // this needs to be verified empirically. - // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results - min := math.MaxFloat64 - for _, c := range implementedCs { - cc := fr.Limbs * 64 * (nbPoints + (1 << (c))) - cost := float64(cc) / float64(c) - if cost < min { - min = cost - C = c - } - } - // empirical, needs to be tuned. - // if C > 16 && nbPoints < 1 << 23 { - // C = 16 - // } - return C - } - - var C uint64 - nbSplits := 1 - nbChunks := 0 - for nbChunks < config.NbTasks { - C = bestC(nbPoints) - nbChunks = int(fr.Limbs * 64 / C) // number of c-bit radixes in a scalar - if (fr.Limbs*64)%C != 0 { - nbChunks++ - } - nbChunks *= nbSplits - if nbChunks < config.NbTasks { - nbSplits <<= 1 - nbPoints >>= 1 - } - } - - // partition the scalars - // note: we do that before the actual chunk processing, as for each c-bit window (starting from LSW) - // if it's larger than 2^{c-1}, we have a carry we need to propagate up to the higher window - var smallValues int - scalars, smallValues = partitionScalars(scalars, C, config.ScalarsMont, config.NbTasks) - - // if we have more than 10% of small values, we split the processing of the first chunk in 2 - // we may want to do that in msmInnerG2Jac , but that would incur a cost of looping through all scalars one more time - splitFirstChunk := (float64(smallValues) / float64(len(scalars))) >= 0.1 - - // we have nbSplits intermediate results that we must sum together. - _p := make([]G2Jac, nbSplits-1) - chDone := make(chan int, nbSplits-1) - for i := 0; i < nbSplits-1; i++ { - start := i * nbPoints - end := start + nbPoints - go func(start, end, i int) { - msmInnerG2Jac(&_p[i], int(C), points[start:end], scalars[start:end], splitFirstChunk) - chDone <- i - }(start, end, i) - } - - msmInnerG2Jac(p, int(C), points[(nbSplits-1)*nbPoints:], scalars[(nbSplits-1)*nbPoints:], splitFirstChunk) - for i := 0; i < nbSplits-1; i++ { - done := <-chDone - p.AddAssign(&_p[done]) - } - close(chDone) - return p, nil -} - -func msmInnerG2Jac(p *G2Jac, c int, points []G2Affine, scalars []fr.Element, splitFirstChunk bool) { - - switch c { - - case 4: - p.msmC4(points, scalars, splitFirstChunk) - - case 5: - p.msmC5(points, scalars, splitFirstChunk) - - case 6: - p.msmC6(points, scalars, splitFirstChunk) - - case 7: - p.msmC7(points, scalars, splitFirstChunk) - - case 8: - p.msmC8(points, scalars, splitFirstChunk) - - case 9: - p.msmC9(points, scalars, splitFirstChunk) - - case 10: - p.msmC10(points, scalars, splitFirstChunk) - - case 11: - p.msmC11(points, scalars, splitFirstChunk) - - case 12: - p.msmC12(points, scalars, splitFirstChunk) - - case 13: - p.msmC13(points, scalars, splitFirstChunk) - - case 14: - p.msmC14(points, scalars, splitFirstChunk) - - case 15: - p.msmC15(points, scalars, splitFirstChunk) - - case 16: - p.msmC16(points, scalars, splitFirstChunk) - - case 20: - p.msmC20(points, scalars, splitFirstChunk) - - case 21: - p.msmC21(points, scalars, splitFirstChunk) - - default: - panic("not implemented") - } -} - -// msmReduceChunkG2Affine reduces the weighted sum of the buckets into the result of the multiExp -func msmReduceChunkG2Affine(p *G2Jac, c int, chChunks []chan g2JacExtended) *G2Jac { - var _p g2JacExtended - totalj := <-chChunks[len(chChunks)-1] - _p.Set(&totalj) - for j := len(chChunks) - 2; j >= 0; j-- { - for l := 0; l < c; l++ { - _p.double(&_p) - } - totalj := <-chChunks[j] - _p.add(&totalj) - } - - return p.unsafeFromJacExtended(&_p) -} - -func msmProcessChunkG2Affine(chunk uint64, - chRes chan<- g2JacExtended, - buckets []g2JacExtended, - c uint64, - points []G2Affine, - scalars []fr.Element) { - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) - - for i := 0; i < len(buckets); i++ { - buckets[i].setInfinity() - } - - jc := uint64(chunk * c) - s := selector{} - s.index = jc / 64 - s.shift = jc - (s.index * 64) - s.mask = mask << s.shift - s.multiWordSelect = (64%c) != 0 && s.shift > (64-c) && s.index < (fr.Limbs-1) - if s.multiWordSelect { - nbBitsHigh := s.shift - uint64(64-c) - s.maskHigh = (1 << nbBitsHigh) - 1 - s.shiftHigh = (c - nbBitsHigh) - } - - // for each scalars, get the digit corresponding to the chunk we're processing. - for i := 0; i < len(scalars); i++ { - bits := (scalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (scalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { - continue - } - - // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { - // add - buckets[bits-1].addMixed(&points[i]) - } else { - // sub - buckets[bits & ^msbWindow].subMixed(&points[i]) - } - } - - // reduce buckets into total - // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] - - var runningSum, total g2JacExtended - runningSum.setInfinity() - total.setInfinity() - for k := len(buckets) - 1; k >= 0; k-- { - if !buckets[k].ZZ.IsZero() { - runningSum.add(&buckets[k]) - } - total.add(&runningSum) - } - - chRes <- total - -} - -func (p *G2Jac) msmC4(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 4 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC5(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 5 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC6(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 6 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC7(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 7 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC8(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 8 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC9(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 9 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC10(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 10 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC11(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 11 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() +// msmReduceChunkG2Affine reduces the weighted sum of the buckets into the result of the multiExp +func msmReduceChunkG2Affine(p *G2Jac, c int, chChunks []chan g2JacExtended) *G2Jac { + var _p g2JacExtended + totalj := <-chChunks[len(chChunks)-1] + _p.Set(&totalj) + for j := len(chChunks) - 2; j >= 0; j-- { + for l := 0; l < c; l++ { + _p.double(&_p) + } + totalj := <-chChunks[j] + _p.add(&totalj) } - return msmReduceChunkG2Affine(p, c, chChunks[:]) + return p.unsafeFromJacExtended(&_p) } -func (p *G2Jac) msmC12(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 12 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } +// selector stores the index, mask and shifts needed to select bits from a scalar +// it is used during the multiExp algorithm or the batch scalar multiplication +type selector struct { + index uint64 // index in the multi-word scalar to select bits from + mask uint64 // mask (c-bit wide) + shift uint64 // shift needed to get our bits on low positions - return msmReduceChunkG2Affine(p, c, chChunks[:]) + multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) + maskHigh uint64 // same than mask, for index+1 + shiftHigh uint64 // same than shift, for index+1 } -func (p *G2Jac) msmC13(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 13 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) +// return number of chunks for a given window size c +// the last chunk may be bigger to accomodate a potential carry from the NAF decomposition +func computeNbChunks(c uint64) uint64 { + return (fr.Bits + c - 1) / c } -func (p *G2Jac) msmC14(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 14 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) +// return the last window size for a scalar; +// this last window should accomodate a carry (from the NAF decomposition) +// it can be == c if we have 1 available bit +// it can be > c if we have 0 available bit +// it can be < c if we have 2+ available bits +func lastC(c uint64) uint64 { + nbAvailableBits := (computeNbChunks(c) * c) - fr.Bits + return c + 1 - nbAvailableBits } -func (p *G2Jac) msmC15(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 15 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } +type chunkStat struct { + // relative weight of work compared to other chunks. 100.0 -> nominal weight. + weight float32 - return msmReduceChunkG2Affine(p, c, chChunks[:]) + // percentage of bucket filled in the window; + ppBucketFilled float32 + nbBucketFilled int } -func (p *G2Jac) msmC16(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 16 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } +// partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits +// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract +// 2^{c} to the current digit, making it negative. +// negative digits can be processed in a later step as adding -G into the bucket instead of G +// (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) +func partitionScalars(scalars []fr.Element, c uint64, nbTasks int) ([]uint16, []chunkStat) { + // number of c-bit radixes in a scalar + nbChunks := computeNbChunks(c) - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + digits := make([]uint16, len(scalars)*int(nbChunks)) - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } + mask := uint64((1 << c) - 1) // low c bits are 1 + max := int(1<<(c-1)) - 1 // max value (inclusive) we want for our digits + cDivides64 := (64 % c) == 0 // if c doesn't divide 64, we may need to select over multiple words - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + // compute offset and word selector / shift to select the right bits of our windows + selectors := make([]selector, nbChunks) + for chunk := uint64(0); chunk < nbChunks; chunk++ { + jc := uint64(chunk * c) + d := selector{} + d.index = jc / 64 + d.shift = jc - (d.index * 64) + d.mask = mask << d.shift + d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs-1) + if d.multiWordSelect { + nbBitsHigh := d.shift - uint64(64-c) + d.maskHigh = (1 << nbBitsHigh) - 1 + d.shiftHigh = (c - nbBitsHigh) + } + selectors[chunk] = d } - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} + parallel.Execute(len(scalars), func(start, end int) { + for i := start; i < end; i++ { + if scalars[i].IsZero() { + // everything is 0, no need to process this scalar + continue + } + scalar := scalars[i].Bits() -func (p *G2Jac) msmC20(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 20 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) + var carry int - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + // for each chunk in the scalar, compute the current digit, and an eventual carry + for chunk := uint64(0); chunk < nbChunks-1; chunk++ { + s := selectors[chunk] - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } + // init with carry if any + digit := carry + carry = 0 - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } + // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract + // 2^{c} to the current digit, making it negative. + if digit > max { + digit -= (1 << c) + carry = 1 + } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } + // if digit is zero, no impact on result + if digit == 0 { + continue + } - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} + var bits uint16 + if digit > 0 { + bits = uint16(digit) << 1 + } else { + bits = (uint16(-digit-1) << 1) + 1 + } + digits[int(chunk)*len(scalars)+i] = bits + } -func (p *G2Jac) msmC21(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 21 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) + // for the last chunk, we don't want to borrow from a next window + // (but may have a larger max value) + chunk := nbChunks - 1 + s := selectors[chunk] + // init with carry if any + digit := carry + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } + digits[int(chunk)*len(scalars)+i] = uint16(digit) << 1 + } - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + }, nbTasks) - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) + // aggregate chunk stats + chunkStats := make([]chunkStat, nbChunks) + if c <= 9 { + // no need to compute stats for small window sizes + return digits, chunkStats } + parallel.Execute(len(chunkStats), func(start, end int) { + // for each chunk compute the statistics + for chunkID := start; chunkID < end; chunkID++ { + // indicates if a bucket is hit. + var b bitSetC16 - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) + // digits for the chunk + chunkDigits := digits[chunkID*len(scalars) : (chunkID+1)*len(scalars)] - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + totalOps := 0 + nz := 0 // non zero buckets count + for _, digit := range chunkDigits { + if digit == 0 { + continue + } + totalOps++ + bucketID := digit >> 1 + if digit&1 == 0 { + bucketID -= 1 + } + if !b[bucketID] { + nz++ + b[bucketID] = true + } + } + chunkStats[chunkID].weight = float32(totalOps) // count number of ops for now, we will compute the weight after + chunkStats[chunkID].ppBucketFilled = (float32(nz) * 100.0) / float32(int(1<<(c-1))) + chunkStats[chunkID].nbBucketFilled = nz + } + }, nbTasks) - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) + totalOps := float32(0.0) + for _, stat := range chunkStats { + totalOps += stat.weight } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + target := totalOps / float32(nbChunks) + if target != 0.0 { + // if target == 0, it means all the scalars are 0 everywhere, there is no work to be done. + for i := 0; i < len(chunkStats); i++ { + chunkStats[i].weight = (chunkStats[i].weight * 100.0) / target + } } - return msmReduceChunkG2Affine(p, c, chChunks[:]) + return digits, chunkStats } diff --git a/ecc/bls24-317/multiexp_affine.go b/ecc/bls24-317/multiexp_affine.go new file mode 100644 index 000000000..803835d81 --- /dev/null +++ b/ecc/bls24-317/multiexp_affine.go @@ -0,0 +1,686 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package bls24317 + +import ( + "github.com/consensys/gnark-crypto/ecc/bls24-317/fp" + "github.com/consensys/gnark-crypto/ecc/bls24-317/internal/fptower" +) + +type batchOpG1Affine struct { + bucketID uint16 + point G1Affine +} + +// processChunkG1BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/ConsenSys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunkG1BatchAffine[BJE ibg1JacExtended, B ibG1Affine, BS bitSet, TP pG1Affine, TPP ppG1Affine, TQ qOpsG1Affine, TC cG1Affine]( + chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16) { + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // g1JacExtended coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediatly + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in G1Affine used with the batch affine additions + // 1 in g1JacExtended used in case the queue of conflicting points + var buckets B + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + bucketsJE[i].setInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize } + + executeAndReset := func() { + batchAddG1Affine[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOpG1Affine) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.setInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *G1Affine, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.setInfinity() + } + return + } + if isAdd { + BK.setInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func() { + for i := 0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func() { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit >> 1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID -= 1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue)-1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total g1JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].ZZ.IsZero() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketG1AffineC10 [512]G1Affine +type bucketG1AffineC11 [1024]G1Affine +type bucketG1AffineC12 [2048]G1Affine +type bucketG1AffineC13 [4096]G1Affine +type bucketG1AffineC14 [8192]G1Affine +type bucketG1AffineC15 [16384]G1Affine +type bucketG1AffineC16 [32768]G1Affine + +// buckets: array of G1Affine points of size 1 << (c-1) +type ibG1Affine interface { + bucketG1AffineC10 | + bucketG1AffineC11 | + bucketG1AffineC12 | + bucketG1AffineC13 | + bucketG1AffineC14 | + bucketG1AffineC15 | + bucketG1AffineC16 +} + +// array of coordinates fp.Element +type cG1Affine interface { + cG1AffineC10 | + cG1AffineC11 | + cG1AffineC12 | + cG1AffineC13 | + cG1AffineC14 | + cG1AffineC15 | + cG1AffineC16 +} + +// buckets: array of G1Affine points (for the batch addition) +type pG1Affine interface { + pG1AffineC10 | + pG1AffineC11 | + pG1AffineC12 | + pG1AffineC13 | + pG1AffineC14 | + pG1AffineC15 | + pG1AffineC16 +} + +// buckets: array of *G1Affine points (for the batch addition) +type ppG1Affine interface { + ppG1AffineC10 | + ppG1AffineC11 | + ppG1AffineC12 | + ppG1AffineC13 | + ppG1AffineC14 | + ppG1AffineC15 | + ppG1AffineC16 +} + +// buckets: array of G1Affine queue operations (for the batch addition) +type qOpsG1Affine interface { + qG1AffineC10 | + qG1AffineC11 | + qG1AffineC12 | + qG1AffineC13 | + qG1AffineC14 | + qG1AffineC15 | + qG1AffineC16 +} + +// batch size 80 when c = 10 +type cG1AffineC10 [80]fp.Element +type pG1AffineC10 [80]G1Affine +type ppG1AffineC10 [80]*G1Affine +type qG1AffineC10 [80]batchOpG1Affine + +// batch size 150 when c = 11 +type cG1AffineC11 [150]fp.Element +type pG1AffineC11 [150]G1Affine +type ppG1AffineC11 [150]*G1Affine +type qG1AffineC11 [150]batchOpG1Affine + +// batch size 200 when c = 12 +type cG1AffineC12 [200]fp.Element +type pG1AffineC12 [200]G1Affine +type ppG1AffineC12 [200]*G1Affine +type qG1AffineC12 [200]batchOpG1Affine + +// batch size 350 when c = 13 +type cG1AffineC13 [350]fp.Element +type pG1AffineC13 [350]G1Affine +type ppG1AffineC13 [350]*G1Affine +type qG1AffineC13 [350]batchOpG1Affine + +// batch size 400 when c = 14 +type cG1AffineC14 [400]fp.Element +type pG1AffineC14 [400]G1Affine +type ppG1AffineC14 [400]*G1Affine +type qG1AffineC14 [400]batchOpG1Affine + +// batch size 500 when c = 15 +type cG1AffineC15 [500]fp.Element +type pG1AffineC15 [500]G1Affine +type ppG1AffineC15 [500]*G1Affine +type qG1AffineC15 [500]batchOpG1Affine + +// batch size 640 when c = 16 +type cG1AffineC16 [640]fp.Element +type pG1AffineC16 [640]G1Affine +type ppG1AffineC16 [640]*G1Affine +type qG1AffineC16 [640]batchOpG1Affine + +type batchOpG2Affine struct { + bucketID uint16 + point G2Affine +} + +// processChunkG2BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/ConsenSys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunkG2BatchAffine[BJE ibg2JacExtended, B ibG2Affine, BS bitSet, TP pG2Affine, TPP ppG2Affine, TQ qOpsG2Affine, TC cG2Affine]( + chunk uint64, + chRes chan<- g2JacExtended, + c uint64, + points []G2Affine, + digits []uint16) { + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // g2JacExtended coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediatly + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in G2Affine used with the batch affine additions + // 1 in g2JacExtended used in case the queue of conflicting points + var buckets B + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + bucketsJE[i].setInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize } + + executeAndReset := func() { + batchAddG2Affine[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOpG2Affine) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.setInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *G2Affine, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.setInfinity() + } + return + } + if isAdd { + BK.setInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func() { + for i := 0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func() { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit >> 1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID -= 1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue)-1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total g2JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].ZZ.IsZero() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketG2AffineC10 [512]G2Affine +type bucketG2AffineC11 [1024]G2Affine +type bucketG2AffineC12 [2048]G2Affine +type bucketG2AffineC13 [4096]G2Affine +type bucketG2AffineC14 [8192]G2Affine +type bucketG2AffineC15 [16384]G2Affine +type bucketG2AffineC16 [32768]G2Affine + +// buckets: array of G2Affine points of size 1 << (c-1) +type ibG2Affine interface { + bucketG2AffineC10 | + bucketG2AffineC11 | + bucketG2AffineC12 | + bucketG2AffineC13 | + bucketG2AffineC14 | + bucketG2AffineC15 | + bucketG2AffineC16 +} + +// array of coordinates fptower.E4 +type cG2Affine interface { + cG2AffineC10 | + cG2AffineC11 | + cG2AffineC12 | + cG2AffineC13 | + cG2AffineC14 | + cG2AffineC15 | + cG2AffineC16 +} + +// buckets: array of G2Affine points (for the batch addition) +type pG2Affine interface { + pG2AffineC10 | + pG2AffineC11 | + pG2AffineC12 | + pG2AffineC13 | + pG2AffineC14 | + pG2AffineC15 | + pG2AffineC16 +} + +// buckets: array of *G2Affine points (for the batch addition) +type ppG2Affine interface { + ppG2AffineC10 | + ppG2AffineC11 | + ppG2AffineC12 | + ppG2AffineC13 | + ppG2AffineC14 | + ppG2AffineC15 | + ppG2AffineC16 +} + +// buckets: array of G2Affine queue operations (for the batch addition) +type qOpsG2Affine interface { + qG2AffineC10 | + qG2AffineC11 | + qG2AffineC12 | + qG2AffineC13 | + qG2AffineC14 | + qG2AffineC15 | + qG2AffineC16 +} + +// batch size 80 when c = 10 +type cG2AffineC10 [80]fptower.E4 +type pG2AffineC10 [80]G2Affine +type ppG2AffineC10 [80]*G2Affine +type qG2AffineC10 [80]batchOpG2Affine + +// batch size 150 when c = 11 +type cG2AffineC11 [150]fptower.E4 +type pG2AffineC11 [150]G2Affine +type ppG2AffineC11 [150]*G2Affine +type qG2AffineC11 [150]batchOpG2Affine + +// batch size 200 when c = 12 +type cG2AffineC12 [200]fptower.E4 +type pG2AffineC12 [200]G2Affine +type ppG2AffineC12 [200]*G2Affine +type qG2AffineC12 [200]batchOpG2Affine + +// batch size 350 when c = 13 +type cG2AffineC13 [350]fptower.E4 +type pG2AffineC13 [350]G2Affine +type ppG2AffineC13 [350]*G2Affine +type qG2AffineC13 [350]batchOpG2Affine + +// batch size 400 when c = 14 +type cG2AffineC14 [400]fptower.E4 +type pG2AffineC14 [400]G2Affine +type ppG2AffineC14 [400]*G2Affine +type qG2AffineC14 [400]batchOpG2Affine + +// batch size 500 when c = 15 +type cG2AffineC15 [500]fptower.E4 +type pG2AffineC15 [500]G2Affine +type ppG2AffineC15 [500]*G2Affine +type qG2AffineC15 [500]batchOpG2Affine + +// batch size 640 when c = 16 +type cG2AffineC16 [640]fptower.E4 +type pG2AffineC16 [640]G2Affine +type ppG2AffineC16 [640]*G2Affine +type qG2AffineC16 [640]batchOpG2Affine + +type bitSetC3 [4]bool +type bitSetC4 [8]bool +type bitSetC5 [16]bool +type bitSetC6 [32]bool +type bitSetC7 [64]bool +type bitSetC8 [128]bool +type bitSetC9 [256]bool +type bitSetC10 [512]bool +type bitSetC11 [1024]bool +type bitSetC12 [2048]bool +type bitSetC13 [4096]bool +type bitSetC14 [8192]bool +type bitSetC15 [16384]bool +type bitSetC16 [32768]bool + +type bitSet interface { + bitSetC3 | + bitSetC4 | + bitSetC5 | + bitSetC6 | + bitSetC7 | + bitSetC8 | + bitSetC9 | + bitSetC10 | + bitSetC11 | + bitSetC12 | + bitSetC13 | + bitSetC14 | + bitSetC15 | + bitSetC16 +} diff --git a/ecc/bls24-317/multiexp_jacobian.go b/ecc/bls24-317/multiexp_jacobian.go new file mode 100644 index 000000000..15fbf46f0 --- /dev/null +++ b/ecc/bls24-317/multiexp_jacobian.go @@ -0,0 +1,171 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package bls24317 + +func processChunkG1Jacobian[B ibg1JacExtended](chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16) { + + var buckets B + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to substract + if digit&1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit >> 1)].subMixed(&points[i]) + } + } + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total g1JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].ZZ.IsZero() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketg1JacExtendedC3 [4]g1JacExtended +type bucketg1JacExtendedC4 [8]g1JacExtended +type bucketg1JacExtendedC5 [16]g1JacExtended +type bucketg1JacExtendedC6 [32]g1JacExtended +type bucketg1JacExtendedC7 [64]g1JacExtended +type bucketg1JacExtendedC8 [128]g1JacExtended +type bucketg1JacExtendedC9 [256]g1JacExtended +type bucketg1JacExtendedC10 [512]g1JacExtended +type bucketg1JacExtendedC11 [1024]g1JacExtended +type bucketg1JacExtendedC12 [2048]g1JacExtended +type bucketg1JacExtendedC13 [4096]g1JacExtended +type bucketg1JacExtendedC14 [8192]g1JacExtended +type bucketg1JacExtendedC15 [16384]g1JacExtended +type bucketg1JacExtendedC16 [32768]g1JacExtended + +type ibg1JacExtended interface { + bucketg1JacExtendedC3 | + bucketg1JacExtendedC4 | + bucketg1JacExtendedC5 | + bucketg1JacExtendedC6 | + bucketg1JacExtendedC7 | + bucketg1JacExtendedC8 | + bucketg1JacExtendedC9 | + bucketg1JacExtendedC10 | + bucketg1JacExtendedC11 | + bucketg1JacExtendedC12 | + bucketg1JacExtendedC13 | + bucketg1JacExtendedC14 | + bucketg1JacExtendedC15 | + bucketg1JacExtendedC16 +} + +func processChunkG2Jacobian[B ibg2JacExtended](chunk uint64, + chRes chan<- g2JacExtended, + c uint64, + points []G2Affine, + digits []uint16) { + + var buckets B + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to substract + if digit&1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit >> 1)].subMixed(&points[i]) + } + } + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total g2JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].ZZ.IsZero() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketg2JacExtendedC3 [4]g2JacExtended +type bucketg2JacExtendedC4 [8]g2JacExtended +type bucketg2JacExtendedC5 [16]g2JacExtended +type bucketg2JacExtendedC6 [32]g2JacExtended +type bucketg2JacExtendedC7 [64]g2JacExtended +type bucketg2JacExtendedC8 [128]g2JacExtended +type bucketg2JacExtendedC9 [256]g2JacExtended +type bucketg2JacExtendedC10 [512]g2JacExtended +type bucketg2JacExtendedC11 [1024]g2JacExtended +type bucketg2JacExtendedC12 [2048]g2JacExtended +type bucketg2JacExtendedC13 [4096]g2JacExtended +type bucketg2JacExtendedC14 [8192]g2JacExtended +type bucketg2JacExtendedC15 [16384]g2JacExtended +type bucketg2JacExtendedC16 [32768]g2JacExtended + +type ibg2JacExtended interface { + bucketg2JacExtendedC3 | + bucketg2JacExtendedC4 | + bucketg2JacExtendedC5 | + bucketg2JacExtendedC6 | + bucketg2JacExtendedC7 | + bucketg2JacExtendedC8 | + bucketg2JacExtendedC9 | + bucketg2JacExtendedC10 | + bucketg2JacExtendedC11 | + bucketg2JacExtendedC12 | + bucketg2JacExtendedC13 | + bucketg2JacExtendedC14 | + bucketg2JacExtendedC15 | + bucketg2JacExtendedC16 +} diff --git a/ecc/bls24-317/multiexp_test.go b/ecc/bls24-317/multiexp_test.go index d5674abae..326cdaa45 100644 --- a/ecc/bls24-317/multiexp_test.go +++ b/ecc/bls24-317/multiexp_test.go @@ -20,9 +20,11 @@ import ( "fmt" "math/big" "math/bits" + "math/rand" "runtime" "sync" "testing" + "time" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" @@ -34,9 +36,9 @@ func TestMultiExpG1(t *testing.T) { parameters := gopter.DefaultTestParameters() if testing.Short() { - parameters.MinSuccessfulTests = 2 + parameters.MinSuccessfulTests = 3 } else { - parameters.MinSuccessfulTests = nbFuzzShort + parameters.MinSuccessfulTests = nbFuzzShort * 2 } properties := gopter.NewProperties(parameters) @@ -55,6 +57,13 @@ func TestMultiExpG1(t *testing.T) { g.AddAssign(&g1Gen) } + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + // final scalar to use in double and add method (without mixer factor) // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) var scalar big.Int @@ -78,13 +87,10 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - scalars16, _ := partitionScalars(sampleScalars[:], 16, false, runtime.NumCPU()) - r16.msmC16(samplePoints[:], scalars16, true) - + r16.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) return r16.Equal(&splitted1) && r16.Equal(&splitted2) @@ -93,7 +99,7 @@ func TestMultiExpG1(t *testing.T) { )) // cRange is generated from template and contains the available parameters for the multiexp window size - cRange := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21} + cRange := []uint64{3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} if testing.Short() { // test only "odd" and "even" (ie windows size divide word size vs not) cRange = []uint64{5, 16} @@ -114,21 +120,72 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - results := make([]G1Jac, len(cRange)+1) + results := make([]G1Jac, len(cRange)) for i, c := range cRange { - scalars, _ := partitionScalars(sampleScalars[:], c, false, runtime.NumCPU()) - msmInnerG1Jac(&results[i], int(c), samplePoints[:], scalars, false) - if c == 16 { - // split the first chunk - msmInnerG1Jac(&results[len(results)-1], 16, samplePoints[:], scalars, true) - } + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) } for i := 1; i < len(results); i++ { if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1], cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]G1Affine + + var expected G1Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) + expected.ScalarMultiplication(&g1Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].setInfinity() + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) return false } } @@ -151,8 +208,7 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= 30; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) samplePoints[i-1].FromJacobian(&g) g.AddAssign(&g1Gen) } @@ -175,6 +231,87 @@ func TestMultiExpG1(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestCrossMultiExpG1(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]G1Affine + var g G1Jac + g.Set(&g1Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g1Gen) + } + + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i := 10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + cRange := []uint64{3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + if testing.Short() { + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange = []uint64{5, 16} + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r G1Jac + _innerMsmG1Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got G1Affine + expected.FromJacobian(&r) + + for i := 0; i < len(results); i++ { + got.FromJacobian(&results[i]) + if !expected.Equal(&got) { + t.Fatalf("cross msm failed with c=%d", cRange[i]) + } + } + +} + +// _innerMsmG1Reference always do ext jacobian with c == 16 +func _innerMsmG1Reference(p *G1Jac, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { + // partition the scalars + digits, _ := partitionScalars(scalars, 16, config.NbTasks) + + nbChunks := computeNbChunks(16) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := processChunkG1Jacobian[bucketg1JacExtendedC16] + go processChunk(uint64(j), chChunks[j], 16, points, digits[j*n:(j+1)*n]) + } + + return msmReduceChunkG1Affine(p, int(16), chChunks[:]) +} + func BenchmarkMultiExpG1(b *testing.B) { const ( @@ -183,11 +320,33 @@ func BenchmarkMultiExpG1(b *testing.B) { ) var ( - samplePoints [nbSamples]G1Affine - sampleScalars [nbSamples]fr.Element + samplePoints [nbSamples]G1Affine + sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element ) fillBenchScalars(sampleScalars[:]) + copy(sampleScalarsSmallValues[:], sampleScalars[:]) + copy(sampleScalarsRedundant[:], sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i := 0; i < len(sampleScalarsSmallValues); i++ { + if i%5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i := 0; i < len(sampleScalarsRedundant); i += 100 { + for j := i + 1; j < i+100 && j < len(sampleScalarsRedundant); j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + fillBenchBasesG1(samplePoints[:]) var testPoint G1Affine @@ -201,6 +360,20 @@ func BenchmarkMultiExpG1(b *testing.B) { testPoint.MultiExp(samplePoints[:using], sampleScalars[:using], ecc.MultiExpConfig{}) } }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using], ecc.MultiExpConfig{}) + } + }) } } @@ -279,9 +452,9 @@ func TestMultiExpG2(t *testing.T) { parameters := gopter.DefaultTestParameters() if testing.Short() { - parameters.MinSuccessfulTests = 2 + parameters.MinSuccessfulTests = 3 } else { - parameters.MinSuccessfulTests = nbFuzzShort + parameters.MinSuccessfulTests = nbFuzzShort * 2 } properties := gopter.NewProperties(parameters) @@ -300,6 +473,13 @@ func TestMultiExpG2(t *testing.T) { g.AddAssign(&g2Gen) } + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + // final scalar to use in double and add method (without mixer factor) // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) var scalar big.Int @@ -323,13 +503,10 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - scalars16, _ := partitionScalars(sampleScalars[:], 16, false, runtime.NumCPU()) - r16.msmC16(samplePoints[:], scalars16, true) - + r16.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) return r16.Equal(&splitted1) && r16.Equal(&splitted2) @@ -357,21 +534,72 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - results := make([]G2Jac, len(cRange)+1) + results := make([]G2Jac, len(cRange)) for i, c := range cRange { - scalars, _ := partitionScalars(sampleScalars[:], c, false, runtime.NumCPU()) - msmInnerG2Jac(&results[i], int(c), samplePoints[:], scalars, false) - if c == 16 { - // split the first chunk - msmInnerG2Jac(&results[len(results)-1], 16, samplePoints[:], scalars, true) - } + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) } for i := 1; i < len(results); i++ { if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1], cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G2] Multi exponentation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]G2Affine + + var expected G2Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) + expected.ScalarMultiplication(&g2Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].setInfinity() + } + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G2] Multi exponentation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) return false } } @@ -394,8 +622,7 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= 30; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) samplePoints[i-1].FromJacobian(&g) g.AddAssign(&g2Gen) } @@ -418,6 +645,85 @@ func TestMultiExpG2(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestCrossMultiExpG2(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]G2Affine + var g G2Jac + g.Set(&g2Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g2Gen) + } + + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i := 10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + // for g2, CI suffers with large c size since it needs to allocate a lot of memory for the buckets. + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange := []uint64{5, 16} + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r G2Jac + _innerMsmG2Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got G2Affine + expected.FromJacobian(&r) + + for i := 0; i < len(results); i++ { + got.FromJacobian(&results[i]) + if !expected.Equal(&got) { + t.Fatalf("cross msm failed with c=%d", cRange[i]) + } + } + +} + +// _innerMsmG2Reference always do ext jacobian with c == 16 +func _innerMsmG2Reference(p *G2Jac, points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G2Jac { + // partition the scalars + digits, _ := partitionScalars(scalars, 16, config.NbTasks) + + nbChunks := computeNbChunks(16) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g2JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g2JacExtended, 1) + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := processChunkG2Jacobian[bucketg2JacExtendedC16] + go processChunk(uint64(j), chChunks[j], 16, points, digits[j*n:(j+1)*n]) + } + + return msmReduceChunkG2Affine(p, int(16), chChunks[:]) +} + func BenchmarkMultiExpG2(b *testing.B) { const ( @@ -426,11 +732,33 @@ func BenchmarkMultiExpG2(b *testing.B) { ) var ( - samplePoints [nbSamples]G2Affine - sampleScalars [nbSamples]fr.Element + samplePoints [nbSamples]G2Affine + sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element ) fillBenchScalars(sampleScalars[:]) + copy(sampleScalarsSmallValues[:], sampleScalars[:]) + copy(sampleScalarsRedundant[:], sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i := 0; i < len(sampleScalarsSmallValues); i++ { + if i%5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i := 0; i < len(sampleScalarsRedundant); i += 100 { + for j := i + 1; j < i+100 && j < len(sampleScalarsRedundant); j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + fillBenchBasesG2(samplePoints[:]) var testPoint G2Affine @@ -444,6 +772,20 @@ func BenchmarkMultiExpG2(b *testing.B) { testPoint.MultiExp(samplePoints[:using], sampleScalars[:using], ecc.MultiExpConfig{}) } }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using], ecc.MultiExpConfig{}) + } + }) } } @@ -520,11 +862,7 @@ func fillBenchBasesG2(samplePoints []G2Affine) { func fillBenchScalars(sampleScalars []fr.Element) { // ensure every words of the scalars are filled - var mixer fr.Element - mixer.SetString("7716837800905789770901243404444209691916730933998574719964609384059111546487") - for i := 1; i <= len(sampleScalars); i++ { - sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + for i := 0; i < len(sampleScalars); i++ { + sampleScalars[i].SetRandom() } } diff --git a/ecc/bls24-317/twistededwards/eddsa/doc.go b/ecc/bls24-317/twistededwards/eddsa/doc.go index 6aa59bf10..a7ab6fcd7 100644 --- a/ecc/bls24-317/twistededwards/eddsa/doc.go +++ b/ecc/bls24-317/twistededwards/eddsa/doc.go @@ -16,7 +16,7 @@ // Package eddsa provides EdDSA signature scheme on bls24-317's twisted edwards curve. // -// See also +// # See also // // https://en.wikipedia.org/wiki/EdDSA package eddsa diff --git a/ecc/bls24-317/twistededwards/eddsa/eddsa_test.go b/ecc/bls24-317/twistededwards/eddsa/eddsa_test.go index 776d3894a..34464874c 100644 --- a/ecc/bls24-317/twistededwards/eddsa/eddsa_test.go +++ b/ecc/bls24-317/twistededwards/eddsa/eddsa_test.go @@ -37,8 +37,10 @@ func Example() { privateKey, _ := GenerateKey(crand.Reader) publicKey := privateKey.PublicKey - // note that the message is on 4 bytes - msg := []byte{0xde, 0xad, 0xf0, 0x0d} + // generate a message (the size must be a multiple of the size of Fr) + var _msg fr.Element + _msg.SetRandom() + msg := _msg.Marshal() // sign the message signature, _ := privateKey.Sign(msg, hFunc) diff --git a/ecc/bls24-317/twistededwards/eddsa/marshal.go b/ecc/bls24-317/twistededwards/eddsa/marshal.go index c68129087..485e9b710 100644 --- a/ecc/bls24-317/twistededwards/eddsa/marshal.go +++ b/ecc/bls24-317/twistededwards/eddsa/marshal.go @@ -94,11 +94,11 @@ func (privKey *PrivateKey) SetBytes(buf []byte) (int, error) { // Bytes returns the binary representation of sig // as a byte array of size 3*sizeFr x||y||s where -// * x, y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x, y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) func (sig *Signature) Bytes() []byte { var res [sizeSignature]byte sigRBin := sig.R.Bytes() @@ -109,11 +109,12 @@ func (sig *Signature) Bytes() []byte { // SetBytes sets sig from a buffer in binary. // buf is read interpreted as x||y||s where -// * x,y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x,y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) +// // It returns the number of bytes read from buf. func (sig *Signature) SetBytes(buf []byte) (int, error) { n := 0 diff --git a/ecc/bls24-317/twistededwards/point.go b/ecc/bls24-317/twistededwards/point.go index c2e7d2c58..0193e627b 100644 --- a/ecc/bls24-317/twistededwards/point.go +++ b/ecc/bls24-317/twistededwards/point.go @@ -49,7 +49,7 @@ const ( mUnmask = 0x7f // size in byte of a compressed point (point.Y --> fr.Element) - sizePointCompressed = fr.Limbs * 8 + sizePointCompressed = fr.Bytes ) // Bytes returns the compressed point as a byte array diff --git a/ecc/bn254/bn254.go b/ecc/bn254/bn254.go index e52e97be4..976384317 100644 --- a/ecc/bn254/bn254.go +++ b/ecc/bn254/bn254.go @@ -16,23 +16,28 @@ // Ethereum pre-compiles as altbn128. // // bn254: A Barreto--Naerig curve with -// seed x₀=4965661367192848881 -// 𝔽r: r=21888242871839275222246405745257275088548364400416034343698204186575808495617 (36x₀⁴+36x₀³+18x₀²+6x₀+1) -// 𝔽p: p=21888242871839275222246405745257275088696311157297823662689037894645226208583 (36x₀⁴+36x₀³+24x₀²+6x₀+1) -// (E/𝔽p): Y²=X³+3 -// (Eₜ/𝔽p²): Y² = X³+3/(u+9) (D-type twist) -// r ∣ #E(Fp) and r ∣ #Eₜ(𝔽p²) +// +// seed x₀=4965661367192848881 +// 𝔽r: r=21888242871839275222246405745257275088548364400416034343698204186575808495617 (36x₀⁴+36x₀³+18x₀²+6x₀+1) +// 𝔽p: p=21888242871839275222246405745257275088696311157297823662689037894645226208583 (36x₀⁴+36x₀³+24x₀²+6x₀+1) +// (E/𝔽p): Y²=X³+3 +// (Eₜ/𝔽p²): Y² = X³+3/(u+9) (D-type twist) +// r ∣ #E(Fp) and r ∣ #Eₜ(𝔽p²) +// // Extension fields tower: -// 𝔽p²[u] = 𝔽p/u²+1 -// 𝔽p⁶[v] = 𝔽p²/v³-9-u -// 𝔽p¹²[w] = 𝔽p⁶/w²-v +// +// 𝔽p²[u] = 𝔽p/u²+1 +// 𝔽p⁶[v] = 𝔽p²/v³-9-u +// 𝔽p¹²[w] = 𝔽p⁶/w²-v +// // optimal Ate loop size: -// 6x₀+2 +// +// 6x₀+2 // // Security: estimated 103-bit level following [https://eprint.iacr.org/2019/885.pdf] // (r is 254 bits and p¹² is 3044 bits) // -// Warning +// # Warning // // This code has been partially audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package bn254 @@ -94,9 +99,6 @@ var endo struct { // seed x₀ of the curve var xGen big.Int -// trace - 1 = 6x₀² -var fixedCoeff big.Int - func init() { bCurveCoeff.SetUint64(3) @@ -142,9 +144,6 @@ func init() { xGen.SetString("4965661367192848881", 10) - // 6x₀² - fixedCoeff.SetString("147946756881789318990833708069417712966", 10) - } // Generators return the generators of the r-torsion group, resp. in ker(pi-id), ker(Tr) diff --git a/ecc/bn254/fp/doc.go b/ecc/bn254/fp/doc.go index ef6facc39..ac24d2e0d 100644 --- a/ecc/bn254/fp/doc.go +++ b/ecc/bn254/fp/doc.go @@ -21,30 +21,33 @@ // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [4]uint64 // -// Usage +// type Element [4]uint64 +// +// # Usage // // Example API signature: -// // Mul z = x * y (mod q) -// func (z *Element) Mul(x, y *Element) *Element +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) // // Modulus q = // -// q[base10] = 21888242871839275222246405745257275088696311157297823662689037894645226208583 -// q[base16] = 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47 +// q[base10] = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +// q[base16] = 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package fp diff --git a/ecc/bn254/fp/element.go b/ecc/bn254/fp/element.go index c60263d4a..0ca5dc476 100644 --- a/ecc/bn254/fp/element.go +++ b/ecc/bn254/fp/element.go @@ -20,13 +20,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "github.com/consensys/gnark-crypto/field" "io" "math/big" "math/bits" "reflect" "strconv" "strings" - "sync" ) // Element represents a field element stored on 4 words (uint64) @@ -35,18 +35,18 @@ import ( // // Modulus q = // -// q[base10] = 21888242871839275222246405745257275088696311157297823662689037894645226208583 -// q[base16] = 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47 +// q[base10] = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +// q[base16] = 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [4]uint64 const ( - Limbs = 4 // number of 64 bits words needed to represent a Element - Bits = 254 // number of bits needed to represent a Element - Bytes = Limbs * 8 // number of bytes needed to represent a Element + Limbs = 4 // number of 64 bits words needed to represent a Element + Bits = 254 // number of bits needed to represent a Element + Bytes = 32 // number of bytes needed to represent a Element ) // Field modulus q @@ -68,8 +68,8 @@ var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int // -// q[base10] = 21888242871839275222246405745257275088696311157297823662689037894645226208583 -// q[base16] = 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47 +// q[base10] = 21888242871839275222246405745257275088696311157297823662689037894645226208583 +// q[base16] = 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47 func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } @@ -78,12 +78,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 9786893198990664585 -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47", 16) } @@ -91,8 +85,9 @@ func init() { // NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v Element -// v.SetUint64(...) +// +// var v Element +// v.SetUint64(...) func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) @@ -103,7 +98,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -133,14 +128,15 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported // supported types: -// Element -// *Element -// uint64 -// int -// string (see SetString for valid formats) -// *big.Int -// big.Int -// []byte +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { if i1 == nil { return nil, errors.New("can't set fp.Element with ") @@ -250,15 +246,13 @@ func (z *Element) IsOne() bool { // IsUint64 reports whether z can be represented as an uint64. func (z *Element) IsUint64() bool { zz := *z - zz.FromMont() + zz.fromMont() return zz.FitsOnOneWord() } // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -270,15 +264,12 @@ func (z *Element) FitsOnOneWord() bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[3] > _x[3] { return 1 } else if _z[3] < _x[3] { @@ -309,8 +300,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 11389680472494603940, 0) @@ -401,67 +391,9 @@ func (z *Element) Halve() { } -// Mul z = x * y (mod q) -// -// x and y must be strictly inferior to q -func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be strictly inferior to q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } @@ -475,7 +407,7 @@ func (z *Element) Add(x, y *Element) *Element { z[2], carry = bits.Add64(x[2], y[2], carry) z[3], _ = bits.Add64(x[3], y[3], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -495,7 +427,7 @@ func (z *Element) Double(x *Element) *Element { z[2], carry = bits.Add64(x[2], x[2], carry) z[3], _ = bits.Add64(x[3], x[3], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -548,65 +480,147 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // see Mul for algorithm documentation - var t [4]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [5]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + + if t[4] != 0 { + // we need to reduce, we have a result on 5 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], _ = bits.Sub64(t[3], q3, b) + return } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - z[3], z[2] = madd3(m, q3, c[0], c[2], c[1]) - } - - // if z >= q → z -= q + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -614,7 +628,6 @@ func _mulGeneric(z, x, y *Element) { z[2], b = bits.Sub64(z[2], q2, b) z[3], _ = bits.Sub64(z[3], q3, b) } - } func _fromMontGeneric(z *Element) { @@ -658,7 +671,7 @@ func _fromMontGeneric(z *Element) { z[3] = C } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -670,7 +683,7 @@ func _fromMontGeneric(z *Element) { func _reduceGeneric(z *Element) { - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -734,6 +747,35 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + // Exp z = xᵏ (mod q) func (z *Element) Exp(x Element, k *big.Int) *Element { if k.IsUint64() && k.Uint64() == 0 { @@ -748,8 +790,8 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } @@ -775,23 +817,29 @@ var rSquare = Element{ 493319470278259999, } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *Element) ToMont() *Element { +func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *Element) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[24:32], z[0]) + binary.BigEndian.PutUint64(b[16:24], z[1]) + binary.BigEndian.PutUint64(b[8:16], z[2]) + binary.BigEndian.PutUint64(b[0:8], z[3]) + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -810,47 +858,49 @@ func (z *Element) Text(base int) string { if base == 10 { var zzNeg Element zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } zz := *z - zz.FromMont() + zz.fromMont() if zz.FitsOnOneWord() { return strconv.FormatUint(zz[0], base) } - vv := bigIntPool.Get().(*big.Int) - r := zz.ToBigInt(vv).Text(base) - bigIntPool.Put(vv) + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) return r } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[24:32], z[0]) - binary.BigEndian.PutUint64(b[16:24], z[1]) - binary.BigEndian.PutUint64(b[8:16], z[2]) - binary.BigEndian.PutUint64(b[0:8], z[3]) - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[24:32], _z[0]) - binary.BigEndian.PutUint64(res[16:24], _z[1]) - binary.BigEndian.PutUint64(res[8:16], _z[2]) - binary.BigEndian.PutUint64(res[0:8], _z[3]) +// Bits provides access to z by returning its value as a little-endian [4]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [4]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -863,19 +913,44 @@ func (z *Element) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } +// SetBytesCanonical interprets e as the bytes of a big-endian 32-byte integer. +// If e is not a 32-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fp.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -893,17 +968,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -925,20 +999,20 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z // // The number prefix determines the actual base: A prefix of -// ''0b'' or ''0B'' selects base 2, ''0'', ''0o'' or ''0O'' selects base 8, -// and ''0x'' or ''0X'' selects base 16. Otherwise, the selected base is 10 +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For base 16, lower and upper case letters are considered the same: // The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. // -// An underscore character ''_'' may appear between a base +// An underscore character ”_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as a panic if there @@ -947,7 +1021,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -956,7 +1030,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -996,7 +1070,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -1005,10 +1079,79 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 32-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[24:32]) + z[1] = binary.BigEndian.Uint64((*b)[16:24]) + z[2] = binary.BigEndian.Uint64((*b)[8:16]) + z[3] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[24:32], e[0]) + binary.BigEndian.PutUint64((*b)[16:24], e[1]) + binary.BigEndian.PutUint64((*b)[8:16], e[2]) + binary.BigEndian.PutUint64((*b)[0:8], e[3]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) +} + +func (littleEndian) String() string { return "LittleEndian" } + // Legendre returns the Legendre symbol of z (either +1, -1, or 0.) func (z *Element) Legendre() int { var l Element @@ -1192,7 +1335,7 @@ func (z *Element) Inverse(x *Element) *Element { // we would multiply by pSq up to 13times; // on x86, the assembly routine outperforms generic code for mul by word // on arm64, we may loose up to ~5% for 6 limbs - mul(&v, &v, &a) + v.Mul(&v, &a) } u.Set(x) // for correctness check @@ -1207,17 +1350,28 @@ func (z *Element) Inverse(x *Element) *Element { // correctness check v.Mul(&u, z) if !v.IsOne() && !u.IsZero() { - return z.inverseExp(&u) + return z.inverseExp(u) } return z } // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) -func (z *Element) inverseExp(x *Element) *Element { - qMinusTwo := Modulus() - qMinusTwo.Sub(qMinusTwo, big.NewInt(2)) - return z.Exp(*x, qMinusTwo) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z } // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits @@ -1316,7 +1470,7 @@ func (z *Element) montReduceSigned(x *Element, xHi uint64) { z[3], z[2] = madd2(m, q3, t[i+3], C) } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) diff --git a/ecc/bn254/fp/element_mul_adx_amd64.s b/ecc/bn254/fp/element_mul_adx_amd64.s deleted file mode 100644 index cb49047bf..000000000 --- a/ecc/bn254/fp/element_mul_adx_amd64.s +++ /dev/null @@ -1,465 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x3c208c16d87cfd47 -DATA q<>+8(SB)/8, $0x97816a916871ca8d -DATA q<>+16(SB)/8, $0xb85045b68181585d -DATA q<>+24(SB)/8, $0x30644e72e131a029 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x87d20782e4866389 -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), NOSPLIT, $0-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -TEXT ·fromMont(SB), NOSPLIT, $0-8 - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET diff --git a/ecc/bn254/fp/element_mul_amd64.s b/ecc/bn254/fp/element_mul_amd64.s index dd7022a95..e58b31681 100644 --- a/ecc/bn254/fp/element_mul_amd64.s +++ b/ecc/bn254/fp/element_mul_amd64.s @@ -1,4 +1,4 @@ -// +build !amd64_adx +// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bn254/fp/element_ops_amd64.go b/ecc/bn254/fp/element_ops_amd64.go index a3c830471..83bba45ae 100644 --- a/ecc/bn254/fp/element_ops_amd64.go +++ b/ecc/bn254/fp/element_ops_amd64.go @@ -1,3 +1,6 @@ +//go:build !purego +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,7 +38,70 @@ func fromMont(res *Element) func reduce(res *Element) // Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) +// +// a = a + b (mod q) +// b = a - b (mod q) +// //go:noescape func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bn254/fp/element_ops_amd64.s b/ecc/bn254/fp/element_ops_amd64.s index e726ce410..48f34db8f 100644 --- a/ecc/bn254/fp/element_ops_amd64.s +++ b/ecc/bn254/fp/element_ops_amd64.s @@ -1,3 +1,5 @@ +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bn254/fp/element_ops_noasm.go b/ecc/bn254/fp/element_ops_noasm.go deleted file mode 100644 index c3ed795a9..000000000 --- a/ecc/bn254/fp/element_ops_noasm.go +++ /dev/null @@ -1,61 +0,0 @@ -//go:build !amd64 -// +build !amd64 - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fp - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 529957932336199972, - 13952065197595570812, - 769406925088786211, - 2691790815622165739, - } - x.Mul(x, &y) -} - -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} diff --git a/ecc/bn254/fp/element_ops_purego.go b/ecc/bn254/fp/element_ops_purego.go new file mode 100644 index 000000000..93aca54dd --- /dev/null +++ b/ecc/bn254/fp/element_ops_purego.go @@ -0,0 +1,443 @@ +//go:build !amd64 || purego +// +build !amd64 purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 529957932336199972, + 13952065197595570812, + 769406925088786211, + 2691790815622165739, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} diff --git a/ecc/bn254/fp/element_test.go b/ecc/bn254/fp/element_test.go index 3e0ca8b50..feb4a5bf1 100644 --- a/ecc/bn254/fp/element_test.go +++ b/ecc/bn254/fp/element_test.go @@ -23,7 +23,7 @@ import ( "math/big" "math/bits" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field" mrand "math/rand" "testing" @@ -182,17 +182,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -638,7 +631,7 @@ func TestElementBitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPairElement) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -753,7 +746,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -766,13 +759,13 @@ func TestElementAdd(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -801,17 +794,17 @@ func TestElementAdd(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -862,7 +855,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -875,13 +868,13 @@ func TestElementSub(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -910,17 +903,17 @@ func TestElementSub(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -971,7 +964,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -984,7 +977,7 @@ func TestElementMul(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -998,7 +991,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1038,11 +1031,11 @@ func TestElementMul(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -1055,7 +1048,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1107,7 +1100,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1120,14 +1113,14 @@ func TestElementDiv(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1156,18 +1149,18 @@ func TestElementDiv(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1218,7 +1211,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1231,13 +1224,13 @@ func TestElementExp(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1266,17 +1259,17 @@ func TestElementExp(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1321,7 +1314,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1342,14 +1335,14 @@ func TestElementSquare(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1393,7 +1386,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1414,14 +1407,14 @@ func TestElementInverse(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1465,7 +1458,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1486,14 +1479,14 @@ func TestElementSqrt(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1537,7 +1530,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1558,14 +1551,14 @@ func TestElementDouble(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1609,7 +1602,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1630,14 +1623,14 @@ func TestElementNeg(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -2025,7 +2018,7 @@ func TestElementNegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -2158,17 +2151,17 @@ func TestElementFromMont(t *testing.T) { func(a testPairElement) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPairElement) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -2247,7 +2240,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } @@ -2309,7 +2302,7 @@ func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { return field.BigIntMatchUint64Slice(&aIntMod, slice) } -//TODO: Phase out in favor of property based testing +// TODO: Phase out in favor of property based testing func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { if err := z.matchVeryBigInt(aHi, aInt); err != nil { @@ -2351,7 +2344,7 @@ func TestElementInversionCorrectionFactorFormula(t *testing.T) { inversionCorrectionFactorWord2, inversionCorrectionFactorWord3, } - inversionCorrectionFactor.ToBigInt(&refFactorInt) + inversionCorrectionFactor.toBigInt(&refFactorInt) if refFactorInt.Cmp(factorInt) != 0 { t.Error("mismatch") @@ -2393,7 +2386,7 @@ func TestElementInversionCorrectionFactor(t *testing.T) { if !oneInv.Equal(&one) { var i big.Int - oneInv.ToBigIntRegular(&i) // no montgomery + oneInv.BigInt(&i) // no montgomery i.ModInverse(&i, Modulus()) var fac Element fac.setBigInt(&i) // back to montgomery @@ -2538,7 +2531,7 @@ func TestElement0Inverse(t *testing.T) { } } -//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen func TestUpdateFactorSubtraction(t *testing.T) { for i := 0; i < 1000; i++ { @@ -2751,11 +2744,11 @@ func randomizeUpdateFactors() (int64, int64) { func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { var p1 big.Int - x.ToBigInt(&p1) + x.toBigInt(&p1) p1.Mul(&p1, big.NewInt(xC)) var p2 big.Int - y.ToBigInt(&p2) + y.toBigInt(&p2) p2.Mul(&p2, big.NewInt(yC)) p1.Add(&p1, &p2) @@ -2792,7 +2785,7 @@ func montReduce(res *big.Int, x *big.Int) { } func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { - z.ToBigInt(i) + z.toBigInt(i) var upperWord big.Int upperWord.SetUint64(xHi) upperWord.Lsh(&upperWord, Limbs*64) @@ -2810,7 +2803,7 @@ func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { var xInt big.Int - x.ToBigInt(&xInt) + x.toBigInt(&xInt) xInt.Mul(&xInt, big.NewInt(c)) @@ -2821,7 +2814,7 @@ func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, result func approximateRef(x *Element) uint64 { var asInt big.Int - x.ToBigInt(&asInt) + x.toBigInt(&asInt) n := x.BitLen() if n <= 64 { diff --git a/ecc/bn254/fr/doc.go b/ecc/bn254/fr/doc.go index d040bb1a7..35388f880 100644 --- a/ecc/bn254/fr/doc.go +++ b/ecc/bn254/fr/doc.go @@ -21,30 +21,33 @@ // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [4]uint64 // -// Usage +// type Element [4]uint64 +// +// # Usage // // Example API signature: -// // Mul z = x * y (mod q) -// func (z *Element) Mul(x, y *Element) *Element +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) // // Modulus q = // -// q[base10] = 21888242871839275222246405745257275088548364400416034343698204186575808495617 -// q[base16] = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +// q[base10] = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +// q[base16] = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package fr diff --git a/ecc/bn254/fr/element.go b/ecc/bn254/fr/element.go index 21cd36352..00b0499c7 100644 --- a/ecc/bn254/fr/element.go +++ b/ecc/bn254/fr/element.go @@ -20,13 +20,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "github.com/consensys/gnark-crypto/field" "io" "math/big" "math/bits" "reflect" "strconv" "strings" - "sync" ) // Element represents a field element stored on 4 words (uint64) @@ -35,18 +35,18 @@ import ( // // Modulus q = // -// q[base10] = 21888242871839275222246405745257275088548364400416034343698204186575808495617 -// q[base16] = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +// q[base10] = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +// q[base16] = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [4]uint64 const ( - Limbs = 4 // number of 64 bits words needed to represent a Element - Bits = 254 // number of bits needed to represent a Element - Bytes = Limbs * 8 // number of bytes needed to represent a Element + Limbs = 4 // number of 64 bits words needed to represent a Element + Bits = 254 // number of bits needed to represent a Element + Bytes = 32 // number of bytes needed to represent a Element ) // Field modulus q @@ -68,8 +68,8 @@ var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int // -// q[base10] = 21888242871839275222246405745257275088548364400416034343698204186575808495617 -// q[base16] = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 +// q[base10] = 21888242871839275222246405745257275088548364400416034343698204186575808495617 +// q[base16] = 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001 func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } @@ -78,12 +78,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 14042775128853446655 -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", 16) } @@ -91,8 +85,9 @@ func init() { // NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v Element -// v.SetUint64(...) +// +// var v Element +// v.SetUint64(...) func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) @@ -103,7 +98,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -133,14 +128,15 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported // supported types: -// Element -// *Element -// uint64 -// int -// string (see SetString for valid formats) -// *big.Int -// big.Int -// []byte +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { if i1 == nil { return nil, errors.New("can't set fr.Element with ") @@ -250,15 +246,13 @@ func (z *Element) IsOne() bool { // IsUint64 reports whether z can be represented as an uint64. func (z *Element) IsUint64() bool { zz := *z - zz.FromMont() + zz.fromMont() return zz.FitsOnOneWord() } // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -270,15 +264,12 @@ func (z *Element) FitsOnOneWord() bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[3] > _x[3] { return 1 } else if _z[3] < _x[3] { @@ -309,8 +300,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 11669102379873075201, 0) @@ -401,67 +391,9 @@ func (z *Element) Halve() { } -// Mul z = x * y (mod q) -// -// x and y must be strictly inferior to q -func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be strictly inferior to q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } @@ -475,7 +407,7 @@ func (z *Element) Add(x, y *Element) *Element { z[2], carry = bits.Add64(x[2], y[2], carry) z[3], _ = bits.Add64(x[3], y[3], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -495,7 +427,7 @@ func (z *Element) Double(x *Element) *Element { z[2], carry = bits.Add64(x[2], x[2], carry) z[3], _ = bits.Add64(x[3], x[3], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -548,65 +480,147 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // see Mul for algorithm documentation - var t [4]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - t[3], t[2] = madd3(m, q3, c[0], c[2], c[1]) + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [5]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + + if t[4] != 0 { + // we need to reduce, we have a result on 5 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], _ = bits.Sub64(t[3], q3, b) + return } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - z[3], z[2] = madd3(m, q3, c[0], c[2], c[1]) - } - - // if z >= q → z -= q + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -614,7 +628,6 @@ func _mulGeneric(z, x, y *Element) { z[2], b = bits.Sub64(z[2], q2, b) z[3], _ = bits.Sub64(z[3], q3, b) } - } func _fromMontGeneric(z *Element) { @@ -658,7 +671,7 @@ func _fromMontGeneric(z *Element) { z[3] = C } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -670,7 +683,7 @@ func _fromMontGeneric(z *Element) { func _reduceGeneric(z *Element) { - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -734,6 +747,35 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + // Exp z = xᵏ (mod q) func (z *Element) Exp(x Element, k *big.Int) *Element { if k.IsUint64() && k.Uint64() == 0 { @@ -748,8 +790,8 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } @@ -775,23 +817,29 @@ var rSquare = Element{ 150537098327114917, } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *Element) ToMont() *Element { +func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *Element) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[24:32], z[0]) + binary.BigEndian.PutUint64(b[16:24], z[1]) + binary.BigEndian.PutUint64(b[8:16], z[2]) + binary.BigEndian.PutUint64(b[0:8], z[3]) + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -810,47 +858,49 @@ func (z *Element) Text(base int) string { if base == 10 { var zzNeg Element zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } zz := *z - zz.FromMont() + zz.fromMont() if zz.FitsOnOneWord() { return strconv.FormatUint(zz[0], base) } - vv := bigIntPool.Get().(*big.Int) - r := zz.ToBigInt(vv).Text(base) - bigIntPool.Put(vv) + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) return r } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[24:32], z[0]) - binary.BigEndian.PutUint64(b[16:24], z[1]) - binary.BigEndian.PutUint64(b[8:16], z[2]) - binary.BigEndian.PutUint64(b[0:8], z[3]) - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[24:32], _z[0]) - binary.BigEndian.PutUint64(res[16:24], _z[1]) - binary.BigEndian.PutUint64(res[8:16], _z[2]) - binary.BigEndian.PutUint64(res[0:8], _z[3]) +// Bits provides access to z by returning its value as a little-endian [4]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [4]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -863,19 +913,44 @@ func (z *Element) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } +// SetBytesCanonical interprets e as the bytes of a big-endian 32-byte integer. +// If e is not a 32-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fr.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -893,17 +968,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -925,20 +999,20 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z // // The number prefix determines the actual base: A prefix of -// ''0b'' or ''0B'' selects base 2, ''0'', ''0o'' or ''0O'' selects base 8, -// and ''0x'' or ''0X'' selects base 16. Otherwise, the selected base is 10 +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For base 16, lower and upper case letters are considered the same: // The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. // -// An underscore character ''_'' may appear between a base +// An underscore character ”_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as a panic if there @@ -947,7 +1021,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -956,7 +1030,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -996,7 +1070,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -1005,10 +1079,79 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 32-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[24:32]) + z[1] = binary.BigEndian.Uint64((*b)[16:24]) + z[2] = binary.BigEndian.Uint64((*b)[8:16]) + z[3] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[24:32], e[0]) + binary.BigEndian.PutUint64((*b)[16:24], e[1]) + binary.BigEndian.PutUint64((*b)[8:16], e[2]) + binary.BigEndian.PutUint64((*b)[0:8], e[3]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) +} + +func (littleEndian) String() string { return "LittleEndian" } + // Legendre returns the Legendre symbol of z (either +1, -1, or 0.) func (z *Element) Legendre() int { var l Element @@ -1041,7 +1184,7 @@ func (z *Element) Sqrt(x *Element) *Element { // y = x^((s+1)/2)) = w * x y.Mul(x, &w) - // b = x^s = w * w * x = y * x + // b = xˢ = w * w * x = y * x b.Mul(&w, &y) // g = nonResidue ^ s @@ -1054,7 +1197,7 @@ func (z *Element) Sqrt(x *Element) *Element { r := uint64(28) // compute legendre symbol - // t = x^((q-1)/2) = r-1 squaring of x^s + // t = x^((q-1)/2) = r-1 squaring of xˢ t = b for i := uint64(0); i < r-1; i++ { t.Square(&t) @@ -1244,7 +1387,7 @@ func (z *Element) Inverse(x *Element) *Element { // we would multiply by pSq up to 13times; // on x86, the assembly routine outperforms generic code for mul by word // on arm64, we may loose up to ~5% for 6 limbs - mul(&v, &v, &a) + v.Mul(&v, &a) } u.Set(x) // for correctness check @@ -1259,17 +1402,28 @@ func (z *Element) Inverse(x *Element) *Element { // correctness check v.Mul(&u, z) if !v.IsOne() && !u.IsZero() { - return z.inverseExp(&u) + return z.inverseExp(u) } return z } // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) -func (z *Element) inverseExp(x *Element) *Element { - qMinusTwo := Modulus() - qMinusTwo.Sub(qMinusTwo, big.NewInt(2)) - return z.Exp(*x, qMinusTwo) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z } // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits @@ -1368,7 +1522,7 @@ func (z *Element) montReduceSigned(x *Element, xHi uint64) { z[3], z[2] = madd2(m, q3, t[i+3], C) } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) diff --git a/ecc/bn254/fr/element_mul_adx_amd64.s b/ecc/bn254/fr/element_mul_adx_amd64.s deleted file mode 100644 index b2f972ea4..000000000 --- a/ecc/bn254/fr/element_mul_adx_amd64.s +++ /dev/null @@ -1,465 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x43e1f593f0000001 -DATA q<>+8(SB)/8, $0x2833e84879b97091 -DATA q<>+16(SB)/8, $0xb85045b68181585d -DATA q<>+24(SB)/8, $0x30644e72e131a029 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xc2e1f593efffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), NOSPLIT, $0-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - MOVQ x+8(FP), SI - - // x[0] -> DI - // x[1] -> R8 - // x[2] -> R9 - // x[3] -> R10 - MOVQ 0(SI), DI - MOVQ 8(SI), R8 - MOVQ 16(SI), R9 - MOVQ 24(SI), R10 - MOVQ y+16(FP), R11 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // clear the flags - XORQ AX, AX - MOVQ 0(R11), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ DI, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R8, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R9, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 8(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 16(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // clear the flags - XORQ AX, AX - MOVQ 24(R11), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ DI, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R8, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R9, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ R10, AX, BP - ADOXQ AX, BX - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R12 - ADCXQ R14, AX - MOVQ R12, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // t[3] = C + A - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ BP, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,R12,R11,DI) - REDUCE(R14,R13,CX,BX,SI,R12,R11,DI) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET - -TEXT ·fromMont(SB), NOSPLIT, $0-8 - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - MOVQ $0, AX - ADCXQ AX, BX - ADOXQ AX, BX - - // reduce element(R14,R13,CX,BX) using temp registers (SI,DI,R8,R9) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - RET diff --git a/ecc/bn254/fr/element_mul_amd64.s b/ecc/bn254/fr/element_mul_amd64.s index 9452b5403..b51bc6998 100644 --- a/ecc/bn254/fr/element_mul_amd64.s +++ b/ecc/bn254/fr/element_mul_amd64.s @@ -1,4 +1,4 @@ -// +build !amd64_adx +// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bn254/fr/element_ops_amd64.go b/ecc/bn254/fr/element_ops_amd64.go index b37a251b6..e40a9caed 100644 --- a/ecc/bn254/fr/element_ops_amd64.go +++ b/ecc/bn254/fr/element_ops_amd64.go @@ -1,3 +1,6 @@ +//go:build !purego +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,7 +38,70 @@ func fromMont(res *Element) func reduce(res *Element) // Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) +// +// a = a + b (mod q) +// b = a - b (mod q) +// //go:noescape func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bn254/fr/element_ops_amd64.s b/ecc/bn254/fr/element_ops_amd64.s index 77efe619f..b9d8a5bfa 100644 --- a/ecc/bn254/fr/element_ops_amd64.s +++ b/ecc/bn254/fr/element_ops_amd64.s @@ -1,3 +1,5 @@ +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bn254/fr/element_ops_noasm.go b/ecc/bn254/fr/element_ops_noasm.go deleted file mode 100644 index 13f66e720..000000000 --- a/ecc/bn254/fr/element_ops_noasm.go +++ /dev/null @@ -1,61 +0,0 @@ -//go:build !amd64 -// +build !amd64 - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 17868810749992763324, - 5924006745939515753, - 769406925088786241, - 2691790815622165739, - } - x.Mul(x, &y) -} - -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} diff --git a/ecc/bn254/fr/element_ops_purego.go b/ecc/bn254/fr/element_ops_purego.go new file mode 100644 index 000000000..be0f08583 --- /dev/null +++ b/ecc/bn254/fr/element_ops_purego.go @@ -0,0 +1,443 @@ +//go:build !amd64 || purego +// +build !amd64 purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 17868810749992763324, + 5924006745939515753, + 769406925088786241, + 2691790815622165739, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3 uint64 + var u0, u1, u2, u3 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + c2, _ = bits.Add64(u3, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + + t2, c0 = bits.Add64(0, c1, c0) + u3, _ = bits.Add64(u3, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + c2, _ = bits.Add64(c2, 0, c0) + t2, c0 = bits.Add64(t3, t2, 0) + t3, _ = bits.Add64(u3, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} diff --git a/ecc/bn254/fr/element_test.go b/ecc/bn254/fr/element_test.go index 6b1d61f60..2170bd19c 100644 --- a/ecc/bn254/fr/element_test.go +++ b/ecc/bn254/fr/element_test.go @@ -23,7 +23,7 @@ import ( "math/big" "math/bits" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field" mrand "math/rand" "testing" @@ -182,17 +182,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -638,7 +631,7 @@ func TestElementBitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPairElement) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -753,7 +746,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -766,13 +759,13 @@ func TestElementAdd(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -801,17 +794,17 @@ func TestElementAdd(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -862,7 +855,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -875,13 +868,13 @@ func TestElementSub(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -910,17 +903,17 @@ func TestElementSub(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -971,7 +964,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -984,7 +977,7 @@ func TestElementMul(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -998,7 +991,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1038,11 +1031,11 @@ func TestElementMul(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -1055,7 +1048,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1107,7 +1100,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1120,14 +1113,14 @@ func TestElementDiv(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1156,18 +1149,18 @@ func TestElementDiv(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1218,7 +1211,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1231,13 +1224,13 @@ func TestElementExp(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1266,17 +1259,17 @@ func TestElementExp(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1321,7 +1314,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1342,14 +1335,14 @@ func TestElementSquare(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1393,7 +1386,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1414,14 +1407,14 @@ func TestElementInverse(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1465,7 +1458,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1486,14 +1479,14 @@ func TestElementSqrt(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1537,7 +1530,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1558,14 +1551,14 @@ func TestElementDouble(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1609,7 +1602,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1630,14 +1623,14 @@ func TestElementNeg(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -2025,7 +2018,7 @@ func TestElementNegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -2158,17 +2151,17 @@ func TestElementFromMont(t *testing.T) { func(a testPairElement) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPairElement) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -2247,7 +2240,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } @@ -2309,7 +2302,7 @@ func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { return field.BigIntMatchUint64Slice(&aIntMod, slice) } -//TODO: Phase out in favor of property based testing +// TODO: Phase out in favor of property based testing func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { if err := z.matchVeryBigInt(aHi, aInt); err != nil { @@ -2351,7 +2344,7 @@ func TestElementInversionCorrectionFactorFormula(t *testing.T) { inversionCorrectionFactorWord2, inversionCorrectionFactorWord3, } - inversionCorrectionFactor.ToBigInt(&refFactorInt) + inversionCorrectionFactor.toBigInt(&refFactorInt) if refFactorInt.Cmp(factorInt) != 0 { t.Error("mismatch") @@ -2393,7 +2386,7 @@ func TestElementInversionCorrectionFactor(t *testing.T) { if !oneInv.Equal(&one) { var i big.Int - oneInv.ToBigIntRegular(&i) // no montgomery + oneInv.BigInt(&i) // no montgomery i.ModInverse(&i, Modulus()) var fac Element fac.setBigInt(&i) // back to montgomery @@ -2538,7 +2531,7 @@ func TestElement0Inverse(t *testing.T) { } } -//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen func TestUpdateFactorSubtraction(t *testing.T) { for i := 0; i < 1000; i++ { @@ -2751,11 +2744,11 @@ func randomizeUpdateFactors() (int64, int64) { func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { var p1 big.Int - x.ToBigInt(&p1) + x.toBigInt(&p1) p1.Mul(&p1, big.NewInt(xC)) var p2 big.Int - y.ToBigInt(&p2) + y.toBigInt(&p2) p2.Mul(&p2, big.NewInt(yC)) p1.Add(&p1, &p2) @@ -2792,7 +2785,7 @@ func montReduce(res *big.Int, x *big.Int) { } func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { - z.ToBigInt(i) + z.toBigInt(i) var upperWord big.Int upperWord.SetUint64(xHi) upperWord.Lsh(&upperWord, Limbs*64) @@ -2810,7 +2803,7 @@ func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { var xInt big.Int - x.ToBigInt(&xInt) + x.toBigInt(&xInt) xInt.Mul(&xInt, big.NewInt(c)) @@ -2821,7 +2814,7 @@ func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, result func approximateRef(x *Element) uint64 { var asInt big.Int - x.ToBigInt(&asInt) + x.toBigInt(&asInt) n := x.BitLen() if n <= 64 { diff --git a/ecc/bn254/fr/fri/fri.go b/ecc/bn254/fr/fri/fri.go index 73e9c247a..664f87bba 100644 --- a/ecc/bn254/fr/fri/fri.go +++ b/ecc/bn254/fr/fri/fri.go @@ -121,7 +121,6 @@ type Round struct { // a function is d-close to a low degree polynomial. // // It is composed of a series of Interactions, emulated with Fiat Shamir, -// type ProofOfProximity struct { // ID unique ID attached to the proof of proximity. It's needed for diff --git a/ecc/bn254/fr/gkr/gkr.go b/ecc/bn254/fr/gkr/gkr.go new file mode 100644 index 000000000..4a75a8611 --- /dev/null +++ b/ecc/bn254/fr/gkr/gkr.go @@ -0,0 +1,774 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/sumcheck" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// Gate must be a low-degree polynomial +type Gate interface { + Evaluate(...fr.Element) fr.Element + Degree() int +} + +type Wire struct { + Gate Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]fr.Element + claimedEvaluations []fr.Element + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]fr.Element) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]fr.Element // x in the paper + claimedEvaluations []fr.Element // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + newEq.Eq(c.evaluationPoints[k]) + + eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// computeValAndStep returns val : i ↦ m(1, i...) and step : i ↦ m(1, i...) - m(0, i...) +func computeValAndStep(m polynomial.MultiLin, p *polynomial.Pool) (val polynomial.MultiLin, step polynomial.MultiLin) { + val = p.Clone(m[len(m)/2:]) + step = p.Clone(m[:len(m)/2]) + + valAsPoly, stepAsPoly := polynomial.Polynomial(val), polynomial.Polynomial(step) + + stepAsPoly.Sub(valAsPoly, stepAsPoly) + return +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = g_{j-1}(r_{j-1}). By convention, g_0 is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() (gJ polynomial.Polynomial) { + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + EVal, EStep := computeValAndStep(c.eq, c.manager.memPool) + + puVal := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO: Make a two-dimensional array struct, and index it i-first rather than inputI first: would result in scanning memory access in the "d" loop and obviate the gateInput variable + puStep := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO, ctd: the greater degGJ, the more this would matter + + for i, puI := range c.inputPreprocessors { + puVal[i], puStep[i] = computeValAndStep(puI, c.manager.memPool) + } + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + gJ = make([]fr.Element, degGJ) + + parallel := len(EVal) >= 1024 //TODO: Experiment with threshold + + var gateInput [][]fr.Element + + if parallel { + gateInput = [][]fr.Element{c.manager.memPool.Make(len(c.inputPreprocessors)), + c.manager.memPool.Make(len(c.inputPreprocessors))} + } else { + gateInput = [][]fr.Element{c.manager.memPool.Make(len(c.inputPreprocessors))} + } + + var wg sync.WaitGroup + + for d := 0; d < degGJ; d++ { + + notLastIteration := d+1 < degGJ + + sumOverI := func(res *fr.Element, gateInput []fr.Element, start, end int) { + for i := start; i < end; i++ { + + for inputI := range puVal { + gateInput[inputI].Set(&puVal[inputI][i]) + if notLastIteration { + puVal[inputI][i].Add(&puVal[inputI][i], &puStep[inputI][i]) + } + } + + // gJAtDI = gJ(d, i...) + gJAtDI := c.wire.Gate.Evaluate(gateInput...) + gJAtDI.Mul(&gJAtDI, &EVal[i]) + + res.Add(res, &gJAtDI) + + if notLastIteration { + EVal[i].Add(&EVal[i], &EStep[i]) + } + } + wg.Done() + } + + if parallel { + var firstHalf, secondHalf fr.Element + wg.Add(2) + go sumOverI(&secondHalf, gateInput[1], len(EVal)/2, len(EVal)) + go sumOverI(&firstHalf, gateInput[0], 0, len(EVal)/2) + wg.Wait() + gJ[d].Add(&firstHalf, &secondHalf) + } else { + wg.Add(1) // formalities + sumOverI(&gJ[d], gateInput[0], 0, len(EVal)) + } + } + + c.manager.memPool.Dump(gateInput...) + c.manager.memPool.Dump(EVal, EStep) + + for inputI := range puVal { + c.manager.memPool.Dump(puVal[inputI], puStep[inputI]) + } + + return +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { + c.eq.Fold(element) + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, pool *polynomial.Pool) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = pool + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< b { + return a + } + return b +} + +func ChallengeNames(sorted []*Wire, logNbInstances int, prefix string) []string { + + // Pre-compute the size TODO: Consider not doing this and just grow the list by appending + size := logNbInstances // first challenge + + for _, w := range sorted { + if w.noProof() { // no proof, no challenge + continue + } + if w.NbClaims() > 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + + claims := newClaimsManager(c, assignment, o.pool) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + + claims := newClaimsManager(c, assignment, o.pool) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]fr.Element) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return fmt.Errorf("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return fmt.Errorf("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +type IdentityGate struct{} + +func (IdentityGate) Evaluate(input ...fr.Element) fr.Element { + return input[0] +} + +func (IdentityGate) Degree() int { + return 1 +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = IdentityGate{} + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + + numEvaluations := 0 + + for _, w := range sortedWires { + if !w.IsInput() { + if numEvaluations == 0 { + numEvaluations = len(a[w.Inputs[0]]) + } + evals := make([]fr.Element, numEvaluations) + ins := make([]fr.Element, len(w.Inputs)) + for k := 0; k < numEvaluations; k++ { + for inI, in := range w.Inputs { + ins[inI] = a[in][k] + } + evals[k] = w.Gate.Evaluate(ins...) + } + a[w] = evals + } + } + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} diff --git a/ecc/bn254/fr/gkr/gkr_test.go b/ecc/bn254/fr/gkr/gkr_test.go new file mode 100644 index 000000000..557a8ed0d --- /dev/null +++ b/ecc/bn254/fr/gkr/gkr_test.go @@ -0,0 +1,722 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/sumcheck" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + pool := polynomial.NewPool(256, 1<<11) + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, &pool) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandom(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/ConsenSys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandom(slice []fr.Element) { + for i := range slice { + slice[i].SetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, testCase.transcriptSetting()) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, testCase.transcriptSetting()) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(&test_vector_utils.MapHash{Map: testCase.Hash}, []byte{1})) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "../../../../internal/generator/gkr/test_vectors" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func BenchmarkGkrMimc(b *testing.B) { + const N = 1 << 19 + fmt.Println("creating circuit structure") + c := mimcCircuit(91) + + in0 := make([]fr.Element, N) + in1 := make([]fr.Element, N) + setRandom(in0) + setRandom(in1) + + fmt.Println("evaluating circuit") + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + + //b.ResetTimer() + fmt.Println("constructing proof") + Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate string `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = gates[c[i].Gate] + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +var gates map[string]Gate + +func init() { + gates = make(map[string]Gate) + gates["identity"] = IdentityGate{} + gates["mul"] = mulGate{} + gates["mimc"] = mimcCipherGate{} //TODO: Add ark + gates["select-input-3"] = _select(2) +} + +type mimcCipherGate struct { + ark fr.Element +} + +func (m mimcCipherGate) Evaluate(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]). + Add(&sum, &m.ark) + + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +func (m mimcCipherGate) Degree() int { + return 7 +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash *test_vector_utils.ElementMap + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash string `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash *test_vector_utils.ElementMap + if _hash, err = test_vector_utils.ElementMapFromFile(filepath.Join(dir, info.Hash)); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func (c *TestCase) transcriptSetting(initialChallenge ...[]byte) fiatshamir.Settings { + return fiatshamir.WithHash(&test_vector_utils.MapHash{Map: c.Hash}, initialChallenge...) +} + +type mulGate struct{} + +func (g mulGate) Evaluate(element ...fr.Element) (result fr.Element) { + result.Mul(&element[0], &element[1]) + return +} + +func (g mulGate) Degree() int { + return 2 +} + +type _select int + +func (g _select) Evaluate(in ...fr.Element) fr.Element { + return in[g] +} + +func (g _select) Degree() int { + return 1 +} diff --git a/ecc/bn254/fr/kzg/kzg.go b/ecc/bn254/fr/kzg/kzg.go index 0fd45d0a3..8de86fa92 100644 --- a/ecc/bn254/fr/kzg/kzg.go +++ b/ecc/bn254/fr/kzg/kzg.go @@ -84,9 +84,6 @@ func NewSRS(size uint64, bAlpha *big.Int) (*SRS, error) { for i := 1; i < len(alphas); i++ { alphas[i].Mul(&alphas[i-1], &alpha) } - for i := 0; i < len(alphas); i++ { - alphas[i].FromMont() - } g1s := bn254.BatchScalarMultiplicationG1(&gen1Aff, alphas) copy(srs.G1[1:], g1s) @@ -125,7 +122,7 @@ func Commit(p []fr.Element, srs *SRS, nbTasks ...int) (Digest, error) { var res bn254.G1Affine - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} if len(nbTasks) > 0 { config.NbTasks = nbTasks[0] } @@ -395,7 +392,7 @@ func BatchVerifyMultiPoints(digests []Digest, proofs []OpeningProof, points []fr for i := 0; i < len(randomNumbers); i++ { quotients[i].Set(&proofs[i].H) } - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} _, err := foldedQuotients.MultiExp(quotients, randomNumbers, config) if err != nil { return nil @@ -478,7 +475,7 @@ func fold(di []Digest, fai []fr.Element, ci []fr.Element) (Digest, fr.Element, e // fold the digests ∑ᵢ[cᵢ]([fᵢ(α)]G₁) var foldedDigests Digest - _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{ScalarsMont: true}) + _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{}) if err != nil { return foldedDigests, foldedEvaluations, err } diff --git a/ecc/bn254/fr/mimc/decompose.go b/ecc/bn254/fr/mimc/decompose.go new file mode 100644 index 000000000..e61417b9b --- /dev/null +++ b/ecc/bn254/fr/mimc/decompose.go @@ -0,0 +1,46 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc/bn254/fr" +) + +// Decompose interpret rawBytes as a bigInt x in big endian, +// and returns the digits of x (from LSB to MSB) when x is written +// in basis modulo. +func Decompose(rawBytes []byte) []fr.Element { + + rawBigInt := big.NewInt(0).SetBytes(rawBytes) + modulo := fr.Modulus() + + // maximum number of chunks that a function + maxNbChunks := len(rawBytes) / fr.Bytes + + res := make([]fr.Element, 0, maxNbChunks) + var tmp fr.Element + t := new(big.Int) + for rawBigInt.Sign() != 0 { + rawBigInt.DivMod(rawBigInt, modulo, t) + tmp.SetBigInt(t) + res = append(res, tmp) + } + + return res +} diff --git a/ecc/bn254/fr/mimc/decompose_test.go b/ecc/bn254/fr/mimc/decompose_test.go new file mode 100644 index 000000000..3597da7a3 --- /dev/null +++ b/ecc/bn254/fr/mimc/decompose_test.go @@ -0,0 +1,53 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bn254/fr" +) + +func TestDecompose(t *testing.T) { + + // create 10 random digits in basis r + nbDigits := 10 + a := make([]fr.Element, nbDigits) + for i := 0; i < nbDigits; i++ { + a[i].SetRandom() + } + + // create a big int whose digits in basis r are a + m := fr.Modulus() + var b, tmp big.Int + for i := nbDigits - 1; i >= 0; i-- { + b.Mul(&b, m) + a[i].ToBigIntRegular(&tmp) + b.Add(&b, &tmp) + } + + // query the decomposition and compare to a + bb := b.Bytes() + d := Decompose(bb) + for i := 0; i < nbDigits; i++ { + if !d[i].Equal(&a[i]) { + t.Fatal("error decomposition") + } + } + +} diff --git a/ecc/bn254/fr/mimc/mimc.go b/ecc/bn254/fr/mimc/mimc.go index 43cc82ec4..87a9776ee 100644 --- a/ecc/bn254/fr/mimc/mimc.go +++ b/ecc/bn254/fr/mimc/mimc.go @@ -17,6 +17,7 @@ package mimc import ( + "errors" "hash" "github.com/consensys/gnark-crypto/ecc/bn254/fr" @@ -91,44 +92,44 @@ func (d *digest) BlockSize() int { } // Write (via the embedded io.Writer interface) adds more data to the running hash. -// It never returns an error. +// +// Each []byte block of size BlockSize represents a big endian fr.Element. +// +// If len(p) is not a multiple of BlockSize and any of the []byte in p represent an integer +// larger than fr.Modulus, this function returns an error. +// +// To hash arbitrary data ([]byte not representing canonical field elements) use Decompose +// function in this package. func (d *digest) Write(p []byte) (n int, err error) { n = len(p) + if n%BlockSize != 0 { + return 0, errors.New("invalid input length: must represent a list of field elements, expects a []byte of len m*BlockSize") + } + + // ensure each block represents a field element in canonical reduced form + for i := 0; i < n; i += BlockSize { + if _, err = fr.BigEndian.Element((*[BlockSize]byte)(p[i : i+BlockSize])); err != nil { + return 0, err + } + } + d.data = append(d.data, p...) return } -// Hash hash using Miyaguchi–Preneel: +// Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form func (d *digest) checksum() fr.Element { + // Write guarantees len(data) % BlockSize == 0 - var buffer [BlockSize]byte - var x fr.Element - - // if data size is not multiple of BlockSizes we padd: - // .. || 0xaf8 -> .. || 0x0000...0af8 - if len(d.data)%BlockSize != 0 { - q := len(d.data) / BlockSize - r := len(d.data) % BlockSize - sliceq := make([]byte, q*BlockSize) - copy(sliceq, d.data) - slicer := make([]byte, r) - copy(slicer, d.data[q*BlockSize:]) - sliceremainder := make([]byte, BlockSize-r) - d.data = append(sliceq, sliceremainder...) - d.data = append(d.data, slicer...) - } - + // TODO @ThomasPiellard shouldn't Sum() returns an error if there is no data? if len(d.data) == 0 { - d.data = make([]byte, 32) + d.data = make([]byte, BlockSize) } - nbChunks := len(d.data) / BlockSize - - for i := 0; i < nbChunks; i++ { - copy(buffer[:], d.data[i*BlockSize:(i+1)*BlockSize]) - x.SetBytes(buffer[:]) + for i := 0; i < len(d.data); i += BlockSize { + x, _ := fr.BigEndian.Element((*[BlockSize]byte)(d.data[i : i+BlockSize])) r := d.encrypt(x) d.h.Add(&r, &d.h).Add(&d.h, &x) } diff --git a/ecc/bn254/fr/mimc/mimc_test.go b/ecc/bn254/fr/mimc/mimc_test.go deleted file mode 100644 index cb8995c03..000000000 --- a/ecc/bn254/fr/mimc/mimc_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package mimc - -// import ( -// "testing" -// ) - -// func TestMimc(t *testing.T) { - -// // Expected result from ethereum -// var data [3]fr.Element -// data[0].SetString("10909369219534740878285360918369814291778422174980871969149168794639722256599") -// data[1].SetString("3811523387212735178398974960485340561880938762308498768570292593755555588442") -// data[2].SetString("21761276089180230617904476026690048826689721630933485969915548849196498965166") - -// h := NewMiMC() -// h.Write(data[0].Marshal()) -// h.Write(data[1].Marshal()) -// h.Write(data[2].Marshal()) - -// r := h.Sum(nil) - -// var b big.Int -// b.SetBytes(r) -// fmt.Printf("%s\n", b.String()) - -//------- - -// h := NewMiMC("mimc") -// var a [3]fr.Element -// a[0].SetRandom() -// a[1].SetRandom() -// a[2].SetRandom() -// fmt.Printf("%s\n", a[0].String()) -// fmt.Printf("%s\n", a[1].String()) -// fmt.Printf("%s\n", a[2].String()) -// fmt.Println("") -// h.Write(a[0].Marshal()) -// h.Write(a[1].Marshal()) -// h.Write(a[2].Marshal()) -// var a fr.Element -// a.SetUint64(2323) -// h.Write(a.Marshal()) -// r := h.Sum(nil) -// var br big.Int -// br.SetBytes(r) -// fmt.Printf("%s\n", br.String()) -//_h := h.(*digest) - -// //var h1, h2, h3 fr.Element -// var h1, h2 fr.Element -// h1.SetString("948723") -// h2.SetString("236878") -// // h3.SetString("283") -// _h.data = append(_h.data, h1.Marshal()...) -// _h.data = append(_h.data, h2.Marshal()...) -// // _h.data = append(_h.data, h3.Marshal()...) - -// _h.checksum() -// fmt.Printf("%s\n", _h.h.String()) -// } diff --git a/ecc/bn254/fr/pedersen/pedersen.go b/ecc/bn254/fr/pedersen/pedersen.go new file mode 100644 index 000000000..475644705 --- /dev/null +++ b/ecc/bn254/fr/pedersen/pedersen.go @@ -0,0 +1,113 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package pedersen + +import ( + "crypto/rand" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "math/big" +) + +// Key for proof and verification +type Key struct { + g bn254.G2Affine // TODO @tabaie: does this really have to be randomized? + gRootSigmaNeg bn254.G2Affine //gRootSigmaNeg = g^{-1/σ} + basis []bn254.G1Affine + basisExpSigma []bn254.G1Affine +} + +func randomOnG2() (bn254.G2Affine, error) { // TODO: Add to G2.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return bn254.G2Affine{}, err + } + return bn254.HashToG2(gBytes, []byte("random on g2")) +} + +func Setup(basis []bn254.G1Affine) (Key, error) { + var ( + k Key + err error + ) + + if k.g, err = randomOnG2(); err != nil { + return k, err + } + + var modMinusOne big.Int + modMinusOne.Sub(fr.Modulus(), big.NewInt(1)) + var sigma *big.Int + if sigma, err = rand.Int(rand.Reader, &modMinusOne); err != nil { + return k, err + } + sigma.Add(sigma, big.NewInt(1)) + + var sigmaInvNeg big.Int + sigmaInvNeg.ModInverse(sigma, fr.Modulus()) + sigmaInvNeg.Sub(fr.Modulus(), &sigmaInvNeg) + k.gRootSigmaNeg.ScalarMultiplication(&k.g, &sigmaInvNeg) + + k.basisExpSigma = make([]bn254.G1Affine, len(basis)) + for i := range basis { + k.basisExpSigma[i].ScalarMultiplication(&basis[i], sigma) + } + + k.basis = basis + return k, err +} + +func (k *Key) Commit(values []fr.Element) (commitment bn254.G1Affine, knowledgeProof bn254.G1Affine, err error) { + + if len(values) != len(k.basis) { + err = fmt.Errorf("unexpected number of values") + return + } + + // TODO @gbotrel this will spawn more than one task, see + // https://github.com/ConsenSys/gnark-crypto/issues/269 + config := ecc.MultiExpConfig{ + NbTasks: 1, // TODO Experiment + } + + if _, err = commitment.MultiExp(k.basis, values, config); err != nil { + return + } + + _, err = knowledgeProof.MultiExp(k.basisExpSigma, values, config) + + return +} + +// VerifyKnowledgeProof checks if the proof of knowledge is valid +func (k *Key) VerifyKnowledgeProof(commitment bn254.G1Affine, knowledgeProof bn254.G1Affine) error { + + if !commitment.IsInSubGroup() || !knowledgeProof.IsInSubGroup() { + return fmt.Errorf("subgroup check failed") + } + + product, err := bn254.Pair([]bn254.G1Affine{commitment, knowledgeProof}, []bn254.G2Affine{k.g, k.gRootSigmaNeg}) + if err != nil { + return err + } + if product.IsOne() { + return nil + } + return fmt.Errorf("proof rejected") +} diff --git a/ecc/bn254/fr/pedersen/pedersen_test.go b/ecc/bn254/fr/pedersen/pedersen_test.go new file mode 100644 index 000000000..eac9cebe6 --- /dev/null +++ b/ecc/bn254/fr/pedersen/pedersen_test.go @@ -0,0 +1,90 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package pedersen + +import ( + "github.com/consensys/gnark-crypto/ecc/bn254" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/stretchr/testify/assert" + "math/rand" + "testing" +) + +func interfaceSliceToFrSlice(t *testing.T, values ...interface{}) []fr.Element { + res := make([]fr.Element, len(values)) + for i, v := range values { + _, err := res[i].SetInterface(v) + assert.NoError(t, err) + } + return res +} + +func randomFrSlice(t *testing.T, size int) []interface{} { + res := make([]interface{}, size) + var err error + for i := range res { + var v fr.Element + res[i], err = v.SetRandom() + assert.NoError(t, err) + } + return res +} + +func randomOnG1() (bn254.G1Affine, error) { // TODO: Add to G1.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return bn254.G1Affine{}, err + } + return bn254.HashToG1(gBytes, []byte("random on g2")) +} + +func testCommit(t *testing.T, values ...interface{}) { + + basis := make([]bn254.G1Affine, len(values)) + for i := range basis { + var err error + basis[i], err = randomOnG1() + assert.NoError(t, err) + } + + var ( + key Key + err error + commitment, pok bn254.G1Affine + ) + + key, err = Setup(basis) + assert.NoError(t, err) + commitment, pok, err = key.Commit(interfaceSliceToFrSlice(t, values...)) + assert.NoError(t, err) + assert.NoError(t, key.VerifyKnowledgeProof(commitment, pok)) + + pok.Neg(&pok) + assert.NotNil(t, key.VerifyKnowledgeProof(commitment, pok)) +} + +func TestCommitToOne(t *testing.T) { + testCommit(t, 1) +} + +func TestCommitSingle(t *testing.T) { + testCommit(t, randomFrSlice(t, 1)...) +} + +func TestCommitFiveElements(t *testing.T) { + testCommit(t, randomFrSlice(t, 5)...) +} diff --git a/ecc/bn254/fr/plookup/vector.go b/ecc/bn254/fr/plookup/vector.go index e2a31166b..5ebbce5c0 100644 --- a/ecc/bn254/fr/plookup/vector.go +++ b/ecc/bn254/fr/plookup/vector.go @@ -125,7 +125,8 @@ func evaluateAccumulationPolynomial(lf, lt, lh1, lh2 []fr.Element, beta, gamma f // evaluateNumBitReversed computes the evaluation (shifted, bit reversed) of h where // h = (x-1)*z*(1+\beta)*(\gamma+f)*(\gamma(1+\beta) + t+ \beta*t(gX)) - -// (x-1)*z(gX)*(\gamma(1+\beta) + h_{1} + \beta*h_{1}(gX))*(\gamma(1+\beta) + h_{2} + \beta*h_{2}(gX) ) +// +// (x-1)*z(gX)*(\gamma(1+\beta) + h_{1} + \beta*h_{1}(gX))*(\gamma(1+\beta) + h_{2} + \beta*h_{2}(gX) ) // // * cz, ch1, ch2, ct, cf are the polynomials z, h1, h2, t, f in canonical basis // * _lz, _lh1, _lh2, _lt, _lf are the polynomials z, h1, h2, t, f in shifted Lagrange basis (domainBig) @@ -370,7 +371,6 @@ func computeQuotientCanonical(alpha fr.Element, lh, lh0, lhn, lh1h2 []fr.Element // before generating a lookup proof), the commitment needs to be done on the // table sorted. Otherwise the commitment in proof.t will not be the same as // the public commitment: it will contain the same values, but permuted. -// func ProveLookupVector(srs *kzg.SRS, f, t Table) (ProofLookupVector, error) { // res diff --git a/ecc/bn254/fr/polynomial/multilin.go b/ecc/bn254/fr/polynomial/multilin.go index 2b2fac3a1..0276fca1d 100644 --- a/ecc/bn254/fr/polynomial/multilin.go +++ b/ecc/bn254/fr/polynomial/multilin.go @@ -18,6 +18,7 @@ package polynomial import ( "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "math/bits" ) // MultiLin tracks the values of a (dense i.e. not sparse) multilinear polynomial @@ -46,44 +47,62 @@ func (m *MultiLin) Fold(r fr.Element) { *m = (*m)[:mid] } +func (m MultiLin) Sum() fr.Element { + s := m[0] + for i := 1; i < len(m); i++ { + s.Add(&s, &m[i]) + } + return s +} + +func _clone(m MultiLin, p *Pool) MultiLin { + if p == nil { + return m.Clone() + } else { + return p.Clone(m) + } +} + +func _dump(m MultiLin, p *Pool) { + if p != nil { + p.Dump(m) + } +} + // Evaluate extrapolate the value of the multilinear polynomial corresponding to m // on the given coordinates -func (m MultiLin) Evaluate(coordinates []fr.Element) fr.Element { +func (m MultiLin) Evaluate(coordinates []fr.Element, p *Pool) fr.Element { // Folding is a mutating operation - bkCopy := m.Clone() + bkCopy := _clone(m, p) // Evaluate step by step through repeated folding (i.e. evaluation at the first remaining variable) for _, r := range coordinates { bkCopy.Fold(r) } - return bkCopy[0] + result := bkCopy[0] + + _dump(bkCopy, p) + return result } -// Clone creates a deep copy of a book-keeping table. +// Clone creates a deep copy of a bookkeeping table. // Both multilinear interpolation and sumcheck require folding an underlying // array, but folding changes the array. To do both one requires a deep copy -// of the book-keeping table. +// of the bookkeeping table. func (m MultiLin) Clone() MultiLin { - tableDeepCopy := Make(len(m)) - copy(tableDeepCopy, m) - return tableDeepCopy + res := make(MultiLin, len(m)) + copy(res, m) + return res } // Add two bookKeepingTables func (m *MultiLin) Add(left, right MultiLin) { size := len(left) // Check that left and right have the same size - if len(right) != size { - panic("Left and right do not have the right size") + if len(right) != size || len(*m) != size { + panic("left, right and destination must have the right size") } - // Reallocate the table if necessary - if cap(*m) < size { - *m = make([]fr.Element, size) - } - - // Resize the destination table - *m = (*m)[:size] // Add elementwise for i := 0; i < size; i++ { @@ -93,15 +112,17 @@ func (m *MultiLin) Add(left, right MultiLin) { // EvalEq computes Eq(q₁, ... , qₙ, h₁, ... , hₙ) = Π₁ⁿ Eq(qᵢ, hᵢ) // where Eq(x,y) = xy + (1-x)(1-y) = 1 - x - y + xy + xy interpolates -// _________________ -// | | | -// | 0 | 1 | -// |_______|_______| -// y | | | -// | 1 | 0 | -// |_______|_______| // -// x +// _________________ +// | | | +// | 0 | 1 | +// |_______|_______| +// y | | | +// | 1 | 0 | +// |_______|_______| +// +// x +// // In other words the polynomial evaluated here is the multilinear extrapolation of // one that evaluates to q' == h' for vectors q', h' of binary values func EvalEq(q, h []fr.Element) fr.Element { @@ -128,10 +149,7 @@ func (m *MultiLin) Eq(q []fr.Element) { n := len(q) if len(*m) != 1< 0 { - i.Sub(fr.Modulus(), &i) - i.Neg(&i) +func (p Polynomial) SetZero() { + for i := 0; i < len(p); i++ { + p[i].SetZero() } - return i } func (p Polynomial) Text(base int) string { @@ -171,12 +175,13 @@ func (p Polynomial) Text(base int) string { continue } - i := signedBigInt(&p[d]) + pD := p[d] + pDText := pD.Text(base) initialLen := builder.Len() - if i.Sign() < 1 { - i.Neg(&i) + if pDText[0] == '-' { + pDText = pDText[1:] if first { builder.WriteString("-") } else { @@ -188,13 +193,8 @@ func (p Polynomial) Text(base int) string { first = false - asInt64 := int64(0) - if i.IsInt64() { - asInt64 = i.Int64() - } - - if asInt64 != 1 || d == 0 { - builder.WriteString(i.Text(base)) + if !pD.IsOne() || d == 0 { + builder.WriteString(pDText) } if builder.Len()-initialLen > 10 { diff --git a/ecc/bn254/fr/polynomial/polynomial_test.go b/ecc/bn254/fr/polynomial/polynomial_test.go index 2664c30e2..a97c4df62 100644 --- a/ecc/bn254/fr/polynomial/polynomial_test.go +++ b/ecc/bn254/fr/polynomial/polynomial_test.go @@ -17,10 +17,10 @@ package polynomial import ( + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/stretchr/testify/assert" "math/big" "testing" - - "github.com/consensys/gnark-crypto/ecc/bn254/fr" ) func TestPolynomialEval(t *testing.T) { @@ -206,3 +206,13 @@ func TestPolynomialAdd(t *testing.T) { t.Fatal("side effect, _f2 should not have been modified") } } + +func TestPolynomialText(t *testing.T) { + var one, negTwo fr.Element + one.SetOne() + negTwo.SetInt64(-2) + + p := Polynomial{one, negTwo, one} + + assert.Equal(t, "X² - 2X + 1", p.Text(10)) +} diff --git a/ecc/bn254/fr/polynomial/pool.go b/ecc/bn254/fr/polynomial/pool.go index 21f2e5a87..29ca322fa 100644 --- a/ecc/bn254/fr/polynomial/pool.go +++ b/ecc/bn254/fr/polynomial/pool.go @@ -17,114 +17,187 @@ package polynomial import ( + "encoding/json" "fmt" "github.com/consensys/gnark-crypto/ecc/bn254/fr" "reflect" + "runtime" + "sort" "sync" "unsafe" ) // Memory management for polynomials -// Copied verbatim from gkr repo +// WARNING: This is not thread safe TODO: Make sure that is not a problem +// TODO: There is a lot of "unsafe" memory management here and needs to be vetted thoroughly -// Sets a maximum for the array size we keep in pool -const maxNForLargePool int = 1 << 24 -const maxNForSmallPool int = 256 +type sizedPool struct { + maxN int + pool sync.Pool + stats poolStats +} + +type inUseData struct { + allocatedFor []uintptr + pool *sizedPool +} -// Aliases because it is annoying to use arrays in all the places -type largeArr = [maxNForLargePool]fr.Element -type smallArr = [maxNForSmallPool]fr.Element +type Pool struct { + //lock sync.Mutex + inUse map[*fr.Element]inUseData + subPools []sizedPool +} -var rC = sync.Map{} +func (p *sizedPool) get(n int) *fr.Element { + p.stats.maake(n) + return p.pool.Get().(*fr.Element) +} -var ( - largePool = sync.Pool{ - New: func() interface{} { - var res largeArr - return &res - }, - } - smallPool = sync.Pool{ - New: func() interface{} { - var res smallArr - return &res - }, +func (p *sizedPool) put(ptr *fr.Element) { + p.stats.dump() + p.pool.Put(ptr) +} + +func NewPool(maxN ...int) (pool Pool) { + + sort.Ints(maxN) + pool = Pool{ + inUse: make(map[*fr.Element]inUseData), + subPools: make([]sizedPool, len(maxN)), } -) -// ClearPool Clears the pool completely, shields against memory leaks -// Eg: if we forgot to dump a polynomial at some point, this will ensure the value get dumped eventually -// Returns how many polynomials were cleared that way -func ClearPool() int { - res := 0 - rC.Range(func(k, _ interface{}) bool { - switch ptr := k.(type) { - case *largeArr: - largePool.Put(ptr) - case *smallArr: - smallPool.Put(ptr) - default: - panic(fmt.Sprintf("tried to clear %v", reflect.TypeOf(ptr))) + for i := range pool.subPools { + subPool := &pool.subPools[i] + subPool.maxN = maxN[i] + subPool.pool = sync.Pool{ + New: func() interface{} { + subPool.stats.Allocated++ + return getDataPointer(make([]fr.Element, 0, subPool.maxN)) + }, } - res++ - return true - }) - return res + } + return } -// CountPool Returns the number of elements in the pool without mutating it -func CountPool() int { - res := 0 - rC.Range(func(_, _ interface{}) bool { - res++ - return true - }) - return res +func (p *Pool) findCorrespondingPool(n int) *sizedPool { + poolI := 0 + for poolI < len(p.subPools) && n > p.subPools[poolI].maxN { + poolI++ + } + return &p.subPools[poolI] // out of bounds error here would mean that n is too large } -// Make tries to find a reusable polynomial or allocates a new one -func Make(n int) []fr.Element { - if n > maxNForLargePool { - panic(fmt.Sprintf("been provided with size of %v but the maximum is %v", n, maxNForLargePool)) +func (p *Pool) Make(n int) []fr.Element { + pool := p.findCorrespondingPool(n) + ptr := pool.get(n) + p.addInUse(ptr, pool) + return unsafe.Slice(ptr, n) +} + +// Dump dumps a set of polynomials into the pool +func (p *Pool) Dump(slices ...[]fr.Element) { + for _, slice := range slices { + ptr := getDataPointer(slice) + if metadata, ok := p.inUse[ptr]; ok { + delete(p.inUse, ptr) + metadata.pool.put(ptr) + } else { + panic("attempting to dump a slice not created by the pool") + } } +} + +func (p *Pool) addInUse(ptr *fr.Element, pool *sizedPool) { + pcs := make([]uintptr, 2) + n := runtime.Callers(3, pcs) - if n <= maxNForSmallPool { - ptr := smallPool.Get().(*smallArr) - rC.Store(ptr, struct{}{}) // registers the pointer being used - return (*ptr)[:n] + if prevPcs, ok := p.inUse[ptr]; ok { // TODO: remove if unnecessary for security + panic(fmt.Errorf("re-allocated non-dumped slice, previously allocated at %v", runtime.CallersFrames(prevPcs.allocatedFor))) } + p.inUse[ptr] = inUseData{ + allocatedFor: pcs[:n], + pool: pool, + } +} - ptr := largePool.Get().(*largeArr) - rC.Store(ptr, struct{}{}) // remember we allocated the pointer is being used - return (*ptr)[:n] +func printFrame(frame runtime.Frame) { + fmt.Printf("\t%s line %d, function %s\n", frame.File, frame.Line, frame.Function) } -// Dump dumps a set of polynomials into the pool -// Returns the number of deallocated polys -func Dump(arrs ...[]fr.Element) int { - cnt := 0 - for _, arr := range arrs { - ptr := ptr(arr) - pool := &smallPool - if len(arr) > maxNForSmallPool { - pool = &largePool - } - // If the rC did not register, then - // either the array was allocated somewhere else which can be ignored - // otherwise a double put which MUST be ignored - if _, ok := rC.Load(ptr); ok { - pool.Put(ptr) - // And deregisters the ptr - rC.Delete(ptr) - cnt++ +func (p *Pool) printInUse() { + fmt.Println("slices never dumped allocated at:") + for _, pcs := range p.inUse { + fmt.Println("-------------------------") + + var frame runtime.Frame + frames := runtime.CallersFrames(pcs.allocatedFor) + more := true + for more { + frame, more = frames.Next() + printFrame(frame) } } - return cnt } -func ptr(m []fr.Element) unsafe.Pointer { - if cap(m) != maxNForSmallPool && cap(m) != maxNForLargePool { - panic(fmt.Sprintf("can't cast to large or small array, the put array's is %v it should have capacity %v or %v", cap(m), maxNForLargePool, maxNForSmallPool)) +type poolStats struct { + Used int + Allocated int + ReuseRate float64 + InUse int + GreatestNUsed int + SmallestNUsed int +} + +type poolsStats struct { + SubPools []poolStats + InUse int +} + +func (s *poolStats) maake(n int) { + s.Used++ + s.InUse++ + if n > s.GreatestNUsed { + s.GreatestNUsed = n + } + if s.SmallestNUsed == 0 || s.SmallestNUsed > n { + s.SmallestNUsed = n + } +} + +func (s *poolStats) dump() { + s.InUse-- +} + +func (s *poolStats) finalize() { + s.ReuseRate = float64(s.Used) / float64(s.Allocated) +} + +func getDataPointer(slice []fr.Element) *fr.Element { + header := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) + return (*fr.Element)(unsafe.Pointer(header.Data)) +} + +func (p *Pool) PrintPoolStats() { + InUse := 0 + subStats := make([]poolStats, len(p.subPools)) + for i := range p.subPools { + subPool := &p.subPools[i] + subPool.stats.finalize() + subStats[i] = subPool.stats + InUse += subPool.stats.InUse } - return unsafe.Pointer(&m[0]) + + poolsStats := poolsStats{ + SubPools: subStats, + InUse: InUse, + } + serialized, _ := json.MarshalIndent(poolsStats, "", " ") + fmt.Println(string(serialized)) + p.printInUse() +} + +func (p *Pool) Clone(slice []fr.Element) []fr.Element { + res := p.Make(len(slice)) + copy(res, slice) + return res } diff --git a/ecc/bn254/fr/sumcheck/sumcheck.go b/ecc/bn254/fr/sumcheck/sumcheck.go new file mode 100644 index 000000000..c7de59c15 --- /dev/null +++ b/ecc/bn254/fr/sumcheck/sumcheck.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = &transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return fmt.Errorf("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/ecc/bn254/fr/sumcheck/sumcheck_test.go b/ecc/bn254/fr/sumcheck/sumcheck_test.go new file mode 100644 index 000000000..d5d86b17b --- /dev/null +++ b/ecc/bn254/fr/sumcheck/sumcheck_test.go @@ -0,0 +1,161 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/ecc/bn254/fr/test_vector_utils/test_vector_utils.go b/ecc/bn254/fr/test_vector_utils/test_vector_utils.go new file mode 100644 index 000000000..f39f6ae47 --- /dev/null +++ b/ecc/bn254/fr/test_vector_utils/test_vector_utils.go @@ -0,0 +1,429 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package test_vector_utils + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + "github.com/consensys/gnark-crypto/ecc/bn254/fr/polynomial" + "hash" + + "os" + "path/filepath" + "reflect" + "sort" + "strconv" + "strings" +) + +type ElementTriplet struct { + key1 fr.Element + key2 fr.Element + key2Present bool + value fr.Element + used bool +} + +func (t *ElementTriplet) CmpKey(o *ElementTriplet) int { + if cmp1 := t.key1.Cmp(&o.key1); cmp1 != 0 { + return cmp1 + } + + if t.key2Present { + if o.key2Present { + return t.key2.Cmp(&o.key2) + } + return 1 + } else { + if o.key2Present { + return -1 + } + return 0 + } +} + +var MapCache = make(map[string]*ElementMap) + +func ElementMapFromFile(path string) (*ElementMap, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if h, ok := MapCache[path]; ok { + return h, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var asMap map[string]interface{} + if err = json.Unmarshal(bytes, &asMap); err != nil { + return nil, err + } + + var h ElementMap + if h, err = CreateElementMap(asMap); err == nil { + MapCache[path] = &h + } + + return &h, err + + } else { + return nil, err + } +} + +func CreateElementMap(rawMap map[string]interface{}) (ElementMap, error) { + res := make(ElementMap, 0, len(rawMap)) + + for k, v := range rawMap { + var entry ElementTriplet + if _, err := SetElement(&entry.value, v); err != nil { + return nil, err + } + + key := strings.Split(k, ",") + switch len(key) { + case 1: + entry.key2Present = false + case 2: + entry.key2Present = true + if _, err := SetElement(&entry.key2, key[1]); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("cannot parse %T as one or two field elements", v) + } + if _, err := SetElement(&entry.key1, key[0]); err != nil { + return nil, err + } + + res = append(res, &entry) + } + + res.sort() + return res, nil +} + +type ElementMap []*ElementTriplet + +type MapHash struct { + Map *ElementMap + state fr.Element + stateValid bool +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func (m *MapHash) Write(p []byte) (n int, err error) { + var x fr.Element + for i := 0; i < len(p); i += fr.Bytes { + x.SetBytes(p[i:min(len(p), i+fr.Bytes)]) + if err = m.write(x); err != nil { + return + } + } + n = len(p) + return +} + +func (m *MapHash) Sum(b []byte) []byte { + mP := *m + if _, err := mP.Write(b); err != nil { + panic(err) + } + bytes := mP.state.Bytes() + return bytes[:] +} + +func (m *MapHash) Reset() { + m.stateValid = false +} + +func (m *MapHash) Size() int { + return fr.Bytes +} + +func (m *MapHash) BlockSize() int { + return fr.Bytes +} + +func (m *MapHash) write(x fr.Element) error { + X := &x + Y := &m.state + if !m.stateValid { + Y = nil + } + var err error + if m.state, err = m.Map.FindPair(X, Y); err == nil { + m.stateValid = true + } + return err +} + +func (t *ElementTriplet) writeKey(sb *strings.Builder) { + sb.WriteRune('"') + sb.WriteString(t.key1.String()) + if t.key2Present { + sb.WriteRune(',') + sb.WriteString(t.key2.String()) + } + sb.WriteRune('"') +} +func (m *ElementMap) UnusedEntries() []interface{} { + unused := make([]interface{}, 0) + for _, v := range *m { + if !v.used { + var vInterface interface{} + if v.key2Present { + vInterface = []interface{}{ElementToInterface(&v.key1), ElementToInterface(&v.key2)} + } else { + vInterface = ElementToInterface(&v.key1) + } + unused = append(unused, vInterface) + } + } + return unused +} + +func (m *ElementMap) sort() { + sort.Slice(*m, func(i, j int) bool { + return (*m)[i].CmpKey((*m)[j]) <= 0 + }) +} + +func (m *ElementMap) find(toFind *ElementTriplet) (fr.Element, error) { + i := sort.Search(len(*m), func(i int) bool { return (*m)[i].CmpKey(toFind) >= 0 }) + + if i < len(*m) && (*m)[i].CmpKey(toFind) == 0 { + (*m)[i].used = true + return (*m)[i].value, nil + } + var sb strings.Builder + sb.WriteString("no value available for input ") + toFind.writeKey(&sb) + return fr.Element{}, fmt.Errorf(sb.String()) +} + +func (m *ElementMap) FindPair(x *fr.Element, y *fr.Element) (fr.Element, error) { + + toFind := ElementTriplet{ + key1: *x, + key2Present: y != nil, + } + + if y != nil { + toFind.key2 = *y + } + + return m.find(&toFind) +} + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type MessageCounter struct { + startState uint64 + state uint64 + step uint64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += uint64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + uint64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: uint64(startState), state: uint64(startState), step: uint64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + text := x.Text(10) + if len(text) < 10 && !strings.Contains(text, "/") { + if i, err := strconv.Atoi(text); err != nil { + panic(err.Error()) + } else { + return i + } + } + return text +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/ecc/bn254/fr/test_vector_utils/test_vector_utils_test.go b/ecc/bn254/fr/test_vector_utils/test_vector_utils_test.go new file mode 100644 index 000000000..9ecbab36b --- /dev/null +++ b/ecc/bn254/fr/test_vector_utils/test_vector_utils_test.go @@ -0,0 +1,44 @@ +package test_vector_utils + +import ( + "github.com/consensys/gnark-crypto/ecc/bn254/fr" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils" + "github.com/stretchr/testify/assert" + "strconv" + "testing" +) + +func TestTranscript(t *testing.T) { + + mp, err := CreateElementMap(map[string]interface{}{ + strconv.Itoa('0'): 2, + "3,2": 5, + }) + assert.NoError(t, err) + + hsh := MapHash{Map: &mp} + transcript := fiatshamir.NewTranscript(&hsh, "0", "1") + bytes := ToElement(3).Bytes() + err = transcript.Bind("0", bytes[:]) + assert.NoError(t, err) + var cBytes []byte + cBytes, err = transcript.ComputeChallenge("0") + assert.NoError(t, err) + var res fr.Element + res.SetBytes(cBytes) + assert.True(t, ToElement(5).Equal(&res)) +} + +func TestCounterTranscriptInequality(t *testing.T) { + const challengeName = "fC.0" + t1 := fiatshamir.NewTranscript(test_vector_utils.NewMessageCounter(1, 1), challengeName) + t2 := fiatshamir.NewTranscript(test_vector_utils.NewMessageCounter(0, 1), challengeName) + var c1, c2 []byte + var err error + c1, err = t1.ComputeChallenge(challengeName) + assert.NoError(t, err) + c2, err = t2.ComputeChallenge(challengeName) + assert.NoError(t, err) + assert.NotEqual(t, c1, c2) +} diff --git a/ecc/bn254/g1.go b/ecc/bn254/g1.go index 872d8218c..62d1937ae 100644 --- a/ecc/bn254/g1.go +++ b/ecc/bn254/g1.go @@ -17,13 +17,12 @@ package bn254 import ( - "math/big" - "runtime" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bn254/fp" "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" + "runtime" ) // G1Affine point in affine coordinates @@ -36,7 +35,7 @@ type G1Jac struct { X, Y, Z fp.Element } -// g1JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +// g1JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) type g1JacExtended struct { X, Y, ZZ, ZZZ fp.Element } @@ -50,6 +49,13 @@ func (p *G1Affine) Set(a *G1Affine) *G1Affine { return p } +// setInfinity sets p to O +func (p *G1Affine) setInfinity() *G1Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + // ScalarMultiplication computes and returns p = a ⋅ s func (p *G1Affine) ScalarMultiplication(a *G1Affine, s *big.Int) *G1Affine { var _p G1Jac @@ -463,8 +469,8 @@ func (p *G1Jac) mulGLV(a *G1Jac, s *big.Int) *G1Jac { // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // we don't target constant-timeness so we check first if we increase the bounds or not maxBit := k1.BitLen() @@ -555,15 +561,15 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { return p } - var A, B, X1ZZ2, X2ZZ1, Y1ZZZ2, Y2ZZZ1 fp.Element + var A, B, U1, U2, S1, S2 fp.Element // p2: q, p1: p - X2ZZ1.Mul(&q.X, &p.ZZ) - X1ZZ2.Mul(&p.X, &q.ZZ) - A.Sub(&X2ZZ1, &X1ZZ2) - Y2ZZZ1.Mul(&q.Y, &p.ZZZ) - Y1ZZZ2.Mul(&p.Y, &q.ZZZ) - B.Sub(&Y2ZZZ1, &Y1ZZZ2) + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) if A.IsZero() { if B.IsZero() { @@ -575,11 +581,7 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { return p } - var U1, U2, S1, S2, P, R, PP, PPP, Q, V fp.Element - U1.Mul(&p.X, &q.ZZ) - U2.Mul(&q.X, &p.ZZ) - S1.Mul(&p.Y, &q.ZZZ) - S2.Mul(&q.Y, &p.ZZZ) + var P, R, PP, PPP, Q, V fp.Element P.Sub(&U2, &U1) R.Sub(&S2, &S1) PP.Square(&P) @@ -604,6 +606,8 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { // double point in Jacobian extended coordinates // http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well func (p *g1JacExtended) double(q *g1JacExtended) *g1JacExtended { var U, V, W, S, XX, M fp.Element @@ -846,95 +850,72 @@ func BatchJacobianToAffineG1(points []G1Jac) []G1Affine { // and return resulting points in affine coordinates // uses a simple windowed-NAF like exponentiation algorithm func BatchScalarMultiplicationG1(base *G1Affine, scalars []fr.Element) []G1Affine { - // approximate cost in group ops is // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) nbPoints := uint64(len(scalars)) min := ^uint64(0) bestC := 0 - for c := 2; c < 18; c++ { - cost := uint64(1 << (c - 1)) - nbChunks := uint64(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - cost += nbPoints * ((fr.Limbs * 64) + nbChunks) + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c - 1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add if cost < min { min = cost bestC = c } } c := uint64(bestC) // window size - nbChunks := int(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c } - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // precompute all powers of base for our window // note here that if performance is critical, we can implement as in the msmX methods // this allocation to be on the stack - baseTable := make([]G1Jac, (1 << (c - 1))) - baseTable[0].Set(&g1Infinity) - baseTable[0].AddMixed(base) + baseTable := make([]G1Jac, (1 << (maxC - 1))) + baseTable[0].FromAffine(base) for i := 1; i < len(baseTable); i++ { baseTable[i] = baseTable[i-1] baseTable[i].AddMixed(base) } - - pScalars, _ := partitionScalars(scalars, c, false, runtime.NumCPU()) - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := 0; chunk < nbChunks; chunk++ { - jc := uint64(uint64(chunk) * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = (64%c) != 0 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } // convert our base exp table into affine to use AddMixed baseTableAff := BatchJacobianToAffineG1(baseTable) toReturn := make([]G1Jac, len(scalars)) + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + // for each digit, take value in the base table, double it c time, voilà. - parallel.Execute(len(pScalars), func(start, end int) { + parallel.Execute(len(scalars), func(start, end int) { var p G1Jac for i := start; i < end; i++ { p.Set(&g1Infinity) for chunk := nbChunks - 1; chunk >= 0; chunk-- { - s := selectors[chunk] if chunk != nbChunks-1 { for j := uint64(0); j < c; j++ { p.DoubleAssign() } } + offset := chunk * len(scalars) + digit := digits[i+offset] - bits := (pScalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (pScalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { + if digit == 0 { continue } // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { + if digit&1 == 0 { // add - p.AddMixed(&baseTableAff[bits-1]) + p.AddMixed(&baseTableAff[(digit>>1)-1]) } else { // sub - t := baseTableAff[bits & ^msbWindow] + t := baseTableAff[digit>>1] t.Neg(&t) p.AddMixed(&t) } @@ -948,3 +929,54 @@ func BatchScalarMultiplicationG1(base *G1Affine, scalars []fr.Element) []G1Affin toReturnAff := BatchJacobianToAffineG1(toReturn) return toReturnAff } + +// batch add affine coordinates +// using batch inversion +// special cases (doubling, infinity) must be filtered out before this call +func batchAddG1Affine[TP pG1Affine, TPP ppG1Affine, TC cG1Affine](R *TPP, P *TP, batchSize int) { + var lambda, lambdain TC + + // add part + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // invert denominator using montgomery batch invert technique + { + var accumulator fp.Element + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var d fp.Element + var rr G1Affine + + // add part + for j := 0; j < batchSize; j++ { + // computa lambda + d.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &d) + + // compute X, Y + rr.X.Square(&lambda[j]) + rr.X.Sub(&rr.X, &(*R)[j].X) + rr.X.Sub(&rr.X, &(*P)[j].X) + d.Sub(&(*R)[j].X, &rr.X) + rr.Y.Mul(&lambda[j], &d) + rr.Y.Sub(&rr.Y, &(*R)[j].Y) + (*R)[j].Set(&rr) + } +} diff --git a/ecc/bn254/g1_test.go b/ecc/bn254/g1_test.go index 8ee025d78..ec7ce9eed 100644 --- a/ecc/bn254/g1_test.go +++ b/ecc/bn254/g1_test.go @@ -19,6 +19,7 @@ package bn254 import ( "fmt" "math/big" + "math/rand" "testing" "github.com/consensys/gnark-crypto/ecc/bn254/fp" @@ -338,7 +339,7 @@ func TestG1AffineOps(t *testing.T) { r := fr.Modulus() var g G1Jac - g.mulGLV(&g1Gen, r) + g.ScalarMultiplication(&g1Gen, r) var scalar, blindedScalar, rminusone big.Int var op1, op2, op3, gneg G1Jac @@ -419,8 +420,7 @@ func TestG1AffineBatchScalarMultiplication(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } result := BatchScalarMultiplicationG1(&g1GenAff, sampleScalars[:]) @@ -433,7 +433,7 @@ func TestG1AffineBatchScalarMultiplication(t *testing.T) { var expectedJac G1Jac var expected G1Affine var b big.Int - expectedJac.mulGLV(&g1Gen, sampleScalars[i].ToBigInt(&b)) + expectedJac.ScalarMultiplication(&g1Gen, sampleScalars[i].ToBigIntRegular(&b)) expected.FromJacobian(&expectedJac) if !result[i].Equal(&expected) { return false @@ -460,6 +460,33 @@ func BenchmarkG1JacIsInSubGroup(b *testing.B) { } +func BenchmarkBatchAddG1Affine(b *testing.B) { + + var P, R pG1AffineC16 + var RR ppG1AffineC16 + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBasesG1(P[:]) + fillBenchBasesG1(R[:]) + + for i := 0; i < len(ridx); i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAddG1Affine[pG1AffineC16, ppG1AffineC16, cG1AffineC16](&RR, &P, len(P)) + } +} + func BenchmarkG1AffineBatchScalarMultiplication(b *testing.B) { // ensure every words of the scalars are filled var mixer fr.Element @@ -472,8 +499,7 @@ func BenchmarkG1AffineBatchScalarMultiplication(b *testing.B) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } for i := 5; i <= pow; i++ { diff --git a/ecc/bn254/g2.go b/ecc/bn254/g2.go index 0a8625caa..49615c541 100644 --- a/ecc/bn254/g2.go +++ b/ecc/bn254/g2.go @@ -17,13 +17,12 @@ package bn254 import ( - "math/big" - "runtime" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/internal/fptower" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" + "runtime" ) // G2Affine point in affine coordinates @@ -36,7 +35,7 @@ type G2Jac struct { X, Y, Z fptower.E2 } -// g2JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +// g2JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) type g2JacExtended struct { X, Y, ZZ, ZZZ fptower.E2 } @@ -55,6 +54,13 @@ func (p *G2Affine) Set(a *G2Affine) *G2Affine { return p } +// setInfinity sets p to O +func (p *G2Affine) setInfinity() *G2Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + // ScalarMultiplication computes and returns p = a ⋅ s func (p *G2Affine) ScalarMultiplication(a *G2Affine, s *big.Int) *G2Affine { var _p G2Jac @@ -370,15 +376,22 @@ func (p *G2Jac) IsOnCurve() bool { } // IsInSubGroup returns true if p is on the r-torsion, false otherwise. -// [r]P == 0 <==> Frob(P) == [6x²]P +// https://eprint.iacr.org/2022/348.pdf, sec. 3 and 5.1 +// [r]P == 0 <==> [x₀+1]P + ψ([x₀]P) + ψ²([x₀]P) = ψ³([2x₀]P) func (p *G2Jac) IsInSubGroup() bool { - var a, res G2Jac - a.psi(p) - res.ScalarMultiplication(p, &fixedCoeff). - SubAssign(&a) + var a, b, c, res G2Jac + a.ScalarMultiplication(p, &xGen) + b.psi(&a) + a.AddAssign(p) + res.psi(&b) + c.Set(&res). + AddAssign(&b). + AddAssign(&a) + res.psi(&res). + Double(&res). + SubAssign(&c) return res.IsOnCurve() && res.Z.IsZero() - } // mulWindowed computes a 2-bits windowed scalar multiplication @@ -472,8 +485,8 @@ func (p *G2Jac) mulGLV(a *G2Jac, s *big.Int) *G2Jac { // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // we don't target constant-timeness so we check first if we increase the bounds or not maxBit := k1.BitLen() @@ -599,15 +612,15 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { return p } - var A, B, X1ZZ2, X2ZZ1, Y1ZZZ2, Y2ZZZ1 fptower.E2 + var A, B, U1, U2, S1, S2 fptower.E2 // p2: q, p1: p - X2ZZ1.Mul(&q.X, &p.ZZ) - X1ZZ2.Mul(&p.X, &q.ZZ) - A.Sub(&X2ZZ1, &X1ZZ2) - Y2ZZZ1.Mul(&q.Y, &p.ZZZ) - Y1ZZZ2.Mul(&p.Y, &q.ZZZ) - B.Sub(&Y2ZZZ1, &Y1ZZZ2) + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) if A.IsZero() { if B.IsZero() { @@ -619,11 +632,7 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { return p } - var U1, U2, S1, S2, P, R, PP, PPP, Q, V fptower.E2 - U1.Mul(&p.X, &q.ZZ) - U2.Mul(&q.X, &p.ZZ) - S1.Mul(&p.Y, &q.ZZZ) - S2.Mul(&q.Y, &p.ZZZ) + var P, R, PP, PPP, Q, V fptower.E2 P.Sub(&U2, &U1) R.Sub(&S2, &S1) PP.Square(&P) @@ -648,6 +657,8 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { // double point in Jacobian extended coordinates // http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well func (p *g2JacExtended) double(q *g2JacExtended) *g2JacExtended { var U, V, W, S, XX, M fptower.E2 @@ -871,93 +882,70 @@ func (p *g2Proj) FromAffine(Q *G2Affine) *g2Proj { // and return resulting points in affine coordinates // uses a simple windowed-NAF like exponentiation algorithm func BatchScalarMultiplicationG2(base *G2Affine, scalars []fr.Element) []G2Affine { - // approximate cost in group ops is // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) nbPoints := uint64(len(scalars)) min := ^uint64(0) bestC := 0 - for c := 2; c < 18; c++ { - cost := uint64(1 << (c - 1)) - nbChunks := uint64(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - cost += nbPoints * ((fr.Limbs * 64) + nbChunks) + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c - 1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add if cost < min { min = cost bestC = c } } c := uint64(bestC) // window size - nbChunks := int(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c } - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // precompute all powers of base for our window // note here that if performance is critical, we can implement as in the msmX methods // this allocation to be on the stack - baseTable := make([]G2Jac, (1 << (c - 1))) - baseTable[0].Set(&g2Infinity) - baseTable[0].AddMixed(base) + baseTable := make([]G2Jac, (1 << (maxC - 1))) + baseTable[0].FromAffine(base) for i := 1; i < len(baseTable); i++ { baseTable[i] = baseTable[i-1] baseTable[i].AddMixed(base) } - - pScalars, _ := partitionScalars(scalars, c, false, runtime.NumCPU()) - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := 0; chunk < nbChunks; chunk++ { - jc := uint64(uint64(chunk) * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = (64%c) != 0 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } toReturn := make([]G2Affine, len(scalars)) + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + // for each digit, take value in the base table, double it c time, voilà. - parallel.Execute(len(pScalars), func(start, end int) { + parallel.Execute(len(scalars), func(start, end int) { var p G2Jac for i := start; i < end; i++ { p.Set(&g2Infinity) for chunk := nbChunks - 1; chunk >= 0; chunk-- { - s := selectors[chunk] if chunk != nbChunks-1 { for j := uint64(0); j < c; j++ { p.DoubleAssign() } } + offset := chunk * len(scalars) + digit := digits[i+offset] - bits := (pScalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (pScalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { + if digit == 0 { continue } // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { + if digit&1 == 0 { // add - p.AddAssign(&baseTable[bits-1]) + p.AddAssign(&baseTable[(digit>>1)-1]) } else { // sub - t := baseTable[bits & ^msbWindow] + t := baseTable[digit>>1] t.Neg(&t) p.AddAssign(&t) } @@ -970,3 +958,54 @@ func BatchScalarMultiplicationG2(base *G2Affine, scalars []fr.Element) []G2Affin }) return toReturn } + +// batch add affine coordinates +// using batch inversion +// special cases (doubling, infinity) must be filtered out before this call +func batchAddG2Affine[TP pG2Affine, TPP ppG2Affine, TC cG2Affine](R *TPP, P *TP, batchSize int) { + var lambda, lambdain TC + + // add part + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // invert denominator using montgomery batch invert technique + { + var accumulator fptower.E2 + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var d fptower.E2 + var rr G2Affine + + // add part + for j := 0; j < batchSize; j++ { + // computa lambda + d.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &d) + + // compute X, Y + rr.X.Square(&lambda[j]) + rr.X.Sub(&rr.X, &(*R)[j].X) + rr.X.Sub(&rr.X, &(*P)[j].X) + d.Sub(&(*R)[j].X, &rr.X) + rr.Y.Mul(&lambda[j], &d) + rr.Y.Sub(&rr.Y, &(*R)[j].Y) + (*R)[j].Set(&rr) + } +} diff --git a/ecc/bn254/g2_test.go b/ecc/bn254/g2_test.go index 17c09d95b..bc0b7fff4 100644 --- a/ecc/bn254/g2_test.go +++ b/ecc/bn254/g2_test.go @@ -19,6 +19,7 @@ package bn254 import ( "fmt" "math/big" + "math/rand" "testing" "github.com/consensys/gnark-crypto/ecc/bn254/internal/fptower" @@ -338,7 +339,7 @@ func TestG2AffineOps(t *testing.T) { r := fr.Modulus() var g G2Jac - g.mulGLV(&g2Gen, r) + g.ScalarMultiplication(&g2Gen, r) var scalar, blindedScalar, rminusone big.Int var op1, op2, op3, gneg G2Jac @@ -463,8 +464,7 @@ func TestG2AffineBatchScalarMultiplication(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } result := BatchScalarMultiplicationG2(&g2GenAff, sampleScalars[:]) @@ -477,7 +477,7 @@ func TestG2AffineBatchScalarMultiplication(t *testing.T) { var expectedJac G2Jac var expected G2Affine var b big.Int - expectedJac.mulGLV(&g2Gen, sampleScalars[i].ToBigInt(&b)) + expectedJac.ScalarMultiplication(&g2Gen, sampleScalars[i].ToBigIntRegular(&b)) expected.FromJacobian(&expectedJac) if !result[i].Equal(&expected) { return false @@ -504,6 +504,33 @@ func BenchmarkG2JacIsInSubGroup(b *testing.B) { } +func BenchmarkBatchAddG2Affine(b *testing.B) { + + var P, R pG2AffineC16 + var RR ppG2AffineC16 + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBasesG2(P[:]) + fillBenchBasesG2(R[:]) + + for i := 0; i < len(ridx); i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAddG2Affine[pG2AffineC16, ppG2AffineC16, cG2AffineC16](&RR, &P, len(P)) + } +} + func BenchmarkG2AffineBatchScalarMultiplication(b *testing.B) { // ensure every words of the scalars are filled var mixer fr.Element @@ -516,8 +543,7 @@ func BenchmarkG2AffineBatchScalarMultiplication(b *testing.B) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } for i := 5; i <= pow; i++ { diff --git a/ecc/bn254/hash_to_g1.go b/ecc/bn254/hash_to_g1.go index 9df840a53..50d40cd0e 100644 --- a/ecc/bn254/hash_to_g1.go +++ b/ecc/bn254/hash_to_g1.go @@ -17,7 +17,6 @@ package bn254 import ( - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bn254/fp" ) @@ -98,35 +97,14 @@ func mapToCurve1(u *fp.Element) G1Affine { return G1Affine{x, y} } -// hashToFp hashes msg to count prime field elements. -// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 -func hashToFp(msg, dst []byte, count int) ([]fp.Element, error) { - // 128 bits of security - // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 - const Bytes = 1 + (fp.Bits-1)/8 - const L = 16 + Bytes - - lenInBytes := count * L - pseudoRandomBytes, err := ecc.ExpandMsgXmd(msg, dst, lenInBytes) - if err != nil { - return nil, err - } - - res := make([]fp.Element, count) - for i := 0; i < count; i++ { - res[i].SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) - } - return res, nil -} - // g1Sgn0 is an algebraic substitute for the notion of sign in ordered fields // Namely, every non-zero quadratic residue in a finite field of characteristic =/= 2 has exactly two square roots, one of each sign // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#name-the-sgn0-function // The sign of an element is not obviously related to that of its Montgomery form func g1Sgn0(z *fp.Element) uint64 { - nonMont := *z - nonMont.FromMont() + nonMont := z.Bits() + // m == 1 return nonMont[0] % 2 @@ -141,11 +119,11 @@ func MapToG1(u fp.Element) G1Affine { // EncodeToG1 hashes a message to a point on the G1 curve using the SVDW map. // It is faster than HashToG1, but the result is not uniformly distributed. Unsuitable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func EncodeToG1(msg, dst []byte) (G1Affine, error) { var res G1Affine - u, err := hashToFp(msg, dst, 1) + u, err := fp.Hash(msg, dst, 1) if err != nil { return res, err } @@ -158,9 +136,9 @@ func EncodeToG1(msg, dst []byte) (G1Affine, error) { // HashToG1 hashes a message to a point on the G1 curve using the SVDW map. // Slower than EncodeToG1, but usable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func HashToG1(msg, dst []byte) (G1Affine, error) { - u, err := hashToFp(msg, dst, 2*1) + u, err := fp.Hash(msg, dst, 2*1) if err != nil { return G1Affine{}, err } diff --git a/ecc/bn254/hash_to_g1_test.go b/ecc/bn254/hash_to_g1_test.go index 439cbba03..da1b1312d 100644 --- a/ecc/bn254/hash_to_g1_test.go +++ b/ecc/bn254/hash_to_g1_test.go @@ -26,7 +26,7 @@ import ( func TestHashToFpG1(t *testing.T) { for _, c := range encodeToG1Vector.cases { - elems, err := hashToFp([]byte(c.msg), encodeToG1Vector.dst, 1) + elems, err := fp.Hash([]byte(c.msg), encodeToG1Vector.dst, 1) if err != nil { t.Error(err) } @@ -34,7 +34,7 @@ func TestHashToFpG1(t *testing.T) { } for _, c := range hashToG1Vector.cases { - elems, err := hashToFp([]byte(c.msg), hashToG1Vector.dst, 2*1) + elems, err := fp.Hash([]byte(c.msg), hashToG1Vector.dst, 2*1) if err != nil { t.Error(err) } @@ -175,7 +175,7 @@ func BenchmarkHashToG1(b *testing.B) { } } -//Only works on simple extensions (two-story towers) +// Only works on simple extensions (two-story towers) func g1CoordSetString(z *fp.Element, s string) { z.SetString(s) } diff --git a/ecc/bn254/hash_to_g2.go b/ecc/bn254/hash_to_g2.go index 439540504..7cf07eeac 100644 --- a/ecc/bn254/hash_to_g2.go +++ b/ecc/bn254/hash_to_g2.go @@ -119,8 +119,7 @@ func mapToCurve2(u *fptower.E2) G2Affine { // The sign of an element is not obviously related to that of its Montgomery form func g2Sgn0(z *fptower.E2) uint64 { - nonMont := *z - nonMont.FromMont() + nonMont := z.Bits() sign := uint64(0) // 1. sign = 0 zero := uint64(1) // 2. zero = 1 @@ -152,11 +151,11 @@ func MapToG2(u fptower.E2) G2Affine { // EncodeToG2 hashes a message to a point on the G2 curve using the SVDW map. // It is faster than HashToG2, but the result is not uniformly distributed. Unsuitable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func EncodeToG2(msg, dst []byte) (G2Affine, error) { var res G2Affine - u, err := hashToFp(msg, dst, 2) + u, err := fp.Hash(msg, dst, 2) if err != nil { return res, err } @@ -173,9 +172,9 @@ func EncodeToG2(msg, dst []byte) (G2Affine, error) { // HashToG2 hashes a message to a point on the G2 curve using the SVDW map. // Slower than EncodeToG2, but usable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func HashToG2(msg, dst []byte) (G2Affine, error) { - u, err := hashToFp(msg, dst, 2*2) + u, err := fp.Hash(msg, dst, 2*2) if err != nil { return G2Affine{}, err } diff --git a/ecc/bn254/hash_to_g2_test.go b/ecc/bn254/hash_to_g2_test.go index 8971fe83d..33e231c67 100644 --- a/ecc/bn254/hash_to_g2_test.go +++ b/ecc/bn254/hash_to_g2_test.go @@ -28,7 +28,7 @@ import ( func TestHashToFpG2(t *testing.T) { for _, c := range encodeToG2Vector.cases { - elems, err := hashToFp([]byte(c.msg), encodeToG2Vector.dst, 2) + elems, err := fp.Hash([]byte(c.msg), encodeToG2Vector.dst, 2) if err != nil { t.Error(err) } @@ -36,7 +36,7 @@ func TestHashToFpG2(t *testing.T) { } for _, c := range hashToG2Vector.cases { - elems, err := hashToFp([]byte(c.msg), hashToG2Vector.dst, 2*2) + elems, err := fp.Hash([]byte(c.msg), hashToG2Vector.dst, 2*2) if err != nil { t.Error(err) } @@ -177,7 +177,7 @@ func BenchmarkHashToG2(b *testing.B) { } } -//Only works on simple extensions (two-story towers) +// Only works on simple extensions (two-story towers) func g2CoordSetString(z *fptower.E2, s string) { ssplit := strings.Split(s, ",") if len(ssplit) != 2 { diff --git a/ecc/bn254/internal/fptower/e12.go b/ecc/bn254/internal/fptower/e12.go index 087e81e17..8fa8345ec 100644 --- a/ecc/bn254/internal/fptower/e12.go +++ b/ecc/bn254/internal/fptower/e12.go @@ -17,7 +17,6 @@ package fptower import ( - "encoding/binary" "errors" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bn254/fp" @@ -68,20 +67,6 @@ func (z *E12) SetOne() *E12 { return z } -// ToMont converts to Mont form -func (z *E12) ToMont() *E12 { - z.C0.ToMont() - z.C1.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E12) FromMont() *E12 { - z.C0.FromMont() - z.C1.FromMont() - return z -} - // Add set z=x+y in E12 and return z func (z *E12) Add(x, y *E12) *E12 { z.C0.Add(&x.C0, &y.C0) @@ -119,6 +104,10 @@ func (z *E12) IsZero() bool { return z.C0.IsZero() && z.C1.IsZero() } +func (z *E12) IsOne() bool { + return z.C0.IsOne() && z.C1.IsZero() +} + // Mul set z=x*y in E12 and return z func (z *E12) Mul(x, y *E12) *E12 { var a, b, c E6 @@ -226,9 +215,12 @@ func (z *E12) CyclotomicSquareCompressed(x *E12) *E12 { // DecompressKarabina Karabina's cyclotomic square result // if g3 != 0 -// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// +// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// // if g3 == 0 -// g4 = 2g1g5/g2 +// +// g4 = 2g1g5/g2 // // if g3=g2=0 then g4=g5=g1=0 and g0=1 (x=1) // Theorem 3.1 is well-defined for all x in Gϕₙ\{1} @@ -289,9 +281,12 @@ func (z *E12) DecompressKarabina(x *E12) *E12 { // BatchDecompressKarabina multiple Karabina's cyclotomic square results // if g3 != 0 -// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// +// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// // if g3 == 0 -// g4 = 2g1g5/g2 +// +// g4 = 2g1g5/g2 // // if g3=g2=0 then g4=g5=g1=0 and g0=1 (x=1) // Theorem 3.1 is well-defined for all x in Gϕₙ\{1} @@ -602,8 +597,8 @@ func (z *E12) ExpGLV(x E12, k *big.Int) *E12 { table[14].Mul(&table[11], &table[2]) // bounds on the lattice base vectors guarantee that s1, s2 are len(r)/2 bits long max - s1.SetBigInt(&s[0]).FromMont() - s2.SetBigInt(&s[1]).FromMont() + s1 = s1.SetBigInt(&s[0]).Bits() + s2 = s2.SetBigInt(&s[1]).Bits() // loop starts from len(s1)/2 due to the bounds for i := len(s1) / 2; i >= 0; i-- { @@ -652,69 +647,20 @@ func (z *E12) Unmarshal(buf []byte) error { // Bytes returns the regular (non montgomery) value // of z as a big-endian byte array. -// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... +// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... func (z *E12) Bytes() (r [SizeOfGT]byte) { - _z := *z - _z.FromMont() - binary.BigEndian.PutUint64(r[376:384], _z.C0.B0.A0[0]) - binary.BigEndian.PutUint64(r[368:376], _z.C0.B0.A0[1]) - binary.BigEndian.PutUint64(r[360:368], _z.C0.B0.A0[2]) - binary.BigEndian.PutUint64(r[352:360], _z.C0.B0.A0[3]) - - binary.BigEndian.PutUint64(r[344:352], _z.C0.B0.A1[0]) - binary.BigEndian.PutUint64(r[336:344], _z.C0.B0.A1[1]) - binary.BigEndian.PutUint64(r[328:336], _z.C0.B0.A1[2]) - binary.BigEndian.PutUint64(r[320:328], _z.C0.B0.A1[3]) - - binary.BigEndian.PutUint64(r[312:320], _z.C0.B1.A0[0]) - binary.BigEndian.PutUint64(r[304:312], _z.C0.B1.A0[1]) - binary.BigEndian.PutUint64(r[296:304], _z.C0.B1.A0[2]) - binary.BigEndian.PutUint64(r[288:296], _z.C0.B1.A0[3]) - - binary.BigEndian.PutUint64(r[280:288], _z.C0.B1.A1[0]) - binary.BigEndian.PutUint64(r[272:280], _z.C0.B1.A1[1]) - binary.BigEndian.PutUint64(r[264:272], _z.C0.B1.A1[2]) - binary.BigEndian.PutUint64(r[256:264], _z.C0.B1.A1[3]) - - binary.BigEndian.PutUint64(r[248:256], _z.C0.B2.A0[0]) - binary.BigEndian.PutUint64(r[240:248], _z.C0.B2.A0[1]) - binary.BigEndian.PutUint64(r[232:240], _z.C0.B2.A0[2]) - binary.BigEndian.PutUint64(r[224:232], _z.C0.B2.A0[3]) - - binary.BigEndian.PutUint64(r[216:224], _z.C0.B2.A1[0]) - binary.BigEndian.PutUint64(r[208:216], _z.C0.B2.A1[1]) - binary.BigEndian.PutUint64(r[200:208], _z.C0.B2.A1[2]) - binary.BigEndian.PutUint64(r[192:200], _z.C0.B2.A1[3]) - - binary.BigEndian.PutUint64(r[184:192], _z.C1.B0.A0[0]) - binary.BigEndian.PutUint64(r[176:184], _z.C1.B0.A0[1]) - binary.BigEndian.PutUint64(r[168:176], _z.C1.B0.A0[2]) - binary.BigEndian.PutUint64(r[160:168], _z.C1.B0.A0[3]) - - binary.BigEndian.PutUint64(r[152:160], _z.C1.B0.A1[0]) - binary.BigEndian.PutUint64(r[144:152], _z.C1.B0.A1[1]) - binary.BigEndian.PutUint64(r[136:144], _z.C1.B0.A1[2]) - binary.BigEndian.PutUint64(r[128:136], _z.C1.B0.A1[3]) - - binary.BigEndian.PutUint64(r[120:128], _z.C1.B1.A0[0]) - binary.BigEndian.PutUint64(r[112:120], _z.C1.B1.A0[1]) - binary.BigEndian.PutUint64(r[104:112], _z.C1.B1.A0[2]) - binary.BigEndian.PutUint64(r[96:104], _z.C1.B1.A0[3]) - - binary.BigEndian.PutUint64(r[88:96], _z.C1.B1.A1[0]) - binary.BigEndian.PutUint64(r[80:88], _z.C1.B1.A1[1]) - binary.BigEndian.PutUint64(r[72:80], _z.C1.B1.A1[2]) - binary.BigEndian.PutUint64(r[64:72], _z.C1.B1.A1[3]) - - binary.BigEndian.PutUint64(r[56:64], _z.C1.B2.A0[0]) - binary.BigEndian.PutUint64(r[48:56], _z.C1.B2.A0[1]) - binary.BigEndian.PutUint64(r[40:48], _z.C1.B2.A0[2]) - binary.BigEndian.PutUint64(r[32:40], _z.C1.B2.A0[3]) - - binary.BigEndian.PutUint64(r[24:32], _z.C1.B2.A1[0]) - binary.BigEndian.PutUint64(r[16:24], _z.C1.B2.A1[1]) - binary.BigEndian.PutUint64(r[8:16], _z.C1.B2.A1[2]) - binary.BigEndian.PutUint64(r[0:8], _z.C1.B2.A1[3]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[352:352+fp.Bytes]), z.C0.B0.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[320:320+fp.Bytes]), z.C0.B0.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[288:288+fp.Bytes]), z.C0.B1.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[256:256+fp.Bytes]), z.C0.B1.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[224:224+fp.Bytes]), z.C0.B2.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[192:192+fp.Bytes]), z.C0.B2.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[160:160+fp.Bytes]), z.C1.B0.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[128:128+fp.Bytes]), z.C1.B0.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[96:96+fp.Bytes]), z.C1.B1.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[64:64+fp.Bytes]), z.C1.B1.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[32:32+fp.Bytes]), z.C1.B2.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(r[0:0+fp.Bytes]), z.C1.B2.A1) return } @@ -722,34 +668,47 @@ func (z *E12) Bytes() (r [SizeOfGT]byte) { // SetBytes interprets e as the bytes of a big-endian GT // sets z to that value (in Montgomery form), and returns z. // size(e) == 32 * 12 -// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... +// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... func (z *E12) SetBytes(e []byte) error { if len(e) != SizeOfGT { return errors.New("invalid buffer size") } - z.C0.B0.A0.SetBytes(e[352 : 352+fp.Bytes]) - - z.C0.B0.A1.SetBytes(e[320 : 320+fp.Bytes]) - - z.C0.B1.A0.SetBytes(e[288 : 288+fp.Bytes]) - - z.C0.B1.A1.SetBytes(e[256 : 256+fp.Bytes]) - - z.C0.B2.A0.SetBytes(e[224 : 224+fp.Bytes]) - - z.C0.B2.A1.SetBytes(e[192 : 192+fp.Bytes]) - - z.C1.B0.A0.SetBytes(e[160 : 160+fp.Bytes]) - - z.C1.B0.A1.SetBytes(e[128 : 128+fp.Bytes]) - - z.C1.B1.A0.SetBytes(e[96 : 96+fp.Bytes]) - - z.C1.B1.A1.SetBytes(e[64 : 64+fp.Bytes]) - - z.C1.B2.A0.SetBytes(e[32 : 32+fp.Bytes]) - - z.C1.B2.A1.SetBytes(e[0 : 0+fp.Bytes]) + if err := z.C0.B0.A0.SetBytesCanonical(e[352 : 352+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B0.A1.SetBytesCanonical(e[320 : 320+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B1.A0.SetBytesCanonical(e[288 : 288+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B1.A1.SetBytesCanonical(e[256 : 256+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B2.A0.SetBytesCanonical(e[224 : 224+fp.Bytes]); err != nil { + return err + } + if err := z.C0.B2.A1.SetBytesCanonical(e[192 : 192+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B0.A0.SetBytesCanonical(e[160 : 160+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B0.A1.SetBytesCanonical(e[128 : 128+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B1.A0.SetBytesCanonical(e[96 : 96+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B1.A1.SetBytesCanonical(e[64 : 64+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B2.A0.SetBytesCanonical(e[32 : 32+fp.Bytes]); err != nil { + return err + } + if err := z.C1.B2.A1.SetBytesCanonical(e[0 : 0+fp.Bytes]); err != nil { + return err + } return nil } diff --git a/ecc/bn254/internal/fptower/e2.go b/ecc/bn254/internal/fptower/e2.go index 12be3b400..65dda3992 100644 --- a/ecc/bn254/internal/fptower/e2.go +++ b/ecc/bn254/internal/fptower/e2.go @@ -31,12 +31,20 @@ func (z *E2) Equal(x *E2) bool { return z.A0.Equal(&x.A0) && z.A1.Equal(&x.A1) } +// Bits +// TODO @gbotrel fixme this shouldn't return a E2 +func (z *E2) Bits() E2 { + r := E2{} + r.A0 = z.A0.Bits() + r.A1 = z.A1.Bits() + return r +} + // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *E2) Cmp(x *E2) int { if a1 := z.A1.Cmp(&x.A1); a1 != 0 { return a1 @@ -98,6 +106,10 @@ func (z *E2) IsZero() bool { return z.A0.IsZero() && z.A1.IsZero() } +func (z *E2) IsOne() bool { + return z.A0.IsOne() && z.A1.IsZero() +} + // Add adds two elements of E2 func (z *E2) Add(x, y *E2) *E2 { addE2(z, x, y) @@ -127,20 +139,6 @@ func (z *E2) String() string { return z.A0.String() + "+" + z.A1.String() + "*u" } -// ToMont converts to mont form -func (z *E2) ToMont() *E2 { - z.A0.ToMont() - z.A1.ToMont() - return z -} - -// FromMont converts from mont form -func (z *E2) FromMont() *E2 { - z.A0.FromMont() - z.A1.FromMont() - return z -} - // MulByElement multiplies an element in E2 by an element in fp func (z *E2) MulByElement(x *E2, y *fp.Element) *E2 { var yCopy fp.Element diff --git a/ecc/bn254/internal/fptower/e2_adx_amd64.s b/ecc/bn254/internal/fptower/e2_adx_amd64.s deleted file mode 100644 index aabff96ad..000000000 --- a/ecc/bn254/internal/fptower/e2_adx_amd64.s +++ /dev/null @@ -1,732 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x3c208c16d87cfd47 -DATA q<>+8(SB)/8, $0x97816a916871ca8d -DATA q<>+16(SB)/8, $0xb85045b68181585d -DATA q<>+24(SB)/8, $0x30644e72e131a029 -GLOBL q<>(SB), (RODATA+NOPTR), $32 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x87d20782e4866389 -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, rb0, rb1, rb2, rb3) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - -// this code is generated and identical to fp.Mul(...) -#define MUL() \ - XORQ AX, AX; \ - MOVQ SI, DX; \ - MULXQ R14, R10, R11; \ - MULXQ R15, AX, R12; \ - ADOXQ AX, R11; \ - MULXQ CX, AX, R13; \ - ADOXQ AX, R12; \ - MULXQ BX, AX, BP; \ - ADOXQ AX, R13; \ - MOVQ $0, AX; \ - ADOXQ AX, BP; \ - PUSHQ BP; \ - MOVQ qInv0<>(SB), DX; \ - IMULQ R10, DX; \ - XORQ AX, AX; \ - MULXQ q<>+0(SB), AX, BP; \ - ADCXQ R10, AX; \ - MOVQ BP, R10; \ - POPQ BP; \ - ADCXQ R11, R10; \ - MULXQ q<>+8(SB), AX, R11; \ - ADOXQ AX, R10; \ - ADCXQ R12, R11; \ - MULXQ q<>+16(SB), AX, R12; \ - ADOXQ AX, R11; \ - ADCXQ R13, R12; \ - MULXQ q<>+24(SB), AX, R13; \ - ADOXQ AX, R12; \ - MOVQ $0, AX; \ - ADCXQ AX, R13; \ - ADOXQ BP, R13; \ - XORQ AX, AX; \ - MOVQ DI, DX; \ - MULXQ R14, AX, BP; \ - ADOXQ AX, R10; \ - ADCXQ BP, R11; \ - MULXQ R15, AX, BP; \ - ADOXQ AX, R11; \ - ADCXQ BP, R12; \ - MULXQ CX, AX, BP; \ - ADOXQ AX, R12; \ - ADCXQ BP, R13; \ - MULXQ BX, AX, BP; \ - ADOXQ AX, R13; \ - MOVQ $0, AX; \ - ADCXQ AX, BP; \ - ADOXQ AX, BP; \ - PUSHQ BP; \ - MOVQ qInv0<>(SB), DX; \ - IMULQ R10, DX; \ - XORQ AX, AX; \ - MULXQ q<>+0(SB), AX, BP; \ - ADCXQ R10, AX; \ - MOVQ BP, R10; \ - POPQ BP; \ - ADCXQ R11, R10; \ - MULXQ q<>+8(SB), AX, R11; \ - ADOXQ AX, R10; \ - ADCXQ R12, R11; \ - MULXQ q<>+16(SB), AX, R12; \ - ADOXQ AX, R11; \ - ADCXQ R13, R12; \ - MULXQ q<>+24(SB), AX, R13; \ - ADOXQ AX, R12; \ - MOVQ $0, AX; \ - ADCXQ AX, R13; \ - ADOXQ BP, R13; \ - XORQ AX, AX; \ - MOVQ R8, DX; \ - MULXQ R14, AX, BP; \ - ADOXQ AX, R10; \ - ADCXQ BP, R11; \ - MULXQ R15, AX, BP; \ - ADOXQ AX, R11; \ - ADCXQ BP, R12; \ - MULXQ CX, AX, BP; \ - ADOXQ AX, R12; \ - ADCXQ BP, R13; \ - MULXQ BX, AX, BP; \ - ADOXQ AX, R13; \ - MOVQ $0, AX; \ - ADCXQ AX, BP; \ - ADOXQ AX, BP; \ - PUSHQ BP; \ - MOVQ qInv0<>(SB), DX; \ - IMULQ R10, DX; \ - XORQ AX, AX; \ - MULXQ q<>+0(SB), AX, BP; \ - ADCXQ R10, AX; \ - MOVQ BP, R10; \ - POPQ BP; \ - ADCXQ R11, R10; \ - MULXQ q<>+8(SB), AX, R11; \ - ADOXQ AX, R10; \ - ADCXQ R12, R11; \ - MULXQ q<>+16(SB), AX, R12; \ - ADOXQ AX, R11; \ - ADCXQ R13, R12; \ - MULXQ q<>+24(SB), AX, R13; \ - ADOXQ AX, R12; \ - MOVQ $0, AX; \ - ADCXQ AX, R13; \ - ADOXQ BP, R13; \ - XORQ AX, AX; \ - MOVQ R9, DX; \ - MULXQ R14, AX, BP; \ - ADOXQ AX, R10; \ - ADCXQ BP, R11; \ - MULXQ R15, AX, BP; \ - ADOXQ AX, R11; \ - ADCXQ BP, R12; \ - MULXQ CX, AX, BP; \ - ADOXQ AX, R12; \ - ADCXQ BP, R13; \ - MULXQ BX, AX, BP; \ - ADOXQ AX, R13; \ - MOVQ $0, AX; \ - ADCXQ AX, BP; \ - ADOXQ AX, BP; \ - PUSHQ BP; \ - MOVQ qInv0<>(SB), DX; \ - IMULQ R10, DX; \ - XORQ AX, AX; \ - MULXQ q<>+0(SB), AX, BP; \ - ADCXQ R10, AX; \ - MOVQ BP, R10; \ - POPQ BP; \ - ADCXQ R11, R10; \ - MULXQ q<>+8(SB), AX, R11; \ - ADOXQ AX, R10; \ - ADCXQ R12, R11; \ - MULXQ q<>+16(SB), AX, R12; \ - ADOXQ AX, R11; \ - ADCXQ R13, R12; \ - MULXQ q<>+24(SB), AX, R13; \ - ADOXQ AX, R12; \ - MOVQ $0, AX; \ - ADCXQ AX, R13; \ - ADOXQ BP, R13; \ - -TEXT ·addE2(SB), NOSPLIT, $0-24 - MOVQ x+8(FP), AX - MOVQ 0(AX), BX - MOVQ 8(AX), SI - MOVQ 16(AX), DI - MOVQ 24(AX), R8 - MOVQ y+16(FP), DX - ADDQ 0(DX), BX - ADCQ 8(DX), SI - ADCQ 16(DX), DI - ADCQ 24(DX), R8 - - // reduce element(BX,SI,DI,R8) using temp registers (R9,R10,R11,R12) - REDUCE(BX,SI,DI,R8,R9,R10,R11,R12) - - MOVQ res+0(FP), CX - MOVQ BX, 0(CX) - MOVQ SI, 8(CX) - MOVQ DI, 16(CX) - MOVQ R8, 24(CX) - MOVQ 32(AX), BX - MOVQ 40(AX), SI - MOVQ 48(AX), DI - MOVQ 56(AX), R8 - ADDQ 32(DX), BX - ADCQ 40(DX), SI - ADCQ 48(DX), DI - ADCQ 56(DX), R8 - - // reduce element(BX,SI,DI,R8) using temp registers (R13,R14,R15,R9) - REDUCE(BX,SI,DI,R8,R13,R14,R15,R9) - - MOVQ BX, 32(CX) - MOVQ SI, 40(CX) - MOVQ DI, 48(CX) - MOVQ R8, 56(CX) - RET - -TEXT ·doubleE2(SB), NOSPLIT, $0-16 - MOVQ res+0(FP), DX - MOVQ x+8(FP), AX - MOVQ 0(AX), CX - MOVQ 8(AX), BX - MOVQ 16(AX), SI - MOVQ 24(AX), DI - ADDQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(CX,BX,SI,DI) using temp registers (R8,R9,R10,R11) - REDUCE(CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ CX, 0(DX) - MOVQ BX, 8(DX) - MOVQ SI, 16(DX) - MOVQ DI, 24(DX) - MOVQ 32(AX), CX - MOVQ 40(AX), BX - MOVQ 48(AX), SI - MOVQ 56(AX), DI - ADDQ CX, CX - ADCQ BX, BX - ADCQ SI, SI - ADCQ DI, DI - - // reduce element(CX,BX,SI,DI) using temp registers (R12,R13,R14,R15) - REDUCE(CX,BX,SI,DI,R12,R13,R14,R15) - - MOVQ CX, 32(DX) - MOVQ BX, 40(DX) - MOVQ SI, 48(DX) - MOVQ DI, 56(DX) - RET - -TEXT ·subE2(SB), NOSPLIT, $0-24 - XORQ DI, DI - MOVQ x+8(FP), SI - MOVQ 0(SI), AX - MOVQ 8(SI), DX - MOVQ 16(SI), CX - MOVQ 24(SI), BX - MOVQ y+16(FP), SI - SUBQ 0(SI), AX - SBBQ 8(SI), DX - SBBQ 16(SI), CX - SBBQ 24(SI), BX - MOVQ x+8(FP), SI - MOVQ $0x3c208c16d87cfd47, R8 - MOVQ $0x97816a916871ca8d, R9 - MOVQ $0xb85045b68181585d, R10 - MOVQ $0x30644e72e131a029, R11 - CMOVQCC DI, R8 - CMOVQCC DI, R9 - CMOVQCC DI, R10 - CMOVQCC DI, R11 - ADDQ R8, AX - ADCQ R9, DX - ADCQ R10, CX - ADCQ R11, BX - MOVQ res+0(FP), R12 - MOVQ AX, 0(R12) - MOVQ DX, 8(R12) - MOVQ CX, 16(R12) - MOVQ BX, 24(R12) - MOVQ 32(SI), AX - MOVQ 40(SI), DX - MOVQ 48(SI), CX - MOVQ 56(SI), BX - MOVQ y+16(FP), SI - SUBQ 32(SI), AX - SBBQ 40(SI), DX - SBBQ 48(SI), CX - SBBQ 56(SI), BX - MOVQ $0x3c208c16d87cfd47, R13 - MOVQ $0x97816a916871ca8d, R14 - MOVQ $0xb85045b68181585d, R15 - MOVQ $0x30644e72e131a029, R8 - CMOVQCC DI, R13 - CMOVQCC DI, R14 - CMOVQCC DI, R15 - CMOVQCC DI, R8 - ADDQ R13, AX - ADCQ R14, DX - ADCQ R15, CX - ADCQ R8, BX - MOVQ res+0(FP), SI - MOVQ AX, 32(SI) - MOVQ DX, 40(SI) - MOVQ CX, 48(SI) - MOVQ BX, 56(SI) - RET - -TEXT ·negE2(SB), NOSPLIT, $0-16 - MOVQ res+0(FP), DX - MOVQ x+8(FP), AX - MOVQ 0(AX), BX - MOVQ 8(AX), SI - MOVQ 16(AX), DI - MOVQ 24(AX), R8 - MOVQ BX, AX - ORQ SI, AX - ORQ DI, AX - ORQ R8, AX - TESTQ AX, AX - JNE l1 - MOVQ AX, 0(DX) - MOVQ AX, 8(DX) - MOVQ AX, 16(DX) - MOVQ AX, 24(DX) - JMP l3 - -l1: - MOVQ $0x3c208c16d87cfd47, CX - SUBQ BX, CX - MOVQ CX, 0(DX) - MOVQ $0x97816a916871ca8d, CX - SBBQ SI, CX - MOVQ CX, 8(DX) - MOVQ $0xb85045b68181585d, CX - SBBQ DI, CX - MOVQ CX, 16(DX) - MOVQ $0x30644e72e131a029, CX - SBBQ R8, CX - MOVQ CX, 24(DX) - -l3: - MOVQ x+8(FP), AX - MOVQ 32(AX), BX - MOVQ 40(AX), SI - MOVQ 48(AX), DI - MOVQ 56(AX), R8 - MOVQ BX, AX - ORQ SI, AX - ORQ DI, AX - ORQ R8, AX - TESTQ AX, AX - JNE l2 - MOVQ AX, 32(DX) - MOVQ AX, 40(DX) - MOVQ AX, 48(DX) - MOVQ AX, 56(DX) - RET - -l2: - MOVQ $0x3c208c16d87cfd47, CX - SUBQ BX, CX - MOVQ CX, 32(DX) - MOVQ $0x97816a916871ca8d, CX - SBBQ SI, CX - MOVQ CX, 40(DX) - MOVQ $0xb85045b68181585d, CX - SBBQ DI, CX - MOVQ CX, 48(DX) - MOVQ $0x30644e72e131a029, CX - SBBQ R8, CX - MOVQ CX, 56(DX) - RET - -TEXT ·mulNonResE2(SB), NOSPLIT, $0-16 - MOVQ x+8(FP), R10 - MOVQ 0(R10), AX - MOVQ 8(R10), DX - MOVQ 16(R10), CX - MOVQ 24(R10), BX - ADDQ AX, AX - ADCQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - - // reduce element(AX,DX,CX,BX) using temp registers (R11,R12,R13,R14) - REDUCE(AX,DX,CX,BX,R11,R12,R13,R14) - - ADDQ AX, AX - ADCQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - - // reduce element(AX,DX,CX,BX) using temp registers (R15,R11,R12,R13) - REDUCE(AX,DX,CX,BX,R15,R11,R12,R13) - - ADDQ AX, AX - ADCQ DX, DX - ADCQ CX, CX - ADCQ BX, BX - - // reduce element(AX,DX,CX,BX) using temp registers (R14,R15,R11,R12) - REDUCE(AX,DX,CX,BX,R14,R15,R11,R12) - - ADDQ 0(R10), AX - ADCQ 8(R10), DX - ADCQ 16(R10), CX - ADCQ 24(R10), BX - - // reduce element(AX,DX,CX,BX) using temp registers (R13,R14,R15,R11) - REDUCE(AX,DX,CX,BX,R13,R14,R15,R11) - - MOVQ 32(R10), SI - MOVQ 40(R10), DI - MOVQ 48(R10), R8 - MOVQ 56(R10), R9 - XORQ R12, R12 - SUBQ SI, AX - SBBQ DI, DX - SBBQ R8, CX - SBBQ R9, BX - MOVQ $0x3c208c16d87cfd47, R13 - MOVQ $0x97816a916871ca8d, R14 - MOVQ $0xb85045b68181585d, R15 - MOVQ $0x30644e72e131a029, R11 - CMOVQCC R12, R13 - CMOVQCC R12, R14 - CMOVQCC R12, R15 - CMOVQCC R12, R11 - ADDQ R13, AX - ADCQ R14, DX - ADCQ R15, CX - ADCQ R11, BX - ADDQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R13,R14,R15,R11) - REDUCE(SI,DI,R8,R9,R13,R14,R15,R11) - - ADDQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R12,R13,R14,R15) - REDUCE(SI,DI,R8,R9,R12,R13,R14,R15) - - ADDQ SI, SI - ADCQ DI, DI - ADCQ R8, R8 - ADCQ R9, R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R11,R12,R13,R14) - REDUCE(SI,DI,R8,R9,R11,R12,R13,R14) - - ADDQ 32(R10), SI - ADCQ 40(R10), DI - ADCQ 48(R10), R8 - ADCQ 56(R10), R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R15,R11,R12,R13) - REDUCE(SI,DI,R8,R9,R15,R11,R12,R13) - - ADDQ 0(R10), SI - ADCQ 8(R10), DI - ADCQ 16(R10), R8 - ADCQ 24(R10), R9 - - // reduce element(SI,DI,R8,R9) using temp registers (R14,R15,R11,R12) - REDUCE(SI,DI,R8,R9,R14,R15,R11,R12) - - MOVQ res+0(FP), R10 - MOVQ AX, 0(R10) - MOVQ DX, 8(R10) - MOVQ CX, 16(R10) - MOVQ BX, 24(R10) - MOVQ SI, 32(R10) - MOVQ DI, 40(R10) - MOVQ R8, 48(R10) - MOVQ R9, 56(R10) - RET - -TEXT ·mulAdxE2(SB), $64-24 - NO_LOCAL_POINTERS - - // var a, b, c fp.Element - // a.Add(&x.A0, &x.A1) - // b.Add(&y.A0, &y.A1) - // a.Mul(&a, &b) - // b.Mul(&x.A0, &y.A0) - // c.Mul(&x.A1, &y.A1) - // z.A1.Sub(&a, &b).Sub(&z.A1, &c) - // z.A0.Sub(&b, &c) - - MOVQ x+8(FP), AX - MOVQ y+16(FP), DX - MOVQ 32(AX), R14 - MOVQ 40(AX), R15 - MOVQ 48(AX), CX - MOVQ 56(AX), BX - MOVQ 32(DX), SI - MOVQ 40(DX), DI - MOVQ 48(DX), R8 - MOVQ 56(DX), R9 - - // mul (R14,R15,CX,BX) with (SI,DI,R8,R9) into (R10,R11,R12,R13) - MUL() - - // reduce element(R10,R11,R12,R13) using temp registers (SI,DI,R8,R9) - REDUCE(R10,R11,R12,R13,SI,DI,R8,R9) - - MOVQ R10, s4-40(SP) - MOVQ R11, s5-48(SP) - MOVQ R12, s6-56(SP) - MOVQ R13, s7-64(SP) - MOVQ x+8(FP), AX - MOVQ y+16(FP), DX - ADDQ 0(AX), R14 - ADCQ 8(AX), R15 - ADCQ 16(AX), CX - ADCQ 24(AX), BX - MOVQ 0(DX), SI - MOVQ 8(DX), DI - MOVQ 16(DX), R8 - MOVQ 24(DX), R9 - ADDQ 32(DX), SI - ADCQ 40(DX), DI - ADCQ 48(DX), R8 - ADCQ 56(DX), R9 - - // mul (R14,R15,CX,BX) with (SI,DI,R8,R9) into (R10,R11,R12,R13) - MUL() - - // reduce element(R10,R11,R12,R13) using temp registers (SI,DI,R8,R9) - REDUCE(R10,R11,R12,R13,SI,DI,R8,R9) - - MOVQ R10, s0-8(SP) - MOVQ R11, s1-16(SP) - MOVQ R12, s2-24(SP) - MOVQ R13, s3-32(SP) - MOVQ x+8(FP), AX - MOVQ y+16(FP), DX - MOVQ 0(AX), R14 - MOVQ 8(AX), R15 - MOVQ 16(AX), CX - MOVQ 24(AX), BX - MOVQ 0(DX), SI - MOVQ 8(DX), DI - MOVQ 16(DX), R8 - MOVQ 24(DX), R9 - - // mul (R14,R15,CX,BX) with (SI,DI,R8,R9) into (R10,R11,R12,R13) - MUL() - - // reduce element(R10,R11,R12,R13) using temp registers (SI,DI,R8,R9) - REDUCE(R10,R11,R12,R13,SI,DI,R8,R9) - - XORQ DX, DX - MOVQ s0-8(SP), R14 - MOVQ s1-16(SP), R15 - MOVQ s2-24(SP), CX - MOVQ s3-32(SP), BX - SUBQ R10, R14 - SBBQ R11, R15 - SBBQ R12, CX - SBBQ R13, BX - MOVQ $0x3c208c16d87cfd47, SI - MOVQ $0x97816a916871ca8d, DI - MOVQ $0xb85045b68181585d, R8 - MOVQ $0x30644e72e131a029, R9 - CMOVQCC DX, SI - CMOVQCC DX, DI - CMOVQCC DX, R8 - CMOVQCC DX, R9 - ADDQ SI, R14 - ADCQ DI, R15 - ADCQ R8, CX - ADCQ R9, BX - SUBQ s4-40(SP), R14 - SBBQ s5-48(SP), R15 - SBBQ s6-56(SP), CX - SBBQ s7-64(SP), BX - MOVQ $0x3c208c16d87cfd47, SI - MOVQ $0x97816a916871ca8d, DI - MOVQ $0xb85045b68181585d, R8 - MOVQ $0x30644e72e131a029, R9 - CMOVQCC DX, SI - CMOVQCC DX, DI - CMOVQCC DX, R8 - CMOVQCC DX, R9 - ADDQ SI, R14 - ADCQ DI, R15 - ADCQ R8, CX - ADCQ R9, BX - MOVQ res+0(FP), AX - MOVQ R14, 32(AX) - MOVQ R15, 40(AX) - MOVQ CX, 48(AX) - MOVQ BX, 56(AX) - MOVQ s4-40(SP), SI - MOVQ s5-48(SP), DI - MOVQ s6-56(SP), R8 - MOVQ s7-64(SP), R9 - SUBQ SI, R10 - SBBQ DI, R11 - SBBQ R8, R12 - SBBQ R9, R13 - MOVQ $0x3c208c16d87cfd47, R14 - MOVQ $0x97816a916871ca8d, R15 - MOVQ $0xb85045b68181585d, CX - MOVQ $0x30644e72e131a029, BX - CMOVQCC DX, R14 - CMOVQCC DX, R15 - CMOVQCC DX, CX - CMOVQCC DX, BX - ADDQ R14, R10 - ADCQ R15, R11 - ADCQ CX, R12 - ADCQ BX, R13 - MOVQ R10, 0(AX) - MOVQ R11, 8(AX) - MOVQ R12, 16(AX) - MOVQ R13, 24(AX) - RET - -TEXT ·squareAdxE2(SB), NOSPLIT, $0-16 - NO_LOCAL_POINTERS - - // z.A0 = (x.A0 + x.A1) * (x.A0 - x.A1) - // z.A1 = 2 * x.A0 * x.A1 - - // 2 * x.A0 * x.A1 - MOVQ x+8(FP), AX - - // x.A0[0] -> SI - // x.A0[1] -> DI - // x.A0[2] -> R8 - // x.A0[3] -> R9 - MOVQ 0(AX), SI - MOVQ 8(AX), DI - MOVQ 16(AX), R8 - MOVQ 24(AX), R9 - - // 2 * x.A1[0] -> R14 - // 2 * x.A1[1] -> R15 - // 2 * x.A1[2] -> CX - // 2 * x.A1[3] -> BX - MOVQ 32(AX), R14 - MOVQ 40(AX), R15 - MOVQ 48(AX), CX - MOVQ 56(AX), BX - ADDQ R14, R14 - ADCQ R15, R15 - ADCQ CX, CX - ADCQ BX, BX - - // mul (R14,R15,CX,BX) with (SI,DI,R8,R9) into (R10,R11,R12,R13) - MUL() - - // reduce element(R10,R11,R12,R13) using temp registers (R14,R15,CX,BX) - REDUCE(R10,R11,R12,R13,R14,R15,CX,BX) - - MOVQ x+8(FP), AX - - // x.A1[0] -> R14 - // x.A1[1] -> R15 - // x.A1[2] -> CX - // x.A1[3] -> BX - MOVQ 32(AX), R14 - MOVQ 40(AX), R15 - MOVQ 48(AX), CX - MOVQ 56(AX), BX - MOVQ res+0(FP), DX - MOVQ R10, 32(DX) - MOVQ R11, 40(DX) - MOVQ R12, 48(DX) - MOVQ R13, 56(DX) - MOVQ R14, R10 - MOVQ R15, R11 - MOVQ CX, R12 - MOVQ BX, R13 - - // Add(&x.A0, &x.A1) - ADDQ SI, R14 - ADCQ DI, R15 - ADCQ R8, CX - ADCQ R9, BX - XORQ BP, BP - - // Sub(&x.A0, &x.A1) - SUBQ R10, SI - SBBQ R11, DI - SBBQ R12, R8 - SBBQ R13, R9 - MOVQ $0x3c208c16d87cfd47, R10 - MOVQ $0x97816a916871ca8d, R11 - MOVQ $0xb85045b68181585d, R12 - MOVQ $0x30644e72e131a029, R13 - CMOVQCC BP, R10 - CMOVQCC BP, R11 - CMOVQCC BP, R12 - CMOVQCC BP, R13 - ADDQ R10, SI - ADCQ R11, DI - ADCQ R12, R8 - ADCQ R13, R9 - - // mul (R14,R15,CX,BX) with (SI,DI,R8,R9) into (R10,R11,R12,R13) - MUL() - - // reduce element(R10,R11,R12,R13) using temp registers (R14,R15,CX,BX) - REDUCE(R10,R11,R12,R13,R14,R15,CX,BX) - - MOVQ res+0(FP), AX - MOVQ R10, 0(AX) - MOVQ R11, 8(AX) - MOVQ R12, 16(AX) - MOVQ R13, 24(AX) - RET diff --git a/ecc/bn254/internal/fptower/e2_amd64.s b/ecc/bn254/internal/fptower/e2_amd64.s index d0c8e8a3d..43ffb7f16 100644 --- a/ecc/bn254/internal/fptower/e2_amd64.s +++ b/ecc/bn254/internal/fptower/e2_amd64.s @@ -1,5 +1,3 @@ -// +build !amd64_adx - // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bn254/internal/fptower/e6.go b/ecc/bn254/internal/fptower/e6.go index 4da093f5f..8ae7216ec 100644 --- a/ecc/bn254/internal/fptower/e6.go +++ b/ecc/bn254/internal/fptower/e6.go @@ -63,25 +63,13 @@ func (z *E6) SetRandom() (*E6, error) { return z, nil } -// IsZero returns true if the two elements are equal, fasle otherwise +// IsZero returns true if the two elements are equal, false otherwise func (z *E6) IsZero() bool { return z.B0.IsZero() && z.B1.IsZero() && z.B2.IsZero() } -// ToMont converts to Mont form -func (z *E6) ToMont() *E6 { - z.B0.ToMont() - z.B1.ToMont() - z.B2.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E6) FromMont() *E6 { - z.B0.FromMont() - z.B1.FromMont() - z.B2.FromMont() - return z +func (z *E6) IsOne() bool { + return z.B0.IsOne() && z.B1.IsZero() && z.B2.IsZero() } // Add adds two elements of E6 diff --git a/ecc/bn254/marshal.go b/ecc/bn254/marshal.go index 586f29686..4ae011cf1 100644 --- a/ecc/bn254/marshal.go +++ b/ecc/bn254/marshal.go @@ -94,7 +94,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fr.Bytes]) + err = t.SetBytesCanonical(buf[:fr.Bytes]) return case *fp.Element: read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) @@ -102,7 +102,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fp.Bytes]) + err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: var sliceLen uint32 @@ -120,7 +120,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fr.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { + return + } } return case *[]fp.Element: @@ -139,7 +141,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fp.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return + } } return case *G1Affine: @@ -215,7 +219,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -270,7 +278,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -626,11 +638,11 @@ func (p *G1Affine) Unmarshal(buf []byte) error { // // we use the 2 most significant bits instead // -// 00 -> uncompressed -// 10 -> compressed, use smallest lexicographically square root of Y^2 -// 11 -> compressed, use largest lexicographically square root of Y^2 -// 01 -> compressed infinity point -// the "uncompressed infinity point" will just have 00 (uncompressed) followed by zeroes (infinity = 0,0 in affine coordinates) +// 00 -> uncompressed +// 10 -> compressed, use smallest lexicographically square root of Y^2 +// 11 -> compressed, use largest lexicographically square root of Y^2 +// 01 -> compressed infinity point +// the "uncompressed infinity point" will just have 00 (uncompressed) followed by zeroes (infinity = 0,0 in affine coordinates) func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { // check if p is infinity point @@ -639,9 +651,6 @@ func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - msbMask := mCompressedSmallest // compressed, we need to know if Y is lexicographically bigger than -Y // if p.Y ">" -p.Y @@ -650,12 +659,7 @@ func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { } // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[24:32], tmp[0]) - binary.BigEndian.PutUint64(res[16:24], tmp[1]) - binary.BigEndian.PutUint64(res[8:16], tmp[2]) - binary.BigEndian.PutUint64(res[0:8], tmp[3]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= msbMask @@ -674,25 +678,12 @@ func (p *G1Affine) RawBytes() (res [SizeOfG1AffineUncompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - // not compressed // we store the Y coordinate - tmp = p.Y - tmp.FromMont() - binary.BigEndian.PutUint64(res[56:64], tmp[0]) - binary.BigEndian.PutUint64(res[48:56], tmp[1]) - binary.BigEndian.PutUint64(res[40:48], tmp[2]) - binary.BigEndian.PutUint64(res[32:40], tmp[3]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[32:32+fp.Bytes]), p.Y) // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[24:32], tmp[0]) - binary.BigEndian.PutUint64(res[16:24], tmp[1]) - binary.BigEndian.PutUint64(res[8:16], tmp[2]) - binary.BigEndian.PutUint64(res[0:8], tmp[3]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= mUncompressed @@ -738,8 +729,12 @@ func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { // uncompressed point if mData == mUncompressed { // read X and Y coordinates - p.X.SetBytes(buf[:fp.Bytes]) - p.Y.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.Y.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } // subgroup check if subGroupCheck && !p.IsInSubGroup() { @@ -759,7 +754,9 @@ func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return 0, err + } var YSquared, Y fp.Element @@ -833,7 +830,7 @@ func (p *G1Affine) unsafeComputeY(subGroupCheck bool) error { // assumes buf[:8] mask is set to compressed // returns true if point is infinity and need no further processing // it sets X coordinate and uses Y for scratch space to store decompression metadata -func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { +func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { // read the most significant byte mData := buf[0] & mMask @@ -842,7 +839,7 @@ func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { p.X.SetZero() p.Y.SetZero() isInfinity = true - return + return isInfinity, nil } // we need to copy the input buffer (to keep this method thread safe) @@ -851,12 +848,14 @@ func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return false, err + } // store mData in p.Y[0] p.Y[0] = uint64(mData) // recomputing Y will be done asynchronously - return + return isInfinity, nil } // SizeOfG2AffineCompressed represents the size in bytes that a G2Affine need in binary form, compressed @@ -883,11 +882,11 @@ func (p *G2Affine) Unmarshal(buf []byte) error { // // we use the 2 most significant bits instead // -// 00 -> uncompressed -// 10 -> compressed, use smallest lexicographically square root of Y^2 -// 11 -> compressed, use largest lexicographically square root of Y^2 -// 01 -> compressed infinity point -// the "uncompressed infinity point" will just have 00 (uncompressed) followed by zeroes (infinity = 0,0 in affine coordinates) +// 00 -> uncompressed +// 10 -> compressed, use smallest lexicographically square root of Y^2 +// 11 -> compressed, use largest lexicographically square root of Y^2 +// 01 -> compressed infinity point +// the "uncompressed infinity point" will just have 00 (uncompressed) followed by zeroes (infinity = 0,0 in affine coordinates) func (p *G2Affine) Bytes() (res [SizeOfG2AffineCompressed]byte) { // check if p is infinity point @@ -896,9 +895,6 @@ func (p *G2Affine) Bytes() (res [SizeOfG2AffineCompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - msbMask := mCompressedSmallest // compressed, we need to know if Y is lexicographically bigger than -Y // if p.Y ">" -p.Y @@ -908,19 +904,8 @@ func (p *G2Affine) Bytes() (res [SizeOfG2AffineCompressed]byte) { // we store X and mask the most significant word with our metadata mask // p.X.A1 | p.X.A0 - tmp = p.X.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[56:64], tmp[0]) - binary.BigEndian.PutUint64(res[48:56], tmp[1]) - binary.BigEndian.PutUint64(res[40:48], tmp[2]) - binary.BigEndian.PutUint64(res[32:40], tmp[3]) - - tmp = p.X.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[24:32], tmp[0]) - binary.BigEndian.PutUint64(res[16:24], tmp[1]) - binary.BigEndian.PutUint64(res[8:16], tmp[2]) - binary.BigEndian.PutUint64(res[0:8], tmp[3]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[32:32+fp.Bytes]), p.X.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X.A1) res[0] |= msbMask @@ -939,41 +924,16 @@ func (p *G2Affine) RawBytes() (res [SizeOfG2AffineUncompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - // not compressed // we store the Y coordinate // p.Y.A1 | p.Y.A0 - tmp = p.Y.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[120:128], tmp[0]) - binary.BigEndian.PutUint64(res[112:120], tmp[1]) - binary.BigEndian.PutUint64(res[104:112], tmp[2]) - binary.BigEndian.PutUint64(res[96:104], tmp[3]) - - tmp = p.Y.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[88:96], tmp[0]) - binary.BigEndian.PutUint64(res[80:88], tmp[1]) - binary.BigEndian.PutUint64(res[72:80], tmp[2]) - binary.BigEndian.PutUint64(res[64:72], tmp[3]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[96:96+fp.Bytes]), p.Y.A0) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[64:64+fp.Bytes]), p.Y.A1) // we store X and mask the most significant word with our metadata mask // p.X.A1 | p.X.A0 - tmp = p.X.A1 - tmp.FromMont() - binary.BigEndian.PutUint64(res[24:32], tmp[0]) - binary.BigEndian.PutUint64(res[16:24], tmp[1]) - binary.BigEndian.PutUint64(res[8:16], tmp[2]) - binary.BigEndian.PutUint64(res[0:8], tmp[3]) - - tmp = p.X.A0 - tmp.FromMont() - binary.BigEndian.PutUint64(res[56:64], tmp[0]) - binary.BigEndian.PutUint64(res[48:56], tmp[1]) - binary.BigEndian.PutUint64(res[40:48], tmp[2]) - binary.BigEndian.PutUint64(res[32:40], tmp[3]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X.A1) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[32:32+fp.Bytes]), p.X.A0) res[0] |= mUncompressed @@ -1020,11 +980,19 @@ func (p *G2Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { if mData == mUncompressed { // read X and Y coordinates // p.X.A1 | p.X.A0 - p.X.A1.SetBytes(buf[:fp.Bytes]) - p.X.A0.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.A1.SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.X.A0.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } // p.Y.A1 | p.Y.A0 - p.Y.A1.SetBytes(buf[fp.Bytes*2 : fp.Bytes*3]) - p.Y.A0.SetBytes(buf[fp.Bytes*3 : fp.Bytes*4]) + if err := p.Y.A1.SetBytesCanonical(buf[fp.Bytes*2 : fp.Bytes*3]); err != nil { + return 0, err + } + if err := p.Y.A0.SetBytesCanonical(buf[fp.Bytes*3 : fp.Bytes*4]); err != nil { + return 0, err + } // subgroup check if subGroupCheck && !p.IsInSubGroup() { @@ -1045,8 +1013,12 @@ func (p *G2Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { // read X coordinate // p.X.A1 | p.X.A0 - p.X.A1.SetBytes(bufX[:fp.Bytes]) - p.X.A0.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.A1.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.X.A0.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } var YSquared, Y fptower.E2 @@ -1122,7 +1094,7 @@ func (p *G2Affine) unsafeComputeY(subGroupCheck bool) error { // assumes buf[:8] mask is set to compressed // returns true if point is infinity and need no further processing // it sets X coordinate and uses Y for scratch space to store decompression metadata -func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { +func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { // read the most significant byte mData := buf[0] & mMask @@ -1131,7 +1103,7 @@ func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { p.X.SetZero() p.Y.SetZero() isInfinity = true - return + return isInfinity, nil } // we need to copy the input buffer (to keep this method thread safe) @@ -1141,12 +1113,16 @@ func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { // read X coordinate // p.X.A1 | p.X.A0 - p.X.A1.SetBytes(bufX[:fp.Bytes]) - p.X.A0.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.A1.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return false, err + } + if err := p.X.A0.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return false, err + } // store mData in p.Y.A0[0] p.Y.A0[0] = uint64(mData) // recomputing Y will be done asynchronously - return + return isInfinity, nil } diff --git a/ecc/bn254/multiexp.go b/ecc/bn254/multiexp.go index 4c120b4be..bb4a7d6e1 100644 --- a/ecc/bn254/multiexp.go +++ b/ecc/bn254/multiexp.go @@ -25,143 +25,6 @@ import ( "runtime" ) -// selector stores the index, mask and shifts needed to select bits from a scalar -// it is used during the multiExp algorithm or the batch scalar multiplication -type selector struct { - index uint64 // index in the multi-word scalar to select bits from - mask uint64 // mask (c-bit wide) - shift uint64 // shift needed to get our bits on low positions - - multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) - maskHigh uint64 // same than mask, for index+1 - shiftHigh uint64 // same than shift, for index+1 -} - -// partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits -// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract -// 2^{c} to the current digit, making it negative. -// negative digits can be processed in a later step as adding -G into the bucket instead of G -// (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) -// scalarsMont indicates wheter the provided scalars are in montgomery form -// returns smallValues, which represent the number of scalars which meets the following condition -// 0 < scalar < 2^c (in other words, scalars where only the c-least significant bits are non zero) -func partitionScalars(scalars []fr.Element, c uint64, scalarsMont bool, nbTasks int) ([]fr.Element, int) { - toReturn := make([]fr.Element, len(scalars)) - - // number of c-bit radixes in a scalar - nbChunks := fr.Limbs * 64 / c - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // msb of the c-bit window - max := int(1 << (c - 1)) // max value we want for our digits - cDivides64 := (64 % c) == 0 // if c doesn't divide 64, we may need to select over multiple words - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := uint64(0); chunk < nbChunks; chunk++ { - jc := uint64(chunk * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } - - // for each chunk, we could track the number of non-zeros points we will need to process - // this way, if a chunk has more work to do than others, we can spawn off more go routines - // (at the cost of more buckets allocated) - // a simplified approach is to track the small values where only the first word is set - // if this number represent a significant number of points, then we will split first chunk - // processing in the msm in 2, to ensure all go routines finish at ~same time - // /!\ nbTasks is enough as parallel.Execute is not going to spawn more than nbTasks go routine - // if it does, though, this will deadlocK. - chSmallValues := make(chan int, nbTasks) - - parallel.Execute(len(scalars), func(start, end int) { - smallValues := 0 - for i := start; i < end; i++ { - var carry int - - scalar := scalars[i] - if scalarsMont { - scalar.FromMont() - } - if scalar.FitsOnOneWord() { - // everything is 0, no need to process this scalar - if scalar[0] == 0 { - continue - } - // low c-bits are 1 in mask - if scalar[0]&mask == scalar[0] { - smallValues++ - } - } - - // for each chunk in the scalar, compute the current digit, and an eventual carry - for chunk := uint64(0); chunk < nbChunks; chunk++ { - s := selectors[chunk] - - // init with carry if any - digit := carry - carry = 0 - - // digit = value of the c-bit window - digit += int((scalar[s.index] & s.mask) >> s.shift) - - if s.multiWordSelect { - // we are selecting bits over 2 words - digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh - } - - // if digit is zero, no impact on result - if digit == 0 { - continue - } - - // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract - // 2^{c} to the current digit, making it negative. - if digit >= max { - digit -= (1 << c) - carry = 1 - } - - var bits uint64 - if digit >= 0 { - bits = uint64(digit) - } else { - bits = uint64(-digit-1) | msbWindow - } - - toReturn[i][s.index] |= (bits << s.shift) - if s.multiWordSelect { - toReturn[i][s.index+1] |= (bits >> s.shiftHigh) - } - - } - } - - chSmallValues <- smallValues - - }, nbTasks) - - // aggregate small values - close(chSmallValues) - smallValues := 0 - for o := range chSmallValues { - smallValues += o - } - return toReturn, smallValues -} - // MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf // // This call return an error if len(scalars) != len(points) or if provided config is invalid. @@ -221,7 +84,7 @@ func (p *G1Jac) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.Mul // we split recursively until nbChunks(c) >= nbTasks, bestC := func(nbPoints int) uint64 { // implemented msmC methods (the c we use must be in this slice) - implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21} + implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} var C uint64 // approximate cost (in group operations) // cost = bits/c * (nbPoints + 2^{c}) @@ -229,118 +92,179 @@ func (p *G1Jac) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.Mul // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results min := math.MaxFloat64 for _, c := range implementedCs { - cc := fr.Limbs * 64 * (nbPoints + (1 << (c))) + cc := (fr.Bits + 1) * (nbPoints + (1 << c)) cost := float64(cc) / float64(c) if cost < min { min = cost C = c } } - // empirical, needs to be tuned. - // if C > 16 && nbPoints < 1 << 23 { - // C = 16 - // } return C } - var C uint64 - nbSplits := 1 - nbChunks := 0 - for nbChunks < config.NbTasks { - C = bestC(nbPoints) - nbChunks = int(fr.Limbs * 64 / C) // number of c-bit radixes in a scalar - if (fr.Limbs*64)%C != 0 { - nbChunks++ - } - nbChunks *= nbSplits - if nbChunks < config.NbTasks { - nbSplits <<= 1 - nbPoints >>= 1 + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // if we don't utilise all the tasks (CPU in the default case) that we could, let's see if it's worth it to split + if config.NbTasks > 1 && nbChunks < config.NbTasks { + // before spliting, let's see if we endup with more tasks than thread; + cSplit := bestC(nbPoints / 2) + nbChunksPostSplit := int(computeNbChunks(cSplit)) + nbTasksPostSplit := nbChunksPostSplit * 2 + if (nbTasksPostSplit <= config.NbTasks/2) || (nbTasksPostSplit-config.NbTasks/2) <= (config.NbTasks-nbChunks) { + // if postSplit we still have less tasks than available CPU + // or if we have more tasks BUT the difference of CPU usage is in our favor, we split. + config.NbTasks /= 2 + var _p G1Jac + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil } } + _innerMsmG1(p, C, points, scalars, config) + + return p, nil +} + +func _innerMsmG1(p *G1Jac, c uint64, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { // partition the scalars - // note: we do that before the actual chunk processing, as for each c-bit window (starting from LSW) - // if it's larger than 2^{c-1}, we have a carry we need to propagate up to the higher window - var smallValues int - scalars, smallValues = partitionScalars(scalars, C, config.ScalarsMont, config.NbTasks) - - // if we have more than 10% of small values, we split the processing of the first chunk in 2 - // we may want to do that in msmInnerG1Jac , but that would incur a cost of looping through all scalars one more time - splitFirstChunk := (float64(smallValues) / float64(len(scalars))) >= 0.1 - - // we have nbSplits intermediate results that we must sum together. - _p := make([]G1Jac, nbSplits-1) - chDone := make(chan int, nbSplits-1) - for i := 0; i < nbSplits-1; i++ { - start := i * nbPoints - end := start + nbPoints - go func(start, end, i int) { - msmInnerG1Jac(&_p[i], int(C), points[start:end], scalars[start:end], splitFirstChunk) - chDone <- i - }(start, end, i) + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) } - msmInnerG1Jac(p, int(C), points[(nbSplits-1)*nbPoints:], scalars[(nbSplits-1)*nbPoints:], splitFirstChunk) - for i := 0; i < nbSplits-1; i++ { - done := <-chDone - p.AddAssign(&_p[done]) + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessorG1(c, chunkStats[j]) + if j == int(nbChunks-1) { + processChunk = getChunkProcessorG1(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan g1JacExtended, 2) + split := n / 2 + go processChunk(uint64(j), chSplit, c, points[:split], digits[j*n:(j*n)+split]) + go processChunk(uint64(j), chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n]) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n]) } - close(chDone) - return p, nil -} -func msmInnerG1Jac(p *G1Jac, c int, points []G1Affine, scalars []fr.Element, splitFirstChunk bool) { + return msmReduceChunkG1Affine(p, int(c), chChunks[:]) +} +// getChunkProcessorG1 decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessorG1(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- g1JacExtended, c uint64, points []G1Affine, digits []uint16) { switch c { + case 2: + return processChunkG1Jacobian[bucketg1JacExtendedC2] + case 3: + return processChunkG1Jacobian[bucketg1JacExtendedC3] case 4: - p.msmC4(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC4] case 5: - p.msmC5(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC5] case 6: - p.msmC6(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC6] case 7: - p.msmC7(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC7] case 8: - p.msmC8(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC8] case 9: - p.msmC9(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC9] case 10: - p.msmC10(points, scalars, splitFirstChunk) - + const batchSize = 80 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC10] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC10, bucketG1AffineC10, bitSetC10, pG1AffineC10, ppG1AffineC10, qG1AffineC10, cG1AffineC10] case 11: - p.msmC11(points, scalars, splitFirstChunk) - + const batchSize = 150 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC11] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC11, bucketG1AffineC11, bitSetC11, pG1AffineC11, ppG1AffineC11, qG1AffineC11, cG1AffineC11] case 12: - p.msmC12(points, scalars, splitFirstChunk) - + const batchSize = 200 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC12] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC12, bucketG1AffineC12, bitSetC12, pG1AffineC12, ppG1AffineC12, qG1AffineC12, cG1AffineC12] case 13: - p.msmC13(points, scalars, splitFirstChunk) - + const batchSize = 350 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC13] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC13, bucketG1AffineC13, bitSetC13, pG1AffineC13, ppG1AffineC13, qG1AffineC13, cG1AffineC13] case 14: - p.msmC14(points, scalars, splitFirstChunk) - + const batchSize = 400 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC14] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC14, bucketG1AffineC14, bitSetC14, pG1AffineC14, ppG1AffineC14, qG1AffineC14, cG1AffineC14] case 15: - p.msmC15(points, scalars, splitFirstChunk) - + const batchSize = 500 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC15] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC15, bucketG1AffineC15, bitSetC15, pG1AffineC15, ppG1AffineC15, qG1AffineC15, cG1AffineC15] case 16: - p.msmC16(points, scalars, splitFirstChunk) - - case 20: - p.msmC20(points, scalars, splitFirstChunk) - - case 21: - p.msmC21(points, scalars, splitFirstChunk) - + const batchSize = 640 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC16] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC16, bucketG1AffineC16, bitSetC16, pG1AffineC16, ppG1AffineC16, qG1AffineC16, cG1AffineC16] default: - panic("not implemented") + // panic("will not happen c != previous values is not generated by templates") + return processChunkG1Jacobian[bucketg1JacExtendedC16] } } @@ -360,1846 +284,447 @@ func msmReduceChunkG1Affine(p *G1Jac, c int, chChunks []chan g1JacExtended) *G1J return p.unsafeFromJacExtended(&_p) } -func msmProcessChunkG1Affine(chunk uint64, - chRes chan<- g1JacExtended, - buckets []g1JacExtended, - c uint64, - points []G1Affine, - scalars []fr.Element) { +// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf +// +// This call return an error if len(scalars) != len(points) or if provided config is invalid. +func (p *G2Affine) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Affine, error) { + var _p G2Jac + if _, err := _p.MultiExp(points, scalars, config); err != nil { + return nil, err + } + p.FromJacobian(&_p) + return p, nil +} - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) +// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf +// +// This call return an error if len(scalars) != len(points) or if provided config is invalid. +func (p *G2Jac) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Jac, error) { + // note: + // each of the msmCX method is the same, except for the c constant it declares + // duplicating (through template generation) these methods allows to declare the buckets on the stack + // the choice of c needs to be improved: + // there is a theoritical value that gives optimal asymptotics + // but in practice, other factors come into play, including: + // * if c doesn't divide 64, the word size, then we're bound to select bits over 2 words of our scalars, instead of 1 + // * number of CPUs + // * cache friendliness (which depends on the host, G1 or G2... ) + // --> for example, on BN254, a G1 point fits into one cache line of 64bytes, but a G2 point don't. - for i := 0; i < len(buckets); i++ { - buckets[i].setInfinity() - } + // for each msmCX + // step 1 + // we compute, for each scalars over c-bit wide windows, nbChunk digits + // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract + // 2^{c} to the current digit, making it negative. + // negative digits will be processed in the next step as adding -G into the bucket instead of G + // (computing -G is cheap, and this saves us half of the buckets) + // step 2 + // buckets are declared on the stack + // notice that we have 2^{c-1} buckets instead of 2^{c} (see step1) + // we use jacobian extended formulas here as they are faster than mixed addition + // msmProcessChunk places points into buckets base on their selector and return the weighted bucket sum in given channel + // step 3 + // reduce the buckets weigthed sums into our result (msmReduceChunk) - jc := uint64(chunk * c) - s := selector{} - s.index = jc / 64 - s.shift = jc - (s.index * 64) - s.mask = mask << s.shift - s.multiWordSelect = (64%c) != 0 && s.shift > (64-c) && s.index < (fr.Limbs-1) - if s.multiWordSelect { - nbBitsHigh := s.shift - uint64(64-c) - s.maskHigh = (1 << nbBitsHigh) - 1 - s.shiftHigh = (c - nbBitsHigh) + // ensure len(points) == len(scalars) + nbPoints := len(points) + if nbPoints != len(scalars) { + return nil, errors.New("len(points) != len(scalars)") } - // for each scalars, get the digit corresponding to the chunk we're processing. - for i := 0; i < len(scalars); i++ { - bits := (scalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (scalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { - continue - } + // if nbTasks is not set, use all available CPUs + if config.NbTasks <= 0 { + config.NbTasks = runtime.NumCPU() + } else if config.NbTasks > 1024 { + return nil, errors.New("invalid config: config.NbTasks > 1024") + } - // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { - // add - buckets[bits-1].addMixed(&points[i]) - } else { - // sub - buckets[bits & ^msbWindow].subMixed(&points[i]) + // here, we compute the best C for nbPoints + // we split recursively until nbChunks(c) >= nbTasks, + bestC := func(nbPoints int) uint64 { + // implemented msmC methods (the c we use must be in this slice) + implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + var C uint64 + // approximate cost (in group operations) + // cost = bits/c * (nbPoints + 2^{c}) + // this needs to be verified empirically. + // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results + min := math.MaxFloat64 + for _, c := range implementedCs { + cc := (fr.Bits + 1) * (nbPoints + (1 << c)) + cost := float64(cc) / float64(c) + if cost < min { + min = cost + C = c + } } + return C } - // reduce buckets into total - // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] - - var runningSum, total g1JacExtended - runningSum.setInfinity() - total.setInfinity() - for k := len(buckets) - 1; k >= 0; k-- { - if !buckets[k].ZZ.IsZero() { - runningSum.add(&buckets[k]) + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // if we don't utilise all the tasks (CPU in the default case) that we could, let's see if it's worth it to split + if config.NbTasks > 1 && nbChunks < config.NbTasks { + // before spliting, let's see if we endup with more tasks than thread; + cSplit := bestC(nbPoints / 2) + nbChunksPostSplit := int(computeNbChunks(cSplit)) + nbTasksPostSplit := nbChunksPostSplit * 2 + if (nbTasksPostSplit <= config.NbTasks/2) || (nbTasksPostSplit-config.NbTasks/2) <= (config.NbTasks-nbChunks) { + // if postSplit we still have less tasks than available CPU + // or if we have more tasks BUT the difference of CPU usage is in our favor, we split. + config.NbTasks /= 2 + var _p G2Jac + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil } - total.add(&runningSum) } - chRes <- total + _innerMsmG2(p, C, points, scalars, config) + return p, nil } -func (p *G1Jac) msmC4(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 4 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) +func _innerMsmG2(p *G2Jac, c uint64, points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G2Jac { + // partition the scalars + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) // for each chunk, spawn one go routine that'll loop through all the scalars in the // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + // note that buckets is an array allocated on the stack and this is critical for performance // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended + chChunks := make([]chan g2JacExtended, nbChunks) for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) + chChunks[i] = make(chan g2JacExtended, 1) } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessorG2(c, chunkStats[j]) + if j == int(nbChunks-1) { + processChunk = getChunkProcessorG2(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan g2JacExtended, 2) + split := n / 2 + go processChunk(uint64(j), chSplit, c, points[:split], digits[j*n:(j*n)+split]) + go processChunk(uint64(j), chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n]) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n]) } - return msmReduceChunkG1Affine(p, c, chChunks[:]) + return msmReduceChunkG2Affine(p, int(c), chChunks[:]) } -func (p *G1Jac) msmC5(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 5 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } +// getChunkProcessorG2 decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessorG2(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- g2JacExtended, c uint64, points []G2Affine, digits []uint16) { + switch c { - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + case 2: + return processChunkG2Jacobian[bucketg2JacExtendedC2] + case 3: + return processChunkG2Jacobian[bucketg2JacExtendedC3] + case 4: + return processChunkG2Jacobian[bucketg2JacExtendedC4] + case 5: + return processChunkG2Jacobian[bucketg2JacExtendedC5] + case 6: + return processChunkG2Jacobian[bucketg2JacExtendedC6] + case 7: + return processChunkG2Jacobian[bucketg2JacExtendedC7] + case 8: + return processChunkG2Jacobian[bucketg2JacExtendedC8] + case 9: + return processChunkG2Jacobian[bucketg2JacExtendedC9] + case 10: + const batchSize = 80 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC10] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC10, bucketG2AffineC10, bitSetC10, pG2AffineC10, ppG2AffineC10, qG2AffineC10, cG2AffineC10] + case 11: + const batchSize = 150 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC11] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC11, bucketG2AffineC11, bitSetC11, pG2AffineC11, ppG2AffineC11, qG2AffineC11, cG2AffineC11] + case 12: + const batchSize = 200 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC12] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC12, bucketG2AffineC12, bitSetC12, pG2AffineC12, ppG2AffineC12, qG2AffineC12, cG2AffineC12] + case 13: + const batchSize = 350 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC13] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC13, bucketG2AffineC13, bitSetC13, pG2AffineC13, ppG2AffineC13, qG2AffineC13, cG2AffineC13] + case 14: + const batchSize = 400 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC14] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC14, bucketG2AffineC14, bitSetC14, pG2AffineC14, ppG2AffineC14, qG2AffineC14, cG2AffineC14] + case 15: + const batchSize = 500 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC15] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC15, bucketG2AffineC15, bitSetC15, pG2AffineC15, ppG2AffineC15, qG2AffineC15, cG2AffineC15] + case 16: + const batchSize = 640 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC16] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC16, bucketG2AffineC16, bitSetC16, pG2AffineC16, ppG2AffineC16, qG2AffineC16, cG2AffineC16] + default: + // panic("will not happen c != previous values is not generated by templates") + return processChunkG2Jacobian[bucketg2JacExtendedC16] } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) } -func (p *G1Jac) msmC6(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 6 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC7(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 7 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC8(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 8 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC9(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 9 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC10(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 10 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC11(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 11 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC12(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 12 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC13(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 13 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC14(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 14 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC15(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 15 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC16(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 16 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC20(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 20 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC21(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 21 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf -// -// This call return an error if len(scalars) != len(points) or if provided config is invalid. -func (p *G2Affine) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Affine, error) { - var _p G2Jac - if _, err := _p.MultiExp(points, scalars, config); err != nil { - return nil, err - } - p.FromJacobian(&_p) - return p, nil -} - -// MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf -// -// This call return an error if len(scalars) != len(points) or if provided config is invalid. -func (p *G2Jac) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) (*G2Jac, error) { - // note: - // each of the msmCX method is the same, except for the c constant it declares - // duplicating (through template generation) these methods allows to declare the buckets on the stack - // the choice of c needs to be improved: - // there is a theoritical value that gives optimal asymptotics - // but in practice, other factors come into play, including: - // * if c doesn't divide 64, the word size, then we're bound to select bits over 2 words of our scalars, instead of 1 - // * number of CPUs - // * cache friendliness (which depends on the host, G1 or G2... ) - // --> for example, on BN254, a G1 point fits into one cache line of 64bytes, but a G2 point don't. - - // for each msmCX - // step 1 - // we compute, for each scalars over c-bit wide windows, nbChunk digits - // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract - // 2^{c} to the current digit, making it negative. - // negative digits will be processed in the next step as adding -G into the bucket instead of G - // (computing -G is cheap, and this saves us half of the buckets) - // step 2 - // buckets are declared on the stack - // notice that we have 2^{c-1} buckets instead of 2^{c} (see step1) - // we use jacobian extended formulas here as they are faster than mixed addition - // msmProcessChunk places points into buckets base on their selector and return the weighted bucket sum in given channel - // step 3 - // reduce the buckets weigthed sums into our result (msmReduceChunk) - - // ensure len(points) == len(scalars) - nbPoints := len(points) - if nbPoints != len(scalars) { - return nil, errors.New("len(points) != len(scalars)") - } - - // if nbTasks is not set, use all available CPUs - if config.NbTasks <= 0 { - config.NbTasks = runtime.NumCPU() - } else if config.NbTasks > 1024 { - return nil, errors.New("invalid config: config.NbTasks > 1024") - } - - // here, we compute the best C for nbPoints - // we split recursively until nbChunks(c) >= nbTasks, - bestC := func(nbPoints int) uint64 { - // implemented msmC methods (the c we use must be in this slice) - implementedCs := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21} - var C uint64 - // approximate cost (in group operations) - // cost = bits/c * (nbPoints + 2^{c}) - // this needs to be verified empirically. - // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results - min := math.MaxFloat64 - for _, c := range implementedCs { - cc := fr.Limbs * 64 * (nbPoints + (1 << (c))) - cost := float64(cc) / float64(c) - if cost < min { - min = cost - C = c - } - } - // empirical, needs to be tuned. - // if C > 16 && nbPoints < 1 << 23 { - // C = 16 - // } - return C - } - - var C uint64 - nbSplits := 1 - nbChunks := 0 - for nbChunks < config.NbTasks { - C = bestC(nbPoints) - nbChunks = int(fr.Limbs * 64 / C) // number of c-bit radixes in a scalar - if (fr.Limbs*64)%C != 0 { - nbChunks++ - } - nbChunks *= nbSplits - if nbChunks < config.NbTasks { - nbSplits <<= 1 - nbPoints >>= 1 - } - } - - // partition the scalars - // note: we do that before the actual chunk processing, as for each c-bit window (starting from LSW) - // if it's larger than 2^{c-1}, we have a carry we need to propagate up to the higher window - var smallValues int - scalars, smallValues = partitionScalars(scalars, C, config.ScalarsMont, config.NbTasks) - - // if we have more than 10% of small values, we split the processing of the first chunk in 2 - // we may want to do that in msmInnerG2Jac , but that would incur a cost of looping through all scalars one more time - splitFirstChunk := (float64(smallValues) / float64(len(scalars))) >= 0.1 - - // we have nbSplits intermediate results that we must sum together. - _p := make([]G2Jac, nbSplits-1) - chDone := make(chan int, nbSplits-1) - for i := 0; i < nbSplits-1; i++ { - start := i * nbPoints - end := start + nbPoints - go func(start, end, i int) { - msmInnerG2Jac(&_p[i], int(C), points[start:end], scalars[start:end], splitFirstChunk) - chDone <- i - }(start, end, i) - } - - msmInnerG2Jac(p, int(C), points[(nbSplits-1)*nbPoints:], scalars[(nbSplits-1)*nbPoints:], splitFirstChunk) - for i := 0; i < nbSplits-1; i++ { - done := <-chDone - p.AddAssign(&_p[done]) - } - close(chDone) - return p, nil -} - -func msmInnerG2Jac(p *G2Jac, c int, points []G2Affine, scalars []fr.Element, splitFirstChunk bool) { - - switch c { - - case 4: - p.msmC4(points, scalars, splitFirstChunk) - - case 5: - p.msmC5(points, scalars, splitFirstChunk) - - case 6: - p.msmC6(points, scalars, splitFirstChunk) - - case 7: - p.msmC7(points, scalars, splitFirstChunk) - - case 8: - p.msmC8(points, scalars, splitFirstChunk) - - case 9: - p.msmC9(points, scalars, splitFirstChunk) - - case 10: - p.msmC10(points, scalars, splitFirstChunk) - - case 11: - p.msmC11(points, scalars, splitFirstChunk) - - case 12: - p.msmC12(points, scalars, splitFirstChunk) - - case 13: - p.msmC13(points, scalars, splitFirstChunk) - - case 14: - p.msmC14(points, scalars, splitFirstChunk) - - case 15: - p.msmC15(points, scalars, splitFirstChunk) - - case 16: - p.msmC16(points, scalars, splitFirstChunk) - - case 20: - p.msmC20(points, scalars, splitFirstChunk) - - case 21: - p.msmC21(points, scalars, splitFirstChunk) - - default: - panic("not implemented") - } -} - -// msmReduceChunkG2Affine reduces the weighted sum of the buckets into the result of the multiExp -func msmReduceChunkG2Affine(p *G2Jac, c int, chChunks []chan g2JacExtended) *G2Jac { - var _p g2JacExtended - totalj := <-chChunks[len(chChunks)-1] - _p.Set(&totalj) - for j := len(chChunks) - 2; j >= 0; j-- { - for l := 0; l < c; l++ { - _p.double(&_p) - } - totalj := <-chChunks[j] - _p.add(&totalj) - } - - return p.unsafeFromJacExtended(&_p) -} - -func msmProcessChunkG2Affine(chunk uint64, - chRes chan<- g2JacExtended, - buckets []g2JacExtended, - c uint64, - points []G2Affine, - scalars []fr.Element) { - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) - - for i := 0; i < len(buckets); i++ { - buckets[i].setInfinity() - } - - jc := uint64(chunk * c) - s := selector{} - s.index = jc / 64 - s.shift = jc - (s.index * 64) - s.mask = mask << s.shift - s.multiWordSelect = (64%c) != 0 && s.shift > (64-c) && s.index < (fr.Limbs-1) - if s.multiWordSelect { - nbBitsHigh := s.shift - uint64(64-c) - s.maskHigh = (1 << nbBitsHigh) - 1 - s.shiftHigh = (c - nbBitsHigh) - } - - // for each scalars, get the digit corresponding to the chunk we're processing. - for i := 0; i < len(scalars); i++ { - bits := (scalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (scalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { - continue - } - - // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { - // add - buckets[bits-1].addMixed(&points[i]) - } else { - // sub - buckets[bits & ^msbWindow].subMixed(&points[i]) - } - } - - // reduce buckets into total - // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] - - var runningSum, total g2JacExtended - runningSum.setInfinity() - total.setInfinity() - for k := len(buckets) - 1; k >= 0; k-- { - if !buckets[k].ZZ.IsZero() { - runningSum.add(&buckets[k]) - } - total.add(&runningSum) - } - - chRes <- total - -} - -func (p *G2Jac) msmC4(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 4 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC5(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 5 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC6(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 6 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC7(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 7 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC8(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 8 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC9(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 9 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC10(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 10 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} - -func (p *G2Jac) msmC11(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 11 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() +// msmReduceChunkG2Affine reduces the weighted sum of the buckets into the result of the multiExp +func msmReduceChunkG2Affine(p *G2Jac, c int, chChunks []chan g2JacExtended) *G2Jac { + var _p g2JacExtended + totalj := <-chChunks[len(chChunks)-1] + _p.Set(&totalj) + for j := len(chChunks) - 2; j >= 0; j-- { + for l := 0; l < c; l++ { + _p.double(&_p) + } + totalj := <-chChunks[j] + _p.add(&totalj) } - return msmReduceChunkG2Affine(p, c, chChunks[:]) + return p.unsafeFromJacExtended(&_p) } -func (p *G2Jac) msmC12(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 12 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } +// selector stores the index, mask and shifts needed to select bits from a scalar +// it is used during the multiExp algorithm or the batch scalar multiplication +type selector struct { + index uint64 // index in the multi-word scalar to select bits from + mask uint64 // mask (c-bit wide) + shift uint64 // shift needed to get our bits on low positions - return msmReduceChunkG2Affine(p, c, chChunks[:]) + multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) + maskHigh uint64 // same than mask, for index+1 + shiftHigh uint64 // same than shift, for index+1 } -func (p *G2Jac) msmC13(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 13 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) +// return number of chunks for a given window size c +// the last chunk may be bigger to accomodate a potential carry from the NAF decomposition +func computeNbChunks(c uint64) uint64 { + return (fr.Bits + c - 1) / c } -func (p *G2Jac) msmC14(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 14 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) +// return the last window size for a scalar; +// this last window should accomodate a carry (from the NAF decomposition) +// it can be == c if we have 1 available bit +// it can be > c if we have 0 available bit +// it can be < c if we have 2+ available bits +func lastC(c uint64) uint64 { + nbAvailableBits := (computeNbChunks(c) * c) - fr.Bits + return c + 1 - nbAvailableBits } -func (p *G2Jac) msmC15(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 15 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } +type chunkStat struct { + // relative weight of work compared to other chunks. 100.0 -> nominal weight. + weight float32 - return msmReduceChunkG2Affine(p, c, chChunks[:]) + // percentage of bucket filled in the window; + ppBucketFilled float32 + nbBucketFilled int } -func (p *G2Jac) msmC16(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 16 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } +// partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits +// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract +// 2^{c} to the current digit, making it negative. +// negative digits can be processed in a later step as adding -G into the bucket instead of G +// (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) +func partitionScalars(scalars []fr.Element, c uint64, nbTasks int) ([]uint16, []chunkStat) { + // number of c-bit radixes in a scalar + nbChunks := computeNbChunks(c) - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + digits := make([]uint16, len(scalars)*int(nbChunks)) - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } + mask := uint64((1 << c) - 1) // low c bits are 1 + max := int(1<<(c-1)) - 1 // max value (inclusive) we want for our digits + cDivides64 := (64 % c) == 0 // if c doesn't divide 64, we may need to select over multiple words - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + // compute offset and word selector / shift to select the right bits of our windows + selectors := make([]selector, nbChunks) + for chunk := uint64(0); chunk < nbChunks; chunk++ { + jc := uint64(chunk * c) + d := selector{} + d.index = jc / 64 + d.shift = jc - (d.index * 64) + d.mask = mask << d.shift + d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs-1) + if d.multiWordSelect { + nbBitsHigh := d.shift - uint64(64-c) + d.maskHigh = (1 << nbBitsHigh) - 1 + d.shiftHigh = (c - nbBitsHigh) + } + selectors[chunk] = d } - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} + parallel.Execute(len(scalars), func(start, end int) { + for i := start; i < end; i++ { + if scalars[i].IsZero() { + // everything is 0, no need to process this scalar + continue + } + scalar := scalars[i].Bits() -func (p *G2Jac) msmC20(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 20 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) + var carry int - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + // for each chunk in the scalar, compute the current digit, and an eventual carry + for chunk := uint64(0); chunk < nbChunks-1; chunk++ { + s := selectors[chunk] - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } + // init with carry if any + digit := carry + carry = 0 - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } + // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract + // 2^{c} to the current digit, making it negative. + if digit > max { + digit -= (1 << c) + carry = 1 + } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } + // if digit is zero, no impact on result + if digit == 0 { + continue + } - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} + var bits uint16 + if digit > 0 { + bits = uint16(digit) << 1 + } else { + bits = (uint16(-digit-1) << 1) + 1 + } + digits[int(chunk)*len(scalars)+i] = bits + } -func (p *G2Jac) msmC21(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 21 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) + // for the last chunk, we don't want to borrow from a next window + // (but may have a larger max value) + chunk := nbChunks - 1 + s := selectors[chunk] + // init with carry if any + digit := carry + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } + digits[int(chunk)*len(scalars)+i] = uint16(digit) << 1 + } - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + }, nbTasks) - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) + // aggregate chunk stats + chunkStats := make([]chunkStat, nbChunks) + if c <= 9 { + // no need to compute stats for small window sizes + return digits, chunkStats } + parallel.Execute(len(chunkStats), func(start, end int) { + // for each chunk compute the statistics + for chunkID := start; chunkID < end; chunkID++ { + // indicates if a bucket is hit. + var b bitSetC16 - // c doesn't divide 256, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) + // digits for the chunk + chunkDigits := digits[chunkID*len(scalars) : (chunkID+1)*len(scalars)] - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + totalOps := 0 + nz := 0 // non zero buckets count + for _, digit := range chunkDigits { + if digit == 0 { + continue + } + totalOps++ + bucketID := digit >> 1 + if digit&1 == 0 { + bucketID -= 1 + } + if !b[bucketID] { + nz++ + b[bucketID] = true + } + } + chunkStats[chunkID].weight = float32(totalOps) // count number of ops for now, we will compute the weight after + chunkStats[chunkID].ppBucketFilled = (float32(nz) * 100.0) / float32(int(1<<(c-1))) + chunkStats[chunkID].nbBucketFilled = nz + } + }, nbTasks) - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) + totalOps := float32(0.0) + for _, stat := range chunkStats { + totalOps += stat.weight } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + target := totalOps / float32(nbChunks) + if target != 0.0 { + // if target == 0, it means all the scalars are 0 everywhere, there is no work to be done. + for i := 0; i < len(chunkStats); i++ { + chunkStats[i].weight = (chunkStats[i].weight * 100.0) / target + } } - return msmReduceChunkG2Affine(p, c, chChunks[:]) + return digits, chunkStats } diff --git a/ecc/bn254/multiexp_affine.go b/ecc/bn254/multiexp_affine.go new file mode 100644 index 000000000..c0fd33431 --- /dev/null +++ b/ecc/bn254/multiexp_affine.go @@ -0,0 +1,688 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package bn254 + +import ( + "github.com/consensys/gnark-crypto/ecc/bn254/fp" + "github.com/consensys/gnark-crypto/ecc/bn254/internal/fptower" +) + +type batchOpG1Affine struct { + bucketID uint16 + point G1Affine +} + +// processChunkG1BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/ConsenSys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunkG1BatchAffine[BJE ibg1JacExtended, B ibG1Affine, BS bitSet, TP pG1Affine, TPP ppG1Affine, TQ qOpsG1Affine, TC cG1Affine]( + chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16) { + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // g1JacExtended coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediatly + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in G1Affine used with the batch affine additions + // 1 in g1JacExtended used in case the queue of conflicting points + var buckets B + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + bucketsJE[i].setInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize } + + executeAndReset := func() { + batchAddG1Affine[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOpG1Affine) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.setInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *G1Affine, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.setInfinity() + } + return + } + if isAdd { + BK.setInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func() { + for i := 0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func() { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit >> 1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID -= 1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue)-1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total g1JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].ZZ.IsZero() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketG1AffineC10 [512]G1Affine +type bucketG1AffineC11 [1024]G1Affine +type bucketG1AffineC12 [2048]G1Affine +type bucketG1AffineC13 [4096]G1Affine +type bucketG1AffineC14 [8192]G1Affine +type bucketG1AffineC15 [16384]G1Affine +type bucketG1AffineC16 [32768]G1Affine + +// buckets: array of G1Affine points of size 1 << (c-1) +type ibG1Affine interface { + bucketG1AffineC10 | + bucketG1AffineC11 | + bucketG1AffineC12 | + bucketG1AffineC13 | + bucketG1AffineC14 | + bucketG1AffineC15 | + bucketG1AffineC16 +} + +// array of coordinates fp.Element +type cG1Affine interface { + cG1AffineC10 | + cG1AffineC11 | + cG1AffineC12 | + cG1AffineC13 | + cG1AffineC14 | + cG1AffineC15 | + cG1AffineC16 +} + +// buckets: array of G1Affine points (for the batch addition) +type pG1Affine interface { + pG1AffineC10 | + pG1AffineC11 | + pG1AffineC12 | + pG1AffineC13 | + pG1AffineC14 | + pG1AffineC15 | + pG1AffineC16 +} + +// buckets: array of *G1Affine points (for the batch addition) +type ppG1Affine interface { + ppG1AffineC10 | + ppG1AffineC11 | + ppG1AffineC12 | + ppG1AffineC13 | + ppG1AffineC14 | + ppG1AffineC15 | + ppG1AffineC16 +} + +// buckets: array of G1Affine queue operations (for the batch addition) +type qOpsG1Affine interface { + qG1AffineC10 | + qG1AffineC11 | + qG1AffineC12 | + qG1AffineC13 | + qG1AffineC14 | + qG1AffineC15 | + qG1AffineC16 +} + +// batch size 80 when c = 10 +type cG1AffineC10 [80]fp.Element +type pG1AffineC10 [80]G1Affine +type ppG1AffineC10 [80]*G1Affine +type qG1AffineC10 [80]batchOpG1Affine + +// batch size 150 when c = 11 +type cG1AffineC11 [150]fp.Element +type pG1AffineC11 [150]G1Affine +type ppG1AffineC11 [150]*G1Affine +type qG1AffineC11 [150]batchOpG1Affine + +// batch size 200 when c = 12 +type cG1AffineC12 [200]fp.Element +type pG1AffineC12 [200]G1Affine +type ppG1AffineC12 [200]*G1Affine +type qG1AffineC12 [200]batchOpG1Affine + +// batch size 350 when c = 13 +type cG1AffineC13 [350]fp.Element +type pG1AffineC13 [350]G1Affine +type ppG1AffineC13 [350]*G1Affine +type qG1AffineC13 [350]batchOpG1Affine + +// batch size 400 when c = 14 +type cG1AffineC14 [400]fp.Element +type pG1AffineC14 [400]G1Affine +type ppG1AffineC14 [400]*G1Affine +type qG1AffineC14 [400]batchOpG1Affine + +// batch size 500 when c = 15 +type cG1AffineC15 [500]fp.Element +type pG1AffineC15 [500]G1Affine +type ppG1AffineC15 [500]*G1Affine +type qG1AffineC15 [500]batchOpG1Affine + +// batch size 640 when c = 16 +type cG1AffineC16 [640]fp.Element +type pG1AffineC16 [640]G1Affine +type ppG1AffineC16 [640]*G1Affine +type qG1AffineC16 [640]batchOpG1Affine + +type batchOpG2Affine struct { + bucketID uint16 + point G2Affine +} + +// processChunkG2BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/ConsenSys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunkG2BatchAffine[BJE ibg2JacExtended, B ibG2Affine, BS bitSet, TP pG2Affine, TPP ppG2Affine, TQ qOpsG2Affine, TC cG2Affine]( + chunk uint64, + chRes chan<- g2JacExtended, + c uint64, + points []G2Affine, + digits []uint16) { + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // g2JacExtended coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediatly + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in G2Affine used with the batch affine additions + // 1 in g2JacExtended used in case the queue of conflicting points + var buckets B + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + bucketsJE[i].setInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize } + + executeAndReset := func() { + batchAddG2Affine[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOpG2Affine) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.setInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *G2Affine, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.setInfinity() + } + return + } + if isAdd { + BK.setInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func() { + for i := 0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func() { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit >> 1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID -= 1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue)-1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total g2JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].ZZ.IsZero() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketG2AffineC10 [512]G2Affine +type bucketG2AffineC11 [1024]G2Affine +type bucketG2AffineC12 [2048]G2Affine +type bucketG2AffineC13 [4096]G2Affine +type bucketG2AffineC14 [8192]G2Affine +type bucketG2AffineC15 [16384]G2Affine +type bucketG2AffineC16 [32768]G2Affine + +// buckets: array of G2Affine points of size 1 << (c-1) +type ibG2Affine interface { + bucketG2AffineC10 | + bucketG2AffineC11 | + bucketG2AffineC12 | + bucketG2AffineC13 | + bucketG2AffineC14 | + bucketG2AffineC15 | + bucketG2AffineC16 +} + +// array of coordinates fptower.E2 +type cG2Affine interface { + cG2AffineC10 | + cG2AffineC11 | + cG2AffineC12 | + cG2AffineC13 | + cG2AffineC14 | + cG2AffineC15 | + cG2AffineC16 +} + +// buckets: array of G2Affine points (for the batch addition) +type pG2Affine interface { + pG2AffineC10 | + pG2AffineC11 | + pG2AffineC12 | + pG2AffineC13 | + pG2AffineC14 | + pG2AffineC15 | + pG2AffineC16 +} + +// buckets: array of *G2Affine points (for the batch addition) +type ppG2Affine interface { + ppG2AffineC10 | + ppG2AffineC11 | + ppG2AffineC12 | + ppG2AffineC13 | + ppG2AffineC14 | + ppG2AffineC15 | + ppG2AffineC16 +} + +// buckets: array of G2Affine queue operations (for the batch addition) +type qOpsG2Affine interface { + qG2AffineC10 | + qG2AffineC11 | + qG2AffineC12 | + qG2AffineC13 | + qG2AffineC14 | + qG2AffineC15 | + qG2AffineC16 +} + +// batch size 80 when c = 10 +type cG2AffineC10 [80]fptower.E2 +type pG2AffineC10 [80]G2Affine +type ppG2AffineC10 [80]*G2Affine +type qG2AffineC10 [80]batchOpG2Affine + +// batch size 150 when c = 11 +type cG2AffineC11 [150]fptower.E2 +type pG2AffineC11 [150]G2Affine +type ppG2AffineC11 [150]*G2Affine +type qG2AffineC11 [150]batchOpG2Affine + +// batch size 200 when c = 12 +type cG2AffineC12 [200]fptower.E2 +type pG2AffineC12 [200]G2Affine +type ppG2AffineC12 [200]*G2Affine +type qG2AffineC12 [200]batchOpG2Affine + +// batch size 350 when c = 13 +type cG2AffineC13 [350]fptower.E2 +type pG2AffineC13 [350]G2Affine +type ppG2AffineC13 [350]*G2Affine +type qG2AffineC13 [350]batchOpG2Affine + +// batch size 400 when c = 14 +type cG2AffineC14 [400]fptower.E2 +type pG2AffineC14 [400]G2Affine +type ppG2AffineC14 [400]*G2Affine +type qG2AffineC14 [400]batchOpG2Affine + +// batch size 500 when c = 15 +type cG2AffineC15 [500]fptower.E2 +type pG2AffineC15 [500]G2Affine +type ppG2AffineC15 [500]*G2Affine +type qG2AffineC15 [500]batchOpG2Affine + +// batch size 640 when c = 16 +type cG2AffineC16 [640]fptower.E2 +type pG2AffineC16 [640]G2Affine +type ppG2AffineC16 [640]*G2Affine +type qG2AffineC16 [640]batchOpG2Affine + +type bitSetC2 [2]bool +type bitSetC3 [4]bool +type bitSetC4 [8]bool +type bitSetC5 [16]bool +type bitSetC6 [32]bool +type bitSetC7 [64]bool +type bitSetC8 [128]bool +type bitSetC9 [256]bool +type bitSetC10 [512]bool +type bitSetC11 [1024]bool +type bitSetC12 [2048]bool +type bitSetC13 [4096]bool +type bitSetC14 [8192]bool +type bitSetC15 [16384]bool +type bitSetC16 [32768]bool + +type bitSet interface { + bitSetC2 | + bitSetC3 | + bitSetC4 | + bitSetC5 | + bitSetC6 | + bitSetC7 | + bitSetC8 | + bitSetC9 | + bitSetC10 | + bitSetC11 | + bitSetC12 | + bitSetC13 | + bitSetC14 | + bitSetC15 | + bitSetC16 +} diff --git a/ecc/bn254/multiexp_jacobian.go b/ecc/bn254/multiexp_jacobian.go new file mode 100644 index 000000000..0bd2482a9 --- /dev/null +++ b/ecc/bn254/multiexp_jacobian.go @@ -0,0 +1,175 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package bn254 + +func processChunkG1Jacobian[B ibg1JacExtended](chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16) { + + var buckets B + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to substract + if digit&1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit >> 1)].subMixed(&points[i]) + } + } + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total g1JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].ZZ.IsZero() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketg1JacExtendedC2 [2]g1JacExtended +type bucketg1JacExtendedC3 [4]g1JacExtended +type bucketg1JacExtendedC4 [8]g1JacExtended +type bucketg1JacExtendedC5 [16]g1JacExtended +type bucketg1JacExtendedC6 [32]g1JacExtended +type bucketg1JacExtendedC7 [64]g1JacExtended +type bucketg1JacExtendedC8 [128]g1JacExtended +type bucketg1JacExtendedC9 [256]g1JacExtended +type bucketg1JacExtendedC10 [512]g1JacExtended +type bucketg1JacExtendedC11 [1024]g1JacExtended +type bucketg1JacExtendedC12 [2048]g1JacExtended +type bucketg1JacExtendedC13 [4096]g1JacExtended +type bucketg1JacExtendedC14 [8192]g1JacExtended +type bucketg1JacExtendedC15 [16384]g1JacExtended +type bucketg1JacExtendedC16 [32768]g1JacExtended + +type ibg1JacExtended interface { + bucketg1JacExtendedC2 | + bucketg1JacExtendedC3 | + bucketg1JacExtendedC4 | + bucketg1JacExtendedC5 | + bucketg1JacExtendedC6 | + bucketg1JacExtendedC7 | + bucketg1JacExtendedC8 | + bucketg1JacExtendedC9 | + bucketg1JacExtendedC10 | + bucketg1JacExtendedC11 | + bucketg1JacExtendedC12 | + bucketg1JacExtendedC13 | + bucketg1JacExtendedC14 | + bucketg1JacExtendedC15 | + bucketg1JacExtendedC16 +} + +func processChunkG2Jacobian[B ibg2JacExtended](chunk uint64, + chRes chan<- g2JacExtended, + c uint64, + points []G2Affine, + digits []uint16) { + + var buckets B + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to substract + if digit&1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit >> 1)].subMixed(&points[i]) + } + } + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total g2JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].ZZ.IsZero() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketg2JacExtendedC2 [2]g2JacExtended +type bucketg2JacExtendedC3 [4]g2JacExtended +type bucketg2JacExtendedC4 [8]g2JacExtended +type bucketg2JacExtendedC5 [16]g2JacExtended +type bucketg2JacExtendedC6 [32]g2JacExtended +type bucketg2JacExtendedC7 [64]g2JacExtended +type bucketg2JacExtendedC8 [128]g2JacExtended +type bucketg2JacExtendedC9 [256]g2JacExtended +type bucketg2JacExtendedC10 [512]g2JacExtended +type bucketg2JacExtendedC11 [1024]g2JacExtended +type bucketg2JacExtendedC12 [2048]g2JacExtended +type bucketg2JacExtendedC13 [4096]g2JacExtended +type bucketg2JacExtendedC14 [8192]g2JacExtended +type bucketg2JacExtendedC15 [16384]g2JacExtended +type bucketg2JacExtendedC16 [32768]g2JacExtended + +type ibg2JacExtended interface { + bucketg2JacExtendedC2 | + bucketg2JacExtendedC3 | + bucketg2JacExtendedC4 | + bucketg2JacExtendedC5 | + bucketg2JacExtendedC6 | + bucketg2JacExtendedC7 | + bucketg2JacExtendedC8 | + bucketg2JacExtendedC9 | + bucketg2JacExtendedC10 | + bucketg2JacExtendedC11 | + bucketg2JacExtendedC12 | + bucketg2JacExtendedC13 | + bucketg2JacExtendedC14 | + bucketg2JacExtendedC15 | + bucketg2JacExtendedC16 +} diff --git a/ecc/bn254/multiexp_test.go b/ecc/bn254/multiexp_test.go index 12055f3eb..d1b70087c 100644 --- a/ecc/bn254/multiexp_test.go +++ b/ecc/bn254/multiexp_test.go @@ -20,9 +20,11 @@ import ( "fmt" "math/big" "math/bits" + "math/rand" "runtime" "sync" "testing" + "time" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bn254/fr" @@ -34,9 +36,9 @@ func TestMultiExpG1(t *testing.T) { parameters := gopter.DefaultTestParameters() if testing.Short() { - parameters.MinSuccessfulTests = 2 + parameters.MinSuccessfulTests = 3 } else { - parameters.MinSuccessfulTests = nbFuzzShort + parameters.MinSuccessfulTests = nbFuzzShort * 2 } properties := gopter.NewProperties(parameters) @@ -55,6 +57,13 @@ func TestMultiExpG1(t *testing.T) { g.AddAssign(&g1Gen) } + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + // final scalar to use in double and add method (without mixer factor) // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) var scalar big.Int @@ -78,13 +87,10 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - scalars16, _ := partitionScalars(sampleScalars[:], 16, false, runtime.NumCPU()) - r16.msmC16(samplePoints[:], scalars16, true) - + r16.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) return r16.Equal(&splitted1) && r16.Equal(&splitted2) @@ -93,7 +99,7 @@ func TestMultiExpG1(t *testing.T) { )) // cRange is generated from template and contains the available parameters for the multiexp window size - cRange := []uint64{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21} + cRange := []uint64{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} if testing.Short() { // test only "odd" and "even" (ie windows size divide word size vs not) cRange = []uint64{5, 16} @@ -114,21 +120,72 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - results := make([]G1Jac, len(cRange)+1) + results := make([]G1Jac, len(cRange)) for i, c := range cRange { - scalars, _ := partitionScalars(sampleScalars[:], c, false, runtime.NumCPU()) - msmInnerG1Jac(&results[i], int(c), samplePoints[:], scalars, false) - if c == 16 { - // split the first chunk - msmInnerG1Jac(&results[len(results)-1], 16, samplePoints[:], scalars, true) - } + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) } for i := 1; i < len(results); i++ { if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1], cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]G1Affine + + var expected G1Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) + expected.ScalarMultiplication(&g1Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].setInfinity() + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) return false } } @@ -151,8 +208,7 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= 30; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) samplePoints[i-1].FromJacobian(&g) g.AddAssign(&g1Gen) } @@ -175,6 +231,87 @@ func TestMultiExpG1(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestCrossMultiExpG1(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]G1Affine + var g G1Jac + g.Set(&g1Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g1Gen) + } + + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i := 10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + cRange := []uint64{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} + if testing.Short() { + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange = []uint64{5, 16} + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r G1Jac + _innerMsmG1Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got G1Affine + expected.FromJacobian(&r) + + for i := 0; i < len(results); i++ { + got.FromJacobian(&results[i]) + if !expected.Equal(&got) { + t.Fatalf("cross msm failed with c=%d", cRange[i]) + } + } + +} + +// _innerMsmG1Reference always do ext jacobian with c == 16 +func _innerMsmG1Reference(p *G1Jac, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { + // partition the scalars + digits, _ := partitionScalars(scalars, 16, config.NbTasks) + + nbChunks := computeNbChunks(16) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := processChunkG1Jacobian[bucketg1JacExtendedC16] + go processChunk(uint64(j), chChunks[j], 16, points, digits[j*n:(j+1)*n]) + } + + return msmReduceChunkG1Affine(p, int(16), chChunks[:]) +} + func BenchmarkMultiExpG1(b *testing.B) { const ( @@ -183,11 +320,33 @@ func BenchmarkMultiExpG1(b *testing.B) { ) var ( - samplePoints [nbSamples]G1Affine - sampleScalars [nbSamples]fr.Element + samplePoints [nbSamples]G1Affine + sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element ) fillBenchScalars(sampleScalars[:]) + copy(sampleScalarsSmallValues[:], sampleScalars[:]) + copy(sampleScalarsRedundant[:], sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i := 0; i < len(sampleScalarsSmallValues); i++ { + if i%5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i := 0; i < len(sampleScalarsRedundant); i += 100 { + for j := i + 1; j < i+100 && j < len(sampleScalarsRedundant); j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + fillBenchBasesG1(samplePoints[:]) var testPoint G1Affine @@ -201,6 +360,20 @@ func BenchmarkMultiExpG1(b *testing.B) { testPoint.MultiExp(samplePoints[:using], sampleScalars[:using], ecc.MultiExpConfig{}) } }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using], ecc.MultiExpConfig{}) + } + }) } } @@ -279,9 +452,9 @@ func TestMultiExpG2(t *testing.T) { parameters := gopter.DefaultTestParameters() if testing.Short() { - parameters.MinSuccessfulTests = 2 + parameters.MinSuccessfulTests = 3 } else { - parameters.MinSuccessfulTests = nbFuzzShort + parameters.MinSuccessfulTests = nbFuzzShort * 2 } properties := gopter.NewProperties(parameters) @@ -300,6 +473,13 @@ func TestMultiExpG2(t *testing.T) { g.AddAssign(&g2Gen) } + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + // final scalar to use in double and add method (without mixer factor) // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) var scalar big.Int @@ -323,13 +503,10 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - scalars16, _ := partitionScalars(sampleScalars[:], 16, false, runtime.NumCPU()) - r16.msmC16(samplePoints[:], scalars16, true) - + r16.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) return r16.Equal(&splitted1) && r16.Equal(&splitted2) @@ -357,21 +534,72 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - results := make([]G2Jac, len(cRange)+1) + results := make([]G2Jac, len(cRange)) for i, c := range cRange { - scalars, _ := partitionScalars(sampleScalars[:], c, false, runtime.NumCPU()) - msmInnerG2Jac(&results[i], int(c), samplePoints[:], scalars, false) - if c == 16 { - // split the first chunk - msmInnerG2Jac(&results[len(results)-1], 16, samplePoints[:], scalars, true) - } + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) } for i := 1; i < len(results); i++ { if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1], cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G2] Multi exponentation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]G2Affine + + var expected G2Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) + expected.ScalarMultiplication(&g2Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].setInfinity() + } + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G2] Multi exponentation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) return false } } @@ -394,8 +622,7 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= 30; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) samplePoints[i-1].FromJacobian(&g) g.AddAssign(&g2Gen) } @@ -418,6 +645,85 @@ func TestMultiExpG2(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestCrossMultiExpG2(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]G2Affine + var g G2Jac + g.Set(&g2Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g2Gen) + } + + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i := 10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + // for g2, CI suffers with large c size since it needs to allocate a lot of memory for the buckets. + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange := []uint64{5, 16} + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r G2Jac + _innerMsmG2Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got G2Affine + expected.FromJacobian(&r) + + for i := 0; i < len(results); i++ { + got.FromJacobian(&results[i]) + if !expected.Equal(&got) { + t.Fatalf("cross msm failed with c=%d", cRange[i]) + } + } + +} + +// _innerMsmG2Reference always do ext jacobian with c == 16 +func _innerMsmG2Reference(p *G2Jac, points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G2Jac { + // partition the scalars + digits, _ := partitionScalars(scalars, 16, config.NbTasks) + + nbChunks := computeNbChunks(16) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g2JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g2JacExtended, 1) + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := processChunkG2Jacobian[bucketg2JacExtendedC16] + go processChunk(uint64(j), chChunks[j], 16, points, digits[j*n:(j+1)*n]) + } + + return msmReduceChunkG2Affine(p, int(16), chChunks[:]) +} + func BenchmarkMultiExpG2(b *testing.B) { const ( @@ -426,11 +732,33 @@ func BenchmarkMultiExpG2(b *testing.B) { ) var ( - samplePoints [nbSamples]G2Affine - sampleScalars [nbSamples]fr.Element + samplePoints [nbSamples]G2Affine + sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element ) fillBenchScalars(sampleScalars[:]) + copy(sampleScalarsSmallValues[:], sampleScalars[:]) + copy(sampleScalarsRedundant[:], sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i := 0; i < len(sampleScalarsSmallValues); i++ { + if i%5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i := 0; i < len(sampleScalarsRedundant); i += 100 { + for j := i + 1; j < i+100 && j < len(sampleScalarsRedundant); j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + fillBenchBasesG2(samplePoints[:]) var testPoint G2Affine @@ -444,6 +772,20 @@ func BenchmarkMultiExpG2(b *testing.B) { testPoint.MultiExp(samplePoints[:using], sampleScalars[:using], ecc.MultiExpConfig{}) } }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using], ecc.MultiExpConfig{}) + } + }) } } @@ -520,11 +862,7 @@ func fillBenchBasesG2(samplePoints []G2Affine) { func fillBenchScalars(sampleScalars []fr.Element) { // ensure every words of the scalars are filled - var mixer fr.Element - mixer.SetString("7716837800905789770901243404444209691916730933998574719964609384059111546487") - for i := 1; i <= len(sampleScalars); i++ { - sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + for i := 0; i < len(sampleScalars); i++ { + sampleScalars[i].SetRandom() } } diff --git a/ecc/bn254/twistededwards/eddsa/doc.go b/ecc/bn254/twistededwards/eddsa/doc.go index ca822aea8..27ff44aca 100644 --- a/ecc/bn254/twistededwards/eddsa/doc.go +++ b/ecc/bn254/twistededwards/eddsa/doc.go @@ -16,7 +16,7 @@ // Package eddsa provides EdDSA signature scheme on bn254's twisted edwards curve. // -// See also +// # See also // // https://en.wikipedia.org/wiki/EdDSA package eddsa diff --git a/ecc/bn254/twistededwards/eddsa/eddsa_test.go b/ecc/bn254/twistededwards/eddsa/eddsa_test.go index c05a995a8..3bf8c9d98 100644 --- a/ecc/bn254/twistededwards/eddsa/eddsa_test.go +++ b/ecc/bn254/twistededwards/eddsa/eddsa_test.go @@ -37,8 +37,10 @@ func Example() { privateKey, _ := GenerateKey(crand.Reader) publicKey := privateKey.PublicKey - // note that the message is on 4 bytes - msg := []byte{0xde, 0xad, 0xf0, 0x0d} + // generate a message (the size must be a multiple of the size of Fr) + var _msg fr.Element + _msg.SetRandom() + msg := _msg.Marshal() // sign the message signature, _ := privateKey.Sign(msg, hFunc) diff --git a/ecc/bn254/twistededwards/eddsa/marshal.go b/ecc/bn254/twistededwards/eddsa/marshal.go index c68129087..485e9b710 100644 --- a/ecc/bn254/twistededwards/eddsa/marshal.go +++ b/ecc/bn254/twistededwards/eddsa/marshal.go @@ -94,11 +94,11 @@ func (privKey *PrivateKey) SetBytes(buf []byte) (int, error) { // Bytes returns the binary representation of sig // as a byte array of size 3*sizeFr x||y||s where -// * x, y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x, y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) func (sig *Signature) Bytes() []byte { var res [sizeSignature]byte sigRBin := sig.R.Bytes() @@ -109,11 +109,12 @@ func (sig *Signature) Bytes() []byte { // SetBytes sets sig from a buffer in binary. // buf is read interpreted as x||y||s where -// * x,y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x,y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) +// // It returns the number of bytes read from buf. func (sig *Signature) SetBytes(buf []byte) (int, error) { n := 0 diff --git a/ecc/bn254/twistededwards/point.go b/ecc/bn254/twistededwards/point.go index 427e5cb27..09303663c 100644 --- a/ecc/bn254/twistededwards/point.go +++ b/ecc/bn254/twistededwards/point.go @@ -49,7 +49,7 @@ const ( mUnmask = 0x7f // size in byte of a compressed point (point.Y --> fr.Element) - sizePointCompressed = fr.Limbs * 8 + sizePointCompressed = fr.Bytes ) // Bytes returns the compressed point as a byte array diff --git a/ecc/bw6-633/bw6-633.go b/ecc/bw6-633/bw6-633.go index 5390374c8..9fb439325 100644 --- a/ecc/bw6-633/bw6-633.go +++ b/ecc/bw6-633/bw6-633.go @@ -1,22 +1,28 @@ // Package bw6633 efficient elliptic curve, pairing and hash to curve implementation for bw6-633. // // bw6-633: A Brezing--Weng curve (2-chain with bls24-315) -// embedding degree k=6 -// seed x₀=-3218079743 -// 𝔽p: p=20494478644167774678813387386538961497669590920908778075528754551012016751717791778743535050360001387419576570244406805463255765034468441182772056330021723098661967429339971741066259394985997 -// 𝔽r: r=39705142709513438335025689890408969744933502416914749335064285505637884093126342347073617133569 -// (E/𝔽p): Y²=X³+4 -// (Eₜ/𝔽p): Y² = X³+8 (M-type twist) -// r ∣ #E(Fp) and r ∣ #Eₜ(𝔽p) +// +// embedding degree k=6 +// seed x₀=-3218079743 +// 𝔽p: p=20494478644167774678813387386538961497669590920908778075528754551012016751717791778743535050360001387419576570244406805463255765034468441182772056330021723098661967429339971741066259394985997 +// 𝔽r: r=39705142709513438335025689890408969744933502416914749335064285505637884093126342347073617133569 +// (E/𝔽p): Y²=X³+4 +// (Eₜ/𝔽p): Y² = X³+8 (M-type twist) +// r ∣ #E(Fp) and r ∣ #Eₜ(𝔽p) +// // Extension fields tower: -// 𝔽p³[u] = 𝔽p/u³-2 -// 𝔽p⁶[v] = 𝔽p²/v²-u +// +// 𝔽p³[u] = 𝔽p/u³-2 +// 𝔽p⁶[v] = 𝔽p²/v²-u +// // optimal Ate loops: -// x₀+1, x₀^5-x₀^4-x₀ +// +// x₀+1, x₀^5-x₀^4-x₀ +// // Security: estimated 124-bit level following [https://eprint.iacr.org/2019/885.pdf] // (r is 315 bits and p⁶ is 3798 bits) // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package bw6633 diff --git a/ecc/bw6-633/fp/doc.go b/ecc/bw6-633/fp/doc.go index 56f6b617b..ac522839d 100644 --- a/ecc/bw6-633/fp/doc.go +++ b/ecc/bw6-633/fp/doc.go @@ -21,30 +21,33 @@ // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [10]uint64 // -// Usage +// type Element [10]uint64 +// +// # Usage // // Example API signature: -// // Mul z = x * y (mod q) -// func (z *Element) Mul(x, y *Element) *Element +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) // // Modulus q = // -// q[base10] = 20494478644167774678813387386538961497669590920908778075528754551012016751717791778743535050360001387419576570244406805463255765034468441182772056330021723098661967429339971741066259394985997 -// q[base16] = 0x126633cc0f35f63fc1a174f01d72ab5a8fcd8c75d79d2c74e59769ad9bbda2f8152a6c0fadea490b8da9f5e83f57c497e0e8850edbda407d7b5ce7ab839c2253d369bd31147f73cd74916ea4570000d +// q[base10] = 20494478644167774678813387386538961497669590920908778075528754551012016751717791778743535050360001387419576570244406805463255765034468441182772056330021723098661967429339971741066259394985997 +// q[base16] = 0x126633cc0f35f63fc1a174f01d72ab5a8fcd8c75d79d2c74e59769ad9bbda2f8152a6c0fadea490b8da9f5e83f57c497e0e8850edbda407d7b5ce7ab839c2253d369bd31147f73cd74916ea4570000d // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package fp diff --git a/ecc/bw6-633/fp/element.go b/ecc/bw6-633/fp/element.go index 1f571e1d9..8f3355e98 100644 --- a/ecc/bw6-633/fp/element.go +++ b/ecc/bw6-633/fp/element.go @@ -20,13 +20,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "github.com/consensys/gnark-crypto/field" "io" "math/big" "math/bits" "reflect" "strconv" "strings" - "sync" ) // Element represents a field element stored on 10 words (uint64) @@ -35,18 +35,18 @@ import ( // // Modulus q = // -// q[base10] = 20494478644167774678813387386538961497669590920908778075528754551012016751717791778743535050360001387419576570244406805463255765034468441182772056330021723098661967429339971741066259394985997 -// q[base16] = 0x126633cc0f35f63fc1a174f01d72ab5a8fcd8c75d79d2c74e59769ad9bbda2f8152a6c0fadea490b8da9f5e83f57c497e0e8850edbda407d7b5ce7ab839c2253d369bd31147f73cd74916ea4570000d +// q[base10] = 20494478644167774678813387386538961497669590920908778075528754551012016751717791778743535050360001387419576570244406805463255765034468441182772056330021723098661967429339971741066259394985997 +// q[base16] = 0x126633cc0f35f63fc1a174f01d72ab5a8fcd8c75d79d2c74e59769ad9bbda2f8152a6c0fadea490b8da9f5e83f57c497e0e8850edbda407d7b5ce7ab839c2253d369bd31147f73cd74916ea4570000d // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [10]uint64 const ( - Limbs = 10 // number of 64 bits words needed to represent a Element - Bits = 633 // number of bits needed to represent a Element - Bytes = Limbs * 8 // number of bytes needed to represent a Element + Limbs = 10 // number of 64 bits words needed to represent a Element + Bits = 633 // number of bits needed to represent a Element + Bytes = 80 // number of bytes needed to represent a Element ) // Field modulus q @@ -80,8 +80,8 @@ var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int // -// q[base10] = 20494478644167774678813387386538961497669590920908778075528754551012016751717791778743535050360001387419576570244406805463255765034468441182772056330021723098661967429339971741066259394985997 -// q[base16] = 0x126633cc0f35f63fc1a174f01d72ab5a8fcd8c75d79d2c74e59769ad9bbda2f8152a6c0fadea490b8da9f5e83f57c497e0e8850edbda407d7b5ce7ab839c2253d369bd31147f73cd74916ea4570000d +// q[base10] = 20494478644167774678813387386538961497669590920908778075528754551012016751717791778743535050360001387419576570244406805463255765034468441182772056330021723098661967429339971741066259394985997 +// q[base16] = 0x126633cc0f35f63fc1a174f01d72ab5a8fcd8c75d79d2c74e59769ad9bbda2f8152a6c0fadea490b8da9f5e83f57c497e0e8850edbda407d7b5ce7ab839c2253d369bd31147f73cd74916ea4570000d func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } @@ -90,12 +90,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 13046692460116554043 -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("126633cc0f35f63fc1a174f01d72ab5a8fcd8c75d79d2c74e59769ad9bbda2f8152a6c0fadea490b8da9f5e83f57c497e0e8850edbda407d7b5ce7ab839c2253d369bd31147f73cd74916ea4570000d", 16) } @@ -103,8 +97,9 @@ func init() { // NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v Element -// v.SetUint64(...) +// +// var v Element +// v.SetUint64(...) func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) @@ -115,7 +110,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -151,14 +146,15 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported // supported types: -// Element -// *Element -// uint64 -// int -// string (see SetString for valid formats) -// *big.Int -// big.Int -// []byte +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { if i1 == nil { return nil, errors.New("can't set fp.Element with ") @@ -280,15 +276,13 @@ func (z *Element) IsOne() bool { // IsUint64 reports whether z can be represented as an uint64. func (z *Element) IsUint64() bool { zz := *z - zz.FromMont() + zz.fromMont() return zz.FitsOnOneWord() } // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -300,15 +294,12 @@ func (z *Element) FitsOnOneWord() bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[9] > _x[9] { return 1 } else if _z[9] < _x[9] { @@ -369,8 +360,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 7756477793448755207, 0) @@ -485,67 +475,9 @@ func (z *Element) Halve() { } -// Mul z = x * y (mod q) -// -// x and y must be strictly inferior to q -func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be strictly inferior to q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } @@ -565,7 +497,7 @@ func (z *Element) Add(x, y *Element) *Element { z[8], carry = bits.Add64(x[8], y[8], carry) z[9], _ = bits.Add64(x[9], y[9], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -597,7 +529,7 @@ func (z *Element) Double(x *Element) *Element { z[8], carry = bits.Add64(x[8], x[8], carry) z[9], _ = bits.Add64(x[9], x[9], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -680,263 +612,411 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // see Mul for algorithm documentation - var t [10]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd1(v, y[4], c[1]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd1(v, y[5], c[1]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd1(v, y[6], c[1]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd1(v, y[7], c[1]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd1(v, y[8], c[1]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd1(v, y[9], c[1]) - t[9], t[8] = madd3(m, q9, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - t[9], t[8] = madd3(m, q9, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - t[9], t[8] = madd3(m, q9, c[0], c[2], c[1]) - } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - t[9], t[8] = madd3(m, q9, c[0], c[2], c[1]) - } - { - // round 4 - v := x[4] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - t[9], t[8] = madd3(m, q9, c[0], c[2], c[1]) - } - { - // round 5 - v := x[5] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - t[9], t[8] = madd3(m, q9, c[0], c[2], c[1]) - } - { - // round 6 - v := x[6] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - t[9], t[8] = madd3(m, q9, c[0], c[2], c[1]) - } - { - // round 7 - v := x[7] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - t[9], t[8] = madd3(m, q9, c[0], c[2], c[1]) - } - { - // round 8 - v := x[8] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - t[9], t[8] = madd3(m, q9, c[0], c[2], c[1]) - } - { - // round 9 - v := x[9] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], z[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], z[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], z[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], z[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], z[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], z[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - z[9], z[8] = madd3(m, q9, c[0], c[2], c[1]) - } - - // if z >= q → z -= q + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [11]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + C, t[4] = madd1(y[0], x[4], C) + C, t[5] = madd1(y[0], x[5], C) + C, t[6] = madd1(y[0], x[6], C) + C, t[7] = madd1(y[0], x[7], C) + C, t[8] = madd1(y[0], x[8], C) + C, t[9] = madd1(y[0], x[9], C) + + t[10], D = bits.Add64(t[10], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + + t[9], C = bits.Add64(t[10], C, 0) + t[10], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + C, t[4] = madd2(y[1], x[4], t[4], C) + C, t[5] = madd2(y[1], x[5], t[5], C) + C, t[6] = madd2(y[1], x[6], t[6], C) + C, t[7] = madd2(y[1], x[7], t[7], C) + C, t[8] = madd2(y[1], x[8], t[8], C) + C, t[9] = madd2(y[1], x[9], t[9], C) + + t[10], D = bits.Add64(t[10], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + + t[9], C = bits.Add64(t[10], C, 0) + t[10], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + C, t[4] = madd2(y[2], x[4], t[4], C) + C, t[5] = madd2(y[2], x[5], t[5], C) + C, t[6] = madd2(y[2], x[6], t[6], C) + C, t[7] = madd2(y[2], x[7], t[7], C) + C, t[8] = madd2(y[2], x[8], t[8], C) + C, t[9] = madd2(y[2], x[9], t[9], C) + + t[10], D = bits.Add64(t[10], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + + t[9], C = bits.Add64(t[10], C, 0) + t[10], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + C, t[4] = madd2(y[3], x[4], t[4], C) + C, t[5] = madd2(y[3], x[5], t[5], C) + C, t[6] = madd2(y[3], x[6], t[6], C) + C, t[7] = madd2(y[3], x[7], t[7], C) + C, t[8] = madd2(y[3], x[8], t[8], C) + C, t[9] = madd2(y[3], x[9], t[9], C) + + t[10], D = bits.Add64(t[10], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + + t[9], C = bits.Add64(t[10], C, 0) + t[10], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[4], x[0], t[0]) + C, t[1] = madd2(y[4], x[1], t[1], C) + C, t[2] = madd2(y[4], x[2], t[2], C) + C, t[3] = madd2(y[4], x[3], t[3], C) + C, t[4] = madd2(y[4], x[4], t[4], C) + C, t[5] = madd2(y[4], x[5], t[5], C) + C, t[6] = madd2(y[4], x[6], t[6], C) + C, t[7] = madd2(y[4], x[7], t[7], C) + C, t[8] = madd2(y[4], x[8], t[8], C) + C, t[9] = madd2(y[4], x[9], t[9], C) + + t[10], D = bits.Add64(t[10], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + + t[9], C = bits.Add64(t[10], C, 0) + t[10], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[5], x[0], t[0]) + C, t[1] = madd2(y[5], x[1], t[1], C) + C, t[2] = madd2(y[5], x[2], t[2], C) + C, t[3] = madd2(y[5], x[3], t[3], C) + C, t[4] = madd2(y[5], x[4], t[4], C) + C, t[5] = madd2(y[5], x[5], t[5], C) + C, t[6] = madd2(y[5], x[6], t[6], C) + C, t[7] = madd2(y[5], x[7], t[7], C) + C, t[8] = madd2(y[5], x[8], t[8], C) + C, t[9] = madd2(y[5], x[9], t[9], C) + + t[10], D = bits.Add64(t[10], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + + t[9], C = bits.Add64(t[10], C, 0) + t[10], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[6], x[0], t[0]) + C, t[1] = madd2(y[6], x[1], t[1], C) + C, t[2] = madd2(y[6], x[2], t[2], C) + C, t[3] = madd2(y[6], x[3], t[3], C) + C, t[4] = madd2(y[6], x[4], t[4], C) + C, t[5] = madd2(y[6], x[5], t[5], C) + C, t[6] = madd2(y[6], x[6], t[6], C) + C, t[7] = madd2(y[6], x[7], t[7], C) + C, t[8] = madd2(y[6], x[8], t[8], C) + C, t[9] = madd2(y[6], x[9], t[9], C) + + t[10], D = bits.Add64(t[10], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + + t[9], C = bits.Add64(t[10], C, 0) + t[10], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[7], x[0], t[0]) + C, t[1] = madd2(y[7], x[1], t[1], C) + C, t[2] = madd2(y[7], x[2], t[2], C) + C, t[3] = madd2(y[7], x[3], t[3], C) + C, t[4] = madd2(y[7], x[4], t[4], C) + C, t[5] = madd2(y[7], x[5], t[5], C) + C, t[6] = madd2(y[7], x[6], t[6], C) + C, t[7] = madd2(y[7], x[7], t[7], C) + C, t[8] = madd2(y[7], x[8], t[8], C) + C, t[9] = madd2(y[7], x[9], t[9], C) + + t[10], D = bits.Add64(t[10], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + + t[9], C = bits.Add64(t[10], C, 0) + t[10], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[8], x[0], t[0]) + C, t[1] = madd2(y[8], x[1], t[1], C) + C, t[2] = madd2(y[8], x[2], t[2], C) + C, t[3] = madd2(y[8], x[3], t[3], C) + C, t[4] = madd2(y[8], x[4], t[4], C) + C, t[5] = madd2(y[8], x[5], t[5], C) + C, t[6] = madd2(y[8], x[6], t[6], C) + C, t[7] = madd2(y[8], x[7], t[7], C) + C, t[8] = madd2(y[8], x[8], t[8], C) + C, t[9] = madd2(y[8], x[9], t[9], C) + + t[10], D = bits.Add64(t[10], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + + t[9], C = bits.Add64(t[10], C, 0) + t[10], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[9], x[0], t[0]) + C, t[1] = madd2(y[9], x[1], t[1], C) + C, t[2] = madd2(y[9], x[2], t[2], C) + C, t[3] = madd2(y[9], x[3], t[3], C) + C, t[4] = madd2(y[9], x[4], t[4], C) + C, t[5] = madd2(y[9], x[5], t[5], C) + C, t[6] = madd2(y[9], x[6], t[6], C) + C, t[7] = madd2(y[9], x[7], t[7], C) + C, t[8] = madd2(y[9], x[8], t[8], C) + C, t[9] = madd2(y[9], x[9], t[9], C) + + t[10], D = bits.Add64(t[10], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + + t[9], C = bits.Add64(t[10], C, 0) + t[10], _ = bits.Add64(0, D, C) + + if t[10] != 0 { + // we need to reduce, we have a result on 11 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], b = bits.Sub64(t[3], q3, b) + z[4], b = bits.Sub64(t[4], q4, b) + z[5], b = bits.Sub64(t[5], q5, b) + z[6], b = bits.Sub64(t[6], q6, b) + z[7], b = bits.Sub64(t[7], q7, b) + z[8], b = bits.Sub64(t[8], q8, b) + z[9], _ = bits.Sub64(t[9], q9, b) + return + } + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + z[4] = t[4] + z[5] = t[5] + z[6] = t[6] + z[7] = t[7] + z[8] = t[8] + z[9] = t[9] + + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -950,7 +1030,6 @@ func _mulGeneric(z, x, y *Element) { z[8], b = bits.Sub64(z[8], q8, b) z[9], _ = bits.Sub64(z[9], q9, b) } - } func _fromMontGeneric(z *Element) { @@ -1108,7 +1187,7 @@ func _fromMontGeneric(z *Element) { z[9] = C } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -1126,7 +1205,7 @@ func _fromMontGeneric(z *Element) { func _reduceGeneric(z *Element) { - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -1214,6 +1293,35 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + // Exp z = xᵏ (mod q) func (z *Element) Exp(x Element, k *big.Int) *Element { if k.IsUint64() && k.Uint64() == 0 { @@ -1228,8 +1336,8 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } @@ -1261,23 +1369,35 @@ var rSquare = Element{ 35368377961363834, } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *Element) ToMont() *Element { +func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *Element) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[72:80], z[0]) + binary.BigEndian.PutUint64(b[64:72], z[1]) + binary.BigEndian.PutUint64(b[56:64], z[2]) + binary.BigEndian.PutUint64(b[48:56], z[3]) + binary.BigEndian.PutUint64(b[40:48], z[4]) + binary.BigEndian.PutUint64(b[32:40], z[5]) + binary.BigEndian.PutUint64(b[24:32], z[6]) + binary.BigEndian.PutUint64(b[16:24], z[7]) + binary.BigEndian.PutUint64(b[8:16], z[8]) + binary.BigEndian.PutUint64(b[0:8], z[9]) + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -1296,59 +1416,49 @@ func (z *Element) Text(base int) string { if base == 10 { var zzNeg Element zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } zz := *z - zz.FromMont() + zz.fromMont() if zz.FitsOnOneWord() { return strconv.FormatUint(zz[0], base) } - vv := bigIntPool.Get().(*big.Int) - r := zz.ToBigInt(vv).Text(base) - bigIntPool.Put(vv) + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) return r } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[72:80], z[0]) - binary.BigEndian.PutUint64(b[64:72], z[1]) - binary.BigEndian.PutUint64(b[56:64], z[2]) - binary.BigEndian.PutUint64(b[48:56], z[3]) - binary.BigEndian.PutUint64(b[40:48], z[4]) - binary.BigEndian.PutUint64(b[32:40], z[5]) - binary.BigEndian.PutUint64(b[24:32], z[6]) - binary.BigEndian.PutUint64(b[16:24], z[7]) - binary.BigEndian.PutUint64(b[8:16], z[8]) - binary.BigEndian.PutUint64(b[0:8], z[9]) - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[72:80], _z[0]) - binary.BigEndian.PutUint64(res[64:72], _z[1]) - binary.BigEndian.PutUint64(res[56:64], _z[2]) - binary.BigEndian.PutUint64(res[48:56], _z[3]) - binary.BigEndian.PutUint64(res[40:48], _z[4]) - binary.BigEndian.PutUint64(res[32:40], _z[5]) - binary.BigEndian.PutUint64(res[24:32], _z[6]) - binary.BigEndian.PutUint64(res[16:24], _z[7]) - binary.BigEndian.PutUint64(res[8:16], _z[8]) - binary.BigEndian.PutUint64(res[0:8], _z[9]) +// Bits provides access to z by returning its value as a little-endian [10]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [10]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -1361,19 +1471,44 @@ func (z *Element) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } +// SetBytesCanonical interprets e as the bytes of a big-endian 80-byte integer. +// If e is not a 80-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fp.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -1391,17 +1526,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -1423,20 +1557,20 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z // // The number prefix determines the actual base: A prefix of -// ''0b'' or ''0B'' selects base 2, ''0'', ''0o'' or ''0O'' selects base 8, -// and ''0x'' or ''0X'' selects base 16. Otherwise, the selected base is 10 +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For base 16, lower and upper case letters are considered the same: // The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. // -// An underscore character ''_'' may appear between a base +// An underscore character ”_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as a panic if there @@ -1445,7 +1579,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -1454,7 +1588,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -1494,7 +1628,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -1503,10 +1637,103 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 80-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[72:80]) + z[1] = binary.BigEndian.Uint64((*b)[64:72]) + z[2] = binary.BigEndian.Uint64((*b)[56:64]) + z[3] = binary.BigEndian.Uint64((*b)[48:56]) + z[4] = binary.BigEndian.Uint64((*b)[40:48]) + z[5] = binary.BigEndian.Uint64((*b)[32:40]) + z[6] = binary.BigEndian.Uint64((*b)[24:32]) + z[7] = binary.BigEndian.Uint64((*b)[16:24]) + z[8] = binary.BigEndian.Uint64((*b)[8:16]) + z[9] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[72:80], e[0]) + binary.BigEndian.PutUint64((*b)[64:72], e[1]) + binary.BigEndian.PutUint64((*b)[56:64], e[2]) + binary.BigEndian.PutUint64((*b)[48:56], e[3]) + binary.BigEndian.PutUint64((*b)[40:48], e[4]) + binary.BigEndian.PutUint64((*b)[32:40], e[5]) + binary.BigEndian.PutUint64((*b)[24:32], e[6]) + binary.BigEndian.PutUint64((*b)[16:24], e[7]) + binary.BigEndian.PutUint64((*b)[8:16], e[8]) + binary.BigEndian.PutUint64((*b)[0:8], e[9]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + z[4] = binary.LittleEndian.Uint64((*b)[32:40]) + z[5] = binary.LittleEndian.Uint64((*b)[40:48]) + z[6] = binary.LittleEndian.Uint64((*b)[48:56]) + z[7] = binary.LittleEndian.Uint64((*b)[56:64]) + z[8] = binary.LittleEndian.Uint64((*b)[64:72]) + z[9] = binary.LittleEndian.Uint64((*b)[72:80]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) + binary.LittleEndian.PutUint64((*b)[32:40], e[4]) + binary.LittleEndian.PutUint64((*b)[40:48], e[5]) + binary.LittleEndian.PutUint64((*b)[48:56], e[6]) + binary.LittleEndian.PutUint64((*b)[56:64], e[7]) + binary.LittleEndian.PutUint64((*b)[64:72], e[8]) + binary.LittleEndian.PutUint64((*b)[72:80], e[9]) +} + +func (littleEndian) String() string { return "LittleEndian" } + // Legendre returns the Legendre symbol of z (either +1, -1, or 0.) func (z *Element) Legendre() int { var l Element @@ -1723,7 +1950,7 @@ func (z *Element) Inverse(x *Element) *Element { // we would multiply by pSq up to 13times; // on x86, the assembly routine outperforms generic code for mul by word // on arm64, we may loose up to ~5% for 6 limbs - mul(&v, &v, &a) + v.Mul(&v, &a) } u.Set(x) // for correctness check @@ -1744,17 +1971,28 @@ func (z *Element) Inverse(x *Element) *Element { // correctness check v.Mul(&u, z) if !v.IsOne() && !u.IsZero() { - return z.inverseExp(&u) + return z.inverseExp(u) } return z } // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) -func (z *Element) inverseExp(x *Element) *Element { - qMinusTwo := Modulus() - qMinusTwo.Sub(qMinusTwo, big.NewInt(2)) - return z.Exp(*x, qMinusTwo) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z } // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits @@ -1979,7 +2217,7 @@ func (z *Element) montReduceSigned(x *Element, xHi uint64) { z[9], z[8] = madd2(m, q9, t[i+9], C) } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) diff --git a/ecc/bw6-633/fp/element_mul_adx_amd64.s b/ecc/bw6-633/fp/element_mul_adx_amd64.s deleted file mode 100644 index f5926e5a0..000000000 --- a/ecc/bw6-633/fp/element_mul_adx_amd64.s +++ /dev/null @@ -1,1954 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xd74916ea4570000d -DATA q<>+8(SB)/8, $0x3d369bd31147f73c -DATA q<>+16(SB)/8, $0xd7b5ce7ab839c225 -DATA q<>+24(SB)/8, $0x7e0e8850edbda407 -DATA q<>+32(SB)/8, $0xb8da9f5e83f57c49 -DATA q<>+40(SB)/8, $0x8152a6c0fadea490 -DATA q<>+48(SB)/8, $0x4e59769ad9bbda2f -DATA q<>+56(SB)/8, $0xa8fcd8c75d79d2c7 -DATA q<>+64(SB)/8, $0xfc1a174f01d72ab5 -DATA q<>+72(SB)/8, $0x0126633cc0f35f63 -GLOBL q<>(SB), (RODATA+NOPTR), $80 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xb50f29ab0b03b13b -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, ra6, ra7, ra8, ra9, rb0, rb1, rb2, rb3, rb4, rb5, rb6, rb7, rb8, rb9) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - MOVQ ra6, rb6; \ - SBBQ q<>+48(SB), ra6; \ - MOVQ ra7, rb7; \ - SBBQ q<>+56(SB), ra7; \ - MOVQ ra8, rb8; \ - SBBQ q<>+64(SB), ra8; \ - MOVQ ra9, rb9; \ - SBBQ q<>+72(SB), ra9; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - CMOVQCS rb6, ra6; \ - CMOVQCS rb7, ra7; \ - CMOVQCS rb8, ra8; \ - CMOVQCS rb9, ra9; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $64-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - MOVQ x+8(FP), R12 - MOVQ y+16(FP), R13 - - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // t[5] -> DI - // t[6] -> R8 - // t[7] -> R9 - // t[8] -> R10 - // t[9] -> R11 - // clear the flags - XORQ AX, AX - MOVQ 0(R13), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ 0(R12), R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ 8(R12), AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ 16(R12), AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(R12), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(R12), AX, DI - ADOXQ AX, SI - - // (A,t[5]) := x[5]*y[0] + A - MULXQ 40(R12), AX, R8 - ADOXQ AX, DI - - // (A,t[6]) := x[6]*y[0] + A - MULXQ 48(R12), AX, R9 - ADOXQ AX, R8 - - // (A,t[7]) := x[7]*y[0] + A - MULXQ 56(R12), AX, R10 - ADOXQ AX, R9 - - // (A,t[8]) := x[8]*y[0] + A - MULXQ 64(R12), AX, R11 - ADOXQ AX, R10 - - // (A,t[9]) := x[9]*y[0] + A - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ 8(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ 0(R12), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ 8(R12), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ 16(R12), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(R12), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(R12), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, DI - MULXQ 40(R12), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[1] + A - ADCXQ BP, R8 - MULXQ 48(R12), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[1] + A - ADCXQ BP, R9 - MULXQ 56(R12), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[1] + A - ADCXQ BP, R10 - MULXQ 64(R12), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[1] + A - ADCXQ BP, R11 - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ 16(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ 0(R12), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ 8(R12), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ 16(R12), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(R12), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(R12), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, DI - MULXQ 40(R12), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[2] + A - ADCXQ BP, R8 - MULXQ 48(R12), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[2] + A - ADCXQ BP, R9 - MULXQ 56(R12), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[2] + A - ADCXQ BP, R10 - MULXQ 64(R12), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[2] + A - ADCXQ BP, R11 - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ 24(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ 0(R12), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ 8(R12), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ 16(R12), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(R12), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(R12), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, DI - MULXQ 40(R12), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[3] + A - ADCXQ BP, R8 - MULXQ 48(R12), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[3] + A - ADCXQ BP, R9 - MULXQ 56(R12), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[3] + A - ADCXQ BP, R10 - MULXQ 64(R12), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[3] + A - ADCXQ BP, R11 - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ 32(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ 0(R12), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R15 - MULXQ 8(R12), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ 16(R12), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(R12), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(R12), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, DI - MULXQ 40(R12), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[4] + A - ADCXQ BP, R8 - MULXQ 48(R12), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[4] + A - ADCXQ BP, R9 - MULXQ 56(R12), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[4] + A - ADCXQ BP, R10 - MULXQ 64(R12), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[4] + A - ADCXQ BP, R11 - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ 40(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ 0(R12), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R15 - MULXQ 8(R12), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, CX - MULXQ 16(R12), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, BX - MULXQ 24(R12), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, SI - MULXQ 32(R12), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, DI - MULXQ 40(R12), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[5] + A - ADCXQ BP, R8 - MULXQ 48(R12), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[5] + A - ADCXQ BP, R9 - MULXQ 56(R12), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[5] + A - ADCXQ BP, R10 - MULXQ 64(R12), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[5] + A - ADCXQ BP, R11 - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ 48(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[6] + A - MULXQ 0(R12), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[6] + A - ADCXQ BP, R15 - MULXQ 8(R12), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[6] + A - ADCXQ BP, CX - MULXQ 16(R12), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[6] + A - ADCXQ BP, BX - MULXQ 24(R12), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[6] + A - ADCXQ BP, SI - MULXQ 32(R12), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[6] + A - ADCXQ BP, DI - MULXQ 40(R12), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[6] + A - ADCXQ BP, R8 - MULXQ 48(R12), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[6] + A - ADCXQ BP, R9 - MULXQ 56(R12), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[6] + A - ADCXQ BP, R10 - MULXQ 64(R12), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[6] + A - ADCXQ BP, R11 - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ 56(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[7] + A - MULXQ 0(R12), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[7] + A - ADCXQ BP, R15 - MULXQ 8(R12), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[7] + A - ADCXQ BP, CX - MULXQ 16(R12), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[7] + A - ADCXQ BP, BX - MULXQ 24(R12), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[7] + A - ADCXQ BP, SI - MULXQ 32(R12), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[7] + A - ADCXQ BP, DI - MULXQ 40(R12), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[7] + A - ADCXQ BP, R8 - MULXQ 48(R12), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[7] + A - ADCXQ BP, R9 - MULXQ 56(R12), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[7] + A - ADCXQ BP, R10 - MULXQ 64(R12), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[7] + A - ADCXQ BP, R11 - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ 64(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[8] + A - MULXQ 0(R12), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[8] + A - ADCXQ BP, R15 - MULXQ 8(R12), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[8] + A - ADCXQ BP, CX - MULXQ 16(R12), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[8] + A - ADCXQ BP, BX - MULXQ 24(R12), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[8] + A - ADCXQ BP, SI - MULXQ 32(R12), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[8] + A - ADCXQ BP, DI - MULXQ 40(R12), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[8] + A - ADCXQ BP, R8 - MULXQ 48(R12), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[8] + A - ADCXQ BP, R9 - MULXQ 56(R12), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[8] + A - ADCXQ BP, R10 - MULXQ 64(R12), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[8] + A - ADCXQ BP, R11 - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // clear the flags - XORQ AX, AX - MOVQ 72(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[9] + A - MULXQ 0(R12), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[9] + A - ADCXQ BP, R15 - MULXQ 8(R12), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[9] + A - ADCXQ BP, CX - MULXQ 16(R12), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[9] + A - ADCXQ BP, BX - MULXQ 24(R12), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[9] + A - ADCXQ BP, SI - MULXQ 32(R12), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[9] + A - ADCXQ BP, DI - MULXQ 40(R12), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[9] + A - ADCXQ BP, R8 - MULXQ 48(R12), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[9] + A - ADCXQ BP, R9 - MULXQ 56(R12), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[9] + A - ADCXQ BP, R10 - MULXQ 64(R12), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[9] + A - ADCXQ BP, R11 - MULXQ 72(R12), AX, BP - ADOXQ AX, R11 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // t[9] = C + A - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ BP, R11 - - // reduce element(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11) using temp registers (R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP)) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP)) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - MOVQ R8, 48(AX) - MOVQ R9, 56(AX) - MOVQ R10, 64(AX) - MOVQ R11, 72(AX) - RET - -TEXT ·fromMont(SB), $64-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - MOVQ 40(DX), DI - MOVQ 48(DX), R8 - MOVQ 56(DX), R9 - MOVQ 64(DX), R10 - MOVQ 72(DX), R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - MOVQ $0, AX - ADCXQ AX, R11 - ADOXQ AX, R11 - - // reduce element(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11) using temp registers (R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP)) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP)) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - MOVQ R8, 48(AX) - MOVQ R9, 56(AX) - MOVQ R10, 64(AX) - MOVQ R11, 72(AX) - RET diff --git a/ecc/bw6-633/fp/element_mul_amd64.s b/ecc/bw6-633/fp/element_mul_amd64.s index e5d5e2469..f0b617228 100644 --- a/ecc/bw6-633/fp/element_mul_amd64.s +++ b/ecc/bw6-633/fp/element_mul_amd64.s @@ -1,4 +1,4 @@ -// +build !amd64_adx +// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-633/fp/element_ops_amd64.go b/ecc/bw6-633/fp/element_ops_amd64.go index a3c830471..83bba45ae 100644 --- a/ecc/bw6-633/fp/element_ops_amd64.go +++ b/ecc/bw6-633/fp/element_ops_amd64.go @@ -1,3 +1,6 @@ +//go:build !purego +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,7 +38,70 @@ func fromMont(res *Element) func reduce(res *Element) // Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) +// +// a = a + b (mod q) +// b = a - b (mod q) +// //go:noescape func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bw6-633/fp/element_ops_amd64.s b/ecc/bw6-633/fp/element_ops_amd64.s index 98d7aaa4e..119efe44f 100644 --- a/ecc/bw6-633/fp/element_ops_amd64.s +++ b/ecc/bw6-633/fp/element_ops_amd64.s @@ -1,3 +1,5 @@ +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bw6-633/fp/element_ops_noasm.go b/ecc/bw6-633/fp/element_ops_noasm.go deleted file mode 100644 index 196fe5823..000000000 --- a/ecc/bw6-633/fp/element_ops_noasm.go +++ /dev/null @@ -1,67 +0,0 @@ -//go:build !amd64 -// +build !amd64 - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fp - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 4881606927653498122, - 47978232019095094, - 8555661377410121478, - 17849732488791568215, - 5227097555314997552, - 839611732066804726, - 5234648925333584632, - 11936054402769696488, - 1228498468693814883, - 2857848702739380, - } - x.Mul(x, &y) -} - -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} diff --git a/ecc/bw6-633/fp/element_ops_purego.go b/ecc/bw6-633/fp/element_ops_purego.go new file mode 100644 index 000000000..69c68919e --- /dev/null +++ b/ecc/bw6-633/fp/element_ops_purego.go @@ -0,0 +1,1637 @@ +//go:build !amd64 || purego +// +build !amd64 purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 4881606927653498122, + 47978232019095094, + 8555661377410121478, + 17849732488791568215, + 5227097555314997552, + 839611732066804726, + 5234648925333584632, + 11936054402769696488, + 1228498468693814883, + 2857848702739380, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + var t0, t1, t2, t3, t4, t5, t6, t7, t8, t9 uint64 + var u0, u1, u2, u3, u4, u5, u6, u7, u8, u9 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + u4, t4 = bits.Mul64(v, y[4]) + u5, t5 = bits.Mul64(v, y[5]) + u6, t6 = bits.Mul64(v, y[6]) + u7, t7 = bits.Mul64(v, y[7]) + u8, t8 = bits.Mul64(v, y[8]) + u9, t9 = bits.Mul64(v, y[9]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[6] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[7] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[8] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[9] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + z[6] = t6 + z[7] = t7 + z[8] = t8 + z[9] = t9 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], b = bits.Sub64(z[5], q5, b) + z[6], b = bits.Sub64(z[6], q6, b) + z[7], b = bits.Sub64(z[7], q7, b) + z[8], b = bits.Sub64(z[8], q8, b) + z[9], _ = bits.Sub64(z[9], q9, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3, t4, t5, t6, t7, t8, t9 uint64 + var u0, u1, u2, u3, u4, u5, u6, u7, u8, u9 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + u4, t4 = bits.Mul64(v, x[4]) + u5, t5 = bits.Mul64(v, x[5]) + u6, t6 = bits.Mul64(v, x[6]) + u7, t7 = bits.Mul64(v, x[7]) + u8, t8 = bits.Mul64(v, x[8]) + u9, t9 = bits.Mul64(v, x[9]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[6] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[7] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[8] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[9] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + c2, _ = bits.Add64(u9, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + + t8, c0 = bits.Add64(0, c1, c0) + u9, _ = bits.Add64(u9, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + c2, _ = bits.Add64(c2, 0, c0) + t8, c0 = bits.Add64(t9, t8, 0) + t9, _ = bits.Add64(u9, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + z[6] = t6 + z[7] = t7 + z[8] = t8 + z[9] = t9 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], b = bits.Sub64(z[5], q5, b) + z[6], b = bits.Sub64(z[6], q6, b) + z[7], b = bits.Sub64(z[7], q7, b) + z[8], b = bits.Sub64(z[8], q8, b) + z[9], _ = bits.Sub64(z[9], q9, b) + } + return z +} diff --git a/ecc/bw6-633/fp/element_test.go b/ecc/bw6-633/fp/element_test.go index 57595382f..835b3d04d 100644 --- a/ecc/bw6-633/fp/element_test.go +++ b/ecc/bw6-633/fp/element_test.go @@ -23,7 +23,7 @@ import ( "math/big" "math/bits" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field" mrand "math/rand" "testing" @@ -182,17 +182,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -650,7 +643,7 @@ func TestElementBitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPairElement) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -765,7 +758,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -778,13 +771,13 @@ func TestElementAdd(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -813,17 +806,17 @@ func TestElementAdd(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -874,7 +867,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -887,13 +880,13 @@ func TestElementSub(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -922,17 +915,17 @@ func TestElementSub(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -983,7 +976,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -996,7 +989,7 @@ func TestElementMul(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -1010,7 +1003,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1050,11 +1043,11 @@ func TestElementMul(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -1067,7 +1060,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1119,7 +1112,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1132,14 +1125,14 @@ func TestElementDiv(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1168,18 +1161,18 @@ func TestElementDiv(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1230,7 +1223,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1243,13 +1236,13 @@ func TestElementExp(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1278,17 +1271,17 @@ func TestElementExp(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1333,7 +1326,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1354,14 +1347,14 @@ func TestElementSquare(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1405,7 +1398,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1426,14 +1419,14 @@ func TestElementInverse(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1477,7 +1470,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1498,14 +1491,14 @@ func TestElementSqrt(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1549,7 +1542,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1570,14 +1563,14 @@ func TestElementDouble(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1621,7 +1614,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1642,14 +1635,14 @@ func TestElementNeg(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -2037,7 +2030,7 @@ func TestElementNegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -2170,17 +2163,17 @@ func TestElementFromMont(t *testing.T) { func(a testPairElement) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPairElement) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -2271,7 +2264,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } @@ -2351,7 +2344,7 @@ func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { return field.BigIntMatchUint64Slice(&aIntMod, slice) } -//TODO: Phase out in favor of property based testing +// TODO: Phase out in favor of property based testing func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { if err := z.matchVeryBigInt(aHi, aInt); err != nil { @@ -2399,7 +2392,7 @@ func TestElementInversionCorrectionFactorFormula(t *testing.T) { inversionCorrectionFactorWord8, inversionCorrectionFactorWord9, } - inversionCorrectionFactor.ToBigInt(&refFactorInt) + inversionCorrectionFactor.toBigInt(&refFactorInt) if refFactorInt.Cmp(factorInt) != 0 { t.Error("mismatch") @@ -2441,7 +2434,7 @@ func TestElementInversionCorrectionFactor(t *testing.T) { if !oneInv.Equal(&one) { var i big.Int - oneInv.ToBigIntRegular(&i) // no montgomery + oneInv.BigInt(&i) // no montgomery i.ModInverse(&i, Modulus()) var fac Element fac.setBigInt(&i) // back to montgomery @@ -2598,7 +2591,7 @@ func TestElement0Inverse(t *testing.T) { } } -//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen func TestUpdateFactorSubtraction(t *testing.T) { for i := 0; i < 1000; i++ { @@ -2811,11 +2804,11 @@ func randomizeUpdateFactors() (int64, int64) { func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { var p1 big.Int - x.ToBigInt(&p1) + x.toBigInt(&p1) p1.Mul(&p1, big.NewInt(xC)) var p2 big.Int - y.ToBigInt(&p2) + y.toBigInt(&p2) p2.Mul(&p2, big.NewInt(yC)) p1.Add(&p1, &p2) @@ -2852,7 +2845,7 @@ func montReduce(res *big.Int, x *big.Int) { } func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { - z.ToBigInt(i) + z.toBigInt(i) var upperWord big.Int upperWord.SetUint64(xHi) upperWord.Lsh(&upperWord, Limbs*64) @@ -2870,7 +2863,7 @@ func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { var xInt big.Int - x.ToBigInt(&xInt) + x.toBigInt(&xInt) xInt.Mul(&xInt, big.NewInt(c)) @@ -2881,7 +2874,7 @@ func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, result func approximateRef(x *Element) uint64 { var asInt big.Int - x.ToBigInt(&asInt) + x.toBigInt(&asInt) n := x.BitLen() if n <= 64 { diff --git a/ecc/bw6-633/fr/doc.go b/ecc/bw6-633/fr/doc.go index 39b239026..cbb389404 100644 --- a/ecc/bw6-633/fr/doc.go +++ b/ecc/bw6-633/fr/doc.go @@ -21,30 +21,33 @@ // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [5]uint64 // -// Usage +// type Element [5]uint64 +// +// # Usage // // Example API signature: -// // Mul z = x * y (mod q) -// func (z *Element) Mul(x, y *Element) *Element +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) // // Modulus q = // -// q[base10] = 39705142709513438335025689890408969744933502416914749335064285505637884093126342347073617133569 -// q[base16] = 0x4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001 +// q[base10] = 39705142709513438335025689890408969744933502416914749335064285505637884093126342347073617133569 +// q[base16] = 0x4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package fr diff --git a/ecc/bw6-633/fr/element.go b/ecc/bw6-633/fr/element.go index d8661fd43..ba049e38f 100644 --- a/ecc/bw6-633/fr/element.go +++ b/ecc/bw6-633/fr/element.go @@ -20,13 +20,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "github.com/consensys/gnark-crypto/field" "io" "math/big" "math/bits" "reflect" "strconv" "strings" - "sync" ) // Element represents a field element stored on 5 words (uint64) @@ -35,18 +35,18 @@ import ( // // Modulus q = // -// q[base10] = 39705142709513438335025689890408969744933502416914749335064285505637884093126342347073617133569 -// q[base16] = 0x4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001 +// q[base10] = 39705142709513438335025689890408969744933502416914749335064285505637884093126342347073617133569 +// q[base16] = 0x4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [5]uint64 const ( - Limbs = 5 // number of 64 bits words needed to represent a Element - Bits = 315 // number of bits needed to represent a Element - Bytes = Limbs * 8 // number of bytes needed to represent a Element + Limbs = 5 // number of 64 bits words needed to represent a Element + Bits = 315 // number of bits needed to represent a Element + Bytes = 40 // number of bytes needed to represent a Element ) // Field modulus q @@ -70,8 +70,8 @@ var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int // -// q[base10] = 39705142709513438335025689890408969744933502416914749335064285505637884093126342347073617133569 -// q[base16] = 0x4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001 +// q[base10] = 39705142709513438335025689890408969744933502416914749335064285505637884093126342347073617133569 +// q[base16] = 0x4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001 func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } @@ -80,12 +80,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 8083954730842193919 -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001", 16) } @@ -93,8 +87,9 @@ func init() { // NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v Element -// v.SetUint64(...) +// +// var v Element +// v.SetUint64(...) func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) @@ -105,7 +100,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -136,14 +131,15 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported // supported types: -// Element -// *Element -// uint64 -// int -// string (see SetString for valid formats) -// *big.Int -// big.Int -// []byte +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { if i1 == nil { return nil, errors.New("can't set fr.Element with ") @@ -255,15 +251,13 @@ func (z *Element) IsOne() bool { // IsUint64 reports whether z can be represented as an uint64. func (z *Element) IsUint64() bool { zz := *z - zz.FromMont() + zz.fromMont() return zz.FitsOnOneWord() } // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -275,15 +269,12 @@ func (z *Element) FitsOnOneWord() bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[4] > _x[4] { return 1 } else if _z[4] < _x[4] { @@ -319,8 +310,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 4031849214061838337, 0) @@ -415,67 +405,9 @@ func (z *Element) Halve() { } -// Mul z = x * y (mod q) -// -// x and y must be strictly inferior to q -func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be strictly inferior to q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } @@ -490,7 +422,7 @@ func (z *Element) Add(x, y *Element) *Element { z[3], carry = bits.Add64(x[3], y[3], carry) z[4], _ = bits.Add64(x[4], y[4], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -512,7 +444,7 @@ func (z *Element) Double(x *Element) *Element { z[3], carry = bits.Add64(x[3], x[3], carry) z[4], _ = bits.Add64(x[4], x[4], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -570,88 +502,181 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // see Mul for algorithm documentation - var t [5]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd1(v, y[4], c[1]) - t[4], t[3] = madd3(m, q4, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - t[4], t[3] = madd3(m, q4, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - t[4], t[3] = madd3(m, q4, c[0], c[2], c[1]) - } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - t[4], t[3] = madd3(m, q4, c[0], c[2], c[1]) - } - { - // round 4 - v := x[4] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], z[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - z[4], z[3] = madd3(m, q4, c[0], c[2], c[1]) - } - - // if z >= q → z -= q + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [6]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + C, t[4] = madd1(y[0], x[4], C) + + t[5], D = bits.Add64(t[5], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + + t[4], C = bits.Add64(t[5], C, 0) + t[5], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + C, t[4] = madd2(y[1], x[4], t[4], C) + + t[5], D = bits.Add64(t[5], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + + t[4], C = bits.Add64(t[5], C, 0) + t[5], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + C, t[4] = madd2(y[2], x[4], t[4], C) + + t[5], D = bits.Add64(t[5], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + + t[4], C = bits.Add64(t[5], C, 0) + t[5], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + C, t[4] = madd2(y[3], x[4], t[4], C) + + t[5], D = bits.Add64(t[5], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + + t[4], C = bits.Add64(t[5], C, 0) + t[5], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[4], x[0], t[0]) + C, t[1] = madd2(y[4], x[1], t[1], C) + C, t[2] = madd2(y[4], x[2], t[2], C) + C, t[3] = madd2(y[4], x[3], t[3], C) + C, t[4] = madd2(y[4], x[4], t[4], C) + + t[5], D = bits.Add64(t[5], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + + t[4], C = bits.Add64(t[5], C, 0) + t[5], _ = bits.Add64(0, D, C) + + if t[5] != 0 { + // we need to reduce, we have a result on 6 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], b = bits.Sub64(t[3], q3, b) + z[4], _ = bits.Sub64(t[4], q4, b) + return + } + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + z[4] = t[4] + + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -660,7 +685,6 @@ func _mulGeneric(z, x, y *Element) { z[3], b = bits.Sub64(z[3], q3, b) z[4], _ = bits.Sub64(z[4], q4, b) } - } func _fromMontGeneric(z *Element) { @@ -718,7 +742,7 @@ func _fromMontGeneric(z *Element) { z[4] = C } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -731,7 +755,7 @@ func _fromMontGeneric(z *Element) { func _reduceGeneric(z *Element) { - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -799,6 +823,35 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + // Exp z = xᵏ (mod q) func (z *Element) Exp(x Element, k *big.Int) *Element { if k.IsUint64() && k.Uint64() == 0 { @@ -813,8 +866,8 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } @@ -841,23 +894,30 @@ var rSquare = Element{ 150264569250089173, } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *Element) ToMont() *Element { +func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *Element) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[32:40], z[0]) + binary.BigEndian.PutUint64(b[24:32], z[1]) + binary.BigEndian.PutUint64(b[16:24], z[2]) + binary.BigEndian.PutUint64(b[8:16], z[3]) + binary.BigEndian.PutUint64(b[0:8], z[4]) + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -876,49 +936,49 @@ func (z *Element) Text(base int) string { if base == 10 { var zzNeg Element zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } zz := *z - zz.FromMont() + zz.fromMont() if zz.FitsOnOneWord() { return strconv.FormatUint(zz[0], base) } - vv := bigIntPool.Get().(*big.Int) - r := zz.ToBigInt(vv).Text(base) - bigIntPool.Put(vv) + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) return r } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[32:40], z[0]) - binary.BigEndian.PutUint64(b[24:32], z[1]) - binary.BigEndian.PutUint64(b[16:24], z[2]) - binary.BigEndian.PutUint64(b[8:16], z[3]) - binary.BigEndian.PutUint64(b[0:8], z[4]) - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[32:40], _z[0]) - binary.BigEndian.PutUint64(res[24:32], _z[1]) - binary.BigEndian.PutUint64(res[16:24], _z[2]) - binary.BigEndian.PutUint64(res[8:16], _z[3]) - binary.BigEndian.PutUint64(res[0:8], _z[4]) +// Bits provides access to z by returning its value as a little-endian [5]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [5]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -931,19 +991,44 @@ func (z *Element) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } +// SetBytesCanonical interprets e as the bytes of a big-endian 40-byte integer. +// If e is not a 40-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fr.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -961,17 +1046,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -993,20 +1077,20 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z // // The number prefix determines the actual base: A prefix of -// ''0b'' or ''0B'' selects base 2, ''0'', ''0o'' or ''0O'' selects base 8, -// and ''0x'' or ''0X'' selects base 16. Otherwise, the selected base is 10 +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For base 16, lower and upper case letters are considered the same: // The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. // -// An underscore character ''_'' may appear between a base +// An underscore character ”_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as a panic if there @@ -1015,7 +1099,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -1024,7 +1108,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -1064,7 +1148,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -1073,10 +1157,83 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 40-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[32:40]) + z[1] = binary.BigEndian.Uint64((*b)[24:32]) + z[2] = binary.BigEndian.Uint64((*b)[16:24]) + z[3] = binary.BigEndian.Uint64((*b)[8:16]) + z[4] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[32:40], e[0]) + binary.BigEndian.PutUint64((*b)[24:32], e[1]) + binary.BigEndian.PutUint64((*b)[16:24], e[2]) + binary.BigEndian.PutUint64((*b)[8:16], e[3]) + binary.BigEndian.PutUint64((*b)[0:8], e[4]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + z[4] = binary.LittleEndian.Uint64((*b)[32:40]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) + binary.LittleEndian.PutUint64((*b)[32:40], e[4]) +} + +func (littleEndian) String() string { return "LittleEndian" } + // Legendre returns the Legendre symbol of z (either +1, -1, or 0.) func (z *Element) Legendre() int { var l Element @@ -1109,7 +1266,7 @@ func (z *Element) Sqrt(x *Element) *Element { // y = x^((s+1)/2)) = w * x y.Mul(x, &w) - // b = x^s = w * w * x = y * x + // b = xˢ = w * w * x = y * x b.Mul(&w, &y) // g = nonResidue ^ s @@ -1123,7 +1280,7 @@ func (z *Element) Sqrt(x *Element) *Element { r := uint64(20) // compute legendre symbol - // t = x^((q-1)/2) = r-1 squaring of x^s + // t = x^((q-1)/2) = r-1 squaring of xˢ t = b for i := uint64(0); i < r-1; i++ { t.Square(&t) @@ -1317,7 +1474,7 @@ func (z *Element) Inverse(x *Element) *Element { // we would multiply by pSq up to 13times; // on x86, the assembly routine outperforms generic code for mul by word // on arm64, we may loose up to ~5% for 6 limbs - mul(&v, &v, &a) + v.Mul(&v, &a) } u.Set(x) // for correctness check @@ -1333,17 +1490,28 @@ func (z *Element) Inverse(x *Element) *Element { // correctness check v.Mul(&u, z) if !v.IsOne() && !u.IsZero() { - return z.inverseExp(&u) + return z.inverseExp(u) } return z } // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) -func (z *Element) inverseExp(x *Element) *Element { - qMinusTwo := Modulus() - qMinusTwo.Sub(qMinusTwo, big.NewInt(2)) - return z.Exp(*x, qMinusTwo) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z } // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits @@ -1458,7 +1626,7 @@ func (z *Element) montReduceSigned(x *Element, xHi uint64) { z[4], z[3] = madd2(m, q4, t[i+4], C) } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) diff --git a/ecc/bw6-633/fr/element_mul_adx_amd64.s b/ecc/bw6-633/fr/element_mul_adx_amd64.s deleted file mode 100644 index c02648d3a..000000000 --- a/ecc/bw6-633/fr/element_mul_adx_amd64.s +++ /dev/null @@ -1,634 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x6fe802ff40300001 -DATA q<>+8(SB)/8, $0x421ee5da52bde502 -DATA q<>+16(SB)/8, $0xdec1d01aa27a1ae0 -DATA q<>+24(SB)/8, $0xd3f7498be97c5eaf -DATA q<>+32(SB)/8, $0x04c23a02b586d650 -GLOBL q<>(SB), (RODATA+NOPTR), $40 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x702ff9ff402fffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, rb0, rb1, rb2, rb3, rb4) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), NOSPLIT, $0-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - MOVQ x+8(FP), DI - - // x[0] -> R9 - // x[1] -> R10 - // x[2] -> R11 - MOVQ 0(DI), R9 - MOVQ 8(DI), R10 - MOVQ 16(DI), R11 - MOVQ y+16(FP), R12 - - // A -> BP - // t[0] -> R14 - // t[1] -> R13 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // clear the flags - XORQ AX, AX - MOVQ 0(R12), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R9, R14, R13 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R10, AX, CX - ADOXQ AX, R13 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R11, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(DI), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 8(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 16(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 24(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // clear the flags - XORQ AX, AX - MOVQ 32(R12), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R9, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R13 - MULXQ R10, AX, BP - ADOXQ AX, R13 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ R11, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(DI), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(DI), AX, BP - ADOXQ AX, SI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R8 - ADCXQ R14, AX - MOVQ R8, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // t[4] = C + A - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ BP, SI - - // reduce element(R14,R13,CX,BX,SI) using temp registers (R8,DI,R12,R9,R10) - REDUCE(R14,R13,CX,BX,SI,R8,DI,R12,R9,R10) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - RET - -TEXT ·fromMont(SB), NOSPLIT, $0-8 - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R13 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R13, R14 - MULXQ q<>+8(SB), AX, R13 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R13 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R13 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - MOVQ $0, AX - ADCXQ AX, SI - ADOXQ AX, SI - - // reduce element(R14,R13,CX,BX,SI) using temp registers (DI,R8,R9,R10,R11) - REDUCE(R14,R13,CX,BX,SI,DI,R8,R9,R10,R11) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R13, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - RET diff --git a/ecc/bw6-633/fr/element_mul_amd64.s b/ecc/bw6-633/fr/element_mul_amd64.s index 94089b607..51165684d 100644 --- a/ecc/bw6-633/fr/element_mul_amd64.s +++ b/ecc/bw6-633/fr/element_mul_amd64.s @@ -1,4 +1,4 @@ -// +build !amd64_adx +// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-633/fr/element_ops_amd64.go b/ecc/bw6-633/fr/element_ops_amd64.go index b37a251b6..e40a9caed 100644 --- a/ecc/bw6-633/fr/element_ops_amd64.go +++ b/ecc/bw6-633/fr/element_ops_amd64.go @@ -1,3 +1,6 @@ +//go:build !purego +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,7 +38,70 @@ func fromMont(res *Element) func reduce(res *Element) // Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) +// +// a = a + b (mod q) +// b = a - b (mod q) +// //go:noescape func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bw6-633/fr/element_ops_amd64.s b/ecc/bw6-633/fr/element_ops_amd64.s index c70e0a5ce..9528ab595 100644 --- a/ecc/bw6-633/fr/element_ops_amd64.s +++ b/ecc/bw6-633/fr/element_ops_amd64.s @@ -1,3 +1,5 @@ +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bw6-633/fr/element_ops_noasm.go b/ecc/bw6-633/fr/element_ops_noasm.go deleted file mode 100644 index 7c58443c8..000000000 --- a/ecc/bw6-633/fr/element_ops_noasm.go +++ /dev/null @@ -1,62 +0,0 @@ -//go:build !amd64 -// +build !amd64 - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 8178485296672800069, - 8476448362227282520, - 14180928431697993131, - 4308307642551989706, - 120359802761433421, - } - x.Mul(x, &y) -} - -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} diff --git a/ecc/bw6-633/fr/element_ops_purego.go b/ecc/bw6-633/fr/element_ops_purego.go new file mode 100644 index 000000000..34d6c54fb --- /dev/null +++ b/ecc/bw6-633/fr/element_ops_purego.go @@ -0,0 +1,582 @@ +//go:build !amd64 || purego +// +build !amd64 purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 8178485296672800069, + 8476448362227282520, + 14180928431697993131, + 4308307642551989706, + 120359802761433421, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + var t0, t1, t2, t3, t4 uint64 + var u0, u1, u2, u3, u4 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + u4, t4 = bits.Mul64(v, y[4]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], _ = bits.Sub64(z[4], q4, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3, t4 uint64 + var u0, u1, u2, u3, u4 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + u4, t4 = bits.Mul64(v, x[4]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + c2, _ = bits.Add64(u4, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + + t3, c0 = bits.Add64(0, c1, c0) + u4, _ = bits.Add64(u4, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + c2, _ = bits.Add64(c2, 0, c0) + t3, c0 = bits.Add64(t4, t3, 0) + t4, _ = bits.Add64(u4, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], _ = bits.Sub64(z[4], q4, b) + } + return z +} diff --git a/ecc/bw6-633/fr/element_test.go b/ecc/bw6-633/fr/element_test.go index b42c74a07..2aeb9fba1 100644 --- a/ecc/bw6-633/fr/element_test.go +++ b/ecc/bw6-633/fr/element_test.go @@ -23,7 +23,7 @@ import ( "math/big" "math/bits" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field" mrand "math/rand" "testing" @@ -182,17 +182,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -640,7 +633,7 @@ func TestElementBitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPairElement) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -755,7 +748,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -768,13 +761,13 @@ func TestElementAdd(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -803,17 +796,17 @@ func TestElementAdd(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -864,7 +857,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -877,13 +870,13 @@ func TestElementSub(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -912,17 +905,17 @@ func TestElementSub(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -973,7 +966,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -986,7 +979,7 @@ func TestElementMul(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -1000,7 +993,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1040,11 +1033,11 @@ func TestElementMul(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -1057,7 +1050,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1109,7 +1102,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1122,14 +1115,14 @@ func TestElementDiv(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1158,18 +1151,18 @@ func TestElementDiv(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1220,7 +1213,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1233,13 +1226,13 @@ func TestElementExp(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1268,17 +1261,17 @@ func TestElementExp(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1323,7 +1316,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1344,14 +1337,14 @@ func TestElementSquare(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1395,7 +1388,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1416,14 +1409,14 @@ func TestElementInverse(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1467,7 +1460,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1488,14 +1481,14 @@ func TestElementSqrt(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1539,7 +1532,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1560,14 +1553,14 @@ func TestElementDouble(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1611,7 +1604,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1632,14 +1625,14 @@ func TestElementNeg(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -2027,7 +2020,7 @@ func TestElementNegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -2160,17 +2153,17 @@ func TestElementFromMont(t *testing.T) { func(a testPairElement) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPairElement) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -2251,7 +2244,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } @@ -2316,7 +2309,7 @@ func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { return field.BigIntMatchUint64Slice(&aIntMod, slice) } -//TODO: Phase out in favor of property based testing +// TODO: Phase out in favor of property based testing func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { if err := z.matchVeryBigInt(aHi, aInt); err != nil { @@ -2359,7 +2352,7 @@ func TestElementInversionCorrectionFactorFormula(t *testing.T) { inversionCorrectionFactorWord3, inversionCorrectionFactorWord4, } - inversionCorrectionFactor.ToBigInt(&refFactorInt) + inversionCorrectionFactor.toBigInt(&refFactorInt) if refFactorInt.Cmp(factorInt) != 0 { t.Error("mismatch") @@ -2401,7 +2394,7 @@ func TestElementInversionCorrectionFactor(t *testing.T) { if !oneInv.Equal(&one) { var i big.Int - oneInv.ToBigIntRegular(&i) // no montgomery + oneInv.BigInt(&i) // no montgomery i.ModInverse(&i, Modulus()) var fac Element fac.setBigInt(&i) // back to montgomery @@ -2548,7 +2541,7 @@ func TestElement0Inverse(t *testing.T) { } } -//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen func TestUpdateFactorSubtraction(t *testing.T) { for i := 0; i < 1000; i++ { @@ -2761,11 +2754,11 @@ func randomizeUpdateFactors() (int64, int64) { func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { var p1 big.Int - x.ToBigInt(&p1) + x.toBigInt(&p1) p1.Mul(&p1, big.NewInt(xC)) var p2 big.Int - y.ToBigInt(&p2) + y.toBigInt(&p2) p2.Mul(&p2, big.NewInt(yC)) p1.Add(&p1, &p2) @@ -2802,7 +2795,7 @@ func montReduce(res *big.Int, x *big.Int) { } func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { - z.ToBigInt(i) + z.toBigInt(i) var upperWord big.Int upperWord.SetUint64(xHi) upperWord.Lsh(&upperWord, Limbs*64) @@ -2820,7 +2813,7 @@ func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { var xInt big.Int - x.ToBigInt(&xInt) + x.toBigInt(&xInt) xInt.Mul(&xInt, big.NewInt(c)) @@ -2831,7 +2824,7 @@ func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, result func approximateRef(x *Element) uint64 { var asInt big.Int - x.ToBigInt(&asInt) + x.toBigInt(&asInt) n := x.BitLen() if n <= 64 { diff --git a/ecc/bw6-633/fr/fri/fri.go b/ecc/bw6-633/fr/fri/fri.go index 2f102e40b..5da24b37e 100644 --- a/ecc/bw6-633/fr/fri/fri.go +++ b/ecc/bw6-633/fr/fri/fri.go @@ -121,7 +121,6 @@ type Round struct { // a function is d-close to a low degree polynomial. // // It is composed of a series of Interactions, emulated with Fiat Shamir, -// type ProofOfProximity struct { // ID unique ID attached to the proof of proximity. It's needed for diff --git a/ecc/bw6-633/fr/gkr/gkr.go b/ecc/bw6-633/fr/gkr/gkr.go new file mode 100644 index 000000000..2b6717b2b --- /dev/null +++ b/ecc/bw6-633/fr/gkr/gkr.go @@ -0,0 +1,774 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/sumcheck" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// Gate must be a low-degree polynomial +type Gate interface { + Evaluate(...fr.Element) fr.Element + Degree() int +} + +type Wire struct { + Gate Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]fr.Element + claimedEvaluations []fr.Element + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]fr.Element) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]fr.Element // x in the paper + claimedEvaluations []fr.Element // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + newEq.Eq(c.evaluationPoints[k]) + + eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// computeValAndStep returns val : i ↦ m(1, i...) and step : i ↦ m(1, i...) - m(0, i...) +func computeValAndStep(m polynomial.MultiLin, p *polynomial.Pool) (val polynomial.MultiLin, step polynomial.MultiLin) { + val = p.Clone(m[len(m)/2:]) + step = p.Clone(m[:len(m)/2]) + + valAsPoly, stepAsPoly := polynomial.Polynomial(val), polynomial.Polynomial(step) + + stepAsPoly.Sub(valAsPoly, stepAsPoly) + return +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = g_{j-1}(r_{j-1}). By convention, g_0 is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() (gJ polynomial.Polynomial) { + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + EVal, EStep := computeValAndStep(c.eq, c.manager.memPool) + + puVal := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO: Make a two-dimensional array struct, and index it i-first rather than inputI first: would result in scanning memory access in the "d" loop and obviate the gateInput variable + puStep := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO, ctd: the greater degGJ, the more this would matter + + for i, puI := range c.inputPreprocessors { + puVal[i], puStep[i] = computeValAndStep(puI, c.manager.memPool) + } + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + gJ = make([]fr.Element, degGJ) + + parallel := len(EVal) >= 1024 //TODO: Experiment with threshold + + var gateInput [][]fr.Element + + if parallel { + gateInput = [][]fr.Element{c.manager.memPool.Make(len(c.inputPreprocessors)), + c.manager.memPool.Make(len(c.inputPreprocessors))} + } else { + gateInput = [][]fr.Element{c.manager.memPool.Make(len(c.inputPreprocessors))} + } + + var wg sync.WaitGroup + + for d := 0; d < degGJ; d++ { + + notLastIteration := d+1 < degGJ + + sumOverI := func(res *fr.Element, gateInput []fr.Element, start, end int) { + for i := start; i < end; i++ { + + for inputI := range puVal { + gateInput[inputI].Set(&puVal[inputI][i]) + if notLastIteration { + puVal[inputI][i].Add(&puVal[inputI][i], &puStep[inputI][i]) + } + } + + // gJAtDI = gJ(d, i...) + gJAtDI := c.wire.Gate.Evaluate(gateInput...) + gJAtDI.Mul(&gJAtDI, &EVal[i]) + + res.Add(res, &gJAtDI) + + if notLastIteration { + EVal[i].Add(&EVal[i], &EStep[i]) + } + } + wg.Done() + } + + if parallel { + var firstHalf, secondHalf fr.Element + wg.Add(2) + go sumOverI(&secondHalf, gateInput[1], len(EVal)/2, len(EVal)) + go sumOverI(&firstHalf, gateInput[0], 0, len(EVal)/2) + wg.Wait() + gJ[d].Add(&firstHalf, &secondHalf) + } else { + wg.Add(1) // formalities + sumOverI(&gJ[d], gateInput[0], 0, len(EVal)) + } + } + + c.manager.memPool.Dump(gateInput...) + c.manager.memPool.Dump(EVal, EStep) + + for inputI := range puVal { + c.manager.memPool.Dump(puVal[inputI], puStep[inputI]) + } + + return +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { + c.eq.Fold(element) + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, pool *polynomial.Pool) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = pool + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< b { + return a + } + return b +} + +func ChallengeNames(sorted []*Wire, logNbInstances int, prefix string) []string { + + // Pre-compute the size TODO: Consider not doing this and just grow the list by appending + size := logNbInstances // first challenge + + for _, w := range sorted { + if w.noProof() { // no proof, no challenge + continue + } + if w.NbClaims() > 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + + claims := newClaimsManager(c, assignment, o.pool) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + + claims := newClaimsManager(c, assignment, o.pool) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]fr.Element) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return fmt.Errorf("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return fmt.Errorf("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +type IdentityGate struct{} + +func (IdentityGate) Evaluate(input ...fr.Element) fr.Element { + return input[0] +} + +func (IdentityGate) Degree() int { + return 1 +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = IdentityGate{} + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + + numEvaluations := 0 + + for _, w := range sortedWires { + if !w.IsInput() { + if numEvaluations == 0 { + numEvaluations = len(a[w.Inputs[0]]) + } + evals := make([]fr.Element, numEvaluations) + ins := make([]fr.Element, len(w.Inputs)) + for k := 0; k < numEvaluations; k++ { + for inI, in := range w.Inputs { + ins[inI] = a[in][k] + } + evals[k] = w.Gate.Evaluate(ins...) + } + a[w] = evals + } + } + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} diff --git a/ecc/bw6-633/fr/gkr/gkr_test.go b/ecc/bw6-633/fr/gkr/gkr_test.go new file mode 100644 index 000000000..ec29f1716 --- /dev/null +++ b/ecc/bw6-633/fr/gkr/gkr_test.go @@ -0,0 +1,722 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/sumcheck" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + pool := polynomial.NewPool(256, 1<<11) + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, &pool) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandom(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/ConsenSys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandom(slice []fr.Element) { + for i := range slice { + slice[i].SetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, testCase.transcriptSetting()) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, testCase.transcriptSetting()) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(&test_vector_utils.MapHash{Map: testCase.Hash}, []byte{1})) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "../../../../internal/generator/gkr/test_vectors" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func BenchmarkGkrMimc(b *testing.B) { + const N = 1 << 19 + fmt.Println("creating circuit structure") + c := mimcCircuit(91) + + in0 := make([]fr.Element, N) + in1 := make([]fr.Element, N) + setRandom(in0) + setRandom(in1) + + fmt.Println("evaluating circuit") + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + + //b.ResetTimer() + fmt.Println("constructing proof") + Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate string `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = gates[c[i].Gate] + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +var gates map[string]Gate + +func init() { + gates = make(map[string]Gate) + gates["identity"] = IdentityGate{} + gates["mul"] = mulGate{} + gates["mimc"] = mimcCipherGate{} //TODO: Add ark + gates["select-input-3"] = _select(2) +} + +type mimcCipherGate struct { + ark fr.Element +} + +func (m mimcCipherGate) Evaluate(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]). + Add(&sum, &m.ark) + + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +func (m mimcCipherGate) Degree() int { + return 7 +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash *test_vector_utils.ElementMap + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash string `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash *test_vector_utils.ElementMap + if _hash, err = test_vector_utils.ElementMapFromFile(filepath.Join(dir, info.Hash)); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func (c *TestCase) transcriptSetting(initialChallenge ...[]byte) fiatshamir.Settings { + return fiatshamir.WithHash(&test_vector_utils.MapHash{Map: c.Hash}, initialChallenge...) +} + +type mulGate struct{} + +func (g mulGate) Evaluate(element ...fr.Element) (result fr.Element) { + result.Mul(&element[0], &element[1]) + return +} + +func (g mulGate) Degree() int { + return 2 +} + +type _select int + +func (g _select) Evaluate(in ...fr.Element) fr.Element { + return in[g] +} + +func (g _select) Degree() int { + return 1 +} diff --git a/ecc/bw6-633/fr/kzg/kzg.go b/ecc/bw6-633/fr/kzg/kzg.go index 824365a3f..d9eb12d4c 100644 --- a/ecc/bw6-633/fr/kzg/kzg.go +++ b/ecc/bw6-633/fr/kzg/kzg.go @@ -84,9 +84,6 @@ func NewSRS(size uint64, bAlpha *big.Int) (*SRS, error) { for i := 1; i < len(alphas); i++ { alphas[i].Mul(&alphas[i-1], &alpha) } - for i := 0; i < len(alphas); i++ { - alphas[i].FromMont() - } g1s := bw6633.BatchScalarMultiplicationG1(&gen1Aff, alphas) copy(srs.G1[1:], g1s) @@ -125,7 +122,7 @@ func Commit(p []fr.Element, srs *SRS, nbTasks ...int) (Digest, error) { var res bw6633.G1Affine - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} if len(nbTasks) > 0 { config.NbTasks = nbTasks[0] } @@ -395,7 +392,7 @@ func BatchVerifyMultiPoints(digests []Digest, proofs []OpeningProof, points []fr for i := 0; i < len(randomNumbers); i++ { quotients[i].Set(&proofs[i].H) } - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} _, err := foldedQuotients.MultiExp(quotients, randomNumbers, config) if err != nil { return nil @@ -478,7 +475,7 @@ func fold(di []Digest, fai []fr.Element, ci []fr.Element) (Digest, fr.Element, e // fold the digests ∑ᵢ[cᵢ]([fᵢ(α)]G₁) var foldedDigests Digest - _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{ScalarsMont: true}) + _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{}) if err != nil { return foldedDigests, foldedEvaluations, err } diff --git a/ecc/bw6-633/fr/mimc/decompose.go b/ecc/bw6-633/fr/mimc/decompose.go new file mode 100644 index 000000000..0f4cc36c9 --- /dev/null +++ b/ecc/bw6-633/fr/mimc/decompose.go @@ -0,0 +1,46 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" +) + +// Decompose interpret rawBytes as a bigInt x in big endian, +// and returns the digits of x (from LSB to MSB) when x is written +// in basis modulo. +func Decompose(rawBytes []byte) []fr.Element { + + rawBigInt := big.NewInt(0).SetBytes(rawBytes) + modulo := fr.Modulus() + + // maximum number of chunks that a function + maxNbChunks := len(rawBytes) / fr.Bytes + + res := make([]fr.Element, 0, maxNbChunks) + var tmp fr.Element + t := new(big.Int) + for rawBigInt.Sign() != 0 { + rawBigInt.DivMod(rawBigInt, modulo, t) + tmp.SetBigInt(t) + res = append(res, tmp) + } + + return res +} diff --git a/ecc/bw6-633/fr/mimc/decompose_test.go b/ecc/bw6-633/fr/mimc/decompose_test.go new file mode 100644 index 000000000..26518fd46 --- /dev/null +++ b/ecc/bw6-633/fr/mimc/decompose_test.go @@ -0,0 +1,53 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" +) + +func TestDecompose(t *testing.T) { + + // create 10 random digits in basis r + nbDigits := 10 + a := make([]fr.Element, nbDigits) + for i := 0; i < nbDigits; i++ { + a[i].SetRandom() + } + + // create a big int whose digits in basis r are a + m := fr.Modulus() + var b, tmp big.Int + for i := nbDigits - 1; i >= 0; i-- { + b.Mul(&b, m) + a[i].ToBigIntRegular(&tmp) + b.Add(&b, &tmp) + } + + // query the decomposition and compare to a + bb := b.Bytes() + d := Decompose(bb) + for i := 0; i < nbDigits; i++ { + if !d[i].Equal(&a[i]) { + t.Fatal("error decomposition") + } + } + +} diff --git a/ecc/bw6-633/fr/mimc/mimc.go b/ecc/bw6-633/fr/mimc/mimc.go index 687bd79e4..d02685515 100644 --- a/ecc/bw6-633/fr/mimc/mimc.go +++ b/ecc/bw6-633/fr/mimc/mimc.go @@ -17,6 +17,7 @@ package mimc import ( + "errors" "hash" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" @@ -91,44 +92,44 @@ func (d *digest) BlockSize() int { } // Write (via the embedded io.Writer interface) adds more data to the running hash. -// It never returns an error. +// +// Each []byte block of size BlockSize represents a big endian fr.Element. +// +// If len(p) is not a multiple of BlockSize and any of the []byte in p represent an integer +// larger than fr.Modulus, this function returns an error. +// +// To hash arbitrary data ([]byte not representing canonical field elements) use Decompose +// function in this package. func (d *digest) Write(p []byte) (n int, err error) { n = len(p) + if n%BlockSize != 0 { + return 0, errors.New("invalid input length: must represent a list of field elements, expects a []byte of len m*BlockSize") + } + + // ensure each block represents a field element in canonical reduced form + for i := 0; i < n; i += BlockSize { + if _, err = fr.BigEndian.Element((*[BlockSize]byte)(p[i : i+BlockSize])); err != nil { + return 0, err + } + } + d.data = append(d.data, p...) return } -// Hash hash using Miyaguchi–Preneel: +// Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form func (d *digest) checksum() fr.Element { + // Write guarantees len(data) % BlockSize == 0 - var buffer [BlockSize]byte - var x fr.Element - - // if data size is not multiple of BlockSizes we padd: - // .. || 0xaf8 -> .. || 0x0000...0af8 - if len(d.data)%BlockSize != 0 { - q := len(d.data) / BlockSize - r := len(d.data) % BlockSize - sliceq := make([]byte, q*BlockSize) - copy(sliceq, d.data) - slicer := make([]byte, r) - copy(slicer, d.data[q*BlockSize:]) - sliceremainder := make([]byte, BlockSize-r) - d.data = append(sliceq, sliceremainder...) - d.data = append(d.data, slicer...) - } - + // TODO @ThomasPiellard shouldn't Sum() returns an error if there is no data? if len(d.data) == 0 { - d.data = make([]byte, 32) + d.data = make([]byte, BlockSize) } - nbChunks := len(d.data) / BlockSize - - for i := 0; i < nbChunks; i++ { - copy(buffer[:], d.data[i*BlockSize:(i+1)*BlockSize]) - x.SetBytes(buffer[:]) + for i := 0; i < len(d.data); i += BlockSize { + x, _ := fr.BigEndian.Element((*[BlockSize]byte)(d.data[i : i+BlockSize])) r := d.encrypt(x) d.h.Add(&r, &d.h).Add(&d.h, &x) } diff --git a/ecc/bw6-633/fr/pedersen/pedersen.go b/ecc/bw6-633/fr/pedersen/pedersen.go new file mode 100644 index 000000000..1c5f269bf --- /dev/null +++ b/ecc/bw6-633/fr/pedersen/pedersen.go @@ -0,0 +1,113 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package pedersen + +import ( + "crypto/rand" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bw6-633" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "math/big" +) + +// Key for proof and verification +type Key struct { + g bw6633.G2Affine // TODO @tabaie: does this really have to be randomized? + gRootSigmaNeg bw6633.G2Affine //gRootSigmaNeg = g^{-1/σ} + basis []bw6633.G1Affine + basisExpSigma []bw6633.G1Affine +} + +func randomOnG2() (bw6633.G2Affine, error) { // TODO: Add to G2.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return bw6633.G2Affine{}, err + } + return bw6633.HashToG2(gBytes, []byte("random on g2")) +} + +func Setup(basis []bw6633.G1Affine) (Key, error) { + var ( + k Key + err error + ) + + if k.g, err = randomOnG2(); err != nil { + return k, err + } + + var modMinusOne big.Int + modMinusOne.Sub(fr.Modulus(), big.NewInt(1)) + var sigma *big.Int + if sigma, err = rand.Int(rand.Reader, &modMinusOne); err != nil { + return k, err + } + sigma.Add(sigma, big.NewInt(1)) + + var sigmaInvNeg big.Int + sigmaInvNeg.ModInverse(sigma, fr.Modulus()) + sigmaInvNeg.Sub(fr.Modulus(), &sigmaInvNeg) + k.gRootSigmaNeg.ScalarMultiplication(&k.g, &sigmaInvNeg) + + k.basisExpSigma = make([]bw6633.G1Affine, len(basis)) + for i := range basis { + k.basisExpSigma[i].ScalarMultiplication(&basis[i], sigma) + } + + k.basis = basis + return k, err +} + +func (k *Key) Commit(values []fr.Element) (commitment bw6633.G1Affine, knowledgeProof bw6633.G1Affine, err error) { + + if len(values) != len(k.basis) { + err = fmt.Errorf("unexpected number of values") + return + } + + // TODO @gbotrel this will spawn more than one task, see + // https://github.com/ConsenSys/gnark-crypto/issues/269 + config := ecc.MultiExpConfig{ + NbTasks: 1, // TODO Experiment + } + + if _, err = commitment.MultiExp(k.basis, values, config); err != nil { + return + } + + _, err = knowledgeProof.MultiExp(k.basisExpSigma, values, config) + + return +} + +// VerifyKnowledgeProof checks if the proof of knowledge is valid +func (k *Key) VerifyKnowledgeProof(commitment bw6633.G1Affine, knowledgeProof bw6633.G1Affine) error { + + if !commitment.IsInSubGroup() || !knowledgeProof.IsInSubGroup() { + return fmt.Errorf("subgroup check failed") + } + + product, err := bw6633.Pair([]bw6633.G1Affine{commitment, knowledgeProof}, []bw6633.G2Affine{k.g, k.gRootSigmaNeg}) + if err != nil { + return err + } + if product.IsOne() { + return nil + } + return fmt.Errorf("proof rejected") +} diff --git a/ecc/bw6-633/fr/pedersen/pedersen_test.go b/ecc/bw6-633/fr/pedersen/pedersen_test.go new file mode 100644 index 000000000..17c1f9aa2 --- /dev/null +++ b/ecc/bw6-633/fr/pedersen/pedersen_test.go @@ -0,0 +1,90 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package pedersen + +import ( + "github.com/consensys/gnark-crypto/ecc/bw6-633" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/stretchr/testify/assert" + "math/rand" + "testing" +) + +func interfaceSliceToFrSlice(t *testing.T, values ...interface{}) []fr.Element { + res := make([]fr.Element, len(values)) + for i, v := range values { + _, err := res[i].SetInterface(v) + assert.NoError(t, err) + } + return res +} + +func randomFrSlice(t *testing.T, size int) []interface{} { + res := make([]interface{}, size) + var err error + for i := range res { + var v fr.Element + res[i], err = v.SetRandom() + assert.NoError(t, err) + } + return res +} + +func randomOnG1() (bw6633.G1Affine, error) { // TODO: Add to G1.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return bw6633.G1Affine{}, err + } + return bw6633.HashToG1(gBytes, []byte("random on g2")) +} + +func testCommit(t *testing.T, values ...interface{}) { + + basis := make([]bw6633.G1Affine, len(values)) + for i := range basis { + var err error + basis[i], err = randomOnG1() + assert.NoError(t, err) + } + + var ( + key Key + err error + commitment, pok bw6633.G1Affine + ) + + key, err = Setup(basis) + assert.NoError(t, err) + commitment, pok, err = key.Commit(interfaceSliceToFrSlice(t, values...)) + assert.NoError(t, err) + assert.NoError(t, key.VerifyKnowledgeProof(commitment, pok)) + + pok.Neg(&pok) + assert.NotNil(t, key.VerifyKnowledgeProof(commitment, pok)) +} + +func TestCommitToOne(t *testing.T) { + testCommit(t, 1) +} + +func TestCommitSingle(t *testing.T) { + testCommit(t, randomFrSlice(t, 1)...) +} + +func TestCommitFiveElements(t *testing.T) { + testCommit(t, randomFrSlice(t, 5)...) +} diff --git a/ecc/bw6-633/fr/plookup/vector.go b/ecc/bw6-633/fr/plookup/vector.go index 86e47de39..ae8f16fd4 100644 --- a/ecc/bw6-633/fr/plookup/vector.go +++ b/ecc/bw6-633/fr/plookup/vector.go @@ -125,7 +125,8 @@ func evaluateAccumulationPolynomial(lf, lt, lh1, lh2 []fr.Element, beta, gamma f // evaluateNumBitReversed computes the evaluation (shifted, bit reversed) of h where // h = (x-1)*z*(1+\beta)*(\gamma+f)*(\gamma(1+\beta) + t+ \beta*t(gX)) - -// (x-1)*z(gX)*(\gamma(1+\beta) + h_{1} + \beta*h_{1}(gX))*(\gamma(1+\beta) + h_{2} + \beta*h_{2}(gX) ) +// +// (x-1)*z(gX)*(\gamma(1+\beta) + h_{1} + \beta*h_{1}(gX))*(\gamma(1+\beta) + h_{2} + \beta*h_{2}(gX) ) // // * cz, ch1, ch2, ct, cf are the polynomials z, h1, h2, t, f in canonical basis // * _lz, _lh1, _lh2, _lt, _lf are the polynomials z, h1, h2, t, f in shifted Lagrange basis (domainBig) @@ -370,7 +371,6 @@ func computeQuotientCanonical(alpha fr.Element, lh, lh0, lhn, lh1h2 []fr.Element // before generating a lookup proof), the commitment needs to be done on the // table sorted. Otherwise the commitment in proof.t will not be the same as // the public commitment: it will contain the same values, but permuted. -// func ProveLookupVector(srs *kzg.SRS, f, t Table) (ProofLookupVector, error) { // res diff --git a/ecc/bw6-633/fr/polynomial/multilin.go b/ecc/bw6-633/fr/polynomial/multilin.go index 8d0a683fe..df3529f01 100644 --- a/ecc/bw6-633/fr/polynomial/multilin.go +++ b/ecc/bw6-633/fr/polynomial/multilin.go @@ -18,6 +18,7 @@ package polynomial import ( "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "math/bits" ) // MultiLin tracks the values of a (dense i.e. not sparse) multilinear polynomial @@ -46,44 +47,62 @@ func (m *MultiLin) Fold(r fr.Element) { *m = (*m)[:mid] } +func (m MultiLin) Sum() fr.Element { + s := m[0] + for i := 1; i < len(m); i++ { + s.Add(&s, &m[i]) + } + return s +} + +func _clone(m MultiLin, p *Pool) MultiLin { + if p == nil { + return m.Clone() + } else { + return p.Clone(m) + } +} + +func _dump(m MultiLin, p *Pool) { + if p != nil { + p.Dump(m) + } +} + // Evaluate extrapolate the value of the multilinear polynomial corresponding to m // on the given coordinates -func (m MultiLin) Evaluate(coordinates []fr.Element) fr.Element { +func (m MultiLin) Evaluate(coordinates []fr.Element, p *Pool) fr.Element { // Folding is a mutating operation - bkCopy := m.Clone() + bkCopy := _clone(m, p) // Evaluate step by step through repeated folding (i.e. evaluation at the first remaining variable) for _, r := range coordinates { bkCopy.Fold(r) } - return bkCopy[0] + result := bkCopy[0] + + _dump(bkCopy, p) + return result } -// Clone creates a deep copy of a book-keeping table. +// Clone creates a deep copy of a bookkeeping table. // Both multilinear interpolation and sumcheck require folding an underlying // array, but folding changes the array. To do both one requires a deep copy -// of the book-keeping table. +// of the bookkeeping table. func (m MultiLin) Clone() MultiLin { - tableDeepCopy := Make(len(m)) - copy(tableDeepCopy, m) - return tableDeepCopy + res := make(MultiLin, len(m)) + copy(res, m) + return res } // Add two bookKeepingTables func (m *MultiLin) Add(left, right MultiLin) { size := len(left) // Check that left and right have the same size - if len(right) != size { - panic("Left and right do not have the right size") + if len(right) != size || len(*m) != size { + panic("left, right and destination must have the right size") } - // Reallocate the table if necessary - if cap(*m) < size { - *m = make([]fr.Element, size) - } - - // Resize the destination table - *m = (*m)[:size] // Add elementwise for i := 0; i < size; i++ { @@ -93,15 +112,17 @@ func (m *MultiLin) Add(left, right MultiLin) { // EvalEq computes Eq(q₁, ... , qₙ, h₁, ... , hₙ) = Π₁ⁿ Eq(qᵢ, hᵢ) // where Eq(x,y) = xy + (1-x)(1-y) = 1 - x - y + xy + xy interpolates -// _________________ -// | | | -// | 0 | 1 | -// |_______|_______| -// y | | | -// | 1 | 0 | -// |_______|_______| // -// x +// _________________ +// | | | +// | 0 | 1 | +// |_______|_______| +// y | | | +// | 1 | 0 | +// |_______|_______| +// +// x +// // In other words the polynomial evaluated here is the multilinear extrapolation of // one that evaluates to q' == h' for vectors q', h' of binary values func EvalEq(q, h []fr.Element) fr.Element { @@ -128,10 +149,7 @@ func (m *MultiLin) Eq(q []fr.Element) { n := len(q) if len(*m) != 1< 0 { - i.Sub(fr.Modulus(), &i) - i.Neg(&i) +func (p Polynomial) SetZero() { + for i := 0; i < len(p); i++ { + p[i].SetZero() } - return i } func (p Polynomial) Text(base int) string { @@ -171,12 +175,13 @@ func (p Polynomial) Text(base int) string { continue } - i := signedBigInt(&p[d]) + pD := p[d] + pDText := pD.Text(base) initialLen := builder.Len() - if i.Sign() < 1 { - i.Neg(&i) + if pDText[0] == '-' { + pDText = pDText[1:] if first { builder.WriteString("-") } else { @@ -188,13 +193,8 @@ func (p Polynomial) Text(base int) string { first = false - asInt64 := int64(0) - if i.IsInt64() { - asInt64 = i.Int64() - } - - if asInt64 != 1 || d == 0 { - builder.WriteString(i.Text(base)) + if !pD.IsOne() || d == 0 { + builder.WriteString(pDText) } if builder.Len()-initialLen > 10 { diff --git a/ecc/bw6-633/fr/polynomial/polynomial_test.go b/ecc/bw6-633/fr/polynomial/polynomial_test.go index ade2ef832..e91d99ac0 100644 --- a/ecc/bw6-633/fr/polynomial/polynomial_test.go +++ b/ecc/bw6-633/fr/polynomial/polynomial_test.go @@ -17,10 +17,10 @@ package polynomial import ( + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/stretchr/testify/assert" "math/big" "testing" - - "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" ) func TestPolynomialEval(t *testing.T) { @@ -206,3 +206,13 @@ func TestPolynomialAdd(t *testing.T) { t.Fatal("side effect, _f2 should not have been modified") } } + +func TestPolynomialText(t *testing.T) { + var one, negTwo fr.Element + one.SetOne() + negTwo.SetInt64(-2) + + p := Polynomial{one, negTwo, one} + + assert.Equal(t, "X² - 2X + 1", p.Text(10)) +} diff --git a/ecc/bw6-633/fr/polynomial/pool.go b/ecc/bw6-633/fr/polynomial/pool.go index 717315fa7..f998232c4 100644 --- a/ecc/bw6-633/fr/polynomial/pool.go +++ b/ecc/bw6-633/fr/polynomial/pool.go @@ -17,114 +17,187 @@ package polynomial import ( + "encoding/json" "fmt" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" "reflect" + "runtime" + "sort" "sync" "unsafe" ) // Memory management for polynomials -// Copied verbatim from gkr repo +// WARNING: This is not thread safe TODO: Make sure that is not a problem +// TODO: There is a lot of "unsafe" memory management here and needs to be vetted thoroughly -// Sets a maximum for the array size we keep in pool -const maxNForLargePool int = 1 << 24 -const maxNForSmallPool int = 256 +type sizedPool struct { + maxN int + pool sync.Pool + stats poolStats +} + +type inUseData struct { + allocatedFor []uintptr + pool *sizedPool +} -// Aliases because it is annoying to use arrays in all the places -type largeArr = [maxNForLargePool]fr.Element -type smallArr = [maxNForSmallPool]fr.Element +type Pool struct { + //lock sync.Mutex + inUse map[*fr.Element]inUseData + subPools []sizedPool +} -var rC = sync.Map{} +func (p *sizedPool) get(n int) *fr.Element { + p.stats.maake(n) + return p.pool.Get().(*fr.Element) +} -var ( - largePool = sync.Pool{ - New: func() interface{} { - var res largeArr - return &res - }, - } - smallPool = sync.Pool{ - New: func() interface{} { - var res smallArr - return &res - }, +func (p *sizedPool) put(ptr *fr.Element) { + p.stats.dump() + p.pool.Put(ptr) +} + +func NewPool(maxN ...int) (pool Pool) { + + sort.Ints(maxN) + pool = Pool{ + inUse: make(map[*fr.Element]inUseData), + subPools: make([]sizedPool, len(maxN)), } -) -// ClearPool Clears the pool completely, shields against memory leaks -// Eg: if we forgot to dump a polynomial at some point, this will ensure the value get dumped eventually -// Returns how many polynomials were cleared that way -func ClearPool() int { - res := 0 - rC.Range(func(k, _ interface{}) bool { - switch ptr := k.(type) { - case *largeArr: - largePool.Put(ptr) - case *smallArr: - smallPool.Put(ptr) - default: - panic(fmt.Sprintf("tried to clear %v", reflect.TypeOf(ptr))) + for i := range pool.subPools { + subPool := &pool.subPools[i] + subPool.maxN = maxN[i] + subPool.pool = sync.Pool{ + New: func() interface{} { + subPool.stats.Allocated++ + return getDataPointer(make([]fr.Element, 0, subPool.maxN)) + }, } - res++ - return true - }) - return res + } + return } -// CountPool Returns the number of elements in the pool without mutating it -func CountPool() int { - res := 0 - rC.Range(func(_, _ interface{}) bool { - res++ - return true - }) - return res +func (p *Pool) findCorrespondingPool(n int) *sizedPool { + poolI := 0 + for poolI < len(p.subPools) && n > p.subPools[poolI].maxN { + poolI++ + } + return &p.subPools[poolI] // out of bounds error here would mean that n is too large } -// Make tries to find a reusable polynomial or allocates a new one -func Make(n int) []fr.Element { - if n > maxNForLargePool { - panic(fmt.Sprintf("been provided with size of %v but the maximum is %v", n, maxNForLargePool)) +func (p *Pool) Make(n int) []fr.Element { + pool := p.findCorrespondingPool(n) + ptr := pool.get(n) + p.addInUse(ptr, pool) + return unsafe.Slice(ptr, n) +} + +// Dump dumps a set of polynomials into the pool +func (p *Pool) Dump(slices ...[]fr.Element) { + for _, slice := range slices { + ptr := getDataPointer(slice) + if metadata, ok := p.inUse[ptr]; ok { + delete(p.inUse, ptr) + metadata.pool.put(ptr) + } else { + panic("attempting to dump a slice not created by the pool") + } } +} + +func (p *Pool) addInUse(ptr *fr.Element, pool *sizedPool) { + pcs := make([]uintptr, 2) + n := runtime.Callers(3, pcs) - if n <= maxNForSmallPool { - ptr := smallPool.Get().(*smallArr) - rC.Store(ptr, struct{}{}) // registers the pointer being used - return (*ptr)[:n] + if prevPcs, ok := p.inUse[ptr]; ok { // TODO: remove if unnecessary for security + panic(fmt.Errorf("re-allocated non-dumped slice, previously allocated at %v", runtime.CallersFrames(prevPcs.allocatedFor))) } + p.inUse[ptr] = inUseData{ + allocatedFor: pcs[:n], + pool: pool, + } +} - ptr := largePool.Get().(*largeArr) - rC.Store(ptr, struct{}{}) // remember we allocated the pointer is being used - return (*ptr)[:n] +func printFrame(frame runtime.Frame) { + fmt.Printf("\t%s line %d, function %s\n", frame.File, frame.Line, frame.Function) } -// Dump dumps a set of polynomials into the pool -// Returns the number of deallocated polys -func Dump(arrs ...[]fr.Element) int { - cnt := 0 - for _, arr := range arrs { - ptr := ptr(arr) - pool := &smallPool - if len(arr) > maxNForSmallPool { - pool = &largePool - } - // If the rC did not register, then - // either the array was allocated somewhere else which can be ignored - // otherwise a double put which MUST be ignored - if _, ok := rC.Load(ptr); ok { - pool.Put(ptr) - // And deregisters the ptr - rC.Delete(ptr) - cnt++ +func (p *Pool) printInUse() { + fmt.Println("slices never dumped allocated at:") + for _, pcs := range p.inUse { + fmt.Println("-------------------------") + + var frame runtime.Frame + frames := runtime.CallersFrames(pcs.allocatedFor) + more := true + for more { + frame, more = frames.Next() + printFrame(frame) } } - return cnt } -func ptr(m []fr.Element) unsafe.Pointer { - if cap(m) != maxNForSmallPool && cap(m) != maxNForLargePool { - panic(fmt.Sprintf("can't cast to large or small array, the put array's is %v it should have capacity %v or %v", cap(m), maxNForLargePool, maxNForSmallPool)) +type poolStats struct { + Used int + Allocated int + ReuseRate float64 + InUse int + GreatestNUsed int + SmallestNUsed int +} + +type poolsStats struct { + SubPools []poolStats + InUse int +} + +func (s *poolStats) maake(n int) { + s.Used++ + s.InUse++ + if n > s.GreatestNUsed { + s.GreatestNUsed = n + } + if s.SmallestNUsed == 0 || s.SmallestNUsed > n { + s.SmallestNUsed = n + } +} + +func (s *poolStats) dump() { + s.InUse-- +} + +func (s *poolStats) finalize() { + s.ReuseRate = float64(s.Used) / float64(s.Allocated) +} + +func getDataPointer(slice []fr.Element) *fr.Element { + header := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) + return (*fr.Element)(unsafe.Pointer(header.Data)) +} + +func (p *Pool) PrintPoolStats() { + InUse := 0 + subStats := make([]poolStats, len(p.subPools)) + for i := range p.subPools { + subPool := &p.subPools[i] + subPool.stats.finalize() + subStats[i] = subPool.stats + InUse += subPool.stats.InUse } - return unsafe.Pointer(&m[0]) + + poolsStats := poolsStats{ + SubPools: subStats, + InUse: InUse, + } + serialized, _ := json.MarshalIndent(poolsStats, "", " ") + fmt.Println(string(serialized)) + p.printInUse() +} + +func (p *Pool) Clone(slice []fr.Element) []fr.Element { + res := p.Make(len(slice)) + copy(res, slice) + return res } diff --git a/ecc/bw6-633/fr/sumcheck/sumcheck.go b/ecc/bw6-633/fr/sumcheck/sumcheck.go new file mode 100644 index 000000000..40be96534 --- /dev/null +++ b/ecc/bw6-633/fr/sumcheck/sumcheck.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = &transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return fmt.Errorf("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/ecc/bw6-633/fr/sumcheck/sumcheck_test.go b/ecc/bw6-633/fr/sumcheck/sumcheck_test.go new file mode 100644 index 000000000..1ae0784c7 --- /dev/null +++ b/ecc/bw6-633/fr/sumcheck/sumcheck_test.go @@ -0,0 +1,161 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/ecc/bw6-633/fr/test_vector_utils/test_vector_utils.go b/ecc/bw6-633/fr/test_vector_utils/test_vector_utils.go new file mode 100644 index 000000000..503f4cb4c --- /dev/null +++ b/ecc/bw6-633/fr/test_vector_utils/test_vector_utils.go @@ -0,0 +1,429 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package test_vector_utils + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/polynomial" + "hash" + + "os" + "path/filepath" + "reflect" + "sort" + "strconv" + "strings" +) + +type ElementTriplet struct { + key1 fr.Element + key2 fr.Element + key2Present bool + value fr.Element + used bool +} + +func (t *ElementTriplet) CmpKey(o *ElementTriplet) int { + if cmp1 := t.key1.Cmp(&o.key1); cmp1 != 0 { + return cmp1 + } + + if t.key2Present { + if o.key2Present { + return t.key2.Cmp(&o.key2) + } + return 1 + } else { + if o.key2Present { + return -1 + } + return 0 + } +} + +var MapCache = make(map[string]*ElementMap) + +func ElementMapFromFile(path string) (*ElementMap, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if h, ok := MapCache[path]; ok { + return h, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var asMap map[string]interface{} + if err = json.Unmarshal(bytes, &asMap); err != nil { + return nil, err + } + + var h ElementMap + if h, err = CreateElementMap(asMap); err == nil { + MapCache[path] = &h + } + + return &h, err + + } else { + return nil, err + } +} + +func CreateElementMap(rawMap map[string]interface{}) (ElementMap, error) { + res := make(ElementMap, 0, len(rawMap)) + + for k, v := range rawMap { + var entry ElementTriplet + if _, err := SetElement(&entry.value, v); err != nil { + return nil, err + } + + key := strings.Split(k, ",") + switch len(key) { + case 1: + entry.key2Present = false + case 2: + entry.key2Present = true + if _, err := SetElement(&entry.key2, key[1]); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("cannot parse %T as one or two field elements", v) + } + if _, err := SetElement(&entry.key1, key[0]); err != nil { + return nil, err + } + + res = append(res, &entry) + } + + res.sort() + return res, nil +} + +type ElementMap []*ElementTriplet + +type MapHash struct { + Map *ElementMap + state fr.Element + stateValid bool +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func (m *MapHash) Write(p []byte) (n int, err error) { + var x fr.Element + for i := 0; i < len(p); i += fr.Bytes { + x.SetBytes(p[i:min(len(p), i+fr.Bytes)]) + if err = m.write(x); err != nil { + return + } + } + n = len(p) + return +} + +func (m *MapHash) Sum(b []byte) []byte { + mP := *m + if _, err := mP.Write(b); err != nil { + panic(err) + } + bytes := mP.state.Bytes() + return bytes[:] +} + +func (m *MapHash) Reset() { + m.stateValid = false +} + +func (m *MapHash) Size() int { + return fr.Bytes +} + +func (m *MapHash) BlockSize() int { + return fr.Bytes +} + +func (m *MapHash) write(x fr.Element) error { + X := &x + Y := &m.state + if !m.stateValid { + Y = nil + } + var err error + if m.state, err = m.Map.FindPair(X, Y); err == nil { + m.stateValid = true + } + return err +} + +func (t *ElementTriplet) writeKey(sb *strings.Builder) { + sb.WriteRune('"') + sb.WriteString(t.key1.String()) + if t.key2Present { + sb.WriteRune(',') + sb.WriteString(t.key2.String()) + } + sb.WriteRune('"') +} +func (m *ElementMap) UnusedEntries() []interface{} { + unused := make([]interface{}, 0) + for _, v := range *m { + if !v.used { + var vInterface interface{} + if v.key2Present { + vInterface = []interface{}{ElementToInterface(&v.key1), ElementToInterface(&v.key2)} + } else { + vInterface = ElementToInterface(&v.key1) + } + unused = append(unused, vInterface) + } + } + return unused +} + +func (m *ElementMap) sort() { + sort.Slice(*m, func(i, j int) bool { + return (*m)[i].CmpKey((*m)[j]) <= 0 + }) +} + +func (m *ElementMap) find(toFind *ElementTriplet) (fr.Element, error) { + i := sort.Search(len(*m), func(i int) bool { return (*m)[i].CmpKey(toFind) >= 0 }) + + if i < len(*m) && (*m)[i].CmpKey(toFind) == 0 { + (*m)[i].used = true + return (*m)[i].value, nil + } + var sb strings.Builder + sb.WriteString("no value available for input ") + toFind.writeKey(&sb) + return fr.Element{}, fmt.Errorf(sb.String()) +} + +func (m *ElementMap) FindPair(x *fr.Element, y *fr.Element) (fr.Element, error) { + + toFind := ElementTriplet{ + key1: *x, + key2Present: y != nil, + } + + if y != nil { + toFind.key2 = *y + } + + return m.find(&toFind) +} + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type MessageCounter struct { + startState uint64 + state uint64 + step uint64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += uint64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + uint64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: uint64(startState), state: uint64(startState), step: uint64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + text := x.Text(10) + if len(text) < 10 && !strings.Contains(text, "/") { + if i, err := strconv.Atoi(text); err != nil { + panic(err.Error()) + } else { + return i + } + } + return text +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/ecc/bw6-633/g1.go b/ecc/bw6-633/g1.go index 0228f01bb..95e71fb1d 100644 --- a/ecc/bw6-633/g1.go +++ b/ecc/bw6-633/g1.go @@ -17,13 +17,12 @@ package bw6633 import ( - "math/big" - "runtime" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-633/fp" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" + "runtime" ) // G1Affine point in affine coordinates @@ -36,7 +35,7 @@ type G1Jac struct { X, Y, Z fp.Element } -// g1JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +// g1JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) type g1JacExtended struct { X, Y, ZZ, ZZZ fp.Element } @@ -55,6 +54,13 @@ func (p *G1Affine) Set(a *G1Affine) *G1Affine { return p } +// setInfinity sets p to O +func (p *G1Affine) setInfinity() *G1Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + // ScalarMultiplication computes and returns p = a ⋅ s func (p *G1Affine) ScalarMultiplication(a *G1Affine, s *big.Int) *G1Affine { var _p G1Jac @@ -378,6 +384,7 @@ func (p *G1Jac) IsOnCurve() bool { } // IsInSubGroup returns true if p is on the r-torsion, false otherwise. + // 3r P = (x+1)ϕ(P) + (-x^5 + x⁴ + x)P func (p *G1Jac) IsInSubGroup() bool { @@ -477,8 +484,8 @@ func (p *G1Jac) mulGLV(a *G1Jac, s *big.Int) *G1Jac { // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // we don't target constant-timeness so we check first if we increase the bounds or not maxBit := k1.BitLen() @@ -517,6 +524,7 @@ func (p *G1Affine) ClearCofactor(a *G1Affine) *G1Affine { // ClearCofactor maps a point in E(Fp) to E(Fp)[r] func (p *G1Jac) ClearCofactor(a *G1Jac) *G1Jac { + var uP, vP, wP, L0, L1, tmp G1Jac var v, one, uPlusOne, uMinusOne, d1, d2, ht big.Int one.SetInt64(1) @@ -612,15 +620,15 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { return p } - var A, B, X1ZZ2, X2ZZ1, Y1ZZZ2, Y2ZZZ1 fp.Element + var A, B, U1, U2, S1, S2 fp.Element // p2: q, p1: p - X2ZZ1.Mul(&q.X, &p.ZZ) - X1ZZ2.Mul(&p.X, &q.ZZ) - A.Sub(&X2ZZ1, &X1ZZ2) - Y2ZZZ1.Mul(&q.Y, &p.ZZZ) - Y1ZZZ2.Mul(&p.Y, &q.ZZZ) - B.Sub(&Y2ZZZ1, &Y1ZZZ2) + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) if A.IsZero() { if B.IsZero() { @@ -632,11 +640,7 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { return p } - var U1, U2, S1, S2, P, R, PP, PPP, Q, V fp.Element - U1.Mul(&p.X, &q.ZZ) - U2.Mul(&q.X, &p.ZZ) - S1.Mul(&p.Y, &q.ZZZ) - S2.Mul(&q.Y, &p.ZZZ) + var P, R, PP, PPP, Q, V fp.Element P.Sub(&U2, &U1) R.Sub(&S2, &S1) PP.Square(&P) @@ -661,6 +665,8 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { // double point in Jacobian extended coordinates // http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well func (p *g1JacExtended) double(q *g1JacExtended) *g1JacExtended { var U, V, W, S, XX, M fp.Element @@ -978,95 +984,72 @@ func BatchJacobianToAffineG1(points []G1Jac) []G1Affine { // and return resulting points in affine coordinates // uses a simple windowed-NAF like exponentiation algorithm func BatchScalarMultiplicationG1(base *G1Affine, scalars []fr.Element) []G1Affine { - // approximate cost in group ops is // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) nbPoints := uint64(len(scalars)) min := ^uint64(0) bestC := 0 - for c := 2; c < 18; c++ { - cost := uint64(1 << (c - 1)) - nbChunks := uint64(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - cost += nbPoints * ((fr.Limbs * 64) + nbChunks) + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c - 1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add if cost < min { min = cost bestC = c } } c := uint64(bestC) // window size - nbChunks := int(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c } - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // precompute all powers of base for our window // note here that if performance is critical, we can implement as in the msmX methods // this allocation to be on the stack - baseTable := make([]G1Jac, (1 << (c - 1))) - baseTable[0].Set(&g1Infinity) - baseTable[0].AddMixed(base) + baseTable := make([]G1Jac, (1 << (maxC - 1))) + baseTable[0].FromAffine(base) for i := 1; i < len(baseTable); i++ { baseTable[i] = baseTable[i-1] baseTable[i].AddMixed(base) } - - pScalars, _ := partitionScalars(scalars, c, false, runtime.NumCPU()) - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := 0; chunk < nbChunks; chunk++ { - jc := uint64(uint64(chunk) * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = (64%c) != 0 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } // convert our base exp table into affine to use AddMixed baseTableAff := BatchJacobianToAffineG1(baseTable) toReturn := make([]G1Jac, len(scalars)) + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + // for each digit, take value in the base table, double it c time, voilà. - parallel.Execute(len(pScalars), func(start, end int) { + parallel.Execute(len(scalars), func(start, end int) { var p G1Jac for i := start; i < end; i++ { p.Set(&g1Infinity) for chunk := nbChunks - 1; chunk >= 0; chunk-- { - s := selectors[chunk] if chunk != nbChunks-1 { for j := uint64(0); j < c; j++ { p.DoubleAssign() } } + offset := chunk * len(scalars) + digit := digits[i+offset] - bits := (pScalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (pScalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { + if digit == 0 { continue } // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { + if digit&1 == 0 { // add - p.AddMixed(&baseTableAff[bits-1]) + p.AddMixed(&baseTableAff[(digit>>1)-1]) } else { // sub - t := baseTableAff[bits & ^msbWindow] + t := baseTableAff[digit>>1] t.Neg(&t) p.AddMixed(&t) } @@ -1080,3 +1063,54 @@ func BatchScalarMultiplicationG1(base *G1Affine, scalars []fr.Element) []G1Affin toReturnAff := BatchJacobianToAffineG1(toReturn) return toReturnAff } + +// batch add affine coordinates +// using batch inversion +// special cases (doubling, infinity) must be filtered out before this call +func batchAddG1Affine[TP pG1Affine, TPP ppG1Affine, TC cG1Affine](R *TPP, P *TP, batchSize int) { + var lambda, lambdain TC + + // add part + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // invert denominator using montgomery batch invert technique + { + var accumulator fp.Element + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var d fp.Element + var rr G1Affine + + // add part + for j := 0; j < batchSize; j++ { + // computa lambda + d.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &d) + + // compute X, Y + rr.X.Square(&lambda[j]) + rr.X.Sub(&rr.X, &(*R)[j].X) + rr.X.Sub(&rr.X, &(*P)[j].X) + d.Sub(&(*R)[j].X, &rr.X) + rr.Y.Mul(&lambda[j], &d) + rr.Y.Sub(&rr.Y, &(*R)[j].Y) + (*R)[j].Set(&rr) + } +} diff --git a/ecc/bw6-633/g1_test.go b/ecc/bw6-633/g1_test.go index 6caf91227..bf3e45975 100644 --- a/ecc/bw6-633/g1_test.go +++ b/ecc/bw6-633/g1_test.go @@ -19,6 +19,7 @@ package bw6633 import ( "fmt" "math/big" + "math/rand" "testing" "github.com/consensys/gnark-crypto/ecc/bw6-633/fp" @@ -338,7 +339,7 @@ func TestG1AffineOps(t *testing.T) { r := fr.Modulus() var g G1Jac - g.mulGLV(&g1Gen, r) + g.ScalarMultiplication(&g1Gen, r) var scalar, blindedScalar, rminusone big.Int var op1, op2, op3, gneg G1Jac @@ -458,8 +459,7 @@ func TestG1AffineBatchScalarMultiplication(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } result := BatchScalarMultiplicationG1(&g1GenAff, sampleScalars[:]) @@ -472,7 +472,7 @@ func TestG1AffineBatchScalarMultiplication(t *testing.T) { var expectedJac G1Jac var expected G1Affine var b big.Int - expectedJac.mulGLV(&g1Gen, sampleScalars[i].ToBigInt(&b)) + expectedJac.ScalarMultiplication(&g1Gen, sampleScalars[i].ToBigIntRegular(&b)) expected.FromJacobian(&expectedJac) if !result[i].Equal(&expected) { return false @@ -499,6 +499,33 @@ func BenchmarkG1JacIsInSubGroup(b *testing.B) { } +func BenchmarkBatchAddG1Affine(b *testing.B) { + + var P, R pG1AffineC16 + var RR ppG1AffineC16 + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBasesG1(P[:]) + fillBenchBasesG1(R[:]) + + for i := 0; i < len(ridx); i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAddG1Affine[pG1AffineC16, ppG1AffineC16, cG1AffineC16](&RR, &P, len(P)) + } +} + func BenchmarkG1AffineBatchScalarMultiplication(b *testing.B) { // ensure every words of the scalars are filled var mixer fr.Element @@ -511,8 +538,7 @@ func BenchmarkG1AffineBatchScalarMultiplication(b *testing.B) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } for i := 5; i <= pow; i++ { diff --git a/ecc/bw6-633/g2.go b/ecc/bw6-633/g2.go index 33adec5f3..1bab9608a 100644 --- a/ecc/bw6-633/g2.go +++ b/ecc/bw6-633/g2.go @@ -17,13 +17,12 @@ package bw6633 import ( - "math/big" - "runtime" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-633/fp" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" + "runtime" ) // G2Affine point in affine coordinates @@ -36,7 +35,7 @@ type G2Jac struct { X, Y, Z fp.Element } -// g2JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +// g2JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) type g2JacExtended struct { X, Y, ZZ, ZZZ fp.Element } @@ -50,6 +49,13 @@ func (p *G2Affine) Set(a *G2Affine) *G2Affine { return p } +// setInfinity sets p to O +func (p *G2Affine) setInfinity() *G2Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + // ScalarMultiplication computes and returns p = a ⋅ s func (p *G2Affine) ScalarMultiplication(a *G2Affine, s *big.Int) *G2Affine { var _p G2Jac @@ -365,6 +371,7 @@ func (p *G2Jac) IsOnCurve() bool { } // IsInSubGroup returns true if p is on the r-torsion, false otherwise. + // 3r P = (x+1)ϕ(P) + (-x^5 + x⁴ + x)P func (p *G2Jac) IsInSubGroup() bool { @@ -464,8 +471,8 @@ func (p *G2Jac) mulGLV(a *G2Jac, s *big.Int) *G2Jac { // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // we don't target constant-timeness so we check first if we increase the bounds or not maxBit := k1.BitLen() @@ -602,15 +609,15 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { return p } - var A, B, X1ZZ2, X2ZZ1, Y1ZZZ2, Y2ZZZ1 fp.Element + var A, B, U1, U2, S1, S2 fp.Element // p2: q, p1: p - X2ZZ1.Mul(&q.X, &p.ZZ) - X1ZZ2.Mul(&p.X, &q.ZZ) - A.Sub(&X2ZZ1, &X1ZZ2) - Y2ZZZ1.Mul(&q.Y, &p.ZZZ) - Y1ZZZ2.Mul(&p.Y, &q.ZZZ) - B.Sub(&Y2ZZZ1, &Y1ZZZ2) + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) if A.IsZero() { if B.IsZero() { @@ -622,11 +629,7 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { return p } - var U1, U2, S1, S2, P, R, PP, PPP, Q, V fp.Element - U1.Mul(&p.X, &q.ZZ) - U2.Mul(&q.X, &p.ZZ) - S1.Mul(&p.Y, &q.ZZZ) - S2.Mul(&q.Y, &p.ZZZ) + var P, R, PP, PPP, Q, V fp.Element P.Sub(&U2, &U1) R.Sub(&S2, &S1) PP.Square(&P) @@ -651,6 +654,8 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { // double point in Jacobian extended coordinates // http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well func (p *g2JacExtended) double(q *g2JacExtended) *g2JacExtended { var U, V, W, S, XX, M fp.Element @@ -844,93 +849,70 @@ func (p *g2JacExtended) doubleMixed(q *G2Affine) *g2JacExtended { // and return resulting points in affine coordinates // uses a simple windowed-NAF like exponentiation algorithm func BatchScalarMultiplicationG2(base *G2Affine, scalars []fr.Element) []G2Affine { - // approximate cost in group ops is // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) nbPoints := uint64(len(scalars)) min := ^uint64(0) bestC := 0 - for c := 2; c < 18; c++ { - cost := uint64(1 << (c - 1)) - nbChunks := uint64(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - cost += nbPoints * ((fr.Limbs * 64) + nbChunks) + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c - 1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add if cost < min { min = cost bestC = c } } c := uint64(bestC) // window size - nbChunks := int(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c } - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // precompute all powers of base for our window // note here that if performance is critical, we can implement as in the msmX methods // this allocation to be on the stack - baseTable := make([]G2Jac, (1 << (c - 1))) - baseTable[0].Set(&g2Infinity) - baseTable[0].AddMixed(base) + baseTable := make([]G2Jac, (1 << (maxC - 1))) + baseTable[0].FromAffine(base) for i := 1; i < len(baseTable); i++ { baseTable[i] = baseTable[i-1] baseTable[i].AddMixed(base) } - - pScalars, _ := partitionScalars(scalars, c, false, runtime.NumCPU()) - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := 0; chunk < nbChunks; chunk++ { - jc := uint64(uint64(chunk) * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = (64%c) != 0 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } toReturn := make([]G2Affine, len(scalars)) + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + // for each digit, take value in the base table, double it c time, voilà. - parallel.Execute(len(pScalars), func(start, end int) { + parallel.Execute(len(scalars), func(start, end int) { var p G2Jac for i := start; i < end; i++ { p.Set(&g2Infinity) for chunk := nbChunks - 1; chunk >= 0; chunk-- { - s := selectors[chunk] if chunk != nbChunks-1 { for j := uint64(0); j < c; j++ { p.DoubleAssign() } } + offset := chunk * len(scalars) + digit := digits[i+offset] - bits := (pScalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (pScalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { + if digit == 0 { continue } // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { + if digit&1 == 0 { // add - p.AddAssign(&baseTable[bits-1]) + p.AddAssign(&baseTable[(digit>>1)-1]) } else { // sub - t := baseTable[bits & ^msbWindow] + t := baseTable[digit>>1] t.Neg(&t) p.AddAssign(&t) } @@ -943,3 +925,54 @@ func BatchScalarMultiplicationG2(base *G2Affine, scalars []fr.Element) []G2Affin }) return toReturn } + +// batch add affine coordinates +// using batch inversion +// special cases (doubling, infinity) must be filtered out before this call +func batchAddG2Affine[TP pG2Affine, TPP ppG2Affine, TC cG2Affine](R *TPP, P *TP, batchSize int) { + var lambda, lambdain TC + + // add part + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // invert denominator using montgomery batch invert technique + { + var accumulator fp.Element + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var d fp.Element + var rr G2Affine + + // add part + for j := 0; j < batchSize; j++ { + // computa lambda + d.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &d) + + // compute X, Y + rr.X.Square(&lambda[j]) + rr.X.Sub(&rr.X, &(*R)[j].X) + rr.X.Sub(&rr.X, &(*P)[j].X) + d.Sub(&(*R)[j].X, &rr.X) + rr.Y.Mul(&lambda[j], &d) + rr.Y.Sub(&rr.Y, &(*R)[j].Y) + (*R)[j].Set(&rr) + } +} diff --git a/ecc/bw6-633/g2_test.go b/ecc/bw6-633/g2_test.go index 32773e971..e5091c050 100644 --- a/ecc/bw6-633/g2_test.go +++ b/ecc/bw6-633/g2_test.go @@ -19,6 +19,7 @@ package bw6633 import ( "fmt" "math/big" + "math/rand" "testing" "github.com/consensys/gnark-crypto/ecc/bw6-633/fp" @@ -325,7 +326,7 @@ func TestG2AffineOps(t *testing.T) { r := fr.Modulus() var g G2Jac - g.mulGLV(&g2Gen, r) + g.ScalarMultiplication(&g2Gen, r) var scalar, blindedScalar, rminusone big.Int var op1, op2, op3, gneg G2Jac @@ -445,8 +446,7 @@ func TestG2AffineBatchScalarMultiplication(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } result := BatchScalarMultiplicationG2(&g2GenAff, sampleScalars[:]) @@ -459,7 +459,7 @@ func TestG2AffineBatchScalarMultiplication(t *testing.T) { var expectedJac G2Jac var expected G2Affine var b big.Int - expectedJac.mulGLV(&g2Gen, sampleScalars[i].ToBigInt(&b)) + expectedJac.ScalarMultiplication(&g2Gen, sampleScalars[i].ToBigIntRegular(&b)) expected.FromJacobian(&expectedJac) if !result[i].Equal(&expected) { return false @@ -486,6 +486,33 @@ func BenchmarkG2JacIsInSubGroup(b *testing.B) { } +func BenchmarkBatchAddG2Affine(b *testing.B) { + + var P, R pG2AffineC16 + var RR ppG2AffineC16 + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBasesG2(P[:]) + fillBenchBasesG2(R[:]) + + for i := 0; i < len(ridx); i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAddG2Affine[pG2AffineC16, ppG2AffineC16, cG2AffineC16](&RR, &P, len(P)) + } +} + func BenchmarkG2AffineBatchScalarMultiplication(b *testing.B) { // ensure every words of the scalars are filled var mixer fr.Element @@ -498,8 +525,7 @@ func BenchmarkG2AffineBatchScalarMultiplication(b *testing.B) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } for i := 5; i <= pow; i++ { diff --git a/ecc/bw6-633/hash_to_g1.go b/ecc/bw6-633/hash_to_g1.go index 7e6d7f1a5..bb4a50239 100644 --- a/ecc/bw6-633/hash_to_g1.go +++ b/ecc/bw6-633/hash_to_g1.go @@ -17,7 +17,6 @@ package bw6633 import ( - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-633/fp" "math/big" @@ -258,35 +257,14 @@ func g1EvalPolynomial(z *fp.Element, monic bool, coefficients []fp.Element, x *f z.Set(&dst) } -// hashToFp hashes msg to count prime field elements. -// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 -func hashToFp(msg, dst []byte, count int) ([]fp.Element, error) { - // 128 bits of security - // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 - const Bytes = 1 + (fp.Bits-1)/8 - const L = 16 + Bytes - - lenInBytes := count * L - pseudoRandomBytes, err := ecc.ExpandMsgXmd(msg, dst, lenInBytes) - if err != nil { - return nil, err - } - - res := make([]fp.Element, count) - for i := 0; i < count; i++ { - res[i].SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) - } - return res, nil -} - // g1Sgn0 is an algebraic substitute for the notion of sign in ordered fields // Namely, every non-zero quadratic residue in a finite field of characteristic =/= 2 has exactly two square roots, one of each sign // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#name-the-sgn0-function // The sign of an element is not obviously related to that of its Montgomery form func g1Sgn0(z *fp.Element) uint64 { - nonMont := *z - nonMont.FromMont() + nonMont := z.Bits() + // m == 1 return nonMont[0] % 2 @@ -304,11 +282,11 @@ func MapToG1(u fp.Element) G1Affine { // EncodeToG1 hashes a message to a point on the G1 curve using the SSWU map. // It is faster than HashToG1, but the result is not uniformly distributed. Unsuitable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func EncodeToG1(msg, dst []byte) (G1Affine, error) { var res G1Affine - u, err := hashToFp(msg, dst, 1) + u, err := fp.Hash(msg, dst, 1) if err != nil { return res, err } @@ -324,9 +302,9 @@ func EncodeToG1(msg, dst []byte) (G1Affine, error) { // HashToG1 hashes a message to a point on the G1 curve using the SSWU map. // Slower than EncodeToG1, but usable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func HashToG1(msg, dst []byte) (G1Affine, error) { - u, err := hashToFp(msg, dst, 2*1) + u, err := fp.Hash(msg, dst, 2*1) if err != nil { return G1Affine{}, err } diff --git a/ecc/bw6-633/hash_to_g1_test.go b/ecc/bw6-633/hash_to_g1_test.go index eaa3c8188..8d299c729 100644 --- a/ecc/bw6-633/hash_to_g1_test.go +++ b/ecc/bw6-633/hash_to_g1_test.go @@ -62,7 +62,7 @@ func TestG1SqrtRatio(t *testing.T) { func TestHashToFpG1(t *testing.T) { for _, c := range encodeToG1Vector.cases { - elems, err := hashToFp([]byte(c.msg), encodeToG1Vector.dst, 1) + elems, err := fp.Hash([]byte(c.msg), encodeToG1Vector.dst, 1) if err != nil { t.Error(err) } @@ -70,7 +70,7 @@ func TestHashToFpG1(t *testing.T) { } for _, c := range hashToG1Vector.cases { - elems, err := hashToFp([]byte(c.msg), hashToG1Vector.dst, 2*1) + elems, err := fp.Hash([]byte(c.msg), hashToG1Vector.dst, 2*1) if err != nil { t.Error(err) } @@ -220,7 +220,7 @@ func BenchmarkHashToG1(b *testing.B) { } } -//TODO: Crude. Do something clever in Jacobian +// TODO: Crude. Do something clever in Jacobian func isOnE1Prime(p G1Affine) bool { var A, B fp.Element @@ -247,7 +247,7 @@ func isOnE1Prime(p G1Affine) bool { return LHS.Equal(&RHS) } -//Only works on simple extensions (two-story towers) +// Only works on simple extensions (two-story towers) func g1CoordSetString(z *fp.Element, s string) { z.SetString(s) } diff --git a/ecc/bw6-633/hash_to_g2.go b/ecc/bw6-633/hash_to_g2.go index 9b7872468..23c9809a6 100644 --- a/ecc/bw6-633/hash_to_g2.go +++ b/ecc/bw6-633/hash_to_g2.go @@ -237,8 +237,8 @@ func g2EvalPolynomial(z *fp.Element, monic bool, coefficients []fp.Element, x *f // The sign of an element is not obviously related to that of its Montgomery form func g2Sgn0(z *fp.Element) uint64 { - nonMont := *z - nonMont.FromMont() + nonMont := z.Bits() + // m == 1 return nonMont[0] % 2 @@ -256,11 +256,11 @@ func MapToG2(u fp.Element) G2Affine { // EncodeToG2 hashes a message to a point on the G2 curve using the SSWU map. // It is faster than HashToG2, but the result is not uniformly distributed. Unsuitable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func EncodeToG2(msg, dst []byte) (G2Affine, error) { var res G2Affine - u, err := hashToFp(msg, dst, 1) + u, err := fp.Hash(msg, dst, 1) if err != nil { return res, err } @@ -276,9 +276,9 @@ func EncodeToG2(msg, dst []byte) (G2Affine, error) { // HashToG2 hashes a message to a point on the G2 curve using the SSWU map. // Slower than EncodeToG2, but usable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func HashToG2(msg, dst []byte) (G2Affine, error) { - u, err := hashToFp(msg, dst, 2*1) + u, err := fp.Hash(msg, dst, 2*1) if err != nil { return G2Affine{}, err } diff --git a/ecc/bw6-633/hash_to_g2_test.go b/ecc/bw6-633/hash_to_g2_test.go index 7bc51d1b8..130f43a0a 100644 --- a/ecc/bw6-633/hash_to_g2_test.go +++ b/ecc/bw6-633/hash_to_g2_test.go @@ -62,7 +62,7 @@ func TestG2SqrtRatio(t *testing.T) { func TestHashToFpG2(t *testing.T) { for _, c := range encodeToG2Vector.cases { - elems, err := hashToFp([]byte(c.msg), encodeToG2Vector.dst, 1) + elems, err := fp.Hash([]byte(c.msg), encodeToG2Vector.dst, 1) if err != nil { t.Error(err) } @@ -70,7 +70,7 @@ func TestHashToFpG2(t *testing.T) { } for _, c := range hashToG2Vector.cases { - elems, err := hashToFp([]byte(c.msg), hashToG2Vector.dst, 2*1) + elems, err := fp.Hash([]byte(c.msg), hashToG2Vector.dst, 2*1) if err != nil { t.Error(err) } @@ -220,7 +220,7 @@ func BenchmarkHashToG2(b *testing.B) { } } -//TODO: Crude. Do something clever in Jacobian +// TODO: Crude. Do something clever in Jacobian func isOnE2Prime(p G2Affine) bool { var A, B fp.Element @@ -247,7 +247,7 @@ func isOnE2Prime(p G2Affine) bool { return LHS.Equal(&RHS) } -//Only works on simple extensions (two-story towers) +// Only works on simple extensions (two-story towers) func g2CoordSetString(z *fp.Element, s string) { z.SetString(s) } diff --git a/ecc/bw6-633/internal/fptower/e3.go b/ecc/bw6-633/internal/fptower/e3.go index 8e309811c..c46ce8b57 100644 --- a/ecc/bw6-633/internal/fptower/e3.go +++ b/ecc/bw6-633/internal/fptower/e3.go @@ -87,6 +87,10 @@ func (z *E3) IsZero() bool { return z.A0.IsZero() && z.A1.IsZero() && z.A2.IsZero() } +func (z *E3) IsOne() bool { + return z.A0.IsOne() && z.A1.IsZero() && z.A2.IsZero() +} + // Neg negates the E3 number func (z *E3) Neg(x *E3) *E3 { z.A0.Neg(&x.A0) @@ -95,22 +99,6 @@ func (z *E3) Neg(x *E3) *E3 { return z } -// ToMont converts to Mont form -func (z *E3) ToMont() *E3 { - z.A0.ToMont() - z.A1.ToMont() - z.A2.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E3) FromMont() *E3 { - z.A0.FromMont() - z.A1.FromMont() - z.A2.FromMont() - return z -} - // Add adds two elements of E3 func (z *E3) Add(x, y *E3) *E3 { z.A0.Add(&x.A0, &y.A0) diff --git a/ecc/bw6-633/internal/fptower/e6.go b/ecc/bw6-633/internal/fptower/e6.go index e6c171b93..0a5c22e97 100644 --- a/ecc/bw6-633/internal/fptower/e6.go +++ b/ecc/bw6-633/internal/fptower/e6.go @@ -68,20 +68,6 @@ func (z *E6) SetOne() *E6 { return z } -// ToMont converts to Mont form -func (z *E6) ToMont() *E6 { - z.B0.ToMont() - z.B1.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E6) FromMont() *E6 { - z.B0.FromMont() - z.B1.FromMont() - return z -} - // Add set z=x+y in E6 and return z func (z *E6) Add(x, y *E6) *E6 { z.B0.Add(&x.B0, &y.B0) @@ -119,6 +105,10 @@ func (z *E6) IsZero() bool { return z.B0.IsZero() && z.B1.IsZero() } +func (z *E6) IsOne() bool { + return z.B0.IsOne() && z.B1.IsZero() +} + // Mul set z=x*y in E6 and return z func (z *E6) Mul(x, y *E6) *E6 { var a, b, c E3 @@ -226,9 +216,12 @@ func (z *E6) CyclotomicSquareCompressed(x *E6) *E6 { // DecompressKarabina Karabina's cyclotomic square result // if g3 != 0 -// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// +// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// // if g3 == 0 -// g4 = 2g1g5/g2 +// +// g4 = 2g1g5/g2 // // if g3=g2=0 then g4=g5=g1=0 and g0=1 (x=1) // Theorem 3.1 is well-defined for all x in Gϕₙ\{1} @@ -253,7 +246,7 @@ func (z *E6) DecompressKarabina(x *E6) *E6 { t[1].Sub(&t[0], &x.B0.A2). Double(&t[1]). Add(&t[1], &t[0]) - // t0 = E * g5^2 + t1 + // t0 = E * g5^2 + t1 t[2].Square(&x.B1.A2) t[0].MulByNonResidue(&t[2]). Add(&t[0], &t[1]) @@ -289,9 +282,12 @@ func (z *E6) DecompressKarabina(x *E6) *E6 { // BatchDecompressKarabina multiple Karabina's cyclotomic square results // if g3 != 0 -// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// +// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// // if g3 == 0 -// g4 = 2g1g5/g2 +// +// g4 = 2g1g5/g2 // // if g3=g2=0 then g4=g5=g1=0 and g0=1 (x=1) // Theorem 3.1 is well-defined for all x in Gϕₙ\{1} @@ -327,7 +323,7 @@ func BatchDecompressKarabina(x []E6) []E6 { t1[i].Sub(&t0[i], &x[i].B0.A2). Double(&t1[i]). Add(&t1[i], &t0[i]) - // t0 = E * g5² + t1 + // t0 = E * g5² + t1 t2[i].Square(&x[i].B1.A2) t0[i].MulByNonResidue(&t2[i]). Add(&t0[i], &t1[i]) @@ -601,8 +597,8 @@ func (z *E6) ExpGLV(x E6, k *big.Int) *E6 { table[14].Mul(&table[11], &table[2]) // bounds on the lattice base vectors guarantee that s1, s2 are len(r)/2 bits long max - s1.SetBigInt(&s[0]).FromMont() - s2.SetBigInt(&s[1]).FromMont() + s1 = s1.SetBigInt(&s[0]).Bits() + s2 = s2.SetBigInt(&s[1]).Bits() // loop starts from len(s1)/2 due to the bounds for i := len(s1) / 2; i >= 0; i-- { diff --git a/ecc/bw6-633/marshal.go b/ecc/bw6-633/marshal.go index 0d162bcca..88d800df2 100644 --- a/ecc/bw6-633/marshal.go +++ b/ecc/bw6-633/marshal.go @@ -100,7 +100,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fr.Bytes]) + err = t.SetBytesCanonical(buf[:fr.Bytes]) return case *fp.Element: read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) @@ -108,7 +108,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fp.Bytes]) + err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: var sliceLen uint32 @@ -126,7 +126,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fr.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { + return + } } return case *[]fp.Element: @@ -145,7 +147,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fp.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return + } } return case *G1Affine: @@ -221,7 +225,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -276,7 +284,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -643,9 +655,6 @@ func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - msbMask := mCompressedSmallest // compressed, we need to know if Y is lexicographically bigger than -Y // if p.Y ">" -p.Y @@ -654,18 +663,7 @@ func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { } // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[72:80], tmp[0]) - binary.BigEndian.PutUint64(res[64:72], tmp[1]) - binary.BigEndian.PutUint64(res[56:64], tmp[2]) - binary.BigEndian.PutUint64(res[48:56], tmp[3]) - binary.BigEndian.PutUint64(res[40:48], tmp[4]) - binary.BigEndian.PutUint64(res[32:40], tmp[5]) - binary.BigEndian.PutUint64(res[24:32], tmp[6]) - binary.BigEndian.PutUint64(res[16:24], tmp[7]) - binary.BigEndian.PutUint64(res[8:16], tmp[8]) - binary.BigEndian.PutUint64(res[0:8], tmp[9]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= msbMask @@ -684,37 +682,12 @@ func (p *G1Affine) RawBytes() (res [SizeOfG1AffineUncompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - // not compressed // we store the Y coordinate - tmp = p.Y - tmp.FromMont() - binary.BigEndian.PutUint64(res[152:160], tmp[0]) - binary.BigEndian.PutUint64(res[144:152], tmp[1]) - binary.BigEndian.PutUint64(res[136:144], tmp[2]) - binary.BigEndian.PutUint64(res[128:136], tmp[3]) - binary.BigEndian.PutUint64(res[120:128], tmp[4]) - binary.BigEndian.PutUint64(res[112:120], tmp[5]) - binary.BigEndian.PutUint64(res[104:112], tmp[6]) - binary.BigEndian.PutUint64(res[96:104], tmp[7]) - binary.BigEndian.PutUint64(res[88:96], tmp[8]) - binary.BigEndian.PutUint64(res[80:88], tmp[9]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[80:80+fp.Bytes]), p.Y) // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[72:80], tmp[0]) - binary.BigEndian.PutUint64(res[64:72], tmp[1]) - binary.BigEndian.PutUint64(res[56:64], tmp[2]) - binary.BigEndian.PutUint64(res[48:56], tmp[3]) - binary.BigEndian.PutUint64(res[40:48], tmp[4]) - binary.BigEndian.PutUint64(res[32:40], tmp[5]) - binary.BigEndian.PutUint64(res[24:32], tmp[6]) - binary.BigEndian.PutUint64(res[16:24], tmp[7]) - binary.BigEndian.PutUint64(res[8:16], tmp[8]) - binary.BigEndian.PutUint64(res[0:8], tmp[9]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= mUncompressed @@ -765,8 +738,12 @@ func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { // uncompressed point if mData == mUncompressed { // read X and Y coordinates - p.X.SetBytes(buf[:fp.Bytes]) - p.Y.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.Y.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } // subgroup check if subGroupCheck && !p.IsInSubGroup() { @@ -786,7 +763,9 @@ func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return 0, err + } var YSquared, Y fp.Element @@ -860,7 +839,7 @@ func (p *G1Affine) unsafeComputeY(subGroupCheck bool) error { // assumes buf[:8] mask is set to compressed // returns true if point is infinity and need no further processing // it sets X coordinate and uses Y for scratch space to store decompression metadata -func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { +func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { // read the most significant byte mData := buf[0] & mMask @@ -869,7 +848,7 @@ func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { p.X.SetZero() p.Y.SetZero() isInfinity = true - return + return isInfinity, nil } // we need to copy the input buffer (to keep this method thread safe) @@ -878,12 +857,14 @@ func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return false, err + } // store mData in p.Y[0] p.Y[0] = uint64(mData) // recomputing Y will be done asynchronously - return + return isInfinity, nil } // SizeOfG2AffineCompressed represents the size in bytes that a G2Affine need in binary form, compressed @@ -921,9 +902,6 @@ func (p *G2Affine) Bytes() (res [SizeOfG2AffineCompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - msbMask := mCompressedSmallest // compressed, we need to know if Y is lexicographically bigger than -Y // if p.Y ">" -p.Y @@ -932,18 +910,7 @@ func (p *G2Affine) Bytes() (res [SizeOfG2AffineCompressed]byte) { } // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[72:80], tmp[0]) - binary.BigEndian.PutUint64(res[64:72], tmp[1]) - binary.BigEndian.PutUint64(res[56:64], tmp[2]) - binary.BigEndian.PutUint64(res[48:56], tmp[3]) - binary.BigEndian.PutUint64(res[40:48], tmp[4]) - binary.BigEndian.PutUint64(res[32:40], tmp[5]) - binary.BigEndian.PutUint64(res[24:32], tmp[6]) - binary.BigEndian.PutUint64(res[16:24], tmp[7]) - binary.BigEndian.PutUint64(res[8:16], tmp[8]) - binary.BigEndian.PutUint64(res[0:8], tmp[9]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= msbMask @@ -962,37 +929,12 @@ func (p *G2Affine) RawBytes() (res [SizeOfG2AffineUncompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - // not compressed // we store the Y coordinate - tmp = p.Y - tmp.FromMont() - binary.BigEndian.PutUint64(res[152:160], tmp[0]) - binary.BigEndian.PutUint64(res[144:152], tmp[1]) - binary.BigEndian.PutUint64(res[136:144], tmp[2]) - binary.BigEndian.PutUint64(res[128:136], tmp[3]) - binary.BigEndian.PutUint64(res[120:128], tmp[4]) - binary.BigEndian.PutUint64(res[112:120], tmp[5]) - binary.BigEndian.PutUint64(res[104:112], tmp[6]) - binary.BigEndian.PutUint64(res[96:104], tmp[7]) - binary.BigEndian.PutUint64(res[88:96], tmp[8]) - binary.BigEndian.PutUint64(res[80:88], tmp[9]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[80:80+fp.Bytes]), p.Y) // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[72:80], tmp[0]) - binary.BigEndian.PutUint64(res[64:72], tmp[1]) - binary.BigEndian.PutUint64(res[56:64], tmp[2]) - binary.BigEndian.PutUint64(res[48:56], tmp[3]) - binary.BigEndian.PutUint64(res[40:48], tmp[4]) - binary.BigEndian.PutUint64(res[32:40], tmp[5]) - binary.BigEndian.PutUint64(res[24:32], tmp[6]) - binary.BigEndian.PutUint64(res[16:24], tmp[7]) - binary.BigEndian.PutUint64(res[8:16], tmp[8]) - binary.BigEndian.PutUint64(res[0:8], tmp[9]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= mUncompressed @@ -1043,8 +985,12 @@ func (p *G2Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { // uncompressed point if mData == mUncompressed { // read X and Y coordinates - p.X.SetBytes(buf[:fp.Bytes]) - p.Y.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.Y.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } // subgroup check if subGroupCheck && !p.IsInSubGroup() { @@ -1064,7 +1010,9 @@ func (p *G2Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return 0, err + } var YSquared, Y fp.Element @@ -1138,7 +1086,7 @@ func (p *G2Affine) unsafeComputeY(subGroupCheck bool) error { // assumes buf[:8] mask is set to compressed // returns true if point is infinity and need no further processing // it sets X coordinate and uses Y for scratch space to store decompression metadata -func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { +func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { // read the most significant byte mData := buf[0] & mMask @@ -1147,7 +1095,7 @@ func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { p.X.SetZero() p.Y.SetZero() isInfinity = true - return + return isInfinity, nil } // we need to copy the input buffer (to keep this method thread safe) @@ -1156,10 +1104,12 @@ func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return false, err + } // store mData in p.Y[0] p.Y[0] = uint64(mData) // recomputing Y will be done asynchronously - return + return isInfinity, nil } diff --git a/ecc/bw6-633/multiexp.go b/ecc/bw6-633/multiexp.go index 200667493..93e279ac6 100644 --- a/ecc/bw6-633/multiexp.go +++ b/ecc/bw6-633/multiexp.go @@ -25,143 +25,6 @@ import ( "runtime" ) -// selector stores the index, mask and shifts needed to select bits from a scalar -// it is used during the multiExp algorithm or the batch scalar multiplication -type selector struct { - index uint64 // index in the multi-word scalar to select bits from - mask uint64 // mask (c-bit wide) - shift uint64 // shift needed to get our bits on low positions - - multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) - maskHigh uint64 // same than mask, for index+1 - shiftHigh uint64 // same than shift, for index+1 -} - -// partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits -// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract -// 2^{c} to the current digit, making it negative. -// negative digits can be processed in a later step as adding -G into the bucket instead of G -// (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) -// scalarsMont indicates wheter the provided scalars are in montgomery form -// returns smallValues, which represent the number of scalars which meets the following condition -// 0 < scalar < 2^c (in other words, scalars where only the c-least significant bits are non zero) -func partitionScalars(scalars []fr.Element, c uint64, scalarsMont bool, nbTasks int) ([]fr.Element, int) { - toReturn := make([]fr.Element, len(scalars)) - - // number of c-bit radixes in a scalar - nbChunks := fr.Limbs * 64 / c - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // msb of the c-bit window - max := int(1 << (c - 1)) // max value we want for our digits - cDivides64 := (64 % c) == 0 // if c doesn't divide 64, we may need to select over multiple words - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := uint64(0); chunk < nbChunks; chunk++ { - jc := uint64(chunk * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } - - // for each chunk, we could track the number of non-zeros points we will need to process - // this way, if a chunk has more work to do than others, we can spawn off more go routines - // (at the cost of more buckets allocated) - // a simplified approach is to track the small values where only the first word is set - // if this number represent a significant number of points, then we will split first chunk - // processing in the msm in 2, to ensure all go routines finish at ~same time - // /!\ nbTasks is enough as parallel.Execute is not going to spawn more than nbTasks go routine - // if it does, though, this will deadlocK. - chSmallValues := make(chan int, nbTasks) - - parallel.Execute(len(scalars), func(start, end int) { - smallValues := 0 - for i := start; i < end; i++ { - var carry int - - scalar := scalars[i] - if scalarsMont { - scalar.FromMont() - } - if scalar.FitsOnOneWord() { - // everything is 0, no need to process this scalar - if scalar[0] == 0 { - continue - } - // low c-bits are 1 in mask - if scalar[0]&mask == scalar[0] { - smallValues++ - } - } - - // for each chunk in the scalar, compute the current digit, and an eventual carry - for chunk := uint64(0); chunk < nbChunks; chunk++ { - s := selectors[chunk] - - // init with carry if any - digit := carry - carry = 0 - - // digit = value of the c-bit window - digit += int((scalar[s.index] & s.mask) >> s.shift) - - if s.multiWordSelect { - // we are selecting bits over 2 words - digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh - } - - // if digit is zero, no impact on result - if digit == 0 { - continue - } - - // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract - // 2^{c} to the current digit, making it negative. - if digit >= max { - digit -= (1 << c) - carry = 1 - } - - var bits uint64 - if digit >= 0 { - bits = uint64(digit) - } else { - bits = uint64(-digit-1) | msbWindow - } - - toReturn[i][s.index] |= (bits << s.shift) - if s.multiWordSelect { - toReturn[i][s.index+1] |= (bits >> s.shiftHigh) - } - - } - } - - chSmallValues <- smallValues - - }, nbTasks) - - // aggregate small values - close(chSmallValues) - smallValues := 0 - for o := range chSmallValues { - smallValues += o - } - return toReturn, smallValues -} - // MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf // // This call return an error if len(scalars) != len(points) or if provided config is invalid. @@ -221,7 +84,7 @@ func (p *G1Jac) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.Mul // we split recursively until nbChunks(c) >= nbTasks, bestC := func(nbPoints int) uint64 { // implemented msmC methods (the c we use must be in this slice) - implementedCs := []uint64{4, 5, 8, 16} + implementedCs := []uint64{4, 5, 6, 8, 12, 16} var C uint64 // approximate cost (in group operations) // cost = bits/c * (nbPoints + 2^{c}) @@ -229,85 +92,126 @@ func (p *G1Jac) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.Mul // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results min := math.MaxFloat64 for _, c := range implementedCs { - cc := fr.Limbs * 64 * (nbPoints + (1 << (c))) + cc := (fr.Bits + 1) * (nbPoints + (1 << c)) cost := float64(cc) / float64(c) if cost < min { min = cost C = c } } - // empirical, needs to be tuned. - // if C > 16 && nbPoints < 1 << 23 { - // C = 16 - // } return C } - var C uint64 - nbSplits := 1 - nbChunks := 0 - for nbChunks < config.NbTasks { - C = bestC(nbPoints) - nbChunks = int(fr.Limbs * 64 / C) // number of c-bit radixes in a scalar - if (fr.Limbs*64)%C != 0 { - nbChunks++ - } - nbChunks *= nbSplits - if nbChunks < config.NbTasks { - nbSplits <<= 1 - nbPoints >>= 1 + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // if we don't utilise all the tasks (CPU in the default case) that we could, let's see if it's worth it to split + if config.NbTasks > 1 && nbChunks < config.NbTasks { + // before spliting, let's see if we endup with more tasks than thread; + cSplit := bestC(nbPoints / 2) + nbChunksPostSplit := int(computeNbChunks(cSplit)) + nbTasksPostSplit := nbChunksPostSplit * 2 + if (nbTasksPostSplit <= config.NbTasks/2) || (nbTasksPostSplit-config.NbTasks/2) <= (config.NbTasks-nbChunks) { + // if postSplit we still have less tasks than available CPU + // or if we have more tasks BUT the difference of CPU usage is in our favor, we split. + config.NbTasks /= 2 + var _p G1Jac + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil } } + _innerMsmG1(p, C, points, scalars, config) + + return p, nil +} + +func _innerMsmG1(p *G1Jac, c uint64, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { // partition the scalars - // note: we do that before the actual chunk processing, as for each c-bit window (starting from LSW) - // if it's larger than 2^{c-1}, we have a carry we need to propagate up to the higher window - var smallValues int - scalars, smallValues = partitionScalars(scalars, C, config.ScalarsMont, config.NbTasks) - - // if we have more than 10% of small values, we split the processing of the first chunk in 2 - // we may want to do that in msmInnerG1Jac , but that would incur a cost of looping through all scalars one more time - splitFirstChunk := (float64(smallValues) / float64(len(scalars))) >= 0.1 - - // we have nbSplits intermediate results that we must sum together. - _p := make([]G1Jac, nbSplits-1) - chDone := make(chan int, nbSplits-1) - for i := 0; i < nbSplits-1; i++ { - start := i * nbPoints - end := start + nbPoints - go func(start, end, i int) { - msmInnerG1Jac(&_p[i], int(C), points[start:end], scalars[start:end], splitFirstChunk) - chDone <- i - }(start, end, i) + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) } - msmInnerG1Jac(p, int(C), points[(nbSplits-1)*nbPoints:], scalars[(nbSplits-1)*nbPoints:], splitFirstChunk) - for i := 0; i < nbSplits-1; i++ { - done := <-chDone - p.AddAssign(&_p[done]) + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessorG1(c, chunkStats[j]) + if j == int(nbChunks-1) { + processChunk = getChunkProcessorG1(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan g1JacExtended, 2) + split := n / 2 + go processChunk(uint64(j), chSplit, c, points[:split], digits[j*n:(j*n)+split]) + go processChunk(uint64(j), chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n]) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n]) } - close(chDone) - return p, nil -} -func msmInnerG1Jac(p *G1Jac, c int, points []G1Affine, scalars []fr.Element, splitFirstChunk bool) { + return msmReduceChunkG1Affine(p, int(c), chChunks[:]) +} +// getChunkProcessorG1 decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessorG1(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- g1JacExtended, c uint64, points []G1Affine, digits []uint16) { switch c { case 4: - p.msmC4(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC4] case 5: - p.msmC5(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC5] + case 6: + return processChunkG1Jacobian[bucketg1JacExtendedC6] case 8: - p.msmC8(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC8] + case 12: + const batchSize = 200 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC12] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC12, bucketG1AffineC12, bitSetC12, pG1AffineC12, ppG1AffineC12, qG1AffineC12, cG1AffineC12] case 16: - p.msmC16(points, scalars, splitFirstChunk) - + const batchSize = 640 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC16] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC16, bucketG1AffineC16, bitSetC16, pG1AffineC16, ppG1AffineC16, qG1AffineC16, cG1AffineC16] default: - panic("not implemented") + // panic("will not happen c != previous values is not generated by templates") + return processChunkG1Jacobian[bucketg1JacExtendedC16] } } @@ -327,250 +231,6 @@ func msmReduceChunkG1Affine(p *G1Jac, c int, chChunks []chan g1JacExtended) *G1J return p.unsafeFromJacExtended(&_p) } -func msmProcessChunkG1Affine(chunk uint64, - chRes chan<- g1JacExtended, - buckets []g1JacExtended, - c uint64, - points []G1Affine, - scalars []fr.Element) { - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) - - for i := 0; i < len(buckets); i++ { - buckets[i].setInfinity() - } - - jc := uint64(chunk * c) - s := selector{} - s.index = jc / 64 - s.shift = jc - (s.index * 64) - s.mask = mask << s.shift - s.multiWordSelect = (64%c) != 0 && s.shift > (64-c) && s.index < (fr.Limbs-1) - if s.multiWordSelect { - nbBitsHigh := s.shift - uint64(64-c) - s.maskHigh = (1 << nbBitsHigh) - 1 - s.shiftHigh = (c - nbBitsHigh) - } - - // for each scalars, get the digit corresponding to the chunk we're processing. - for i := 0; i < len(scalars); i++ { - bits := (scalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (scalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { - continue - } - - // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { - // add - buckets[bits-1].addMixed(&points[i]) - } else { - // sub - buckets[bits & ^msbWindow].subMixed(&points[i]) - } - } - - // reduce buckets into total - // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] - - var runningSum, total g1JacExtended - runningSum.setInfinity() - total.setInfinity() - for k := len(buckets) - 1; k >= 0; k-- { - if !buckets[k].ZZ.IsZero() { - runningSum.add(&buckets[k]) - } - total.add(&runningSum) - } - - chRes <- total - -} - -func (p *G1Jac) msmC4(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 4 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC5(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 5 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC8(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 8 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC16(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 16 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - // MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf // // This call return an error if len(scalars) != len(points) or if provided config is invalid. @@ -630,7 +290,7 @@ func (p *G2Jac) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.Mul // we split recursively until nbChunks(c) >= nbTasks, bestC := func(nbPoints int) uint64 { // implemented msmC methods (the c we use must be in this slice) - implementedCs := []uint64{4, 5, 8, 16} + implementedCs := []uint64{4, 5, 6, 8, 12, 16} var C uint64 // approximate cost (in group operations) // cost = bits/c * (nbPoints + 2^{c}) @@ -638,85 +298,126 @@ func (p *G2Jac) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.Mul // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results min := math.MaxFloat64 for _, c := range implementedCs { - cc := fr.Limbs * 64 * (nbPoints + (1 << (c))) + cc := (fr.Bits + 1) * (nbPoints + (1 << c)) cost := float64(cc) / float64(c) if cost < min { min = cost C = c } } - // empirical, needs to be tuned. - // if C > 16 && nbPoints < 1 << 23 { - // C = 16 - // } return C } - var C uint64 - nbSplits := 1 - nbChunks := 0 - for nbChunks < config.NbTasks { - C = bestC(nbPoints) - nbChunks = int(fr.Limbs * 64 / C) // number of c-bit radixes in a scalar - if (fr.Limbs*64)%C != 0 { - nbChunks++ - } - nbChunks *= nbSplits - if nbChunks < config.NbTasks { - nbSplits <<= 1 - nbPoints >>= 1 + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // if we don't utilise all the tasks (CPU in the default case) that we could, let's see if it's worth it to split + if config.NbTasks > 1 && nbChunks < config.NbTasks { + // before spliting, let's see if we endup with more tasks than thread; + cSplit := bestC(nbPoints / 2) + nbChunksPostSplit := int(computeNbChunks(cSplit)) + nbTasksPostSplit := nbChunksPostSplit * 2 + if (nbTasksPostSplit <= config.NbTasks/2) || (nbTasksPostSplit-config.NbTasks/2) <= (config.NbTasks-nbChunks) { + // if postSplit we still have less tasks than available CPU + // or if we have more tasks BUT the difference of CPU usage is in our favor, we split. + config.NbTasks /= 2 + var _p G2Jac + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil } } + _innerMsmG2(p, C, points, scalars, config) + + return p, nil +} + +func _innerMsmG2(p *G2Jac, c uint64, points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G2Jac { // partition the scalars - // note: we do that before the actual chunk processing, as for each c-bit window (starting from LSW) - // if it's larger than 2^{c-1}, we have a carry we need to propagate up to the higher window - var smallValues int - scalars, smallValues = partitionScalars(scalars, C, config.ScalarsMont, config.NbTasks) - - // if we have more than 10% of small values, we split the processing of the first chunk in 2 - // we may want to do that in msmInnerG2Jac , but that would incur a cost of looping through all scalars one more time - splitFirstChunk := (float64(smallValues) / float64(len(scalars))) >= 0.1 - - // we have nbSplits intermediate results that we must sum together. - _p := make([]G2Jac, nbSplits-1) - chDone := make(chan int, nbSplits-1) - for i := 0; i < nbSplits-1; i++ { - start := i * nbPoints - end := start + nbPoints - go func(start, end, i int) { - msmInnerG2Jac(&_p[i], int(C), points[start:end], scalars[start:end], splitFirstChunk) - chDone <- i - }(start, end, i) + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g2JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g2JacExtended, 1) } - msmInnerG2Jac(p, int(C), points[(nbSplits-1)*nbPoints:], scalars[(nbSplits-1)*nbPoints:], splitFirstChunk) - for i := 0; i < nbSplits-1; i++ { - done := <-chDone - p.AddAssign(&_p[done]) + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessorG2(c, chunkStats[j]) + if j == int(nbChunks-1) { + processChunk = getChunkProcessorG2(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan g2JacExtended, 2) + split := n / 2 + go processChunk(uint64(j), chSplit, c, points[:split], digits[j*n:(j*n)+split]) + go processChunk(uint64(j), chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n]) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n]) } - close(chDone) - return p, nil -} -func msmInnerG2Jac(p *G2Jac, c int, points []G2Affine, scalars []fr.Element, splitFirstChunk bool) { + return msmReduceChunkG2Affine(p, int(c), chChunks[:]) +} +// getChunkProcessorG2 decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessorG2(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- g2JacExtended, c uint64, points []G2Affine, digits []uint16) { switch c { case 4: - p.msmC4(points, scalars, splitFirstChunk) - + return processChunkG2Jacobian[bucketg2JacExtendedC4] case 5: - p.msmC5(points, scalars, splitFirstChunk) - + return processChunkG2Jacobian[bucketg2JacExtendedC5] + case 6: + return processChunkG2Jacobian[bucketg2JacExtendedC6] case 8: - p.msmC8(points, scalars, splitFirstChunk) - + return processChunkG2Jacobian[bucketg2JacExtendedC8] + case 12: + const batchSize = 200 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC12] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC12, bucketG2AffineC12, bitSetC12, pG2AffineC12, ppG2AffineC12, qG2AffineC12, cG2AffineC12] case 16: - p.msmC16(points, scalars, splitFirstChunk) - + const batchSize = 640 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC16] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC16, bucketG2AffineC16, bitSetC16, pG2AffineC16, ppG2AffineC16, qG2AffineC16, cG2AffineC16] default: - panic("not implemented") + // panic("will not happen c != previous values is not generated by templates") + return processChunkG2Jacobian[bucketg2JacExtendedC16] } } @@ -736,246 +437,188 @@ func msmReduceChunkG2Affine(p *G2Jac, c int, chChunks []chan g2JacExtended) *G2J return p.unsafeFromJacExtended(&_p) } -func msmProcessChunkG2Affine(chunk uint64, - chRes chan<- g2JacExtended, - buckets []g2JacExtended, - c uint64, - points []G2Affine, - scalars []fr.Element) { - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) - - for i := 0; i < len(buckets); i++ { - buckets[i].setInfinity() - } - - jc := uint64(chunk * c) - s := selector{} - s.index = jc / 64 - s.shift = jc - (s.index * 64) - s.mask = mask << s.shift - s.multiWordSelect = (64%c) != 0 && s.shift > (64-c) && s.index < (fr.Limbs-1) - if s.multiWordSelect { - nbBitsHigh := s.shift - uint64(64-c) - s.maskHigh = (1 << nbBitsHigh) - 1 - s.shiftHigh = (c - nbBitsHigh) - } - - // for each scalars, get the digit corresponding to the chunk we're processing. - for i := 0; i < len(scalars); i++ { - bits := (scalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (scalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { - continue - } - - // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { - // add - buckets[bits-1].addMixed(&points[i]) - } else { - // sub - buckets[bits & ^msbWindow].subMixed(&points[i]) - } - } - - // reduce buckets into total - // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] - - var runningSum, total g2JacExtended - runningSum.setInfinity() - total.setInfinity() - for k := len(buckets) - 1; k >= 0; k-- { - if !buckets[k].ZZ.IsZero() { - runningSum.add(&buckets[k]) - } - total.add(&runningSum) - } - - chRes <- total +// selector stores the index, mask and shifts needed to select bits from a scalar +// it is used during the multiExp algorithm or the batch scalar multiplication +type selector struct { + index uint64 // index in the multi-word scalar to select bits from + mask uint64 // mask (c-bit wide) + shift uint64 // shift needed to get our bits on low positions + multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) + maskHigh uint64 // same than mask, for index+1 + shiftHigh uint64 // same than shift, for index+1 } -func (p *G2Jac) msmC4(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 4 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } +// return number of chunks for a given window size c +// the last chunk may be bigger to accomodate a potential carry from the NAF decomposition +func computeNbChunks(c uint64) uint64 { + return (fr.Bits + c - 1) / c +} - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } +// return the last window size for a scalar; +// this last window should accomodate a carry (from the NAF decomposition) +// it can be == c if we have 1 available bit +// it can be > c if we have 0 available bit +// it can be < c if we have 2+ available bits +func lastC(c uint64) uint64 { + nbAvailableBits := (computeNbChunks(c) * c) - fr.Bits + return c + 1 - nbAvailableBits +} - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } +type chunkStat struct { + // relative weight of work compared to other chunks. 100.0 -> nominal weight. + weight float32 - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG2Affine(p, c, chChunks[:]) + // percentage of bucket filled in the window; + ppBucketFilled float32 + nbBucketFilled int } -func (p *G2Jac) msmC5(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 5 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) +// partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits +// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract +// 2^{c} to the current digit, making it negative. +// negative digits can be processed in a later step as adding -G into the bucket instead of G +// (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) +func partitionScalars(scalars []fr.Element, c uint64, nbTasks int) ([]uint16, []chunkStat) { + // number of c-bit radixes in a scalar + nbChunks := computeNbChunks(c) - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + digits := make([]uint16, len(scalars)*int(nbChunks)) - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + mask := uint64((1 << c) - 1) // low c bits are 1 + max := int(1<<(c-1)) - 1 // max value (inclusive) we want for our digits + cDivides64 := (64 % c) == 0 // if c doesn't divide 64, we may need to select over multiple words - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) + // compute offset and word selector / shift to select the right bits of our windows + selectors := make([]selector, nbChunks) + for chunk := uint64(0); chunk < nbChunks; chunk++ { + jc := uint64(chunk * c) + d := selector{} + d.index = jc / 64 + d.shift = jc - (d.index * 64) + d.mask = mask << d.shift + d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs-1) + if d.multiWordSelect { + nbBitsHigh := d.shift - uint64(64-c) + d.maskHigh = (1 << nbBitsHigh) - 1 + d.shiftHigh = (c - nbBitsHigh) + } + selectors[chunk] = d } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } + parallel.Execute(len(scalars), func(start, end int) { + for i := start; i < end; i++ { + if scalars[i].IsZero() { + // everything is 0, no need to process this scalar + continue + } + scalar := scalars[i].Bits() - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} + var carry int -func (p *G2Jac) msmC8(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 8 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) + // for each chunk in the scalar, compute the current digit, and an eventual carry + for chunk := uint64(0); chunk < nbChunks-1; chunk++ { + s := selectors[chunk] - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + // init with carry if any + digit := carry + carry = 0 - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } + // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract + // 2^{c} to the current digit, making it negative. + if digit > max { + digit -= (1 << c) + carry = 1 + } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } + // if digit is zero, no impact on result + if digit == 0 { + continue + } - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} + var bits uint16 + if digit > 0 { + bits = uint16(digit) << 1 + } else { + bits = (uint16(-digit-1) << 1) + 1 + } + digits[int(chunk)*len(scalars)+i] = bits + } -func (p *G2Jac) msmC16(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 16 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) + // for the last chunk, we don't want to borrow from a next window + // (but may have a larger max value) + chunk := nbChunks - 1 + s := selectors[chunk] + // init with carry if any + digit := carry + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } + digits[int(chunk)*len(scalars)+i] = uint16(digit) << 1 + } - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + }, nbTasks) - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) + // aggregate chunk stats + chunkStats := make([]chunkStat, nbChunks) + if c <= 9 { + // no need to compute stats for small window sizes + return digits, chunkStats } + parallel.Execute(len(chunkStats), func(start, end int) { + // for each chunk compute the statistics + for chunkID := start; chunkID < end; chunkID++ { + // indicates if a bucket is hit. + var b bitSetC16 - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + // digits for the chunk + chunkDigits := digits[chunkID*len(scalars) : (chunkID+1)*len(scalars)] - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) + totalOps := 0 + nz := 0 // non zero buckets count + for _, digit := range chunkDigits { + if digit == 0 { + continue + } + totalOps++ + bucketID := digit >> 1 + if digit&1 == 0 { + bucketID -= 1 + } + if !b[bucketID] { + nz++ + b[bucketID] = true + } + } + chunkStats[chunkID].weight = float32(totalOps) // count number of ops for now, we will compute the weight after + chunkStats[chunkID].ppBucketFilled = (float32(nz) * 100.0) / float32(int(1<<(c-1))) + chunkStats[chunkID].nbBucketFilled = nz + } + }, nbTasks) + + totalOps := float32(0.0) + for _, stat := range chunkStats { + totalOps += stat.weight } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + target := totalOps / float32(nbChunks) + if target != 0.0 { + // if target == 0, it means all the scalars are 0 everywhere, there is no work to be done. + for i := 0; i < len(chunkStats); i++ { + chunkStats[i].weight = (chunkStats[i].weight * 100.0) / target + } } - return msmReduceChunkG2Affine(p, c, chChunks[:]) + return digits, chunkStats } diff --git a/ecc/bw6-633/multiexp_affine.go b/ecc/bw6-633/multiexp_affine.go new file mode 100644 index 000000000..ae821f8bb --- /dev/null +++ b/ecc/bw6-633/multiexp_affine.go @@ -0,0 +1,549 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package bw6633 + +import ( + "github.com/consensys/gnark-crypto/ecc/bw6-633/fp" +) + +type batchOpG1Affine struct { + bucketID uint16 + point G1Affine +} + +// processChunkG1BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/ConsenSys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunkG1BatchAffine[BJE ibg1JacExtended, B ibG1Affine, BS bitSet, TP pG1Affine, TPP ppG1Affine, TQ qOpsG1Affine, TC cG1Affine]( + chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16) { + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // g1JacExtended coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediatly + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in G1Affine used with the batch affine additions + // 1 in g1JacExtended used in case the queue of conflicting points + var buckets B + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + bucketsJE[i].setInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize } + + executeAndReset := func() { + batchAddG1Affine[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOpG1Affine) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.setInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *G1Affine, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.setInfinity() + } + return + } + if isAdd { + BK.setInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func() { + for i := 0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func() { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit >> 1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID -= 1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue)-1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total g1JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].ZZ.IsZero() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketG1AffineC12 [2048]G1Affine +type bucketG1AffineC16 [32768]G1Affine + +// buckets: array of G1Affine points of size 1 << (c-1) +type ibG1Affine interface { + bucketG1AffineC12 | + bucketG1AffineC16 +} + +// array of coordinates fp.Element +type cG1Affine interface { + cG1AffineC12 | + cG1AffineC16 +} + +// buckets: array of G1Affine points (for the batch addition) +type pG1Affine interface { + pG1AffineC12 | + pG1AffineC16 +} + +// buckets: array of *G1Affine points (for the batch addition) +type ppG1Affine interface { + ppG1AffineC12 | + ppG1AffineC16 +} + +// buckets: array of G1Affine queue operations (for the batch addition) +type qOpsG1Affine interface { + qG1AffineC12 | + qG1AffineC16 +} + +// batch size 200 when c = 12 +type cG1AffineC12 [200]fp.Element +type pG1AffineC12 [200]G1Affine +type ppG1AffineC12 [200]*G1Affine +type qG1AffineC12 [200]batchOpG1Affine + +// batch size 640 when c = 16 +type cG1AffineC16 [640]fp.Element +type pG1AffineC16 [640]G1Affine +type ppG1AffineC16 [640]*G1Affine +type qG1AffineC16 [640]batchOpG1Affine + +type batchOpG2Affine struct { + bucketID uint16 + point G2Affine +} + +// processChunkG2BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/ConsenSys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunkG2BatchAffine[BJE ibg2JacExtended, B ibG2Affine, BS bitSet, TP pG2Affine, TPP ppG2Affine, TQ qOpsG2Affine, TC cG2Affine]( + chunk uint64, + chRes chan<- g2JacExtended, + c uint64, + points []G2Affine, + digits []uint16) { + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // g2JacExtended coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediatly + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in G2Affine used with the batch affine additions + // 1 in g2JacExtended used in case the queue of conflicting points + var buckets B + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + bucketsJE[i].setInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize } + + executeAndReset := func() { + batchAddG2Affine[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOpG2Affine) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.setInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *G2Affine, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.setInfinity() + } + return + } + if isAdd { + BK.setInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func() { + for i := 0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func() { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit >> 1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID -= 1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue)-1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total g2JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].ZZ.IsZero() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketG2AffineC12 [2048]G2Affine +type bucketG2AffineC16 [32768]G2Affine + +// buckets: array of G2Affine points of size 1 << (c-1) +type ibG2Affine interface { + bucketG2AffineC12 | + bucketG2AffineC16 +} + +// array of coordinates fp.Element +type cG2Affine interface { + cG2AffineC12 | + cG2AffineC16 +} + +// buckets: array of G2Affine points (for the batch addition) +type pG2Affine interface { + pG2AffineC12 | + pG2AffineC16 +} + +// buckets: array of *G2Affine points (for the batch addition) +type ppG2Affine interface { + ppG2AffineC12 | + ppG2AffineC16 +} + +// buckets: array of G2Affine queue operations (for the batch addition) +type qOpsG2Affine interface { + qG2AffineC12 | + qG2AffineC16 +} + +// batch size 200 when c = 12 +type cG2AffineC12 [200]fp.Element +type pG2AffineC12 [200]G2Affine +type ppG2AffineC12 [200]*G2Affine +type qG2AffineC12 [200]batchOpG2Affine + +// batch size 640 when c = 16 +type cG2AffineC16 [640]fp.Element +type pG2AffineC16 [640]G2Affine +type ppG2AffineC16 [640]*G2Affine +type qG2AffineC16 [640]batchOpG2Affine + +type bitSetC4 [8]bool +type bitSetC5 [16]bool +type bitSetC6 [32]bool +type bitSetC8 [128]bool +type bitSetC12 [2048]bool +type bitSetC16 [32768]bool + +type bitSet interface { + bitSetC4 | + bitSetC5 | + bitSetC6 | + bitSetC8 | + bitSetC12 | + bitSetC16 +} diff --git a/ecc/bw6-633/multiexp_jacobian.go b/ecc/bw6-633/multiexp_jacobian.go new file mode 100644 index 000000000..2d992da14 --- /dev/null +++ b/ecc/bw6-633/multiexp_jacobian.go @@ -0,0 +1,139 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package bw6633 + +func processChunkG1Jacobian[B ibg1JacExtended](chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16) { + + var buckets B + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to substract + if digit&1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit >> 1)].subMixed(&points[i]) + } + } + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total g1JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].ZZ.IsZero() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketg1JacExtendedC4 [8]g1JacExtended +type bucketg1JacExtendedC5 [16]g1JacExtended +type bucketg1JacExtendedC6 [32]g1JacExtended +type bucketg1JacExtendedC8 [128]g1JacExtended +type bucketg1JacExtendedC12 [2048]g1JacExtended +type bucketg1JacExtendedC16 [32768]g1JacExtended + +type ibg1JacExtended interface { + bucketg1JacExtendedC4 | + bucketg1JacExtendedC5 | + bucketg1JacExtendedC6 | + bucketg1JacExtendedC8 | + bucketg1JacExtendedC12 | + bucketg1JacExtendedC16 +} + +func processChunkG2Jacobian[B ibg2JacExtended](chunk uint64, + chRes chan<- g2JacExtended, + c uint64, + points []G2Affine, + digits []uint16) { + + var buckets B + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to substract + if digit&1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit >> 1)].subMixed(&points[i]) + } + } + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total g2JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].ZZ.IsZero() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketg2JacExtendedC4 [8]g2JacExtended +type bucketg2JacExtendedC5 [16]g2JacExtended +type bucketg2JacExtendedC6 [32]g2JacExtended +type bucketg2JacExtendedC8 [128]g2JacExtended +type bucketg2JacExtendedC12 [2048]g2JacExtended +type bucketg2JacExtendedC16 [32768]g2JacExtended + +type ibg2JacExtended interface { + bucketg2JacExtendedC4 | + bucketg2JacExtendedC5 | + bucketg2JacExtendedC6 | + bucketg2JacExtendedC8 | + bucketg2JacExtendedC12 | + bucketg2JacExtendedC16 +} diff --git a/ecc/bw6-633/multiexp_test.go b/ecc/bw6-633/multiexp_test.go index 725042a72..6014b1b10 100644 --- a/ecc/bw6-633/multiexp_test.go +++ b/ecc/bw6-633/multiexp_test.go @@ -20,9 +20,11 @@ import ( "fmt" "math/big" "math/bits" + "math/rand" "runtime" "sync" "testing" + "time" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-633/fr" @@ -34,9 +36,9 @@ func TestMultiExpG1(t *testing.T) { parameters := gopter.DefaultTestParameters() if testing.Short() { - parameters.MinSuccessfulTests = 2 + parameters.MinSuccessfulTests = 3 } else { - parameters.MinSuccessfulTests = nbFuzzShort + parameters.MinSuccessfulTests = nbFuzzShort * 2 } properties := gopter.NewProperties(parameters) @@ -55,6 +57,13 @@ func TestMultiExpG1(t *testing.T) { g.AddAssign(&g1Gen) } + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + // final scalar to use in double and add method (without mixer factor) // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) var scalar big.Int @@ -78,13 +87,10 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - scalars16, _ := partitionScalars(sampleScalars[:], 16, false, runtime.NumCPU()) - r16.msmC16(samplePoints[:], scalars16, true) - + r16.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) return r16.Equal(&splitted1) && r16.Equal(&splitted2) @@ -93,7 +99,7 @@ func TestMultiExpG1(t *testing.T) { )) // cRange is generated from template and contains the available parameters for the multiexp window size - cRange := []uint64{4, 5, 8, 16} + cRange := []uint64{4, 5, 6, 8, 12, 16} if testing.Short() { // test only "odd" and "even" (ie windows size divide word size vs not) cRange = []uint64{5, 16} @@ -114,21 +120,72 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - results := make([]G1Jac, len(cRange)+1) + results := make([]G1Jac, len(cRange)) for i, c := range cRange { - scalars, _ := partitionScalars(sampleScalars[:], c, false, runtime.NumCPU()) - msmInnerG1Jac(&results[i], int(c), samplePoints[:], scalars, false) - if c == 16 { - // split the first chunk - msmInnerG1Jac(&results[len(results)-1], 16, samplePoints[:], scalars, true) - } + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) } for i := 1; i < len(results); i++ { if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1], cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]G1Affine + + var expected G1Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) + expected.ScalarMultiplication(&g1Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].setInfinity() + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) return false } } @@ -151,8 +208,7 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= 30; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) samplePoints[i-1].FromJacobian(&g) g.AddAssign(&g1Gen) } @@ -175,6 +231,87 @@ func TestMultiExpG1(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestCrossMultiExpG1(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]G1Affine + var g G1Jac + g.Set(&g1Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g1Gen) + } + + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i := 10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + cRange := []uint64{4, 5, 6, 8, 12, 16} + if testing.Short() { + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange = []uint64{5, 16} + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r G1Jac + _innerMsmG1Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got G1Affine + expected.FromJacobian(&r) + + for i := 0; i < len(results); i++ { + got.FromJacobian(&results[i]) + if !expected.Equal(&got) { + t.Fatalf("cross msm failed with c=%d", cRange[i]) + } + } + +} + +// _innerMsmG1Reference always do ext jacobian with c == 16 +func _innerMsmG1Reference(p *G1Jac, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { + // partition the scalars + digits, _ := partitionScalars(scalars, 16, config.NbTasks) + + nbChunks := computeNbChunks(16) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := processChunkG1Jacobian[bucketg1JacExtendedC16] + go processChunk(uint64(j), chChunks[j], 16, points, digits[j*n:(j+1)*n]) + } + + return msmReduceChunkG1Affine(p, int(16), chChunks[:]) +} + func BenchmarkMultiExpG1(b *testing.B) { const ( @@ -183,11 +320,33 @@ func BenchmarkMultiExpG1(b *testing.B) { ) var ( - samplePoints [nbSamples]G1Affine - sampleScalars [nbSamples]fr.Element + samplePoints [nbSamples]G1Affine + sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element ) fillBenchScalars(sampleScalars[:]) + copy(sampleScalarsSmallValues[:], sampleScalars[:]) + copy(sampleScalarsRedundant[:], sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i := 0; i < len(sampleScalarsSmallValues); i++ { + if i%5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i := 0; i < len(sampleScalarsRedundant); i += 100 { + for j := i + 1; j < i+100 && j < len(sampleScalarsRedundant); j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + fillBenchBasesG1(samplePoints[:]) var testPoint G1Affine @@ -201,6 +360,20 @@ func BenchmarkMultiExpG1(b *testing.B) { testPoint.MultiExp(samplePoints[:using], sampleScalars[:using], ecc.MultiExpConfig{}) } }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using], ecc.MultiExpConfig{}) + } + }) } } @@ -279,9 +452,9 @@ func TestMultiExpG2(t *testing.T) { parameters := gopter.DefaultTestParameters() if testing.Short() { - parameters.MinSuccessfulTests = 2 + parameters.MinSuccessfulTests = 3 } else { - parameters.MinSuccessfulTests = nbFuzzShort + parameters.MinSuccessfulTests = nbFuzzShort * 2 } properties := gopter.NewProperties(parameters) @@ -300,6 +473,13 @@ func TestMultiExpG2(t *testing.T) { g.AddAssign(&g2Gen) } + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + // final scalar to use in double and add method (without mixer factor) // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) var scalar big.Int @@ -323,13 +503,10 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - scalars16, _ := partitionScalars(sampleScalars[:], 16, false, runtime.NumCPU()) - r16.msmC16(samplePoints[:], scalars16, true) - + r16.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) return r16.Equal(&splitted1) && r16.Equal(&splitted2) @@ -357,21 +534,72 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - results := make([]G2Jac, len(cRange)+1) + results := make([]G2Jac, len(cRange)) for i, c := range cRange { - scalars, _ := partitionScalars(sampleScalars[:], c, false, runtime.NumCPU()) - msmInnerG2Jac(&results[i], int(c), samplePoints[:], scalars, false) - if c == 16 { - // split the first chunk - msmInnerG2Jac(&results[len(results)-1], 16, samplePoints[:], scalars, true) - } + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) } for i := 1; i < len(results); i++ { if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1], cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G2] Multi exponentation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]G2Affine + + var expected G2Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) + expected.ScalarMultiplication(&g2Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].setInfinity() + } + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G2] Multi exponentation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) return false } } @@ -394,8 +622,7 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= 30; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) samplePoints[i-1].FromJacobian(&g) g.AddAssign(&g2Gen) } @@ -418,6 +645,85 @@ func TestMultiExpG2(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestCrossMultiExpG2(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]G2Affine + var g G2Jac + g.Set(&g2Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g2Gen) + } + + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i := 10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + // for g2, CI suffers with large c size since it needs to allocate a lot of memory for the buckets. + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange := []uint64{5, 16} + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r G2Jac + _innerMsmG2Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got G2Affine + expected.FromJacobian(&r) + + for i := 0; i < len(results); i++ { + got.FromJacobian(&results[i]) + if !expected.Equal(&got) { + t.Fatalf("cross msm failed with c=%d", cRange[i]) + } + } + +} + +// _innerMsmG2Reference always do ext jacobian with c == 16 +func _innerMsmG2Reference(p *G2Jac, points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G2Jac { + // partition the scalars + digits, _ := partitionScalars(scalars, 16, config.NbTasks) + + nbChunks := computeNbChunks(16) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g2JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g2JacExtended, 1) + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := processChunkG2Jacobian[bucketg2JacExtendedC16] + go processChunk(uint64(j), chChunks[j], 16, points, digits[j*n:(j+1)*n]) + } + + return msmReduceChunkG2Affine(p, int(16), chChunks[:]) +} + func BenchmarkMultiExpG2(b *testing.B) { const ( @@ -426,11 +732,33 @@ func BenchmarkMultiExpG2(b *testing.B) { ) var ( - samplePoints [nbSamples]G2Affine - sampleScalars [nbSamples]fr.Element + samplePoints [nbSamples]G2Affine + sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element ) fillBenchScalars(sampleScalars[:]) + copy(sampleScalarsSmallValues[:], sampleScalars[:]) + copy(sampleScalarsRedundant[:], sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i := 0; i < len(sampleScalarsSmallValues); i++ { + if i%5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i := 0; i < len(sampleScalarsRedundant); i += 100 { + for j := i + 1; j < i+100 && j < len(sampleScalarsRedundant); j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + fillBenchBasesG2(samplePoints[:]) var testPoint G2Affine @@ -444,6 +772,20 @@ func BenchmarkMultiExpG2(b *testing.B) { testPoint.MultiExp(samplePoints[:using], sampleScalars[:using], ecc.MultiExpConfig{}) } }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using], ecc.MultiExpConfig{}) + } + }) } } @@ -520,11 +862,7 @@ func fillBenchBasesG2(samplePoints []G2Affine) { func fillBenchScalars(sampleScalars []fr.Element) { // ensure every words of the scalars are filled - var mixer fr.Element - mixer.SetString("7716837800905789770901243404444209691916730933998574719964609384059111546487") - for i := 1; i <= len(sampleScalars); i++ { - sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + for i := 0; i < len(sampleScalars); i++ { + sampleScalars[i].SetRandom() } } diff --git a/ecc/bw6-633/twistededwards/eddsa/doc.go b/ecc/bw6-633/twistededwards/eddsa/doc.go index f6afacaea..4e8b44bcd 100644 --- a/ecc/bw6-633/twistededwards/eddsa/doc.go +++ b/ecc/bw6-633/twistededwards/eddsa/doc.go @@ -16,7 +16,7 @@ // Package eddsa provides EdDSA signature scheme on bw6-633's twisted edwards curve. // -// See also +// # See also // // https://en.wikipedia.org/wiki/EdDSA package eddsa diff --git a/ecc/bw6-633/twistededwards/eddsa/eddsa_test.go b/ecc/bw6-633/twistededwards/eddsa/eddsa_test.go index 93e70b88e..066a7ac20 100644 --- a/ecc/bw6-633/twistededwards/eddsa/eddsa_test.go +++ b/ecc/bw6-633/twistededwards/eddsa/eddsa_test.go @@ -37,8 +37,10 @@ func Example() { privateKey, _ := GenerateKey(crand.Reader) publicKey := privateKey.PublicKey - // note that the message is on 4 bytes - msg := []byte{0xde, 0xad, 0xf0, 0x0d} + // generate a message (the size must be a multiple of the size of Fr) + var _msg fr.Element + _msg.SetRandom() + msg := _msg.Marshal() // sign the message signature, _ := privateKey.Sign(msg, hFunc) diff --git a/ecc/bw6-633/twistededwards/eddsa/marshal.go b/ecc/bw6-633/twistededwards/eddsa/marshal.go index c68129087..485e9b710 100644 --- a/ecc/bw6-633/twistededwards/eddsa/marshal.go +++ b/ecc/bw6-633/twistededwards/eddsa/marshal.go @@ -94,11 +94,11 @@ func (privKey *PrivateKey) SetBytes(buf []byte) (int, error) { // Bytes returns the binary representation of sig // as a byte array of size 3*sizeFr x||y||s where -// * x, y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x, y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) func (sig *Signature) Bytes() []byte { var res [sizeSignature]byte sigRBin := sig.R.Bytes() @@ -109,11 +109,12 @@ func (sig *Signature) Bytes() []byte { // SetBytes sets sig from a buffer in binary. // buf is read interpreted as x||y||s where -// * x,y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x,y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) +// // It returns the number of bytes read from buf. func (sig *Signature) SetBytes(buf []byte) (int, error) { n := 0 diff --git a/ecc/bw6-633/twistededwards/point.go b/ecc/bw6-633/twistededwards/point.go index e4055d8d8..45c401d15 100644 --- a/ecc/bw6-633/twistededwards/point.go +++ b/ecc/bw6-633/twistededwards/point.go @@ -49,7 +49,7 @@ const ( mUnmask = 0x7f // size in byte of a compressed point (point.Y --> fr.Element) - sizePointCompressed = fr.Limbs * 8 + sizePointCompressed = fr.Bytes ) // Bytes returns the compressed point as a byte array diff --git a/ecc/bw6-756/bw6-756.go b/ecc/bw6-756/bw6-756.go index ae4713330..0204d7948 100644 --- a/ecc/bw6-756/bw6-756.go +++ b/ecc/bw6-756/bw6-756.go @@ -1,22 +1,28 @@ // Package bw6756 efficient elliptic curve, pairing and hash to curve implementation for bw6-756. // // bw6-756: A Brezing--Weng curve (2-chain with bls12-378) -// embedding degree k=6 -// seed x₀=11045256207009841153. -// 𝔽p: p=366325390957376286590726555727219947825377821289246188278797409783441745356050456327989347160777465284190855125642086860525706497928518803244008749360363712553766506755227344593404398783886857865261088226271336335268413437902849 -// 𝔽r: r=605248206075306171733248481581800960739847691770924913753520744034740935903401304776283802348837311170974282940417 -// (E/𝔽p): Y²=X³+1 -// (Eₜ/𝔽p): Y² = X³+33 (M-type twist) -// r ∣ #E(Fp) and r ∣ #Eₜ(𝔽p) +// +// embedding degree k=6 +// seed x₀=11045256207009841153. +// 𝔽p: p=366325390957376286590726555727219947825377821289246188278797409783441745356050456327989347160777465284190855125642086860525706497928518803244008749360363712553766506755227344593404398783886857865261088226271336335268413437902849 +// 𝔽r: r=605248206075306171733248481581800960739847691770924913753520744034740935903401304776283802348837311170974282940417 +// (E/𝔽p): Y²=X³+1 +// (Eₜ/𝔽p): Y² = X³+33 (M-type twist) +// r ∣ #E(Fp) and r ∣ #Eₜ(𝔽p) +// // Extension fields tower: -// 𝔽p³[u] = 𝔽p/u³-33 -// 𝔽p⁶[v] = 𝔽p²/v²-u +// +// 𝔽p³[u] = 𝔽p/u³-33 +// 𝔽p⁶[v] = 𝔽p²/v²-u +// // optimal Ate loops: -// x₀+1, x₀²-x₀-1 +// +// x₀+1, x₀²-x₀-1 +// // Security: estimated 126-bit level following [https://eprint.iacr.org/2019/885.pdf] // (r is 378 bits and p⁶ is 4536 bits) // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package bw6756 diff --git a/ecc/bw6-756/fp/doc.go b/ecc/bw6-756/fp/doc.go index 8421a7225..a94273407 100644 --- a/ecc/bw6-756/fp/doc.go +++ b/ecc/bw6-756/fp/doc.go @@ -21,30 +21,33 @@ // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [12]uint64 // -// Usage +// type Element [12]uint64 +// +// # Usage // // Example API signature: -// // Mul z = x * y (mod q) -// func (z *Element) Mul(x, y *Element) *Element +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) // // Modulus q = // -// q[base10] = 366325390957376286590726555727219947825377821289246188278797409783441745356050456327989347160777465284190855125642086860525706497928518803244008749360363712553766506755227344593404398783886857865261088226271336335268413437902849 -// q[base16] = 0xf76adbb5bb98ae2ac127e1e3568cf5c978cd2fac2ce89fbf23221455163a6ccc6ae73c42a46d9eb02c812ea04faaa0a7eb1cb3d06e646e292cd15edb646a54302aa3c258de7ded0b685e868524ec033c7e63f868400000000000000000001 +// q[base10] = 366325390957376286590726555727219947825377821289246188278797409783441745356050456327989347160777465284190855125642086860525706497928518803244008749360363712553766506755227344593404398783886857865261088226271336335268413437902849 +// q[base16] = 0xf76adbb5bb98ae2ac127e1e3568cf5c978cd2fac2ce89fbf23221455163a6ccc6ae73c42a46d9eb02c812ea04faaa0a7eb1cb3d06e646e292cd15edb646a54302aa3c258de7ded0b685e868524ec033c7e63f868400000000000000000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package fp diff --git a/ecc/bw6-756/fp/element.go b/ecc/bw6-756/fp/element.go index 4ef865aed..1be54119d 100644 --- a/ecc/bw6-756/fp/element.go +++ b/ecc/bw6-756/fp/element.go @@ -20,13 +20,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "github.com/consensys/gnark-crypto/field" "io" "math/big" "math/bits" "reflect" "strconv" "strings" - "sync" ) // Element represents a field element stored on 12 words (uint64) @@ -35,18 +35,18 @@ import ( // // Modulus q = // -// q[base10] = 366325390957376286590726555727219947825377821289246188278797409783441745356050456327989347160777465284190855125642086860525706497928518803244008749360363712553766506755227344593404398783886857865261088226271336335268413437902849 -// q[base16] = 0xf76adbb5bb98ae2ac127e1e3568cf5c978cd2fac2ce89fbf23221455163a6ccc6ae73c42a46d9eb02c812ea04faaa0a7eb1cb3d06e646e292cd15edb646a54302aa3c258de7ded0b685e868524ec033c7e63f868400000000000000000001 +// q[base10] = 366325390957376286590726555727219947825377821289246188278797409783441745356050456327989347160777465284190855125642086860525706497928518803244008749360363712553766506755227344593404398783886857865261088226271336335268413437902849 +// q[base16] = 0xf76adbb5bb98ae2ac127e1e3568cf5c978cd2fac2ce89fbf23221455163a6ccc6ae73c42a46d9eb02c812ea04faaa0a7eb1cb3d06e646e292cd15edb646a54302aa3c258de7ded0b685e868524ec033c7e63f868400000000000000000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [12]uint64 const ( - Limbs = 12 // number of 64 bits words needed to represent a Element - Bits = 756 // number of bits needed to represent a Element - Bytes = Limbs * 8 // number of bytes needed to represent a Element + Limbs = 12 // number of 64 bits words needed to represent a Element + Bits = 756 // number of bits needed to represent a Element + Bytes = 96 // number of bytes needed to represent a Element ) // Field modulus q @@ -84,8 +84,8 @@ var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int // -// q[base10] = 366325390957376286590726555727219947825377821289246188278797409783441745356050456327989347160777465284190855125642086860525706497928518803244008749360363712553766506755227344593404398783886857865261088226271336335268413437902849 -// q[base16] = 0xf76adbb5bb98ae2ac127e1e3568cf5c978cd2fac2ce89fbf23221455163a6ccc6ae73c42a46d9eb02c812ea04faaa0a7eb1cb3d06e646e292cd15edb646a54302aa3c258de7ded0b685e868524ec033c7e63f868400000000000000000001 +// q[base10] = 366325390957376286590726555727219947825377821289246188278797409783441745356050456327989347160777465284190855125642086860525706497928518803244008749360363712553766506755227344593404398783886857865261088226271336335268413437902849 +// q[base16] = 0xf76adbb5bb98ae2ac127e1e3568cf5c978cd2fac2ce89fbf23221455163a6ccc6ae73c42a46d9eb02c812ea04faaa0a7eb1cb3d06e646e292cd15edb646a54302aa3c258de7ded0b685e868524ec033c7e63f868400000000000000000001 func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } @@ -94,12 +94,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 18446744073709551615 -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("f76adbb5bb98ae2ac127e1e3568cf5c978cd2fac2ce89fbf23221455163a6ccc6ae73c42a46d9eb02c812ea04faaa0a7eb1cb3d06e646e292cd15edb646a54302aa3c258de7ded0b685e868524ec033c7e63f868400000000000000000001", 16) } @@ -107,8 +101,9 @@ func init() { // NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v Element -// v.SetUint64(...) +// +// var v Element +// v.SetUint64(...) func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) @@ -119,7 +114,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -157,14 +152,15 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported // supported types: -// Element -// *Element -// uint64 -// int -// string (see SetString for valid formats) -// *big.Int -// big.Int -// []byte +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { if i1 == nil { return nil, errors.New("can't set fp.Element with ") @@ -290,15 +286,13 @@ func (z *Element) IsOne() bool { // IsUint64 reports whether z can be represented as an uint64. func (z *Element) IsUint64() bool { zz := *z - zz.FromMont() + zz.fromMont() return zz.FitsOnOneWord() } // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -310,15 +304,12 @@ func (z *Element) FitsOnOneWord() bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[11] > _x[11] { return 1 } else if _z[11] < _x[11] { @@ -389,8 +380,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 1, 0) @@ -513,67 +503,9 @@ func (z *Element) Halve() { } -// Mul z = x * y (mod q) -// -// x and y must be strictly inferior to q -func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be strictly inferior to q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } @@ -595,7 +527,7 @@ func (z *Element) Add(x, y *Element) *Element { z[10], carry = bits.Add64(x[10], y[10], carry) z[11], _ = bits.Add64(x[11], y[11], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -631,7 +563,7 @@ func (z *Element) Double(x *Element) *Element { z[10], carry = bits.Add64(x[10], x[10], carry) z[11], _ = bits.Add64(x[11], x[11], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -724,361 +656,531 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // see Mul for algorithm documentation - var t [12]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd1(v, y[4], c[1]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd1(v, y[5], c[1]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd1(v, y[6], c[1]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd1(v, y[7], c[1]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd1(v, y[8], c[1]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd1(v, y[9], c[1]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd1(v, y[10], c[1]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd1(v, y[11], c[1]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 4 - v := x[4] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 5 - v := x[5] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 6 - v := x[6] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 7 - v := x[7] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 8 - v := x[8] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 9 - v := x[9] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 10 - v := x[10] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 11 - v := x[11] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], z[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], z[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], z[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], z[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], z[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], z[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], z[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], z[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - z[11], z[10] = madd3(m, q11, c[0], c[2], c[1]) - } - - // if z >= q → z -= q + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [13]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + C, t[4] = madd1(y[0], x[4], C) + C, t[5] = madd1(y[0], x[5], C) + C, t[6] = madd1(y[0], x[6], C) + C, t[7] = madd1(y[0], x[7], C) + C, t[8] = madd1(y[0], x[8], C) + C, t[9] = madd1(y[0], x[9], C) + C, t[10] = madd1(y[0], x[10], C) + C, t[11] = madd1(y[0], x[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + C, t[4] = madd2(y[1], x[4], t[4], C) + C, t[5] = madd2(y[1], x[5], t[5], C) + C, t[6] = madd2(y[1], x[6], t[6], C) + C, t[7] = madd2(y[1], x[7], t[7], C) + C, t[8] = madd2(y[1], x[8], t[8], C) + C, t[9] = madd2(y[1], x[9], t[9], C) + C, t[10] = madd2(y[1], x[10], t[10], C) + C, t[11] = madd2(y[1], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + C, t[4] = madd2(y[2], x[4], t[4], C) + C, t[5] = madd2(y[2], x[5], t[5], C) + C, t[6] = madd2(y[2], x[6], t[6], C) + C, t[7] = madd2(y[2], x[7], t[7], C) + C, t[8] = madd2(y[2], x[8], t[8], C) + C, t[9] = madd2(y[2], x[9], t[9], C) + C, t[10] = madd2(y[2], x[10], t[10], C) + C, t[11] = madd2(y[2], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + C, t[4] = madd2(y[3], x[4], t[4], C) + C, t[5] = madd2(y[3], x[5], t[5], C) + C, t[6] = madd2(y[3], x[6], t[6], C) + C, t[7] = madd2(y[3], x[7], t[7], C) + C, t[8] = madd2(y[3], x[8], t[8], C) + C, t[9] = madd2(y[3], x[9], t[9], C) + C, t[10] = madd2(y[3], x[10], t[10], C) + C, t[11] = madd2(y[3], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[4], x[0], t[0]) + C, t[1] = madd2(y[4], x[1], t[1], C) + C, t[2] = madd2(y[4], x[2], t[2], C) + C, t[3] = madd2(y[4], x[3], t[3], C) + C, t[4] = madd2(y[4], x[4], t[4], C) + C, t[5] = madd2(y[4], x[5], t[5], C) + C, t[6] = madd2(y[4], x[6], t[6], C) + C, t[7] = madd2(y[4], x[7], t[7], C) + C, t[8] = madd2(y[4], x[8], t[8], C) + C, t[9] = madd2(y[4], x[9], t[9], C) + C, t[10] = madd2(y[4], x[10], t[10], C) + C, t[11] = madd2(y[4], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[5], x[0], t[0]) + C, t[1] = madd2(y[5], x[1], t[1], C) + C, t[2] = madd2(y[5], x[2], t[2], C) + C, t[3] = madd2(y[5], x[3], t[3], C) + C, t[4] = madd2(y[5], x[4], t[4], C) + C, t[5] = madd2(y[5], x[5], t[5], C) + C, t[6] = madd2(y[5], x[6], t[6], C) + C, t[7] = madd2(y[5], x[7], t[7], C) + C, t[8] = madd2(y[5], x[8], t[8], C) + C, t[9] = madd2(y[5], x[9], t[9], C) + C, t[10] = madd2(y[5], x[10], t[10], C) + C, t[11] = madd2(y[5], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[6], x[0], t[0]) + C, t[1] = madd2(y[6], x[1], t[1], C) + C, t[2] = madd2(y[6], x[2], t[2], C) + C, t[3] = madd2(y[6], x[3], t[3], C) + C, t[4] = madd2(y[6], x[4], t[4], C) + C, t[5] = madd2(y[6], x[5], t[5], C) + C, t[6] = madd2(y[6], x[6], t[6], C) + C, t[7] = madd2(y[6], x[7], t[7], C) + C, t[8] = madd2(y[6], x[8], t[8], C) + C, t[9] = madd2(y[6], x[9], t[9], C) + C, t[10] = madd2(y[6], x[10], t[10], C) + C, t[11] = madd2(y[6], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[7], x[0], t[0]) + C, t[1] = madd2(y[7], x[1], t[1], C) + C, t[2] = madd2(y[7], x[2], t[2], C) + C, t[3] = madd2(y[7], x[3], t[3], C) + C, t[4] = madd2(y[7], x[4], t[4], C) + C, t[5] = madd2(y[7], x[5], t[5], C) + C, t[6] = madd2(y[7], x[6], t[6], C) + C, t[7] = madd2(y[7], x[7], t[7], C) + C, t[8] = madd2(y[7], x[8], t[8], C) + C, t[9] = madd2(y[7], x[9], t[9], C) + C, t[10] = madd2(y[7], x[10], t[10], C) + C, t[11] = madd2(y[7], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[8], x[0], t[0]) + C, t[1] = madd2(y[8], x[1], t[1], C) + C, t[2] = madd2(y[8], x[2], t[2], C) + C, t[3] = madd2(y[8], x[3], t[3], C) + C, t[4] = madd2(y[8], x[4], t[4], C) + C, t[5] = madd2(y[8], x[5], t[5], C) + C, t[6] = madd2(y[8], x[6], t[6], C) + C, t[7] = madd2(y[8], x[7], t[7], C) + C, t[8] = madd2(y[8], x[8], t[8], C) + C, t[9] = madd2(y[8], x[9], t[9], C) + C, t[10] = madd2(y[8], x[10], t[10], C) + C, t[11] = madd2(y[8], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[9], x[0], t[0]) + C, t[1] = madd2(y[9], x[1], t[1], C) + C, t[2] = madd2(y[9], x[2], t[2], C) + C, t[3] = madd2(y[9], x[3], t[3], C) + C, t[4] = madd2(y[9], x[4], t[4], C) + C, t[5] = madd2(y[9], x[5], t[5], C) + C, t[6] = madd2(y[9], x[6], t[6], C) + C, t[7] = madd2(y[9], x[7], t[7], C) + C, t[8] = madd2(y[9], x[8], t[8], C) + C, t[9] = madd2(y[9], x[9], t[9], C) + C, t[10] = madd2(y[9], x[10], t[10], C) + C, t[11] = madd2(y[9], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[10], x[0], t[0]) + C, t[1] = madd2(y[10], x[1], t[1], C) + C, t[2] = madd2(y[10], x[2], t[2], C) + C, t[3] = madd2(y[10], x[3], t[3], C) + C, t[4] = madd2(y[10], x[4], t[4], C) + C, t[5] = madd2(y[10], x[5], t[5], C) + C, t[6] = madd2(y[10], x[6], t[6], C) + C, t[7] = madd2(y[10], x[7], t[7], C) + C, t[8] = madd2(y[10], x[8], t[8], C) + C, t[9] = madd2(y[10], x[9], t[9], C) + C, t[10] = madd2(y[10], x[10], t[10], C) + C, t[11] = madd2(y[10], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[11], x[0], t[0]) + C, t[1] = madd2(y[11], x[1], t[1], C) + C, t[2] = madd2(y[11], x[2], t[2], C) + C, t[3] = madd2(y[11], x[3], t[3], C) + C, t[4] = madd2(y[11], x[4], t[4], C) + C, t[5] = madd2(y[11], x[5], t[5], C) + C, t[6] = madd2(y[11], x[6], t[6], C) + C, t[7] = madd2(y[11], x[7], t[7], C) + C, t[8] = madd2(y[11], x[8], t[8], C) + C, t[9] = madd2(y[11], x[9], t[9], C) + C, t[10] = madd2(y[11], x[10], t[10], C) + C, t[11] = madd2(y[11], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + + if t[12] != 0 { + // we need to reduce, we have a result on 13 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], b = bits.Sub64(t[3], q3, b) + z[4], b = bits.Sub64(t[4], q4, b) + z[5], b = bits.Sub64(t[5], q5, b) + z[6], b = bits.Sub64(t[6], q6, b) + z[7], b = bits.Sub64(t[7], q7, b) + z[8], b = bits.Sub64(t[8], q8, b) + z[9], b = bits.Sub64(t[9], q9, b) + z[10], b = bits.Sub64(t[10], q10, b) + z[11], _ = bits.Sub64(t[11], q11, b) + return + } + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + z[4] = t[4] + z[5] = t[5] + z[6] = t[6] + z[7] = t[7] + z[8] = t[8] + z[9] = t[9] + z[10] = t[10] + z[11] = t[11] + + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -1094,7 +1196,6 @@ func _mulGeneric(z, x, y *Element) { z[10], b = bits.Sub64(z[10], q10, b) z[11], _ = bits.Sub64(z[11], q11, b) } - } func _fromMontGeneric(z *Element) { @@ -1306,7 +1407,7 @@ func _fromMontGeneric(z *Element) { z[11] = C } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -1326,7 +1427,7 @@ func _fromMontGeneric(z *Element) { func _reduceGeneric(z *Element) { - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -1422,6 +1523,35 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + // Exp z = xᵏ (mod q) func (z *Element) Exp(x Element, k *big.Int) *Element { if k.IsUint64() && k.Uint64() == 0 { @@ -1436,8 +1566,8 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } @@ -1471,23 +1601,37 @@ var rSquare = Element{ 4166778949326216, } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *Element) ToMont() *Element { +func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *Element) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[88:96], z[0]) + binary.BigEndian.PutUint64(b[80:88], z[1]) + binary.BigEndian.PutUint64(b[72:80], z[2]) + binary.BigEndian.PutUint64(b[64:72], z[3]) + binary.BigEndian.PutUint64(b[56:64], z[4]) + binary.BigEndian.PutUint64(b[48:56], z[5]) + binary.BigEndian.PutUint64(b[40:48], z[6]) + binary.BigEndian.PutUint64(b[32:40], z[7]) + binary.BigEndian.PutUint64(b[24:32], z[8]) + binary.BigEndian.PutUint64(b[16:24], z[9]) + binary.BigEndian.PutUint64(b[8:16], z[10]) + binary.BigEndian.PutUint64(b[0:8], z[11]) + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -1506,63 +1650,49 @@ func (z *Element) Text(base int) string { if base == 10 { var zzNeg Element zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } zz := *z - zz.FromMont() + zz.fromMont() if zz.FitsOnOneWord() { return strconv.FormatUint(zz[0], base) } - vv := bigIntPool.Get().(*big.Int) - r := zz.ToBigInt(vv).Text(base) - bigIntPool.Put(vv) + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) return r } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[88:96], z[0]) - binary.BigEndian.PutUint64(b[80:88], z[1]) - binary.BigEndian.PutUint64(b[72:80], z[2]) - binary.BigEndian.PutUint64(b[64:72], z[3]) - binary.BigEndian.PutUint64(b[56:64], z[4]) - binary.BigEndian.PutUint64(b[48:56], z[5]) - binary.BigEndian.PutUint64(b[40:48], z[6]) - binary.BigEndian.PutUint64(b[32:40], z[7]) - binary.BigEndian.PutUint64(b[24:32], z[8]) - binary.BigEndian.PutUint64(b[16:24], z[9]) - binary.BigEndian.PutUint64(b[8:16], z[10]) - binary.BigEndian.PutUint64(b[0:8], z[11]) - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[88:96], _z[0]) - binary.BigEndian.PutUint64(res[80:88], _z[1]) - binary.BigEndian.PutUint64(res[72:80], _z[2]) - binary.BigEndian.PutUint64(res[64:72], _z[3]) - binary.BigEndian.PutUint64(res[56:64], _z[4]) - binary.BigEndian.PutUint64(res[48:56], _z[5]) - binary.BigEndian.PutUint64(res[40:48], _z[6]) - binary.BigEndian.PutUint64(res[32:40], _z[7]) - binary.BigEndian.PutUint64(res[24:32], _z[8]) - binary.BigEndian.PutUint64(res[16:24], _z[9]) - binary.BigEndian.PutUint64(res[8:16], _z[10]) - binary.BigEndian.PutUint64(res[0:8], _z[11]) +// Bits provides access to z by returning its value as a little-endian [12]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [12]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -1575,19 +1705,44 @@ func (z *Element) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } +// SetBytesCanonical interprets e as the bytes of a big-endian 96-byte integer. +// If e is not a 96-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fp.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -1605,17 +1760,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -1637,20 +1791,20 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z // // The number prefix determines the actual base: A prefix of -// ''0b'' or ''0B'' selects base 2, ''0'', ''0o'' or ''0O'' selects base 8, -// and ''0x'' or ''0X'' selects base 16. Otherwise, the selected base is 10 +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For base 16, lower and upper case letters are considered the same: // The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. // -// An underscore character ''_'' may appear between a base +// An underscore character ”_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as a panic if there @@ -1659,7 +1813,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -1668,7 +1822,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -1708,7 +1862,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -1717,10 +1871,111 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 96-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[88:96]) + z[1] = binary.BigEndian.Uint64((*b)[80:88]) + z[2] = binary.BigEndian.Uint64((*b)[72:80]) + z[3] = binary.BigEndian.Uint64((*b)[64:72]) + z[4] = binary.BigEndian.Uint64((*b)[56:64]) + z[5] = binary.BigEndian.Uint64((*b)[48:56]) + z[6] = binary.BigEndian.Uint64((*b)[40:48]) + z[7] = binary.BigEndian.Uint64((*b)[32:40]) + z[8] = binary.BigEndian.Uint64((*b)[24:32]) + z[9] = binary.BigEndian.Uint64((*b)[16:24]) + z[10] = binary.BigEndian.Uint64((*b)[8:16]) + z[11] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[88:96], e[0]) + binary.BigEndian.PutUint64((*b)[80:88], e[1]) + binary.BigEndian.PutUint64((*b)[72:80], e[2]) + binary.BigEndian.PutUint64((*b)[64:72], e[3]) + binary.BigEndian.PutUint64((*b)[56:64], e[4]) + binary.BigEndian.PutUint64((*b)[48:56], e[5]) + binary.BigEndian.PutUint64((*b)[40:48], e[6]) + binary.BigEndian.PutUint64((*b)[32:40], e[7]) + binary.BigEndian.PutUint64((*b)[24:32], e[8]) + binary.BigEndian.PutUint64((*b)[16:24], e[9]) + binary.BigEndian.PutUint64((*b)[8:16], e[10]) + binary.BigEndian.PutUint64((*b)[0:8], e[11]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + z[4] = binary.LittleEndian.Uint64((*b)[32:40]) + z[5] = binary.LittleEndian.Uint64((*b)[40:48]) + z[6] = binary.LittleEndian.Uint64((*b)[48:56]) + z[7] = binary.LittleEndian.Uint64((*b)[56:64]) + z[8] = binary.LittleEndian.Uint64((*b)[64:72]) + z[9] = binary.LittleEndian.Uint64((*b)[72:80]) + z[10] = binary.LittleEndian.Uint64((*b)[80:88]) + z[11] = binary.LittleEndian.Uint64((*b)[88:96]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) + binary.LittleEndian.PutUint64((*b)[32:40], e[4]) + binary.LittleEndian.PutUint64((*b)[40:48], e[5]) + binary.LittleEndian.PutUint64((*b)[48:56], e[6]) + binary.LittleEndian.PutUint64((*b)[56:64], e[7]) + binary.LittleEndian.PutUint64((*b)[64:72], e[8]) + binary.LittleEndian.PutUint64((*b)[72:80], e[9]) + binary.LittleEndian.PutUint64((*b)[80:88], e[10]) + binary.LittleEndian.PutUint64((*b)[88:96], e[11]) +} + +func (littleEndian) String() string { return "LittleEndian" } + // Legendre returns the Legendre symbol of z (either +1, -1, or 0.) func (z *Element) Legendre() int { var l Element @@ -1753,7 +2008,7 @@ func (z *Element) Sqrt(x *Element) *Element { // y = x^((s+1)/2)) = w * x y.Mul(x, &w) - // b = x^s = w * w * x = y * x + // b = xˢ = w * w * x = y * x b.Mul(&w, &y) // g = nonResidue ^ s @@ -1774,7 +2029,7 @@ func (z *Element) Sqrt(x *Element) *Element { r := uint64(82) // compute legendre symbol - // t = x^((q-1)/2) = r-1 squaring of x^s + // t = x^((q-1)/2) = r-1 squaring of xˢ t = b for i := uint64(0); i < r-1; i++ { t.Square(&t) @@ -1996,7 +2251,7 @@ func (z *Element) Inverse(x *Element) *Element { // we would multiply by pSq up to 13times; // on x86, the assembly routine outperforms generic code for mul by word // on arm64, we may loose up to ~5% for 6 limbs - mul(&v, &v, &a) + v.Mul(&v, &a) } u.Set(x) // for correctness check @@ -2019,17 +2274,28 @@ func (z *Element) Inverse(x *Element) *Element { // correctness check v.Mul(&u, z) if !v.IsOne() && !u.IsZero() { - return z.inverseExp(&u) + return z.inverseExp(u) } return z } // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) -func (z *Element) inverseExp(x *Element) *Element { - qMinusTwo := Modulus() - qMinusTwo.Sub(qMinusTwo, big.NewInt(2)) - return z.Exp(*x, qMinusTwo) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z } // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits @@ -2312,7 +2578,7 @@ func (z *Element) montReduceSigned(x *Element, xHi uint64) { z[11], z[10] = madd2(m, q11, t[i+11], C) } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) diff --git a/ecc/bw6-756/fp/element_mul_adx_amd64.s b/ecc/bw6-756/fp/element_mul_adx_amd64.s deleted file mode 100644 index 1647f7e43..000000000 --- a/ecc/bw6-756/fp/element_mul_adx_amd64.s +++ /dev/null @@ -1,2738 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $1 -DATA q<>+8(SB)/8, $0x33c7e63f86840000 -DATA q<>+16(SB)/8, $0xd0b685e868524ec0 -DATA q<>+24(SB)/8, $0x4302aa3c258de7de -DATA q<>+32(SB)/8, $0xe292cd15edb646a5 -DATA q<>+40(SB)/8, $0x0a7eb1cb3d06e646 -DATA q<>+48(SB)/8, $0xeb02c812ea04faaa -DATA q<>+56(SB)/8, $0xccc6ae73c42a46d9 -DATA q<>+64(SB)/8, $0xfbf23221455163a6 -DATA q<>+72(SB)/8, $0x5c978cd2fac2ce89 -DATA q<>+80(SB)/8, $0xe2ac127e1e3568cf -DATA q<>+88(SB)/8, $0x000f76adbb5bb98a -GLOBL q<>(SB), (RODATA+NOPTR), $96 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0xffffffffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, ra6, ra7, ra8, ra9, ra10, ra11, rb0, rb1, rb2, rb3, rb4, rb5, rb6, rb7, rb8, rb9, rb10, rb11) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - MOVQ ra6, rb6; \ - SBBQ q<>+48(SB), ra6; \ - MOVQ ra7, rb7; \ - SBBQ q<>+56(SB), ra7; \ - MOVQ ra8, rb8; \ - SBBQ q<>+64(SB), ra8; \ - MOVQ ra9, rb9; \ - SBBQ q<>+72(SB), ra9; \ - MOVQ ra10, rb10; \ - SBBQ q<>+80(SB), ra10; \ - MOVQ ra11, rb11; \ - SBBQ q<>+88(SB), ra11; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - CMOVQCS rb6, ra6; \ - CMOVQCS rb7, ra7; \ - CMOVQCS rb8, ra8; \ - CMOVQCS rb9, ra9; \ - CMOVQCS rb10, ra10; \ - CMOVQCS rb11, ra11; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $96-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - MOVQ x+8(FP), AX - - // x[0] -> s0-8(SP) - // x[1] -> s1-16(SP) - // x[2] -> s2-24(SP) - // x[3] -> s3-32(SP) - // x[4] -> s4-40(SP) - // x[5] -> s5-48(SP) - // x[6] -> s6-56(SP) - // x[7] -> s7-64(SP) - // x[8] -> s8-72(SP) - // x[9] -> s9-80(SP) - // x[10] -> s10-88(SP) - // x[11] -> s11-96(SP) - MOVQ 0(AX), R14 - MOVQ 8(AX), R15 - MOVQ 16(AX), CX - MOVQ 24(AX), BX - MOVQ 32(AX), SI - MOVQ 40(AX), DI - MOVQ 48(AX), R8 - MOVQ 56(AX), R9 - MOVQ 64(AX), R10 - MOVQ 72(AX), R11 - MOVQ 80(AX), R12 - MOVQ 88(AX), R13 - MOVQ R14, s0-8(SP) - MOVQ R15, s1-16(SP) - MOVQ CX, s2-24(SP) - MOVQ BX, s3-32(SP) - MOVQ SI, s4-40(SP) - MOVQ DI, s5-48(SP) - MOVQ R8, s6-56(SP) - MOVQ R9, s7-64(SP) - MOVQ R10, s8-72(SP) - MOVQ R11, s9-80(SP) - MOVQ R12, s10-88(SP) - MOVQ R13, s11-96(SP) - - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // t[5] -> DI - // t[6] -> R8 - // t[7] -> R9 - // t[8] -> R10 - // t[9] -> R11 - // t[10] -> R12 - // t[11] -> R13 - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 0(AX), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ s0-8(SP), R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ s1-16(SP), AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ s2-24(SP), AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ s3-32(SP), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ s4-40(SP), AX, DI - ADOXQ AX, SI - - // (A,t[5]) := x[5]*y[0] + A - MULXQ s5-48(SP), AX, R8 - ADOXQ AX, DI - - // (A,t[6]) := x[6]*y[0] + A - MULXQ s6-56(SP), AX, R9 - ADOXQ AX, R8 - - // (A,t[7]) := x[7]*y[0] + A - MULXQ s7-64(SP), AX, R10 - ADOXQ AX, R9 - - // (A,t[8]) := x[8]*y[0] + A - MULXQ s8-72(SP), AX, R11 - ADOXQ AX, R10 - - // (A,t[9]) := x[9]*y[0] + A - MULXQ s9-80(SP), AX, R12 - ADOXQ AX, R11 - - // (A,t[10]) := x[10]*y[0] + A - MULXQ s10-88(SP), AX, R13 - ADOXQ AX, R12 - - // (A,t[11]) := x[11]*y[0] + A - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 8(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[1] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[1] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[1] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[1] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[1] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[1] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 16(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[2] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[2] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[2] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[2] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[2] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[2] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 24(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[3] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[3] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[3] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[3] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[3] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[3] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 32(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[4] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[4] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[4] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[4] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[4] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[4] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 40(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[5] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[5] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[5] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[5] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[5] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[5] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 48(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[6] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[6] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[6] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[6] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[6] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[6] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[6] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[6] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[6] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[6] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[6] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[6] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 56(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[7] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[7] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[7] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[7] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[7] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[7] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[7] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[7] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[7] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[7] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[7] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[7] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 64(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[8] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[8] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[8] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[8] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[8] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[8] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[8] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[8] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[8] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[8] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[8] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[8] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 72(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[9] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[9] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[9] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[9] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[9] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[9] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[9] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[9] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[9] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[9] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[9] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[9] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 80(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[10] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[10] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[10] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[10] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[10] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[10] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[10] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[10] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[10] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[10] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[10] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[10] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 88(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[11] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[11] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[11] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[11] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[11] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[11] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[11] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[11] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[11] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[11] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[11] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[11] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // reduce element(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) using temp registers (s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP)) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP)) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - MOVQ R8, 48(AX) - MOVQ R9, 56(AX) - MOVQ R10, 64(AX) - MOVQ R11, 72(AX) - MOVQ R12, 80(AX) - MOVQ R13, 88(AX) - RET - -TEXT ·fromMont(SB), $96-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - MOVQ 40(DX), DI - MOVQ 48(DX), R8 - MOVQ 56(DX), R9 - MOVQ 64(DX), R10 - MOVQ 72(DX), R11 - MOVQ 80(DX), R12 - MOVQ 88(DX), R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - - // reduce element(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) using temp registers (s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP)) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP)) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - MOVQ R8, 48(AX) - MOVQ R9, 56(AX) - MOVQ R10, 64(AX) - MOVQ R11, 72(AX) - MOVQ R12, 80(AX) - MOVQ R13, 88(AX) - RET diff --git a/ecc/bw6-756/fp/element_mul_amd64.s b/ecc/bw6-756/fp/element_mul_amd64.s index 5fb0dd4a7..b01df63c2 100644 --- a/ecc/bw6-756/fp/element_mul_amd64.s +++ b/ecc/bw6-756/fp/element_mul_amd64.s @@ -1,4 +1,4 @@ -// +build !amd64_adx +// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-756/fp/element_ops_amd64.go b/ecc/bw6-756/fp/element_ops_amd64.go index a3c830471..83bba45ae 100644 --- a/ecc/bw6-756/fp/element_ops_amd64.go +++ b/ecc/bw6-756/fp/element_ops_amd64.go @@ -1,3 +1,6 @@ +//go:build !purego +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,7 +38,70 @@ func fromMont(res *Element) func reduce(res *Element) // Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) +// +// a = a + b (mod q) +// b = a - b (mod q) +// //go:noescape func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bw6-756/fp/element_ops_amd64.s b/ecc/bw6-756/fp/element_ops_amd64.s index 61f80c041..57ceebbe4 100644 --- a/ecc/bw6-756/fp/element_ops_amd64.s +++ b/ecc/bw6-756/fp/element_ops_amd64.s @@ -1,3 +1,5 @@ +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bw6-756/fp/element_ops_noasm.go b/ecc/bw6-756/fp/element_ops_noasm.go deleted file mode 100644 index 194638ba4..000000000 --- a/ecc/bw6-756/fp/element_ops_noasm.go +++ /dev/null @@ -1,69 +0,0 @@ -//go:build !amd64 -// +build !amd64 - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fp - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 18446744073709496521, - 18279598932724285439, - 16020160894313802039, - 6734264120679796183, - 592370298347718455, - 6302256704987790972, - 3197310980453914279, - 2651858637075104463, - 8565083029697102127, - 15288469570225946050, - 14519635472382186671, - 448955992735434, - } - x.Mul(x, &y) -} - -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} diff --git a/ecc/bw6-756/fp/element_ops_purego.go b/ecc/bw6-756/fp/element_ops_purego.go new file mode 100644 index 000000000..2b12c05a0 --- /dev/null +++ b/ecc/bw6-756/fp/element_ops_purego.go @@ -0,0 +1,2227 @@ +//go:build !amd64 || purego +// +build !amd64 purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 18446744073709496521, + 18279598932724285439, + 16020160894313802039, + 6734264120679796183, + 592370298347718455, + 6302256704987790972, + 3197310980453914279, + 2651858637075104463, + 8565083029697102127, + 15288469570225946050, + 14519635472382186671, + 448955992735434, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + var t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11 uint64 + var u0, u1, u2, u3, u4, u5, u6, u7, u8, u9, u10, u11 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + u4, t4 = bits.Mul64(v, y[4]) + u5, t5 = bits.Mul64(v, y[5]) + u6, t6 = bits.Mul64(v, y[6]) + u7, t7 = bits.Mul64(v, y[7]) + u8, t8 = bits.Mul64(v, y[8]) + u9, t9 = bits.Mul64(v, y[9]) + u10, t10 = bits.Mul64(v, y[10]) + u11, t11 = bits.Mul64(v, y[11]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[6] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[7] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[8] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[9] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[10] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[11] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + z[6] = t6 + z[7] = t7 + z[8] = t8 + z[9] = t9 + z[10] = t10 + z[11] = t11 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], b = bits.Sub64(z[5], q5, b) + z[6], b = bits.Sub64(z[6], q6, b) + z[7], b = bits.Sub64(z[7], q7, b) + z[8], b = bits.Sub64(z[8], q8, b) + z[9], b = bits.Sub64(z[9], q9, b) + z[10], b = bits.Sub64(z[10], q10, b) + z[11], _ = bits.Sub64(z[11], q11, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11 uint64 + var u0, u1, u2, u3, u4, u5, u6, u7, u8, u9, u10, u11 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + u4, t4 = bits.Mul64(v, x[4]) + u5, t5 = bits.Mul64(v, x[5]) + u6, t6 = bits.Mul64(v, x[6]) + u7, t7 = bits.Mul64(v, x[7]) + u8, t8 = bits.Mul64(v, x[8]) + u9, t9 = bits.Mul64(v, x[9]) + u10, t10 = bits.Mul64(v, x[10]) + u11, t11 = bits.Mul64(v, x[11]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[6] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[7] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[8] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[9] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[10] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[11] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + z[6] = t6 + z[7] = t7 + z[8] = t8 + z[9] = t9 + z[10] = t10 + z[11] = t11 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], b = bits.Sub64(z[5], q5, b) + z[6], b = bits.Sub64(z[6], q6, b) + z[7], b = bits.Sub64(z[7], q7, b) + z[8], b = bits.Sub64(z[8], q8, b) + z[9], b = bits.Sub64(z[9], q9, b) + z[10], b = bits.Sub64(z[10], q10, b) + z[11], _ = bits.Sub64(z[11], q11, b) + } + return z +} diff --git a/ecc/bw6-756/fp/element_test.go b/ecc/bw6-756/fp/element_test.go index 61b5a8859..7dbbb6d44 100644 --- a/ecc/bw6-756/fp/element_test.go +++ b/ecc/bw6-756/fp/element_test.go @@ -23,7 +23,7 @@ import ( "math/big" "math/bits" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field" mrand "math/rand" "testing" @@ -182,17 +182,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -654,7 +647,7 @@ func TestElementBitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPairElement) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -769,7 +762,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -782,13 +775,13 @@ func TestElementAdd(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -817,17 +810,17 @@ func TestElementAdd(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -878,7 +871,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -891,13 +884,13 @@ func TestElementSub(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -926,17 +919,17 @@ func TestElementSub(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -987,7 +980,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1000,7 +993,7 @@ func TestElementMul(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -1014,7 +1007,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1054,11 +1047,11 @@ func TestElementMul(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -1071,7 +1064,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1123,7 +1116,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1136,14 +1129,14 @@ func TestElementDiv(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1172,18 +1165,18 @@ func TestElementDiv(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1234,7 +1227,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1247,13 +1240,13 @@ func TestElementExp(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1282,17 +1275,17 @@ func TestElementExp(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1337,7 +1330,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1358,14 +1351,14 @@ func TestElementSquare(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1409,7 +1402,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1430,14 +1423,14 @@ func TestElementInverse(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1481,7 +1474,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1502,14 +1495,14 @@ func TestElementSqrt(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1553,7 +1546,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1574,14 +1567,14 @@ func TestElementDouble(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1625,7 +1618,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1646,14 +1639,14 @@ func TestElementNeg(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -2041,7 +2034,7 @@ func TestElementNegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -2174,17 +2167,17 @@ func TestElementFromMont(t *testing.T) { func(a testPairElement) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPairElement) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -2279,7 +2272,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } @@ -2365,7 +2358,7 @@ func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { return field.BigIntMatchUint64Slice(&aIntMod, slice) } -//TODO: Phase out in favor of property based testing +// TODO: Phase out in favor of property based testing func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { if err := z.matchVeryBigInt(aHi, aInt); err != nil { @@ -2415,7 +2408,7 @@ func TestElementInversionCorrectionFactorFormula(t *testing.T) { inversionCorrectionFactorWord10, inversionCorrectionFactorWord11, } - inversionCorrectionFactor.ToBigInt(&refFactorInt) + inversionCorrectionFactor.toBigInt(&refFactorInt) if refFactorInt.Cmp(factorInt) != 0 { t.Error("mismatch") @@ -2457,7 +2450,7 @@ func TestElementInversionCorrectionFactor(t *testing.T) { if !oneInv.Equal(&one) { var i big.Int - oneInv.ToBigIntRegular(&i) // no montgomery + oneInv.BigInt(&i) // no montgomery i.ModInverse(&i, Modulus()) var fac Element fac.setBigInt(&i) // back to montgomery @@ -2618,7 +2611,7 @@ func TestElement0Inverse(t *testing.T) { } } -//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen func TestUpdateFactorSubtraction(t *testing.T) { for i := 0; i < 1000; i++ { @@ -2831,11 +2824,11 @@ func randomizeUpdateFactors() (int64, int64) { func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { var p1 big.Int - x.ToBigInt(&p1) + x.toBigInt(&p1) p1.Mul(&p1, big.NewInt(xC)) var p2 big.Int - y.ToBigInt(&p2) + y.toBigInt(&p2) p2.Mul(&p2, big.NewInt(yC)) p1.Add(&p1, &p2) @@ -2872,7 +2865,7 @@ func montReduce(res *big.Int, x *big.Int) { } func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { - z.ToBigInt(i) + z.toBigInt(i) var upperWord big.Int upperWord.SetUint64(xHi) upperWord.Lsh(&upperWord, Limbs*64) @@ -2890,7 +2883,7 @@ func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { var xInt big.Int - x.ToBigInt(&xInt) + x.toBigInt(&xInt) xInt.Mul(&xInt, big.NewInt(c)) @@ -2901,7 +2894,7 @@ func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, result func approximateRef(x *Element) uint64 { var asInt big.Int - x.ToBigInt(&asInt) + x.toBigInt(&asInt) n := x.BitLen() if n <= 64 { diff --git a/ecc/bw6-756/fr/doc.go b/ecc/bw6-756/fr/doc.go index 50fa83dec..38411d5d4 100644 --- a/ecc/bw6-756/fr/doc.go +++ b/ecc/bw6-756/fr/doc.go @@ -21,30 +21,33 @@ // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [6]uint64 // -// Usage +// type Element [6]uint64 +// +// # Usage // // Example API signature: -// // Mul z = x * y (mod q) -// func (z *Element) Mul(x, y *Element) *Element +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) // // Modulus q = // -// q[base10] = 605248206075306171733248481581800960739847691770924913753520744034740935903401304776283802348837311170974282940417 -// q[base16] = 0x3eeb0416684d19053cb5d240ed107a284059eb647102326980dc360d0a49d7fce97f76a822c00009948a20000000001 +// q[base10] = 605248206075306171733248481581800960739847691770924913753520744034740935903401304776283802348837311170974282940417 +// q[base16] = 0x3eeb0416684d19053cb5d240ed107a284059eb647102326980dc360d0a49d7fce97f76a822c00009948a20000000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package fr diff --git a/ecc/bw6-756/fr/element.go b/ecc/bw6-756/fr/element.go index 6beb3e9c0..1f601ad18 100644 --- a/ecc/bw6-756/fr/element.go +++ b/ecc/bw6-756/fr/element.go @@ -20,13 +20,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "github.com/consensys/gnark-crypto/field" "io" "math/big" "math/bits" "reflect" "strconv" "strings" - "sync" ) // Element represents a field element stored on 6 words (uint64) @@ -35,18 +35,18 @@ import ( // // Modulus q = // -// q[base10] = 605248206075306171733248481581800960739847691770924913753520744034740935903401304776283802348837311170974282940417 -// q[base16] = 0x3eeb0416684d19053cb5d240ed107a284059eb647102326980dc360d0a49d7fce97f76a822c00009948a20000000001 +// q[base10] = 605248206075306171733248481581800960739847691770924913753520744034740935903401304776283802348837311170974282940417 +// q[base16] = 0x3eeb0416684d19053cb5d240ed107a284059eb647102326980dc360d0a49d7fce97f76a822c00009948a20000000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [6]uint64 const ( - Limbs = 6 // number of 64 bits words needed to represent a Element - Bits = 378 // number of bits needed to represent a Element - Bytes = Limbs * 8 // number of bytes needed to represent a Element + Limbs = 6 // number of 64 bits words needed to represent a Element + Bits = 378 // number of bits needed to represent a Element + Bytes = 48 // number of bytes needed to represent a Element ) // Field modulus q @@ -72,8 +72,8 @@ var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int // -// q[base10] = 605248206075306171733248481581800960739847691770924913753520744034740935903401304776283802348837311170974282940417 -// q[base16] = 0x3eeb0416684d19053cb5d240ed107a284059eb647102326980dc360d0a49d7fce97f76a822c00009948a20000000001 +// q[base10] = 605248206075306171733248481581800960739847691770924913753520744034740935903401304776283802348837311170974282940417 +// q[base16] = 0x3eeb0416684d19053cb5d240ed107a284059eb647102326980dc360d0a49d7fce97f76a822c00009948a20000000001 func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } @@ -82,12 +82,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 11045256207009841151 -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("3eeb0416684d19053cb5d240ed107a284059eb647102326980dc360d0a49d7fce97f76a822c00009948a20000000001", 16) } @@ -95,8 +89,9 @@ func init() { // NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v Element -// v.SetUint64(...) +// +// var v Element +// v.SetUint64(...) func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) @@ -107,7 +102,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -139,14 +134,15 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported // supported types: -// Element -// *Element -// uint64 -// int -// string (see SetString for valid formats) -// *big.Int -// big.Int -// []byte +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { if i1 == nil { return nil, errors.New("can't set fr.Element with ") @@ -260,15 +256,13 @@ func (z *Element) IsOne() bool { // IsUint64 reports whether z can be represented as an uint64. func (z *Element) IsUint64() bool { zz := *z - zz.FromMont() + zz.fromMont() return zz.FitsOnOneWord() } // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -280,15 +274,12 @@ func (z *Element) FitsOnOneWord() bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[5] > _x[5] { return 1 } else if _z[5] < _x[5] { @@ -329,8 +320,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 5522628103504920577, 0) @@ -429,67 +419,9 @@ func (z *Element) Halve() { } -// Mul z = x * y (mod q) -// -// x and y must be strictly inferior to q -func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be strictly inferior to q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } @@ -505,7 +437,7 @@ func (z *Element) Add(x, y *Element) *Element { z[4], carry = bits.Add64(x[4], y[4], carry) z[5], _ = bits.Add64(x[5], y[5], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -529,7 +461,7 @@ func (z *Element) Double(x *Element) *Element { z[4], carry = bits.Add64(x[4], x[4], carry) z[5], _ = bits.Add64(x[5], x[5], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -592,115 +524,219 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // see Mul for algorithm documentation - var t [6]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd1(v, y[4], c[1]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd1(v, y[5], c[1]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 4 - v := x[4] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 5 - v := x[5] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], z[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], z[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - z[5], z[4] = madd3(m, q5, c[0], c[2], c[1]) - } - - // if z >= q → z -= q + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [7]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + C, t[4] = madd1(y[0], x[4], C) + C, t[5] = madd1(y[0], x[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + C, t[4] = madd2(y[1], x[4], t[4], C) + C, t[5] = madd2(y[1], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + C, t[4] = madd2(y[2], x[4], t[4], C) + C, t[5] = madd2(y[2], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + C, t[4] = madd2(y[3], x[4], t[4], C) + C, t[5] = madd2(y[3], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[4], x[0], t[0]) + C, t[1] = madd2(y[4], x[1], t[1], C) + C, t[2] = madd2(y[4], x[2], t[2], C) + C, t[3] = madd2(y[4], x[3], t[3], C) + C, t[4] = madd2(y[4], x[4], t[4], C) + C, t[5] = madd2(y[4], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[5], x[0], t[0]) + C, t[1] = madd2(y[5], x[1], t[1], C) + C, t[2] = madd2(y[5], x[2], t[2], C) + C, t[3] = madd2(y[5], x[3], t[3], C) + C, t[4] = madd2(y[5], x[4], t[4], C) + C, t[5] = madd2(y[5], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + + if t[6] != 0 { + // we need to reduce, we have a result on 7 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], b = bits.Sub64(t[3], q3, b) + z[4], b = bits.Sub64(t[4], q4, b) + z[5], _ = bits.Sub64(t[5], q5, b) + return + } + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + z[4] = t[4] + z[5] = t[5] + + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -710,7 +746,6 @@ func _mulGeneric(z, x, y *Element) { z[4], b = bits.Sub64(z[4], q4, b) z[5], _ = bits.Sub64(z[5], q5, b) } - } func _fromMontGeneric(z *Element) { @@ -784,7 +819,7 @@ func _fromMontGeneric(z *Element) { z[5] = C } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -798,7 +833,7 @@ func _fromMontGeneric(z *Element) { func _reduceGeneric(z *Element) { - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -870,6 +905,35 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + // Exp z = xᵏ (mod q) func (z *Element) Exp(x Element, k *big.Int) *Element { if k.IsUint64() && k.Uint64() == 0 { @@ -884,8 +948,8 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } @@ -913,23 +977,31 @@ var rSquare = Element{ 51529254522778566, } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *Element) ToMont() *Element { +func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *Element) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[40:48], z[0]) + binary.BigEndian.PutUint64(b[32:40], z[1]) + binary.BigEndian.PutUint64(b[24:32], z[2]) + binary.BigEndian.PutUint64(b[16:24], z[3]) + binary.BigEndian.PutUint64(b[8:16], z[4]) + binary.BigEndian.PutUint64(b[0:8], z[5]) + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -948,51 +1020,49 @@ func (z *Element) Text(base int) string { if base == 10 { var zzNeg Element zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } zz := *z - zz.FromMont() + zz.fromMont() if zz.FitsOnOneWord() { return strconv.FormatUint(zz[0], base) } - vv := bigIntPool.Get().(*big.Int) - r := zz.ToBigInt(vv).Text(base) - bigIntPool.Put(vv) + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) return r } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[40:48], z[0]) - binary.BigEndian.PutUint64(b[32:40], z[1]) - binary.BigEndian.PutUint64(b[24:32], z[2]) - binary.BigEndian.PutUint64(b[16:24], z[3]) - binary.BigEndian.PutUint64(b[8:16], z[4]) - binary.BigEndian.PutUint64(b[0:8], z[5]) - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[40:48], _z[0]) - binary.BigEndian.PutUint64(res[32:40], _z[1]) - binary.BigEndian.PutUint64(res[24:32], _z[2]) - binary.BigEndian.PutUint64(res[16:24], _z[3]) - binary.BigEndian.PutUint64(res[8:16], _z[4]) - binary.BigEndian.PutUint64(res[0:8], _z[5]) +// Bits provides access to z by returning its value as a little-endian [6]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [6]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -1005,19 +1075,44 @@ func (z *Element) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } +// SetBytesCanonical interprets e as the bytes of a big-endian 48-byte integer. +// If e is not a 48-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fr.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -1035,17 +1130,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -1067,20 +1161,20 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z // // The number prefix determines the actual base: A prefix of -// ''0b'' or ''0B'' selects base 2, ''0'', ''0o'' or ''0O'' selects base 8, -// and ''0x'' or ''0X'' selects base 16. Otherwise, the selected base is 10 +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For base 16, lower and upper case letters are considered the same: // The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. // -// An underscore character ''_'' may appear between a base +// An underscore character ”_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as a panic if there @@ -1089,7 +1183,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -1098,7 +1192,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -1138,7 +1232,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -1147,10 +1241,87 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 48-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[40:48]) + z[1] = binary.BigEndian.Uint64((*b)[32:40]) + z[2] = binary.BigEndian.Uint64((*b)[24:32]) + z[3] = binary.BigEndian.Uint64((*b)[16:24]) + z[4] = binary.BigEndian.Uint64((*b)[8:16]) + z[5] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[40:48], e[0]) + binary.BigEndian.PutUint64((*b)[32:40], e[1]) + binary.BigEndian.PutUint64((*b)[24:32], e[2]) + binary.BigEndian.PutUint64((*b)[16:24], e[3]) + binary.BigEndian.PutUint64((*b)[8:16], e[4]) + binary.BigEndian.PutUint64((*b)[0:8], e[5]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + z[4] = binary.LittleEndian.Uint64((*b)[32:40]) + z[5] = binary.LittleEndian.Uint64((*b)[40:48]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) + binary.LittleEndian.PutUint64((*b)[32:40], e[4]) + binary.LittleEndian.PutUint64((*b)[40:48], e[5]) +} + +func (littleEndian) String() string { return "LittleEndian" } + // Legendre returns the Legendre symbol of z (either +1, -1, or 0.) func (z *Element) Legendre() int { var l Element @@ -1183,7 +1354,7 @@ func (z *Element) Sqrt(x *Element) *Element { // y = x^((s+1)/2)) = w * x y.Mul(x, &w) - // b = x^s = w * w * x = y * x + // b = xˢ = w * w * x = y * x b.Mul(&w, &y) // g = nonResidue ^ s @@ -1198,7 +1369,7 @@ func (z *Element) Sqrt(x *Element) *Element { r := uint64(41) // compute legendre symbol - // t = x^((q-1)/2) = r-1 squaring of x^s + // t = x^((q-1)/2) = r-1 squaring of xˢ t = b for i := uint64(0); i < r-1; i++ { t.Square(&t) @@ -1396,7 +1567,7 @@ func (z *Element) Inverse(x *Element) *Element { // we would multiply by pSq up to 13times; // on x86, the assembly routine outperforms generic code for mul by word // on arm64, we may loose up to ~5% for 6 limbs - mul(&v, &v, &a) + v.Mul(&v, &a) } u.Set(x) // for correctness check @@ -1413,17 +1584,28 @@ func (z *Element) Inverse(x *Element) *Element { // correctness check v.Mul(&u, z) if !v.IsOne() && !u.IsZero() { - return z.inverseExp(&u) + return z.inverseExp(u) } return z } // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) -func (z *Element) inverseExp(x *Element) *Element { - qMinusTwo := Modulus() - qMinusTwo.Sub(qMinusTwo, big.NewInt(2)) - return z.Exp(*x, qMinusTwo) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z } // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits @@ -1556,7 +1738,7 @@ func (z *Element) montReduceSigned(x *Element, xHi uint64) { z[5], z[4] = madd2(m, q5, t[i+5], C) } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) diff --git a/ecc/bw6-756/fr/element_mul_adx_amd64.s b/ecc/bw6-756/fr/element_mul_adx_amd64.s deleted file mode 100644 index 58909b6ec..000000000 --- a/ecc/bw6-756/fr/element_mul_adx_amd64.s +++ /dev/null @@ -1,835 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x9948a20000000001 -DATA q<>+8(SB)/8, $0xce97f76a822c0000 -DATA q<>+16(SB)/8, $0x980dc360d0a49d7f -DATA q<>+24(SB)/8, $0x84059eb647102326 -DATA q<>+32(SB)/8, $0x53cb5d240ed107a2 -DATA q<>+40(SB)/8, $0x03eeb0416684d190 -GLOBL q<>(SB), (RODATA+NOPTR), $48 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x9948a1ffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), NOSPLIT, $0-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - MOVQ x+8(FP), R8 - - // x[0] -> R10 - // x[1] -> R11 - // x[2] -> R12 - MOVQ 0(R8), R10 - MOVQ 8(R8), R11 - MOVQ 16(R8), R12 - MOVQ y+16(FP), R13 - - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // t[5] -> DI - // clear the flags - XORQ AX, AX - MOVQ 0(R13), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R10, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R11, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R12, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(R8), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(R8), AX, DI - ADOXQ AX, SI - - // (A,t[5]) := x[5]*y[0] + A - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 8(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 16(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 24(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 32(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 40(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R9,R8,R13,R10,R11,R12) - REDUCE(R14,R15,CX,BX,SI,DI,R9,R8,R13,R10,R11,R12) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET - -TEXT ·fromMont(SB), NOSPLIT, $0-8 - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - MOVQ 40(DX), DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12,R13) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET diff --git a/ecc/bw6-756/fr/element_mul_amd64.s b/ecc/bw6-756/fr/element_mul_amd64.s index 3afd58112..39ededda7 100644 --- a/ecc/bw6-756/fr/element_mul_amd64.s +++ b/ecc/bw6-756/fr/element_mul_amd64.s @@ -1,4 +1,4 @@ -// +build !amd64_adx +// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-756/fr/element_ops_amd64.go b/ecc/bw6-756/fr/element_ops_amd64.go index b37a251b6..e40a9caed 100644 --- a/ecc/bw6-756/fr/element_ops_amd64.go +++ b/ecc/bw6-756/fr/element_ops_amd64.go @@ -1,3 +1,6 @@ +//go:build !purego +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,7 +38,70 @@ func fromMont(res *Element) func reduce(res *Element) // Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) +// +// a = a + b (mod q) +// b = a - b (mod q) +// //go:noescape func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bw6-756/fr/element_ops_amd64.s b/ecc/bw6-756/fr/element_ops_amd64.s index fa881ff9c..9440e0ccb 100644 --- a/ecc/bw6-756/fr/element_ops_amd64.s +++ b/ecc/bw6-756/fr/element_ops_amd64.s @@ -1,3 +1,5 @@ +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bw6-756/fr/element_ops_noasm.go b/ecc/bw6-756/fr/element_ops_noasm.go deleted file mode 100644 index f6c33e63b..000000000 --- a/ecc/bw6-756/fr/element_ops_noasm.go +++ /dev/null @@ -1,63 +0,0 @@ -//go:build !amd64 -// +build !amd64 - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 8212494240417053874, - 5029498262967025157, - 9404736542133420963, - 13073247822498485877, - 1581382318314538223, - 87125160541517067, - } - x.Mul(x, &y) -} - -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} diff --git a/ecc/bw6-756/fr/element_ops_purego.go b/ecc/bw6-756/fr/element_ops_purego.go new file mode 100644 index 000000000..4b243fca5 --- /dev/null +++ b/ecc/bw6-756/fr/element_ops_purego.go @@ -0,0 +1,745 @@ +//go:build !amd64 || purego +// +build !amd64 purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 8212494240417053874, + 5029498262967025157, + 9404736542133420963, + 13073247822498485877, + 1581382318314538223, + 87125160541517067, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + var t0, t1, t2, t3, t4, t5 uint64 + var u0, u1, u2, u3, u4, u5 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + u4, t4 = bits.Mul64(v, y[4]) + u5, t5 = bits.Mul64(v, y[5]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3, t4, t5 uint64 + var u0, u1, u2, u3, u4, u5 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + u4, t4 = bits.Mul64(v, x[4]) + u5, t5 = bits.Mul64(v, x[5]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} diff --git a/ecc/bw6-756/fr/element_test.go b/ecc/bw6-756/fr/element_test.go index e819d5704..d38e7128a 100644 --- a/ecc/bw6-756/fr/element_test.go +++ b/ecc/bw6-756/fr/element_test.go @@ -23,7 +23,7 @@ import ( "math/big" "math/bits" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field" mrand "math/rand" "testing" @@ -182,17 +182,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -642,7 +635,7 @@ func TestElementBitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPairElement) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -757,7 +750,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -770,13 +763,13 @@ func TestElementAdd(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -805,17 +798,17 @@ func TestElementAdd(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -866,7 +859,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -879,13 +872,13 @@ func TestElementSub(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -914,17 +907,17 @@ func TestElementSub(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -975,7 +968,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -988,7 +981,7 @@ func TestElementMul(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -1002,7 +995,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1042,11 +1035,11 @@ func TestElementMul(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -1059,7 +1052,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1111,7 +1104,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1124,14 +1117,14 @@ func TestElementDiv(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1160,18 +1153,18 @@ func TestElementDiv(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1222,7 +1215,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1235,13 +1228,13 @@ func TestElementExp(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1270,17 +1263,17 @@ func TestElementExp(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1325,7 +1318,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1346,14 +1339,14 @@ func TestElementSquare(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1397,7 +1390,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1418,14 +1411,14 @@ func TestElementInverse(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1469,7 +1462,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1490,14 +1483,14 @@ func TestElementSqrt(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1541,7 +1534,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1562,14 +1555,14 @@ func TestElementDouble(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1613,7 +1606,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1634,14 +1627,14 @@ func TestElementNeg(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -2029,7 +2022,7 @@ func TestElementNegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -2162,17 +2155,17 @@ func TestElementFromMont(t *testing.T) { func(a testPairElement) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPairElement) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -2255,7 +2248,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } @@ -2323,7 +2316,7 @@ func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { return field.BigIntMatchUint64Slice(&aIntMod, slice) } -//TODO: Phase out in favor of property based testing +// TODO: Phase out in favor of property based testing func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { if err := z.matchVeryBigInt(aHi, aInt); err != nil { @@ -2367,7 +2360,7 @@ func TestElementInversionCorrectionFactorFormula(t *testing.T) { inversionCorrectionFactorWord4, inversionCorrectionFactorWord5, } - inversionCorrectionFactor.ToBigInt(&refFactorInt) + inversionCorrectionFactor.toBigInt(&refFactorInt) if refFactorInt.Cmp(factorInt) != 0 { t.Error("mismatch") @@ -2409,7 +2402,7 @@ func TestElementInversionCorrectionFactor(t *testing.T) { if !oneInv.Equal(&one) { var i big.Int - oneInv.ToBigIntRegular(&i) // no montgomery + oneInv.BigInt(&i) // no montgomery i.ModInverse(&i, Modulus()) var fac Element fac.setBigInt(&i) // back to montgomery @@ -2558,7 +2551,7 @@ func TestElement0Inverse(t *testing.T) { } } -//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen func TestUpdateFactorSubtraction(t *testing.T) { for i := 0; i < 1000; i++ { @@ -2771,11 +2764,11 @@ func randomizeUpdateFactors() (int64, int64) { func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { var p1 big.Int - x.ToBigInt(&p1) + x.toBigInt(&p1) p1.Mul(&p1, big.NewInt(xC)) var p2 big.Int - y.ToBigInt(&p2) + y.toBigInt(&p2) p2.Mul(&p2, big.NewInt(yC)) p1.Add(&p1, &p2) @@ -2812,7 +2805,7 @@ func montReduce(res *big.Int, x *big.Int) { } func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { - z.ToBigInt(i) + z.toBigInt(i) var upperWord big.Int upperWord.SetUint64(xHi) upperWord.Lsh(&upperWord, Limbs*64) @@ -2830,7 +2823,7 @@ func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { var xInt big.Int - x.ToBigInt(&xInt) + x.toBigInt(&xInt) xInt.Mul(&xInt, big.NewInt(c)) @@ -2841,7 +2834,7 @@ func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, result func approximateRef(x *Element) uint64 { var asInt big.Int - x.ToBigInt(&asInt) + x.toBigInt(&asInt) n := x.BitLen() if n <= 64 { diff --git a/ecc/bw6-756/fr/fri/fri.go b/ecc/bw6-756/fr/fri/fri.go index 25efd1527..a349ec816 100644 --- a/ecc/bw6-756/fr/fri/fri.go +++ b/ecc/bw6-756/fr/fri/fri.go @@ -121,7 +121,6 @@ type Round struct { // a function is d-close to a low degree polynomial. // // It is composed of a series of Interactions, emulated with Fiat Shamir, -// type ProofOfProximity struct { // ID unique ID attached to the proof of proximity. It's needed for diff --git a/ecc/bw6-756/fr/gkr/gkr.go b/ecc/bw6-756/fr/gkr/gkr.go new file mode 100644 index 000000000..1925f630d --- /dev/null +++ b/ecc/bw6-756/fr/gkr/gkr.go @@ -0,0 +1,774 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr/sumcheck" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// Gate must be a low-degree polynomial +type Gate interface { + Evaluate(...fr.Element) fr.Element + Degree() int +} + +type Wire struct { + Gate Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]fr.Element + claimedEvaluations []fr.Element + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]fr.Element) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]fr.Element // x in the paper + claimedEvaluations []fr.Element // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + newEq.Eq(c.evaluationPoints[k]) + + eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// computeValAndStep returns val : i ↦ m(1, i...) and step : i ↦ m(1, i...) - m(0, i...) +func computeValAndStep(m polynomial.MultiLin, p *polynomial.Pool) (val polynomial.MultiLin, step polynomial.MultiLin) { + val = p.Clone(m[len(m)/2:]) + step = p.Clone(m[:len(m)/2]) + + valAsPoly, stepAsPoly := polynomial.Polynomial(val), polynomial.Polynomial(step) + + stepAsPoly.Sub(valAsPoly, stepAsPoly) + return +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = g_{j-1}(r_{j-1}). By convention, g_0 is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() (gJ polynomial.Polynomial) { + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + EVal, EStep := computeValAndStep(c.eq, c.manager.memPool) + + puVal := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO: Make a two-dimensional array struct, and index it i-first rather than inputI first: would result in scanning memory access in the "d" loop and obviate the gateInput variable + puStep := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO, ctd: the greater degGJ, the more this would matter + + for i, puI := range c.inputPreprocessors { + puVal[i], puStep[i] = computeValAndStep(puI, c.manager.memPool) + } + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + gJ = make([]fr.Element, degGJ) + + parallel := len(EVal) >= 1024 //TODO: Experiment with threshold + + var gateInput [][]fr.Element + + if parallel { + gateInput = [][]fr.Element{c.manager.memPool.Make(len(c.inputPreprocessors)), + c.manager.memPool.Make(len(c.inputPreprocessors))} + } else { + gateInput = [][]fr.Element{c.manager.memPool.Make(len(c.inputPreprocessors))} + } + + var wg sync.WaitGroup + + for d := 0; d < degGJ; d++ { + + notLastIteration := d+1 < degGJ + + sumOverI := func(res *fr.Element, gateInput []fr.Element, start, end int) { + for i := start; i < end; i++ { + + for inputI := range puVal { + gateInput[inputI].Set(&puVal[inputI][i]) + if notLastIteration { + puVal[inputI][i].Add(&puVal[inputI][i], &puStep[inputI][i]) + } + } + + // gJAtDI = gJ(d, i...) + gJAtDI := c.wire.Gate.Evaluate(gateInput...) + gJAtDI.Mul(&gJAtDI, &EVal[i]) + + res.Add(res, &gJAtDI) + + if notLastIteration { + EVal[i].Add(&EVal[i], &EStep[i]) + } + } + wg.Done() + } + + if parallel { + var firstHalf, secondHalf fr.Element + wg.Add(2) + go sumOverI(&secondHalf, gateInput[1], len(EVal)/2, len(EVal)) + go sumOverI(&firstHalf, gateInput[0], 0, len(EVal)/2) + wg.Wait() + gJ[d].Add(&firstHalf, &secondHalf) + } else { + wg.Add(1) // formalities + sumOverI(&gJ[d], gateInput[0], 0, len(EVal)) + } + } + + c.manager.memPool.Dump(gateInput...) + c.manager.memPool.Dump(EVal, EStep) + + for inputI := range puVal { + c.manager.memPool.Dump(puVal[inputI], puStep[inputI]) + } + + return +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { + c.eq.Fold(element) + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, pool *polynomial.Pool) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = pool + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< b { + return a + } + return b +} + +func ChallengeNames(sorted []*Wire, logNbInstances int, prefix string) []string { + + // Pre-compute the size TODO: Consider not doing this and just grow the list by appending + size := logNbInstances // first challenge + + for _, w := range sorted { + if w.noProof() { // no proof, no challenge + continue + } + if w.NbClaims() > 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + + claims := newClaimsManager(c, assignment, o.pool) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + + claims := newClaimsManager(c, assignment, o.pool) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]fr.Element) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return fmt.Errorf("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return fmt.Errorf("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +type IdentityGate struct{} + +func (IdentityGate) Evaluate(input ...fr.Element) fr.Element { + return input[0] +} + +func (IdentityGate) Degree() int { + return 1 +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = IdentityGate{} + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + + numEvaluations := 0 + + for _, w := range sortedWires { + if !w.IsInput() { + if numEvaluations == 0 { + numEvaluations = len(a[w.Inputs[0]]) + } + evals := make([]fr.Element, numEvaluations) + ins := make([]fr.Element, len(w.Inputs)) + for k := 0; k < numEvaluations; k++ { + for inI, in := range w.Inputs { + ins[inI] = a[in][k] + } + evals[k] = w.Gate.Evaluate(ins...) + } + a[w] = evals + } + } + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} diff --git a/ecc/bw6-756/fr/gkr/gkr_test.go b/ecc/bw6-756/fr/gkr/gkr_test.go new file mode 100644 index 000000000..de89bcaad --- /dev/null +++ b/ecc/bw6-756/fr/gkr/gkr_test.go @@ -0,0 +1,722 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr/sumcheck" + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + pool := polynomial.NewPool(256, 1<<11) + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, &pool) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandom(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/ConsenSys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandom(slice []fr.Element) { + for i := range slice { + slice[i].SetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, testCase.transcriptSetting()) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, testCase.transcriptSetting()) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(&test_vector_utils.MapHash{Map: testCase.Hash}, []byte{1})) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "../../../../internal/generator/gkr/test_vectors" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func BenchmarkGkrMimc(b *testing.B) { + const N = 1 << 19 + fmt.Println("creating circuit structure") + c := mimcCircuit(91) + + in0 := make([]fr.Element, N) + in1 := make([]fr.Element, N) + setRandom(in0) + setRandom(in1) + + fmt.Println("evaluating circuit") + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + + //b.ResetTimer() + fmt.Println("constructing proof") + Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate string `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = gates[c[i].Gate] + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +var gates map[string]Gate + +func init() { + gates = make(map[string]Gate) + gates["identity"] = IdentityGate{} + gates["mul"] = mulGate{} + gates["mimc"] = mimcCipherGate{} //TODO: Add ark + gates["select-input-3"] = _select(2) +} + +type mimcCipherGate struct { + ark fr.Element +} + +func (m mimcCipherGate) Evaluate(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]). + Add(&sum, &m.ark) + + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +func (m mimcCipherGate) Degree() int { + return 7 +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash *test_vector_utils.ElementMap + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash string `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash *test_vector_utils.ElementMap + if _hash, err = test_vector_utils.ElementMapFromFile(filepath.Join(dir, info.Hash)); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func (c *TestCase) transcriptSetting(initialChallenge ...[]byte) fiatshamir.Settings { + return fiatshamir.WithHash(&test_vector_utils.MapHash{Map: c.Hash}, initialChallenge...) +} + +type mulGate struct{} + +func (g mulGate) Evaluate(element ...fr.Element) (result fr.Element) { + result.Mul(&element[0], &element[1]) + return +} + +func (g mulGate) Degree() int { + return 2 +} + +type _select int + +func (g _select) Evaluate(in ...fr.Element) fr.Element { + return in[g] +} + +func (g _select) Degree() int { + return 1 +} diff --git a/ecc/bw6-756/fr/kzg/kzg.go b/ecc/bw6-756/fr/kzg/kzg.go index 7796c16b5..fa3e1f420 100644 --- a/ecc/bw6-756/fr/kzg/kzg.go +++ b/ecc/bw6-756/fr/kzg/kzg.go @@ -84,9 +84,6 @@ func NewSRS(size uint64, bAlpha *big.Int) (*SRS, error) { for i := 1; i < len(alphas); i++ { alphas[i].Mul(&alphas[i-1], &alpha) } - for i := 0; i < len(alphas); i++ { - alphas[i].FromMont() - } g1s := bw6756.BatchScalarMultiplicationG1(&gen1Aff, alphas) copy(srs.G1[1:], g1s) @@ -125,7 +122,7 @@ func Commit(p []fr.Element, srs *SRS, nbTasks ...int) (Digest, error) { var res bw6756.G1Affine - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} if len(nbTasks) > 0 { config.NbTasks = nbTasks[0] } @@ -395,7 +392,7 @@ func BatchVerifyMultiPoints(digests []Digest, proofs []OpeningProof, points []fr for i := 0; i < len(randomNumbers); i++ { quotients[i].Set(&proofs[i].H) } - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} _, err := foldedQuotients.MultiExp(quotients, randomNumbers, config) if err != nil { return nil @@ -478,7 +475,7 @@ func fold(di []Digest, fai []fr.Element, ci []fr.Element) (Digest, fr.Element, e // fold the digests ∑ᵢ[cᵢ]([fᵢ(α)]G₁) var foldedDigests Digest - _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{ScalarsMont: true}) + _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{}) if err != nil { return foldedDigests, foldedEvaluations, err } diff --git a/ecc/bw6-756/fr/mimc/decompose.go b/ecc/bw6-756/fr/mimc/decompose.go new file mode 100644 index 000000000..c5bb70239 --- /dev/null +++ b/ecc/bw6-756/fr/mimc/decompose.go @@ -0,0 +1,46 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" +) + +// Decompose interpret rawBytes as a bigInt x in big endian, +// and returns the digits of x (from LSB to MSB) when x is written +// in basis modulo. +func Decompose(rawBytes []byte) []fr.Element { + + rawBigInt := big.NewInt(0).SetBytes(rawBytes) + modulo := fr.Modulus() + + // maximum number of chunks that a function + maxNbChunks := len(rawBytes) / fr.Bytes + + res := make([]fr.Element, 0, maxNbChunks) + var tmp fr.Element + t := new(big.Int) + for rawBigInt.Sign() != 0 { + rawBigInt.DivMod(rawBigInt, modulo, t) + tmp.SetBigInt(t) + res = append(res, tmp) + } + + return res +} diff --git a/ecc/bw6-756/fr/mimc/decompose_test.go b/ecc/bw6-756/fr/mimc/decompose_test.go new file mode 100644 index 000000000..04d86a58c --- /dev/null +++ b/ecc/bw6-756/fr/mimc/decompose_test.go @@ -0,0 +1,53 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" +) + +func TestDecompose(t *testing.T) { + + // create 10 random digits in basis r + nbDigits := 10 + a := make([]fr.Element, nbDigits) + for i := 0; i < nbDigits; i++ { + a[i].SetRandom() + } + + // create a big int whose digits in basis r are a + m := fr.Modulus() + var b, tmp big.Int + for i := nbDigits - 1; i >= 0; i-- { + b.Mul(&b, m) + a[i].ToBigIntRegular(&tmp) + b.Add(&b, &tmp) + } + + // query the decomposition and compare to a + bb := b.Bytes() + d := Decompose(bb) + for i := 0; i < nbDigits; i++ { + if !d[i].Equal(&a[i]) { + t.Fatal("error decomposition") + } + } + +} diff --git a/ecc/bw6-756/fr/mimc/mimc.go b/ecc/bw6-756/fr/mimc/mimc.go index 26f6bb9e8..9218a8e35 100644 --- a/ecc/bw6-756/fr/mimc/mimc.go +++ b/ecc/bw6-756/fr/mimc/mimc.go @@ -17,6 +17,7 @@ package mimc import ( + "errors" "hash" "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" @@ -91,44 +92,44 @@ func (d *digest) BlockSize() int { } // Write (via the embedded io.Writer interface) adds more data to the running hash. -// It never returns an error. +// +// Each []byte block of size BlockSize represents a big endian fr.Element. +// +// If len(p) is not a multiple of BlockSize and any of the []byte in p represent an integer +// larger than fr.Modulus, this function returns an error. +// +// To hash arbitrary data ([]byte not representing canonical field elements) use Decompose +// function in this package. func (d *digest) Write(p []byte) (n int, err error) { n = len(p) + if n%BlockSize != 0 { + return 0, errors.New("invalid input length: must represent a list of field elements, expects a []byte of len m*BlockSize") + } + + // ensure each block represents a field element in canonical reduced form + for i := 0; i < n; i += BlockSize { + if _, err = fr.BigEndian.Element((*[BlockSize]byte)(p[i : i+BlockSize])); err != nil { + return 0, err + } + } + d.data = append(d.data, p...) return } -// Hash hash using Miyaguchi–Preneel: +// Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form func (d *digest) checksum() fr.Element { + // Write guarantees len(data) % BlockSize == 0 - var buffer [BlockSize]byte - var x fr.Element - - // if data size is not multiple of BlockSizes we padd: - // .. || 0xaf8 -> .. || 0x0000...0af8 - if len(d.data)%BlockSize != 0 { - q := len(d.data) / BlockSize - r := len(d.data) % BlockSize - sliceq := make([]byte, q*BlockSize) - copy(sliceq, d.data) - slicer := make([]byte, r) - copy(slicer, d.data[q*BlockSize:]) - sliceremainder := make([]byte, BlockSize-r) - d.data = append(sliceq, sliceremainder...) - d.data = append(d.data, slicer...) - } - + // TODO @ThomasPiellard shouldn't Sum() returns an error if there is no data? if len(d.data) == 0 { - d.data = make([]byte, 32) + d.data = make([]byte, BlockSize) } - nbChunks := len(d.data) / BlockSize - - for i := 0; i < nbChunks; i++ { - copy(buffer[:], d.data[i*BlockSize:(i+1)*BlockSize]) - x.SetBytes(buffer[:]) + for i := 0; i < len(d.data); i += BlockSize { + x, _ := fr.BigEndian.Element((*[BlockSize]byte)(d.data[i : i+BlockSize])) r := d.encrypt(x) d.h.Add(&r, &d.h).Add(&d.h, &x) } diff --git a/ecc/bw6-756/fr/pedersen/pedersen.go b/ecc/bw6-756/fr/pedersen/pedersen.go new file mode 100644 index 000000000..a670f6ac5 --- /dev/null +++ b/ecc/bw6-756/fr/pedersen/pedersen.go @@ -0,0 +1,113 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package pedersen + +import ( + "crypto/rand" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bw6-756" + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" + "math/big" +) + +// Key for proof and verification +type Key struct { + g bw6756.G2Affine // TODO @tabaie: does this really have to be randomized? + gRootSigmaNeg bw6756.G2Affine //gRootSigmaNeg = g^{-1/σ} + basis []bw6756.G1Affine + basisExpSigma []bw6756.G1Affine +} + +func randomOnG2() (bw6756.G2Affine, error) { // TODO: Add to G2.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return bw6756.G2Affine{}, err + } + return bw6756.HashToG2(gBytes, []byte("random on g2")) +} + +func Setup(basis []bw6756.G1Affine) (Key, error) { + var ( + k Key + err error + ) + + if k.g, err = randomOnG2(); err != nil { + return k, err + } + + var modMinusOne big.Int + modMinusOne.Sub(fr.Modulus(), big.NewInt(1)) + var sigma *big.Int + if sigma, err = rand.Int(rand.Reader, &modMinusOne); err != nil { + return k, err + } + sigma.Add(sigma, big.NewInt(1)) + + var sigmaInvNeg big.Int + sigmaInvNeg.ModInverse(sigma, fr.Modulus()) + sigmaInvNeg.Sub(fr.Modulus(), &sigmaInvNeg) + k.gRootSigmaNeg.ScalarMultiplication(&k.g, &sigmaInvNeg) + + k.basisExpSigma = make([]bw6756.G1Affine, len(basis)) + for i := range basis { + k.basisExpSigma[i].ScalarMultiplication(&basis[i], sigma) + } + + k.basis = basis + return k, err +} + +func (k *Key) Commit(values []fr.Element) (commitment bw6756.G1Affine, knowledgeProof bw6756.G1Affine, err error) { + + if len(values) != len(k.basis) { + err = fmt.Errorf("unexpected number of values") + return + } + + // TODO @gbotrel this will spawn more than one task, see + // https://github.com/ConsenSys/gnark-crypto/issues/269 + config := ecc.MultiExpConfig{ + NbTasks: 1, // TODO Experiment + } + + if _, err = commitment.MultiExp(k.basis, values, config); err != nil { + return + } + + _, err = knowledgeProof.MultiExp(k.basisExpSigma, values, config) + + return +} + +// VerifyKnowledgeProof checks if the proof of knowledge is valid +func (k *Key) VerifyKnowledgeProof(commitment bw6756.G1Affine, knowledgeProof bw6756.G1Affine) error { + + if !commitment.IsInSubGroup() || !knowledgeProof.IsInSubGroup() { + return fmt.Errorf("subgroup check failed") + } + + product, err := bw6756.Pair([]bw6756.G1Affine{commitment, knowledgeProof}, []bw6756.G2Affine{k.g, k.gRootSigmaNeg}) + if err != nil { + return err + } + if product.IsOne() { + return nil + } + return fmt.Errorf("proof rejected") +} diff --git a/ecc/bw6-756/fr/pedersen/pedersen_test.go b/ecc/bw6-756/fr/pedersen/pedersen_test.go new file mode 100644 index 000000000..a4b5387c5 --- /dev/null +++ b/ecc/bw6-756/fr/pedersen/pedersen_test.go @@ -0,0 +1,90 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package pedersen + +import ( + "github.com/consensys/gnark-crypto/ecc/bw6-756" + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" + "github.com/stretchr/testify/assert" + "math/rand" + "testing" +) + +func interfaceSliceToFrSlice(t *testing.T, values ...interface{}) []fr.Element { + res := make([]fr.Element, len(values)) + for i, v := range values { + _, err := res[i].SetInterface(v) + assert.NoError(t, err) + } + return res +} + +func randomFrSlice(t *testing.T, size int) []interface{} { + res := make([]interface{}, size) + var err error + for i := range res { + var v fr.Element + res[i], err = v.SetRandom() + assert.NoError(t, err) + } + return res +} + +func randomOnG1() (bw6756.G1Affine, error) { // TODO: Add to G1.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return bw6756.G1Affine{}, err + } + return bw6756.HashToG1(gBytes, []byte("random on g2")) +} + +func testCommit(t *testing.T, values ...interface{}) { + + basis := make([]bw6756.G1Affine, len(values)) + for i := range basis { + var err error + basis[i], err = randomOnG1() + assert.NoError(t, err) + } + + var ( + key Key + err error + commitment, pok bw6756.G1Affine + ) + + key, err = Setup(basis) + assert.NoError(t, err) + commitment, pok, err = key.Commit(interfaceSliceToFrSlice(t, values...)) + assert.NoError(t, err) + assert.NoError(t, key.VerifyKnowledgeProof(commitment, pok)) + + pok.Neg(&pok) + assert.NotNil(t, key.VerifyKnowledgeProof(commitment, pok)) +} + +func TestCommitToOne(t *testing.T) { + testCommit(t, 1) +} + +func TestCommitSingle(t *testing.T) { + testCommit(t, randomFrSlice(t, 1)...) +} + +func TestCommitFiveElements(t *testing.T) { + testCommit(t, randomFrSlice(t, 5)...) +} diff --git a/ecc/bw6-756/fr/plookup/vector.go b/ecc/bw6-756/fr/plookup/vector.go index 3bee7af96..b992efa13 100644 --- a/ecc/bw6-756/fr/plookup/vector.go +++ b/ecc/bw6-756/fr/plookup/vector.go @@ -125,7 +125,8 @@ func evaluateAccumulationPolynomial(lf, lt, lh1, lh2 []fr.Element, beta, gamma f // evaluateNumBitReversed computes the evaluation (shifted, bit reversed) of h where // h = (x-1)*z*(1+\beta)*(\gamma+f)*(\gamma(1+\beta) + t+ \beta*t(gX)) - -// (x-1)*z(gX)*(\gamma(1+\beta) + h_{1} + \beta*h_{1}(gX))*(\gamma(1+\beta) + h_{2} + \beta*h_{2}(gX) ) +// +// (x-1)*z(gX)*(\gamma(1+\beta) + h_{1} + \beta*h_{1}(gX))*(\gamma(1+\beta) + h_{2} + \beta*h_{2}(gX) ) // // * cz, ch1, ch2, ct, cf are the polynomials z, h1, h2, t, f in canonical basis // * _lz, _lh1, _lh2, _lt, _lf are the polynomials z, h1, h2, t, f in shifted Lagrange basis (domainBig) @@ -370,7 +371,6 @@ func computeQuotientCanonical(alpha fr.Element, lh, lh0, lhn, lh1h2 []fr.Element // before generating a lookup proof), the commitment needs to be done on the // table sorted. Otherwise the commitment in proof.t will not be the same as // the public commitment: it will contain the same values, but permuted. -// func ProveLookupVector(srs *kzg.SRS, f, t Table) (ProofLookupVector, error) { // res diff --git a/ecc/bw6-756/fr/polynomial/multilin.go b/ecc/bw6-756/fr/polynomial/multilin.go index 26ba3f26f..551f10c40 100644 --- a/ecc/bw6-756/fr/polynomial/multilin.go +++ b/ecc/bw6-756/fr/polynomial/multilin.go @@ -18,6 +18,7 @@ package polynomial import ( "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" + "math/bits" ) // MultiLin tracks the values of a (dense i.e. not sparse) multilinear polynomial @@ -46,44 +47,62 @@ func (m *MultiLin) Fold(r fr.Element) { *m = (*m)[:mid] } +func (m MultiLin) Sum() fr.Element { + s := m[0] + for i := 1; i < len(m); i++ { + s.Add(&s, &m[i]) + } + return s +} + +func _clone(m MultiLin, p *Pool) MultiLin { + if p == nil { + return m.Clone() + } else { + return p.Clone(m) + } +} + +func _dump(m MultiLin, p *Pool) { + if p != nil { + p.Dump(m) + } +} + // Evaluate extrapolate the value of the multilinear polynomial corresponding to m // on the given coordinates -func (m MultiLin) Evaluate(coordinates []fr.Element) fr.Element { +func (m MultiLin) Evaluate(coordinates []fr.Element, p *Pool) fr.Element { // Folding is a mutating operation - bkCopy := m.Clone() + bkCopy := _clone(m, p) // Evaluate step by step through repeated folding (i.e. evaluation at the first remaining variable) for _, r := range coordinates { bkCopy.Fold(r) } - return bkCopy[0] + result := bkCopy[0] + + _dump(bkCopy, p) + return result } -// Clone creates a deep copy of a book-keeping table. +// Clone creates a deep copy of a bookkeeping table. // Both multilinear interpolation and sumcheck require folding an underlying // array, but folding changes the array. To do both one requires a deep copy -// of the book-keeping table. +// of the bookkeeping table. func (m MultiLin) Clone() MultiLin { - tableDeepCopy := Make(len(m)) - copy(tableDeepCopy, m) - return tableDeepCopy + res := make(MultiLin, len(m)) + copy(res, m) + return res } // Add two bookKeepingTables func (m *MultiLin) Add(left, right MultiLin) { size := len(left) // Check that left and right have the same size - if len(right) != size { - panic("Left and right do not have the right size") + if len(right) != size || len(*m) != size { + panic("left, right and destination must have the right size") } - // Reallocate the table if necessary - if cap(*m) < size { - *m = make([]fr.Element, size) - } - - // Resize the destination table - *m = (*m)[:size] // Add elementwise for i := 0; i < size; i++ { @@ -93,15 +112,17 @@ func (m *MultiLin) Add(left, right MultiLin) { // EvalEq computes Eq(q₁, ... , qₙ, h₁, ... , hₙ) = Π₁ⁿ Eq(qᵢ, hᵢ) // where Eq(x,y) = xy + (1-x)(1-y) = 1 - x - y + xy + xy interpolates -// _________________ -// | | | -// | 0 | 1 | -// |_______|_______| -// y | | | -// | 1 | 0 | -// |_______|_______| // -// x +// _________________ +// | | | +// | 0 | 1 | +// |_______|_______| +// y | | | +// | 1 | 0 | +// |_______|_______| +// +// x +// // In other words the polynomial evaluated here is the multilinear extrapolation of // one that evaluates to q' == h' for vectors q', h' of binary values func EvalEq(q, h []fr.Element) fr.Element { @@ -128,10 +149,7 @@ func (m *MultiLin) Eq(q []fr.Element) { n := len(q) if len(*m) != 1< 0 { - i.Sub(fr.Modulus(), &i) - i.Neg(&i) +func (p Polynomial) SetZero() { + for i := 0; i < len(p); i++ { + p[i].SetZero() } - return i } func (p Polynomial) Text(base int) string { @@ -171,12 +175,13 @@ func (p Polynomial) Text(base int) string { continue } - i := signedBigInt(&p[d]) + pD := p[d] + pDText := pD.Text(base) initialLen := builder.Len() - if i.Sign() < 1 { - i.Neg(&i) + if pDText[0] == '-' { + pDText = pDText[1:] if first { builder.WriteString("-") } else { @@ -188,13 +193,8 @@ func (p Polynomial) Text(base int) string { first = false - asInt64 := int64(0) - if i.IsInt64() { - asInt64 = i.Int64() - } - - if asInt64 != 1 || d == 0 { - builder.WriteString(i.Text(base)) + if !pD.IsOne() || d == 0 { + builder.WriteString(pDText) } if builder.Len()-initialLen > 10 { diff --git a/ecc/bw6-756/fr/polynomial/polynomial_test.go b/ecc/bw6-756/fr/polynomial/polynomial_test.go index 9a1298763..414f3a307 100644 --- a/ecc/bw6-756/fr/polynomial/polynomial_test.go +++ b/ecc/bw6-756/fr/polynomial/polynomial_test.go @@ -17,10 +17,10 @@ package polynomial import ( + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" + "github.com/stretchr/testify/assert" "math/big" "testing" - - "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" ) func TestPolynomialEval(t *testing.T) { @@ -206,3 +206,13 @@ func TestPolynomialAdd(t *testing.T) { t.Fatal("side effect, _f2 should not have been modified") } } + +func TestPolynomialText(t *testing.T) { + var one, negTwo fr.Element + one.SetOne() + negTwo.SetInt64(-2) + + p := Polynomial{one, negTwo, one} + + assert.Equal(t, "X² - 2X + 1", p.Text(10)) +} diff --git a/ecc/bw6-756/fr/polynomial/pool.go b/ecc/bw6-756/fr/polynomial/pool.go index 3d3f5155f..98e8f3e02 100644 --- a/ecc/bw6-756/fr/polynomial/pool.go +++ b/ecc/bw6-756/fr/polynomial/pool.go @@ -17,114 +17,187 @@ package polynomial import ( + "encoding/json" "fmt" "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" "reflect" + "runtime" + "sort" "sync" "unsafe" ) // Memory management for polynomials -// Copied verbatim from gkr repo +// WARNING: This is not thread safe TODO: Make sure that is not a problem +// TODO: There is a lot of "unsafe" memory management here and needs to be vetted thoroughly -// Sets a maximum for the array size we keep in pool -const maxNForLargePool int = 1 << 24 -const maxNForSmallPool int = 256 +type sizedPool struct { + maxN int + pool sync.Pool + stats poolStats +} + +type inUseData struct { + allocatedFor []uintptr + pool *sizedPool +} -// Aliases because it is annoying to use arrays in all the places -type largeArr = [maxNForLargePool]fr.Element -type smallArr = [maxNForSmallPool]fr.Element +type Pool struct { + //lock sync.Mutex + inUse map[*fr.Element]inUseData + subPools []sizedPool +} -var rC = sync.Map{} +func (p *sizedPool) get(n int) *fr.Element { + p.stats.maake(n) + return p.pool.Get().(*fr.Element) +} -var ( - largePool = sync.Pool{ - New: func() interface{} { - var res largeArr - return &res - }, - } - smallPool = sync.Pool{ - New: func() interface{} { - var res smallArr - return &res - }, +func (p *sizedPool) put(ptr *fr.Element) { + p.stats.dump() + p.pool.Put(ptr) +} + +func NewPool(maxN ...int) (pool Pool) { + + sort.Ints(maxN) + pool = Pool{ + inUse: make(map[*fr.Element]inUseData), + subPools: make([]sizedPool, len(maxN)), } -) -// ClearPool Clears the pool completely, shields against memory leaks -// Eg: if we forgot to dump a polynomial at some point, this will ensure the value get dumped eventually -// Returns how many polynomials were cleared that way -func ClearPool() int { - res := 0 - rC.Range(func(k, _ interface{}) bool { - switch ptr := k.(type) { - case *largeArr: - largePool.Put(ptr) - case *smallArr: - smallPool.Put(ptr) - default: - panic(fmt.Sprintf("tried to clear %v", reflect.TypeOf(ptr))) + for i := range pool.subPools { + subPool := &pool.subPools[i] + subPool.maxN = maxN[i] + subPool.pool = sync.Pool{ + New: func() interface{} { + subPool.stats.Allocated++ + return getDataPointer(make([]fr.Element, 0, subPool.maxN)) + }, } - res++ - return true - }) - return res + } + return } -// CountPool Returns the number of elements in the pool without mutating it -func CountPool() int { - res := 0 - rC.Range(func(_, _ interface{}) bool { - res++ - return true - }) - return res +func (p *Pool) findCorrespondingPool(n int) *sizedPool { + poolI := 0 + for poolI < len(p.subPools) && n > p.subPools[poolI].maxN { + poolI++ + } + return &p.subPools[poolI] // out of bounds error here would mean that n is too large } -// Make tries to find a reusable polynomial or allocates a new one -func Make(n int) []fr.Element { - if n > maxNForLargePool { - panic(fmt.Sprintf("been provided with size of %v but the maximum is %v", n, maxNForLargePool)) +func (p *Pool) Make(n int) []fr.Element { + pool := p.findCorrespondingPool(n) + ptr := pool.get(n) + p.addInUse(ptr, pool) + return unsafe.Slice(ptr, n) +} + +// Dump dumps a set of polynomials into the pool +func (p *Pool) Dump(slices ...[]fr.Element) { + for _, slice := range slices { + ptr := getDataPointer(slice) + if metadata, ok := p.inUse[ptr]; ok { + delete(p.inUse, ptr) + metadata.pool.put(ptr) + } else { + panic("attempting to dump a slice not created by the pool") + } } +} + +func (p *Pool) addInUse(ptr *fr.Element, pool *sizedPool) { + pcs := make([]uintptr, 2) + n := runtime.Callers(3, pcs) - if n <= maxNForSmallPool { - ptr := smallPool.Get().(*smallArr) - rC.Store(ptr, struct{}{}) // registers the pointer being used - return (*ptr)[:n] + if prevPcs, ok := p.inUse[ptr]; ok { // TODO: remove if unnecessary for security + panic(fmt.Errorf("re-allocated non-dumped slice, previously allocated at %v", runtime.CallersFrames(prevPcs.allocatedFor))) } + p.inUse[ptr] = inUseData{ + allocatedFor: pcs[:n], + pool: pool, + } +} - ptr := largePool.Get().(*largeArr) - rC.Store(ptr, struct{}{}) // remember we allocated the pointer is being used - return (*ptr)[:n] +func printFrame(frame runtime.Frame) { + fmt.Printf("\t%s line %d, function %s\n", frame.File, frame.Line, frame.Function) } -// Dump dumps a set of polynomials into the pool -// Returns the number of deallocated polys -func Dump(arrs ...[]fr.Element) int { - cnt := 0 - for _, arr := range arrs { - ptr := ptr(arr) - pool := &smallPool - if len(arr) > maxNForSmallPool { - pool = &largePool - } - // If the rC did not register, then - // either the array was allocated somewhere else which can be ignored - // otherwise a double put which MUST be ignored - if _, ok := rC.Load(ptr); ok { - pool.Put(ptr) - // And deregisters the ptr - rC.Delete(ptr) - cnt++ +func (p *Pool) printInUse() { + fmt.Println("slices never dumped allocated at:") + for _, pcs := range p.inUse { + fmt.Println("-------------------------") + + var frame runtime.Frame + frames := runtime.CallersFrames(pcs.allocatedFor) + more := true + for more { + frame, more = frames.Next() + printFrame(frame) } } - return cnt } -func ptr(m []fr.Element) unsafe.Pointer { - if cap(m) != maxNForSmallPool && cap(m) != maxNForLargePool { - panic(fmt.Sprintf("can't cast to large or small array, the put array's is %v it should have capacity %v or %v", cap(m), maxNForLargePool, maxNForSmallPool)) +type poolStats struct { + Used int + Allocated int + ReuseRate float64 + InUse int + GreatestNUsed int + SmallestNUsed int +} + +type poolsStats struct { + SubPools []poolStats + InUse int +} + +func (s *poolStats) maake(n int) { + s.Used++ + s.InUse++ + if n > s.GreatestNUsed { + s.GreatestNUsed = n + } + if s.SmallestNUsed == 0 || s.SmallestNUsed > n { + s.SmallestNUsed = n + } +} + +func (s *poolStats) dump() { + s.InUse-- +} + +func (s *poolStats) finalize() { + s.ReuseRate = float64(s.Used) / float64(s.Allocated) +} + +func getDataPointer(slice []fr.Element) *fr.Element { + header := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) + return (*fr.Element)(unsafe.Pointer(header.Data)) +} + +func (p *Pool) PrintPoolStats() { + InUse := 0 + subStats := make([]poolStats, len(p.subPools)) + for i := range p.subPools { + subPool := &p.subPools[i] + subPool.stats.finalize() + subStats[i] = subPool.stats + InUse += subPool.stats.InUse } - return unsafe.Pointer(&m[0]) + + poolsStats := poolsStats{ + SubPools: subStats, + InUse: InUse, + } + serialized, _ := json.MarshalIndent(poolsStats, "", " ") + fmt.Println(string(serialized)) + p.printInUse() +} + +func (p *Pool) Clone(slice []fr.Element) []fr.Element { + res := p.Make(len(slice)) + copy(res, slice) + return res } diff --git a/ecc/bw6-756/fr/sumcheck/sumcheck.go b/ecc/bw6-756/fr/sumcheck/sumcheck.go new file mode 100644 index 000000000..c27e7da8a --- /dev/null +++ b/ecc/bw6-756/fr/sumcheck/sumcheck.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = &transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return fmt.Errorf("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/ecc/bw6-756/fr/sumcheck/sumcheck_test.go b/ecc/bw6-756/fr/sumcheck/sumcheck_test.go new file mode 100644 index 000000000..99801b312 --- /dev/null +++ b/ecc/bw6-756/fr/sumcheck/sumcheck_test.go @@ -0,0 +1,161 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/ecc/bw6-756/fr/test_vector_utils/test_vector_utils.go b/ecc/bw6-756/fr/test_vector_utils/test_vector_utils.go new file mode 100644 index 000000000..cf5544679 --- /dev/null +++ b/ecc/bw6-756/fr/test_vector_utils/test_vector_utils.go @@ -0,0 +1,429 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package test_vector_utils + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-756/fr/polynomial" + "hash" + + "os" + "path/filepath" + "reflect" + "sort" + "strconv" + "strings" +) + +type ElementTriplet struct { + key1 fr.Element + key2 fr.Element + key2Present bool + value fr.Element + used bool +} + +func (t *ElementTriplet) CmpKey(o *ElementTriplet) int { + if cmp1 := t.key1.Cmp(&o.key1); cmp1 != 0 { + return cmp1 + } + + if t.key2Present { + if o.key2Present { + return t.key2.Cmp(&o.key2) + } + return 1 + } else { + if o.key2Present { + return -1 + } + return 0 + } +} + +var MapCache = make(map[string]*ElementMap) + +func ElementMapFromFile(path string) (*ElementMap, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if h, ok := MapCache[path]; ok { + return h, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var asMap map[string]interface{} + if err = json.Unmarshal(bytes, &asMap); err != nil { + return nil, err + } + + var h ElementMap + if h, err = CreateElementMap(asMap); err == nil { + MapCache[path] = &h + } + + return &h, err + + } else { + return nil, err + } +} + +func CreateElementMap(rawMap map[string]interface{}) (ElementMap, error) { + res := make(ElementMap, 0, len(rawMap)) + + for k, v := range rawMap { + var entry ElementTriplet + if _, err := SetElement(&entry.value, v); err != nil { + return nil, err + } + + key := strings.Split(k, ",") + switch len(key) { + case 1: + entry.key2Present = false + case 2: + entry.key2Present = true + if _, err := SetElement(&entry.key2, key[1]); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("cannot parse %T as one or two field elements", v) + } + if _, err := SetElement(&entry.key1, key[0]); err != nil { + return nil, err + } + + res = append(res, &entry) + } + + res.sort() + return res, nil +} + +type ElementMap []*ElementTriplet + +type MapHash struct { + Map *ElementMap + state fr.Element + stateValid bool +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func (m *MapHash) Write(p []byte) (n int, err error) { + var x fr.Element + for i := 0; i < len(p); i += fr.Bytes { + x.SetBytes(p[i:min(len(p), i+fr.Bytes)]) + if err = m.write(x); err != nil { + return + } + } + n = len(p) + return +} + +func (m *MapHash) Sum(b []byte) []byte { + mP := *m + if _, err := mP.Write(b); err != nil { + panic(err) + } + bytes := mP.state.Bytes() + return bytes[:] +} + +func (m *MapHash) Reset() { + m.stateValid = false +} + +func (m *MapHash) Size() int { + return fr.Bytes +} + +func (m *MapHash) BlockSize() int { + return fr.Bytes +} + +func (m *MapHash) write(x fr.Element) error { + X := &x + Y := &m.state + if !m.stateValid { + Y = nil + } + var err error + if m.state, err = m.Map.FindPair(X, Y); err == nil { + m.stateValid = true + } + return err +} + +func (t *ElementTriplet) writeKey(sb *strings.Builder) { + sb.WriteRune('"') + sb.WriteString(t.key1.String()) + if t.key2Present { + sb.WriteRune(',') + sb.WriteString(t.key2.String()) + } + sb.WriteRune('"') +} +func (m *ElementMap) UnusedEntries() []interface{} { + unused := make([]interface{}, 0) + for _, v := range *m { + if !v.used { + var vInterface interface{} + if v.key2Present { + vInterface = []interface{}{ElementToInterface(&v.key1), ElementToInterface(&v.key2)} + } else { + vInterface = ElementToInterface(&v.key1) + } + unused = append(unused, vInterface) + } + } + return unused +} + +func (m *ElementMap) sort() { + sort.Slice(*m, func(i, j int) bool { + return (*m)[i].CmpKey((*m)[j]) <= 0 + }) +} + +func (m *ElementMap) find(toFind *ElementTriplet) (fr.Element, error) { + i := sort.Search(len(*m), func(i int) bool { return (*m)[i].CmpKey(toFind) >= 0 }) + + if i < len(*m) && (*m)[i].CmpKey(toFind) == 0 { + (*m)[i].used = true + return (*m)[i].value, nil + } + var sb strings.Builder + sb.WriteString("no value available for input ") + toFind.writeKey(&sb) + return fr.Element{}, fmt.Errorf(sb.String()) +} + +func (m *ElementMap) FindPair(x *fr.Element, y *fr.Element) (fr.Element, error) { + + toFind := ElementTriplet{ + key1: *x, + key2Present: y != nil, + } + + if y != nil { + toFind.key2 = *y + } + + return m.find(&toFind) +} + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type MessageCounter struct { + startState uint64 + state uint64 + step uint64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += uint64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + uint64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: uint64(startState), state: uint64(startState), step: uint64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + text := x.Text(10) + if len(text) < 10 && !strings.Contains(text, "/") { + if i, err := strconv.Atoi(text); err != nil { + panic(err.Error()) + } else { + return i + } + } + return text +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/ecc/bw6-756/g1.go b/ecc/bw6-756/g1.go index 899113c8d..9ad403028 100644 --- a/ecc/bw6-756/g1.go +++ b/ecc/bw6-756/g1.go @@ -17,13 +17,12 @@ package bw6756 import ( - "math/big" - "runtime" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-756/fp" "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" + "runtime" ) // G1Affine point in affine coordinates @@ -36,7 +35,7 @@ type G1Jac struct { X, Y, Z fp.Element } -// g1JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +// g1JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) type g1JacExtended struct { X, Y, ZZ, ZZZ fp.Element } @@ -55,6 +54,13 @@ func (p *G1Affine) Set(a *G1Affine) *G1Affine { return p } +// setInfinity sets p to O +func (p *G1Affine) setInfinity() *G1Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + // ScalarMultiplication computes and returns p = a ⋅ s func (p *G1Affine) ScalarMultiplication(a *G1Affine, s *big.Int) *G1Affine { var _p G1Jac @@ -378,6 +384,7 @@ func (p *G1Jac) IsOnCurve() bool { } // IsInSubGroup returns true if p is on the r-torsion, false otherwise. + // Z[r,0]+Z[-lambdaG1Affine, 1] is the kernel // of (u,v)->u+lambdaG1Affinev mod r. Expressing r, lambdaG1Affine as // polynomials in x, a short vector of this Zmodule is @@ -481,8 +488,8 @@ func (p *G1Jac) mulGLV(a *G1Jac, s *big.Int) *G1Jac { // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // we don't target constant-timeness so we check first if we increase the bounds or not maxBit := k1.BitLen() @@ -521,6 +528,7 @@ func (p *G1Affine) ClearCofactor(a *G1Affine) *G1Affine { // ClearCofactor maps a point in E(Fp) to E(Fp)[r] func (p *G1Jac) ClearCofactor(a *G1Jac) *G1Jac { + var L0, L1, uP, u2P, u3P, tmp G1Jac uP.ScalarMultiplication(a, &xGen) @@ -612,15 +620,15 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { return p } - var A, B, X1ZZ2, X2ZZ1, Y1ZZZ2, Y2ZZZ1 fp.Element + var A, B, U1, U2, S1, S2 fp.Element // p2: q, p1: p - X2ZZ1.Mul(&q.X, &p.ZZ) - X1ZZ2.Mul(&p.X, &q.ZZ) - A.Sub(&X2ZZ1, &X1ZZ2) - Y2ZZZ1.Mul(&q.Y, &p.ZZZ) - Y1ZZZ2.Mul(&p.Y, &q.ZZZ) - B.Sub(&Y2ZZZ1, &Y1ZZZ2) + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) if A.IsZero() { if B.IsZero() { @@ -632,11 +640,7 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { return p } - var U1, U2, S1, S2, P, R, PP, PPP, Q, V fp.Element - U1.Mul(&p.X, &q.ZZ) - U2.Mul(&q.X, &p.ZZ) - S1.Mul(&p.Y, &q.ZZZ) - S2.Mul(&q.Y, &p.ZZZ) + var P, R, PP, PPP, Q, V fp.Element P.Sub(&U2, &U1) R.Sub(&S2, &S1) PP.Square(&P) @@ -661,6 +665,8 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { // double point in Jacobian extended coordinates // http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well func (p *g1JacExtended) double(q *g1JacExtended) *g1JacExtended { var U, V, W, S, XX, M fp.Element @@ -978,95 +984,72 @@ func BatchJacobianToAffineG1(points []G1Jac) []G1Affine { // and return resulting points in affine coordinates // uses a simple windowed-NAF like exponentiation algorithm func BatchScalarMultiplicationG1(base *G1Affine, scalars []fr.Element) []G1Affine { - // approximate cost in group ops is // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) nbPoints := uint64(len(scalars)) min := ^uint64(0) bestC := 0 - for c := 2; c < 18; c++ { - cost := uint64(1 << (c - 1)) - nbChunks := uint64(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - cost += nbPoints * ((fr.Limbs * 64) + nbChunks) + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c - 1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add if cost < min { min = cost bestC = c } } c := uint64(bestC) // window size - nbChunks := int(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c } - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // precompute all powers of base for our window // note here that if performance is critical, we can implement as in the msmX methods // this allocation to be on the stack - baseTable := make([]G1Jac, (1 << (c - 1))) - baseTable[0].Set(&g1Infinity) - baseTable[0].AddMixed(base) + baseTable := make([]G1Jac, (1 << (maxC - 1))) + baseTable[0].FromAffine(base) for i := 1; i < len(baseTable); i++ { baseTable[i] = baseTable[i-1] baseTable[i].AddMixed(base) } - - pScalars, _ := partitionScalars(scalars, c, false, runtime.NumCPU()) - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := 0; chunk < nbChunks; chunk++ { - jc := uint64(uint64(chunk) * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = (64%c) != 0 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } // convert our base exp table into affine to use AddMixed baseTableAff := BatchJacobianToAffineG1(baseTable) toReturn := make([]G1Jac, len(scalars)) + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + // for each digit, take value in the base table, double it c time, voilà. - parallel.Execute(len(pScalars), func(start, end int) { + parallel.Execute(len(scalars), func(start, end int) { var p G1Jac for i := start; i < end; i++ { p.Set(&g1Infinity) for chunk := nbChunks - 1; chunk >= 0; chunk-- { - s := selectors[chunk] if chunk != nbChunks-1 { for j := uint64(0); j < c; j++ { p.DoubleAssign() } } + offset := chunk * len(scalars) + digit := digits[i+offset] - bits := (pScalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (pScalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { + if digit == 0 { continue } // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { + if digit&1 == 0 { // add - p.AddMixed(&baseTableAff[bits-1]) + p.AddMixed(&baseTableAff[(digit>>1)-1]) } else { // sub - t := baseTableAff[bits & ^msbWindow] + t := baseTableAff[digit>>1] t.Neg(&t) p.AddMixed(&t) } @@ -1080,3 +1063,54 @@ func BatchScalarMultiplicationG1(base *G1Affine, scalars []fr.Element) []G1Affin toReturnAff := BatchJacobianToAffineG1(toReturn) return toReturnAff } + +// batch add affine coordinates +// using batch inversion +// special cases (doubling, infinity) must be filtered out before this call +func batchAddG1Affine[TP pG1Affine, TPP ppG1Affine, TC cG1Affine](R *TPP, P *TP, batchSize int) { + var lambda, lambdain TC + + // add part + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // invert denominator using montgomery batch invert technique + { + var accumulator fp.Element + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var d fp.Element + var rr G1Affine + + // add part + for j := 0; j < batchSize; j++ { + // computa lambda + d.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &d) + + // compute X, Y + rr.X.Square(&lambda[j]) + rr.X.Sub(&rr.X, &(*R)[j].X) + rr.X.Sub(&rr.X, &(*P)[j].X) + d.Sub(&(*R)[j].X, &rr.X) + rr.Y.Mul(&lambda[j], &d) + rr.Y.Sub(&rr.Y, &(*R)[j].Y) + (*R)[j].Set(&rr) + } +} diff --git a/ecc/bw6-756/g1_test.go b/ecc/bw6-756/g1_test.go index 81ecf8155..4ab6a4cee 100644 --- a/ecc/bw6-756/g1_test.go +++ b/ecc/bw6-756/g1_test.go @@ -19,6 +19,7 @@ package bw6756 import ( "fmt" "math/big" + "math/rand" "testing" "github.com/consensys/gnark-crypto/ecc/bw6-756/fp" @@ -338,7 +339,7 @@ func TestG1AffineOps(t *testing.T) { r := fr.Modulus() var g G1Jac - g.mulGLV(&g1Gen, r) + g.ScalarMultiplication(&g1Gen, r) var scalar, blindedScalar, rminusone big.Int var op1, op2, op3, gneg G1Jac @@ -458,8 +459,7 @@ func TestG1AffineBatchScalarMultiplication(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } result := BatchScalarMultiplicationG1(&g1GenAff, sampleScalars[:]) @@ -472,7 +472,7 @@ func TestG1AffineBatchScalarMultiplication(t *testing.T) { var expectedJac G1Jac var expected G1Affine var b big.Int - expectedJac.mulGLV(&g1Gen, sampleScalars[i].ToBigInt(&b)) + expectedJac.ScalarMultiplication(&g1Gen, sampleScalars[i].ToBigIntRegular(&b)) expected.FromJacobian(&expectedJac) if !result[i].Equal(&expected) { return false @@ -499,6 +499,33 @@ func BenchmarkG1JacIsInSubGroup(b *testing.B) { } +func BenchmarkBatchAddG1Affine(b *testing.B) { + + var P, R pG1AffineC16 + var RR ppG1AffineC16 + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBasesG1(P[:]) + fillBenchBasesG1(R[:]) + + for i := 0; i < len(ridx); i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAddG1Affine[pG1AffineC16, ppG1AffineC16, cG1AffineC16](&RR, &P, len(P)) + } +} + func BenchmarkG1AffineBatchScalarMultiplication(b *testing.B) { // ensure every words of the scalars are filled var mixer fr.Element @@ -511,8 +538,7 @@ func BenchmarkG1AffineBatchScalarMultiplication(b *testing.B) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } for i := 5; i <= pow; i++ { diff --git a/ecc/bw6-756/g2.go b/ecc/bw6-756/g2.go index 3abfa71d5..641b798cd 100644 --- a/ecc/bw6-756/g2.go +++ b/ecc/bw6-756/g2.go @@ -17,13 +17,12 @@ package bw6756 import ( - "math/big" - "runtime" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-756/fp" "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" + "runtime" ) // G2Affine point in affine coordinates @@ -36,7 +35,7 @@ type G2Jac struct { X, Y, Z fp.Element } -// g2JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +// g2JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) type g2JacExtended struct { X, Y, ZZ, ZZZ fp.Element } @@ -50,6 +49,13 @@ func (p *G2Affine) Set(a *G2Affine) *G2Affine { return p } +// setInfinity sets p to O +func (p *G2Affine) setInfinity() *G2Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + // ScalarMultiplication computes and returns p = a ⋅ s func (p *G2Affine) ScalarMultiplication(a *G2Affine, s *big.Int) *G2Affine { var _p G2Jac @@ -365,6 +371,7 @@ func (p *G2Jac) IsOnCurve() bool { } // IsInSubGroup returns true if p is on the r-torsion, false otherwise. + // Z[r,0]+Z[-lambdaG2Affine, 1] is the kernel // of (u,v)->u+lambdaG2Affinev mod r. Expressing r, lambdaG2Affine as // polynomials in x, a short vector of this Zmodule is @@ -468,8 +475,8 @@ func (p *G2Jac) mulGLV(a *G2Jac, s *big.Int) *G2Jac { // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // we don't target constant-timeness so we check first if we increase the bounds or not maxBit := k1.BitLen() @@ -508,7 +515,6 @@ func (p *G2Affine) ClearCofactor(a *G2Affine) *G2Affine { // ClearCofactor maps a point in curve to r-torsion func (p *G2Jac) ClearCofactor(a *G2Jac) *G2Jac { - var L0, L1, uP, u2P, u3P, tmp G2Jac uP.ScalarMultiplication(a, &xGen) @@ -596,15 +602,15 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { return p } - var A, B, X1ZZ2, X2ZZ1, Y1ZZZ2, Y2ZZZ1 fp.Element + var A, B, U1, U2, S1, S2 fp.Element // p2: q, p1: p - X2ZZ1.Mul(&q.X, &p.ZZ) - X1ZZ2.Mul(&p.X, &q.ZZ) - A.Sub(&X2ZZ1, &X1ZZ2) - Y2ZZZ1.Mul(&q.Y, &p.ZZZ) - Y1ZZZ2.Mul(&p.Y, &q.ZZZ) - B.Sub(&Y2ZZZ1, &Y1ZZZ2) + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) if A.IsZero() { if B.IsZero() { @@ -616,11 +622,7 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { return p } - var U1, U2, S1, S2, P, R, PP, PPP, Q, V fp.Element - U1.Mul(&p.X, &q.ZZ) - U2.Mul(&q.X, &p.ZZ) - S1.Mul(&p.Y, &q.ZZZ) - S2.Mul(&q.Y, &p.ZZZ) + var P, R, PP, PPP, Q, V fp.Element P.Sub(&U2, &U1) R.Sub(&S2, &S1) PP.Square(&P) @@ -645,6 +647,8 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { // double point in Jacobian extended coordinates // http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well func (p *g2JacExtended) double(q *g2JacExtended) *g2JacExtended { var U, V, W, S, XX, M fp.Element @@ -838,93 +842,70 @@ func (p *g2JacExtended) doubleMixed(q *G2Affine) *g2JacExtended { // and return resulting points in affine coordinates // uses a simple windowed-NAF like exponentiation algorithm func BatchScalarMultiplicationG2(base *G2Affine, scalars []fr.Element) []G2Affine { - // approximate cost in group ops is // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) nbPoints := uint64(len(scalars)) min := ^uint64(0) bestC := 0 - for c := 2; c < 18; c++ { - cost := uint64(1 << (c - 1)) - nbChunks := uint64(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - cost += nbPoints * ((fr.Limbs * 64) + nbChunks) + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c - 1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add if cost < min { min = cost bestC = c } } c := uint64(bestC) // window size - nbChunks := int(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c } - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // precompute all powers of base for our window // note here that if performance is critical, we can implement as in the msmX methods // this allocation to be on the stack - baseTable := make([]G2Jac, (1 << (c - 1))) - baseTable[0].Set(&g2Infinity) - baseTable[0].AddMixed(base) + baseTable := make([]G2Jac, (1 << (maxC - 1))) + baseTable[0].FromAffine(base) for i := 1; i < len(baseTable); i++ { baseTable[i] = baseTable[i-1] baseTable[i].AddMixed(base) } - - pScalars, _ := partitionScalars(scalars, c, false, runtime.NumCPU()) - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := 0; chunk < nbChunks; chunk++ { - jc := uint64(uint64(chunk) * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = (64%c) != 0 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } toReturn := make([]G2Affine, len(scalars)) + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + // for each digit, take value in the base table, double it c time, voilà. - parallel.Execute(len(pScalars), func(start, end int) { + parallel.Execute(len(scalars), func(start, end int) { var p G2Jac for i := start; i < end; i++ { p.Set(&g2Infinity) for chunk := nbChunks - 1; chunk >= 0; chunk-- { - s := selectors[chunk] if chunk != nbChunks-1 { for j := uint64(0); j < c; j++ { p.DoubleAssign() } } + offset := chunk * len(scalars) + digit := digits[i+offset] - bits := (pScalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (pScalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { + if digit == 0 { continue } // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { + if digit&1 == 0 { // add - p.AddAssign(&baseTable[bits-1]) + p.AddAssign(&baseTable[(digit>>1)-1]) } else { // sub - t := baseTable[bits & ^msbWindow] + t := baseTable[digit>>1] t.Neg(&t) p.AddAssign(&t) } @@ -937,3 +918,54 @@ func BatchScalarMultiplicationG2(base *G2Affine, scalars []fr.Element) []G2Affin }) return toReturn } + +// batch add affine coordinates +// using batch inversion +// special cases (doubling, infinity) must be filtered out before this call +func batchAddG2Affine[TP pG2Affine, TPP ppG2Affine, TC cG2Affine](R *TPP, P *TP, batchSize int) { + var lambda, lambdain TC + + // add part + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // invert denominator using montgomery batch invert technique + { + var accumulator fp.Element + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var d fp.Element + var rr G2Affine + + // add part + for j := 0; j < batchSize; j++ { + // computa lambda + d.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &d) + + // compute X, Y + rr.X.Square(&lambda[j]) + rr.X.Sub(&rr.X, &(*R)[j].X) + rr.X.Sub(&rr.X, &(*P)[j].X) + d.Sub(&(*R)[j].X, &rr.X) + rr.Y.Mul(&lambda[j], &d) + rr.Y.Sub(&rr.Y, &(*R)[j].Y) + (*R)[j].Set(&rr) + } +} diff --git a/ecc/bw6-756/g2_test.go b/ecc/bw6-756/g2_test.go index ecfc97332..fb099c989 100644 --- a/ecc/bw6-756/g2_test.go +++ b/ecc/bw6-756/g2_test.go @@ -19,6 +19,7 @@ package bw6756 import ( "fmt" "math/big" + "math/rand" "testing" "github.com/consensys/gnark-crypto/ecc/bw6-756/fp" @@ -325,7 +326,7 @@ func TestG2AffineOps(t *testing.T) { r := fr.Modulus() var g G2Jac - g.mulGLV(&g2Gen, r) + g.ScalarMultiplication(&g2Gen, r) var scalar, blindedScalar, rminusone big.Int var op1, op2, op3, gneg G2Jac @@ -445,8 +446,7 @@ func TestG2AffineBatchScalarMultiplication(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } result := BatchScalarMultiplicationG2(&g2GenAff, sampleScalars[:]) @@ -459,7 +459,7 @@ func TestG2AffineBatchScalarMultiplication(t *testing.T) { var expectedJac G2Jac var expected G2Affine var b big.Int - expectedJac.mulGLV(&g2Gen, sampleScalars[i].ToBigInt(&b)) + expectedJac.ScalarMultiplication(&g2Gen, sampleScalars[i].ToBigIntRegular(&b)) expected.FromJacobian(&expectedJac) if !result[i].Equal(&expected) { return false @@ -486,6 +486,33 @@ func BenchmarkG2JacIsInSubGroup(b *testing.B) { } +func BenchmarkBatchAddG2Affine(b *testing.B) { + + var P, R pG2AffineC16 + var RR ppG2AffineC16 + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBasesG2(P[:]) + fillBenchBasesG2(R[:]) + + for i := 0; i < len(ridx); i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAddG2Affine[pG2AffineC16, ppG2AffineC16, cG2AffineC16](&RR, &P, len(P)) + } +} + func BenchmarkG2AffineBatchScalarMultiplication(b *testing.B) { // ensure every words of the scalars are filled var mixer fr.Element @@ -498,8 +525,7 @@ func BenchmarkG2AffineBatchScalarMultiplication(b *testing.B) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } for i := 5; i <= pow; i++ { diff --git a/ecc/bw6-756/hash_to_g1.go b/ecc/bw6-756/hash_to_g1.go index 03ee750cf..bf4a9ba3a 100644 --- a/ecc/bw6-756/hash_to_g1.go +++ b/ecc/bw6-756/hash_to_g1.go @@ -17,7 +17,6 @@ package bw6756 import ( - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-756/fp" "math/big" @@ -256,35 +255,14 @@ func g1EvalPolynomial(z *fp.Element, monic bool, coefficients []fp.Element, x *f z.Set(&dst) } -// hashToFp hashes msg to count prime field elements. -// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 -func hashToFp(msg, dst []byte, count int) ([]fp.Element, error) { - // 128 bits of security - // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 - const Bytes = 1 + (fp.Bits-1)/8 - const L = 16 + Bytes - - lenInBytes := count * L - pseudoRandomBytes, err := ecc.ExpandMsgXmd(msg, dst, lenInBytes) - if err != nil { - return nil, err - } - - res := make([]fp.Element, count) - for i := 0; i < count; i++ { - res[i].SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) - } - return res, nil -} - // g1Sgn0 is an algebraic substitute for the notion of sign in ordered fields // Namely, every non-zero quadratic residue in a finite field of characteristic =/= 2 has exactly two square roots, one of each sign // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#name-the-sgn0-function // The sign of an element is not obviously related to that of its Montgomery form func g1Sgn0(z *fp.Element) uint64 { - nonMont := *z - nonMont.FromMont() + nonMont := z.Bits() + // m == 1 return nonMont[0] % 2 @@ -302,11 +280,11 @@ func MapToG1(u fp.Element) G1Affine { // EncodeToG1 hashes a message to a point on the G1 curve using the SSWU map. // It is faster than HashToG1, but the result is not uniformly distributed. Unsuitable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func EncodeToG1(msg, dst []byte) (G1Affine, error) { var res G1Affine - u, err := hashToFp(msg, dst, 1) + u, err := fp.Hash(msg, dst, 1) if err != nil { return res, err } @@ -322,9 +300,9 @@ func EncodeToG1(msg, dst []byte) (G1Affine, error) { // HashToG1 hashes a message to a point on the G1 curve using the SSWU map. // Slower than EncodeToG1, but usable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func HashToG1(msg, dst []byte) (G1Affine, error) { - u, err := hashToFp(msg, dst, 2*1) + u, err := fp.Hash(msg, dst, 2*1) if err != nil { return G1Affine{}, err } diff --git a/ecc/bw6-756/hash_to_g1_test.go b/ecc/bw6-756/hash_to_g1_test.go index c8e1d909a..7a7b2dbcf 100644 --- a/ecc/bw6-756/hash_to_g1_test.go +++ b/ecc/bw6-756/hash_to_g1_test.go @@ -62,7 +62,7 @@ func TestG1SqrtRatio(t *testing.T) { func TestHashToFpG1(t *testing.T) { for _, c := range encodeToG1Vector.cases { - elems, err := hashToFp([]byte(c.msg), encodeToG1Vector.dst, 1) + elems, err := fp.Hash([]byte(c.msg), encodeToG1Vector.dst, 1) if err != nil { t.Error(err) } @@ -70,7 +70,7 @@ func TestHashToFpG1(t *testing.T) { } for _, c := range hashToG1Vector.cases { - elems, err := hashToFp([]byte(c.msg), hashToG1Vector.dst, 2*1) + elems, err := fp.Hash([]byte(c.msg), hashToG1Vector.dst, 2*1) if err != nil { t.Error(err) } @@ -220,7 +220,7 @@ func BenchmarkHashToG1(b *testing.B) { } } -//TODO: Crude. Do something clever in Jacobian +// TODO: Crude. Do something clever in Jacobian func isOnE1Prime(p G1Affine) bool { var A, B fp.Element @@ -247,7 +247,7 @@ func isOnE1Prime(p G1Affine) bool { return LHS.Equal(&RHS) } -//Only works on simple extensions (two-story towers) +// Only works on simple extensions (two-story towers) func g1CoordSetString(z *fp.Element, s string) { z.SetString(s) } diff --git a/ecc/bw6-756/hash_to_g2.go b/ecc/bw6-756/hash_to_g2.go index d84ae9b50..f3246ede4 100644 --- a/ecc/bw6-756/hash_to_g2.go +++ b/ecc/bw6-756/hash_to_g2.go @@ -333,8 +333,8 @@ func g2EvalPolynomial(z *fp.Element, monic bool, coefficients []fp.Element, x *f // The sign of an element is not obviously related to that of its Montgomery form func g2Sgn0(z *fp.Element) uint64 { - nonMont := *z - nonMont.FromMont() + nonMont := z.Bits() + // m == 1 return nonMont[0] % 2 @@ -352,11 +352,11 @@ func MapToG2(u fp.Element) G2Affine { // EncodeToG2 hashes a message to a point on the G2 curve using the SSWU map. // It is faster than HashToG2, but the result is not uniformly distributed. Unsuitable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func EncodeToG2(msg, dst []byte) (G2Affine, error) { var res G2Affine - u, err := hashToFp(msg, dst, 1) + u, err := fp.Hash(msg, dst, 1) if err != nil { return res, err } @@ -372,9 +372,9 @@ func EncodeToG2(msg, dst []byte) (G2Affine, error) { // HashToG2 hashes a message to a point on the G2 curve using the SSWU map. // Slower than EncodeToG2, but usable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func HashToG2(msg, dst []byte) (G2Affine, error) { - u, err := hashToFp(msg, dst, 2*1) + u, err := fp.Hash(msg, dst, 2*1) if err != nil { return G2Affine{}, err } diff --git a/ecc/bw6-756/hash_to_g2_test.go b/ecc/bw6-756/hash_to_g2_test.go index bf37421b0..4ec5aa85b 100644 --- a/ecc/bw6-756/hash_to_g2_test.go +++ b/ecc/bw6-756/hash_to_g2_test.go @@ -62,7 +62,7 @@ func TestG2SqrtRatio(t *testing.T) { func TestHashToFpG2(t *testing.T) { for _, c := range encodeToG2Vector.cases { - elems, err := hashToFp([]byte(c.msg), encodeToG2Vector.dst, 1) + elems, err := fp.Hash([]byte(c.msg), encodeToG2Vector.dst, 1) if err != nil { t.Error(err) } @@ -70,7 +70,7 @@ func TestHashToFpG2(t *testing.T) { } for _, c := range hashToG2Vector.cases { - elems, err := hashToFp([]byte(c.msg), hashToG2Vector.dst, 2*1) + elems, err := fp.Hash([]byte(c.msg), hashToG2Vector.dst, 2*1) if err != nil { t.Error(err) } @@ -220,7 +220,7 @@ func BenchmarkHashToG2(b *testing.B) { } } -//TODO: Crude. Do something clever in Jacobian +// TODO: Crude. Do something clever in Jacobian func isOnE2Prime(p G2Affine) bool { var A, B fp.Element @@ -247,7 +247,7 @@ func isOnE2Prime(p G2Affine) bool { return LHS.Equal(&RHS) } -//Only works on simple extensions (two-story towers) +// Only works on simple extensions (two-story towers) func g2CoordSetString(z *fp.Element, s string) { z.SetString(s) } diff --git a/ecc/bw6-756/internal/fptower/e3.go b/ecc/bw6-756/internal/fptower/e3.go index afd75d062..148fecded 100644 --- a/ecc/bw6-756/internal/fptower/e3.go +++ b/ecc/bw6-756/internal/fptower/e3.go @@ -83,6 +83,10 @@ func (z *E3) IsZero() bool { return z.A0.IsZero() && z.A1.IsZero() && z.A2.IsZero() } +func (z *E3) IsOne() bool { + return z.A0.IsOne() && z.A1.IsZero() && z.A2.IsZero() +} + // Neg negates the E3 number func (z *E3) Neg(x *E3) *E3 { z.A0.Neg(&x.A0) @@ -91,22 +95,6 @@ func (z *E3) Neg(x *E3) *E3 { return z } -// ToMont converts to Mont form -func (z *E3) ToMont() *E3 { - z.A0.ToMont() - z.A1.ToMont() - z.A2.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E3) FromMont() *E3 { - z.A0.FromMont() - z.A1.FromMont() - z.A2.FromMont() - return z -} - // Add adds two elements of E3 func (z *E3) Add(x, y *E3) *E3 { z.A0.Add(&x.A0, &y.A0) diff --git a/ecc/bw6-756/internal/fptower/e6.go b/ecc/bw6-756/internal/fptower/e6.go index 6dadd7270..e43bb2635 100644 --- a/ecc/bw6-756/internal/fptower/e6.go +++ b/ecc/bw6-756/internal/fptower/e6.go @@ -67,20 +67,6 @@ func (z *E6) SetOne() *E6 { return z } -// ToMont converts to Mont form -func (z *E6) ToMont() *E6 { - z.B0.ToMont() - z.B1.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E6) FromMont() *E6 { - z.B0.FromMont() - z.B1.FromMont() - return z -} - // Add set z=x+y in E6 and return z func (z *E6) Add(x, y *E6) *E6 { z.B0.Add(&x.B0, &y.B0) @@ -118,6 +104,10 @@ func (z *E6) IsZero() bool { return z.B0.IsZero() && z.B1.IsZero() } +func (z *E6) IsOne() bool { + return z.B0.IsOne() && z.B1.IsZero() +} + // Mul set z=x*y in E6 and return z func (z *E6) Mul(x, y *E6) *E6 { var a, b, c E3 @@ -225,9 +215,12 @@ func (z *E6) CyclotomicSquareCompressed(x *E6) *E6 { // DecompressKarabina Karabina's cyclotomic square result // if g3 != 0 -// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// +// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// // if g3 == 0 -// g4 = 2g1g5/g2 +// +// g4 = 2g1g5/g2 // // if g3=g2=0 then g4=g5=g1=0 and g0=1 (x=1) // Theorem 3.1 is well-defined for all x in Gϕₙ\{1} @@ -252,7 +245,7 @@ func (z *E6) DecompressKarabina(x *E6) *E6 { t[1].Sub(&t[0], &x.B0.A2). Double(&t[1]). Add(&t[1], &t[0]) - // t0 = E * g5^2 + t1 + // t0 = E * g5^2 + t1 t[2].Square(&x.B1.A2) t[0].MulByNonResidue(&t[2]). Add(&t[0], &t[1]) @@ -525,8 +518,8 @@ func (z *E6) ExpGLV(x E6, k *big.Int) *E6 { table[14].Mul(&table[11], &table[2]) // bounds on the lattice base vectors guarantee that s1, s2 are len(r)/2 bits long max - s1.SetBigInt(&s[0]).FromMont() - s2.SetBigInt(&s[1]).FromMont() + s1 = s1.SetBigInt(&s[0]).Bits() + s2 = s2.SetBigInt(&s[1]).Bits() // loop starts from len(s1)/2 due to the bounds for i := len(s1) / 2; i >= 0; i-- { diff --git a/ecc/bw6-756/marshal.go b/ecc/bw6-756/marshal.go index 200f8829d..12283d885 100644 --- a/ecc/bw6-756/marshal.go +++ b/ecc/bw6-756/marshal.go @@ -100,7 +100,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fr.Bytes]) + err = t.SetBytesCanonical(buf[:fr.Bytes]) return case *fp.Element: read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) @@ -108,7 +108,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fp.Bytes]) + err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: var sliceLen uint32 @@ -126,7 +126,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fr.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { + return + } } return case *[]fp.Element: @@ -145,7 +147,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fp.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return + } } return case *G1Affine: @@ -221,7 +225,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -276,7 +284,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -643,9 +655,6 @@ func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - msbMask := mCompressedSmallest // compressed, we need to know if Y is lexicographically bigger than -Y // if p.Y ">" -p.Y @@ -654,20 +663,7 @@ func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { } // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[88:96], tmp[0]) - binary.BigEndian.PutUint64(res[80:88], tmp[1]) - binary.BigEndian.PutUint64(res[72:80], tmp[2]) - binary.BigEndian.PutUint64(res[64:72], tmp[3]) - binary.BigEndian.PutUint64(res[56:64], tmp[4]) - binary.BigEndian.PutUint64(res[48:56], tmp[5]) - binary.BigEndian.PutUint64(res[40:48], tmp[6]) - binary.BigEndian.PutUint64(res[32:40], tmp[7]) - binary.BigEndian.PutUint64(res[24:32], tmp[8]) - binary.BigEndian.PutUint64(res[16:24], tmp[9]) - binary.BigEndian.PutUint64(res[8:16], tmp[10]) - binary.BigEndian.PutUint64(res[0:8], tmp[11]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= msbMask @@ -686,41 +682,12 @@ func (p *G1Affine) RawBytes() (res [SizeOfG1AffineUncompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - // not compressed // we store the Y coordinate - tmp = p.Y - tmp.FromMont() - binary.BigEndian.PutUint64(res[184:192], tmp[0]) - binary.BigEndian.PutUint64(res[176:184], tmp[1]) - binary.BigEndian.PutUint64(res[168:176], tmp[2]) - binary.BigEndian.PutUint64(res[160:168], tmp[3]) - binary.BigEndian.PutUint64(res[152:160], tmp[4]) - binary.BigEndian.PutUint64(res[144:152], tmp[5]) - binary.BigEndian.PutUint64(res[136:144], tmp[6]) - binary.BigEndian.PutUint64(res[128:136], tmp[7]) - binary.BigEndian.PutUint64(res[120:128], tmp[8]) - binary.BigEndian.PutUint64(res[112:120], tmp[9]) - binary.BigEndian.PutUint64(res[104:112], tmp[10]) - binary.BigEndian.PutUint64(res[96:104], tmp[11]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[96:96+fp.Bytes]), p.Y) // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[88:96], tmp[0]) - binary.BigEndian.PutUint64(res[80:88], tmp[1]) - binary.BigEndian.PutUint64(res[72:80], tmp[2]) - binary.BigEndian.PutUint64(res[64:72], tmp[3]) - binary.BigEndian.PutUint64(res[56:64], tmp[4]) - binary.BigEndian.PutUint64(res[48:56], tmp[5]) - binary.BigEndian.PutUint64(res[40:48], tmp[6]) - binary.BigEndian.PutUint64(res[32:40], tmp[7]) - binary.BigEndian.PutUint64(res[24:32], tmp[8]) - binary.BigEndian.PutUint64(res[16:24], tmp[9]) - binary.BigEndian.PutUint64(res[8:16], tmp[10]) - binary.BigEndian.PutUint64(res[0:8], tmp[11]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= mUncompressed @@ -771,8 +738,12 @@ func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { // uncompressed point if mData == mUncompressed { // read X and Y coordinates - p.X.SetBytes(buf[:fp.Bytes]) - p.Y.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.Y.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } // subgroup check if subGroupCheck && !p.IsInSubGroup() { @@ -792,7 +763,9 @@ func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return 0, err + } var YSquared, Y fp.Element @@ -866,7 +839,7 @@ func (p *G1Affine) unsafeComputeY(subGroupCheck bool) error { // assumes buf[:8] mask is set to compressed // returns true if point is infinity and need no further processing // it sets X coordinate and uses Y for scratch space to store decompression metadata -func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { +func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { // read the most significant byte mData := buf[0] & mMask @@ -875,7 +848,7 @@ func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { p.X.SetZero() p.Y.SetZero() isInfinity = true - return + return isInfinity, nil } // we need to copy the input buffer (to keep this method thread safe) @@ -884,12 +857,14 @@ func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return false, err + } // store mData in p.Y[0] p.Y[0] = uint64(mData) // recomputing Y will be done asynchronously - return + return isInfinity, nil } // SizeOfG2AffineCompressed represents the size in bytes that a G2Affine need in binary form, compressed @@ -927,9 +902,6 @@ func (p *G2Affine) Bytes() (res [SizeOfG2AffineCompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - msbMask := mCompressedSmallest // compressed, we need to know if Y is lexicographically bigger than -Y // if p.Y ">" -p.Y @@ -938,20 +910,7 @@ func (p *G2Affine) Bytes() (res [SizeOfG2AffineCompressed]byte) { } // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[88:96], tmp[0]) - binary.BigEndian.PutUint64(res[80:88], tmp[1]) - binary.BigEndian.PutUint64(res[72:80], tmp[2]) - binary.BigEndian.PutUint64(res[64:72], tmp[3]) - binary.BigEndian.PutUint64(res[56:64], tmp[4]) - binary.BigEndian.PutUint64(res[48:56], tmp[5]) - binary.BigEndian.PutUint64(res[40:48], tmp[6]) - binary.BigEndian.PutUint64(res[32:40], tmp[7]) - binary.BigEndian.PutUint64(res[24:32], tmp[8]) - binary.BigEndian.PutUint64(res[16:24], tmp[9]) - binary.BigEndian.PutUint64(res[8:16], tmp[10]) - binary.BigEndian.PutUint64(res[0:8], tmp[11]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= msbMask @@ -970,41 +929,12 @@ func (p *G2Affine) RawBytes() (res [SizeOfG2AffineUncompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - // not compressed // we store the Y coordinate - tmp = p.Y - tmp.FromMont() - binary.BigEndian.PutUint64(res[184:192], tmp[0]) - binary.BigEndian.PutUint64(res[176:184], tmp[1]) - binary.BigEndian.PutUint64(res[168:176], tmp[2]) - binary.BigEndian.PutUint64(res[160:168], tmp[3]) - binary.BigEndian.PutUint64(res[152:160], tmp[4]) - binary.BigEndian.PutUint64(res[144:152], tmp[5]) - binary.BigEndian.PutUint64(res[136:144], tmp[6]) - binary.BigEndian.PutUint64(res[128:136], tmp[7]) - binary.BigEndian.PutUint64(res[120:128], tmp[8]) - binary.BigEndian.PutUint64(res[112:120], tmp[9]) - binary.BigEndian.PutUint64(res[104:112], tmp[10]) - binary.BigEndian.PutUint64(res[96:104], tmp[11]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[96:96+fp.Bytes]), p.Y) // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[88:96], tmp[0]) - binary.BigEndian.PutUint64(res[80:88], tmp[1]) - binary.BigEndian.PutUint64(res[72:80], tmp[2]) - binary.BigEndian.PutUint64(res[64:72], tmp[3]) - binary.BigEndian.PutUint64(res[56:64], tmp[4]) - binary.BigEndian.PutUint64(res[48:56], tmp[5]) - binary.BigEndian.PutUint64(res[40:48], tmp[6]) - binary.BigEndian.PutUint64(res[32:40], tmp[7]) - binary.BigEndian.PutUint64(res[24:32], tmp[8]) - binary.BigEndian.PutUint64(res[16:24], tmp[9]) - binary.BigEndian.PutUint64(res[8:16], tmp[10]) - binary.BigEndian.PutUint64(res[0:8], tmp[11]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= mUncompressed @@ -1055,8 +985,12 @@ func (p *G2Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { // uncompressed point if mData == mUncompressed { // read X and Y coordinates - p.X.SetBytes(buf[:fp.Bytes]) - p.Y.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.Y.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } // subgroup check if subGroupCheck && !p.IsInSubGroup() { @@ -1076,7 +1010,9 @@ func (p *G2Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return 0, err + } var YSquared, Y fp.Element @@ -1150,7 +1086,7 @@ func (p *G2Affine) unsafeComputeY(subGroupCheck bool) error { // assumes buf[:8] mask is set to compressed // returns true if point is infinity and need no further processing // it sets X coordinate and uses Y for scratch space to store decompression metadata -func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { +func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { // read the most significant byte mData := buf[0] & mMask @@ -1159,7 +1095,7 @@ func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { p.X.SetZero() p.Y.SetZero() isInfinity = true - return + return isInfinity, nil } // we need to copy the input buffer (to keep this method thread safe) @@ -1168,10 +1104,12 @@ func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return false, err + } // store mData in p.Y[0] p.Y[0] = uint64(mData) // recomputing Y will be done asynchronously - return + return isInfinity, nil } diff --git a/ecc/bw6-756/multiexp.go b/ecc/bw6-756/multiexp.go index 698835f34..fa9bd11b0 100644 --- a/ecc/bw6-756/multiexp.go +++ b/ecc/bw6-756/multiexp.go @@ -25,143 +25,6 @@ import ( "runtime" ) -// selector stores the index, mask and shifts needed to select bits from a scalar -// it is used during the multiExp algorithm or the batch scalar multiplication -type selector struct { - index uint64 // index in the multi-word scalar to select bits from - mask uint64 // mask (c-bit wide) - shift uint64 // shift needed to get our bits on low positions - - multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) - maskHigh uint64 // same than mask, for index+1 - shiftHigh uint64 // same than shift, for index+1 -} - -// partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits -// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract -// 2^{c} to the current digit, making it negative. -// negative digits can be processed in a later step as adding -G into the bucket instead of G -// (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) -// scalarsMont indicates wheter the provided scalars are in montgomery form -// returns smallValues, which represent the number of scalars which meets the following condition -// 0 < scalar < 2^c (in other words, scalars where only the c-least significant bits are non zero) -func partitionScalars(scalars []fr.Element, c uint64, scalarsMont bool, nbTasks int) ([]fr.Element, int) { - toReturn := make([]fr.Element, len(scalars)) - - // number of c-bit radixes in a scalar - nbChunks := fr.Limbs * 64 / c - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // msb of the c-bit window - max := int(1 << (c - 1)) // max value we want for our digits - cDivides64 := (64 % c) == 0 // if c doesn't divide 64, we may need to select over multiple words - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := uint64(0); chunk < nbChunks; chunk++ { - jc := uint64(chunk * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } - - // for each chunk, we could track the number of non-zeros points we will need to process - // this way, if a chunk has more work to do than others, we can spawn off more go routines - // (at the cost of more buckets allocated) - // a simplified approach is to track the small values where only the first word is set - // if this number represent a significant number of points, then we will split first chunk - // processing in the msm in 2, to ensure all go routines finish at ~same time - // /!\ nbTasks is enough as parallel.Execute is not going to spawn more than nbTasks go routine - // if it does, though, this will deadlocK. - chSmallValues := make(chan int, nbTasks) - - parallel.Execute(len(scalars), func(start, end int) { - smallValues := 0 - for i := start; i < end; i++ { - var carry int - - scalar := scalars[i] - if scalarsMont { - scalar.FromMont() - } - if scalar.FitsOnOneWord() { - // everything is 0, no need to process this scalar - if scalar[0] == 0 { - continue - } - // low c-bits are 1 in mask - if scalar[0]&mask == scalar[0] { - smallValues++ - } - } - - // for each chunk in the scalar, compute the current digit, and an eventual carry - for chunk := uint64(0); chunk < nbChunks; chunk++ { - s := selectors[chunk] - - // init with carry if any - digit := carry - carry = 0 - - // digit = value of the c-bit window - digit += int((scalar[s.index] & s.mask) >> s.shift) - - if s.multiWordSelect { - // we are selecting bits over 2 words - digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh - } - - // if digit is zero, no impact on result - if digit == 0 { - continue - } - - // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract - // 2^{c} to the current digit, making it negative. - if digit >= max { - digit -= (1 << c) - carry = 1 - } - - var bits uint64 - if digit >= 0 { - bits = uint64(digit) - } else { - bits = uint64(-digit-1) | msbWindow - } - - toReturn[i][s.index] |= (bits << s.shift) - if s.multiWordSelect { - toReturn[i][s.index+1] |= (bits >> s.shiftHigh) - } - - } - } - - chSmallValues <- smallValues - - }, nbTasks) - - // aggregate small values - close(chSmallValues) - smallValues := 0 - for o := range chSmallValues { - smallValues += o - } - return toReturn, smallValues -} - // MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf // // This call return an error if len(scalars) != len(points) or if provided config is invalid. @@ -221,7 +84,7 @@ func (p *G1Jac) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.Mul // we split recursively until nbChunks(c) >= nbTasks, bestC := func(nbPoints int) uint64 { // implemented msmC methods (the c we use must be in this slice) - implementedCs := []uint64{4, 5, 8, 16} + implementedCs := []uint64{4, 5, 8, 11, 16} var C uint64 // approximate cost (in group operations) // cost = bits/c * (nbPoints + 2^{c}) @@ -229,85 +92,126 @@ func (p *G1Jac) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.Mul // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results min := math.MaxFloat64 for _, c := range implementedCs { - cc := fr.Limbs * 64 * (nbPoints + (1 << (c))) + cc := (fr.Bits + 1) * (nbPoints + (1 << c)) cost := float64(cc) / float64(c) if cost < min { min = cost C = c } } - // empirical, needs to be tuned. - // if C > 16 && nbPoints < 1 << 23 { - // C = 16 - // } return C } - var C uint64 - nbSplits := 1 - nbChunks := 0 - for nbChunks < config.NbTasks { - C = bestC(nbPoints) - nbChunks = int(fr.Limbs * 64 / C) // number of c-bit radixes in a scalar - if (fr.Limbs*64)%C != 0 { - nbChunks++ - } - nbChunks *= nbSplits - if nbChunks < config.NbTasks { - nbSplits <<= 1 - nbPoints >>= 1 + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // if we don't utilise all the tasks (CPU in the default case) that we could, let's see if it's worth it to split + if config.NbTasks > 1 && nbChunks < config.NbTasks { + // before spliting, let's see if we endup with more tasks than thread; + cSplit := bestC(nbPoints / 2) + nbChunksPostSplit := int(computeNbChunks(cSplit)) + nbTasksPostSplit := nbChunksPostSplit * 2 + if (nbTasksPostSplit <= config.NbTasks/2) || (nbTasksPostSplit-config.NbTasks/2) <= (config.NbTasks-nbChunks) { + // if postSplit we still have less tasks than available CPU + // or if we have more tasks BUT the difference of CPU usage is in our favor, we split. + config.NbTasks /= 2 + var _p G1Jac + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil } } + _innerMsmG1(p, C, points, scalars, config) + + return p, nil +} + +func _innerMsmG1(p *G1Jac, c uint64, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { // partition the scalars - // note: we do that before the actual chunk processing, as for each c-bit window (starting from LSW) - // if it's larger than 2^{c-1}, we have a carry we need to propagate up to the higher window - var smallValues int - scalars, smallValues = partitionScalars(scalars, C, config.ScalarsMont, config.NbTasks) - - // if we have more than 10% of small values, we split the processing of the first chunk in 2 - // we may want to do that in msmInnerG1Jac , but that would incur a cost of looping through all scalars one more time - splitFirstChunk := (float64(smallValues) / float64(len(scalars))) >= 0.1 - - // we have nbSplits intermediate results that we must sum together. - _p := make([]G1Jac, nbSplits-1) - chDone := make(chan int, nbSplits-1) - for i := 0; i < nbSplits-1; i++ { - start := i * nbPoints - end := start + nbPoints - go func(start, end, i int) { - msmInnerG1Jac(&_p[i], int(C), points[start:end], scalars[start:end], splitFirstChunk) - chDone <- i - }(start, end, i) + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) } - msmInnerG1Jac(p, int(C), points[(nbSplits-1)*nbPoints:], scalars[(nbSplits-1)*nbPoints:], splitFirstChunk) - for i := 0; i < nbSplits-1; i++ { - done := <-chDone - p.AddAssign(&_p[done]) + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessorG1(c, chunkStats[j]) + if j == int(nbChunks-1) { + processChunk = getChunkProcessorG1(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan g1JacExtended, 2) + split := n / 2 + go processChunk(uint64(j), chSplit, c, points[:split], digits[j*n:(j*n)+split]) + go processChunk(uint64(j), chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n]) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n]) } - close(chDone) - return p, nil -} -func msmInnerG1Jac(p *G1Jac, c int, points []G1Affine, scalars []fr.Element, splitFirstChunk bool) { + return msmReduceChunkG1Affine(p, int(c), chChunks[:]) +} +// getChunkProcessorG1 decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessorG1(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- g1JacExtended, c uint64, points []G1Affine, digits []uint16) { switch c { + case 3: + return processChunkG1Jacobian[bucketg1JacExtendedC3] case 4: - p.msmC4(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC4] case 5: - p.msmC5(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC5] case 8: - p.msmC8(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC8] + case 11: + const batchSize = 150 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC11] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC11, bucketG1AffineC11, bitSetC11, pG1AffineC11, ppG1AffineC11, qG1AffineC11, cG1AffineC11] case 16: - p.msmC16(points, scalars, splitFirstChunk) - + const batchSize = 640 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC16] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC16, bucketG1AffineC16, bitSetC16, pG1AffineC16, ppG1AffineC16, qG1AffineC16, cG1AffineC16] default: - panic("not implemented") + // panic("will not happen c != previous values is not generated by templates") + return processChunkG1Jacobian[bucketg1JacExtendedC16] } } @@ -327,257 +231,6 @@ func msmReduceChunkG1Affine(p *G1Jac, c int, chChunks []chan g1JacExtended) *G1J return p.unsafeFromJacExtended(&_p) } -func msmProcessChunkG1Affine(chunk uint64, - chRes chan<- g1JacExtended, - buckets []g1JacExtended, - c uint64, - points []G1Affine, - scalars []fr.Element) { - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) - - for i := 0; i < len(buckets); i++ { - buckets[i].setInfinity() - } - - jc := uint64(chunk * c) - s := selector{} - s.index = jc / 64 - s.shift = jc - (s.index * 64) - s.mask = mask << s.shift - s.multiWordSelect = (64%c) != 0 && s.shift > (64-c) && s.index < (fr.Limbs-1) - if s.multiWordSelect { - nbBitsHigh := s.shift - uint64(64-c) - s.maskHigh = (1 << nbBitsHigh) - 1 - s.shiftHigh = (c - nbBitsHigh) - } - - // for each scalars, get the digit corresponding to the chunk we're processing. - for i := 0; i < len(scalars); i++ { - bits := (scalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (scalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { - continue - } - - // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { - // add - buckets[bits-1].addMixed(&points[i]) - } else { - // sub - buckets[bits & ^msbWindow].subMixed(&points[i]) - } - } - - // reduce buckets into total - // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] - - var runningSum, total g1JacExtended - runningSum.setInfinity() - total.setInfinity() - for k := len(buckets) - 1; k >= 0; k-- { - if !buckets[k].ZZ.IsZero() { - runningSum.add(&buckets[k]) - } - total.add(&runningSum) - } - - chRes <- total - -} - -func (p *G1Jac) msmC4(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 4 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC5(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 5 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 384, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC8(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 8 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC16(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 16 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - // MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf // // This call return an error if len(scalars) != len(points) or if provided config is invalid. @@ -637,7 +290,7 @@ func (p *G2Jac) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.Mul // we split recursively until nbChunks(c) >= nbTasks, bestC := func(nbPoints int) uint64 { // implemented msmC methods (the c we use must be in this slice) - implementedCs := []uint64{4, 5, 8, 16} + implementedCs := []uint64{4, 5, 8, 11, 16} var C uint64 // approximate cost (in group operations) // cost = bits/c * (nbPoints + 2^{c}) @@ -645,85 +298,126 @@ func (p *G2Jac) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.Mul // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results min := math.MaxFloat64 for _, c := range implementedCs { - cc := fr.Limbs * 64 * (nbPoints + (1 << (c))) + cc := (fr.Bits + 1) * (nbPoints + (1 << c)) cost := float64(cc) / float64(c) if cost < min { min = cost C = c } } - // empirical, needs to be tuned. - // if C > 16 && nbPoints < 1 << 23 { - // C = 16 - // } return C } - var C uint64 - nbSplits := 1 - nbChunks := 0 - for nbChunks < config.NbTasks { - C = bestC(nbPoints) - nbChunks = int(fr.Limbs * 64 / C) // number of c-bit radixes in a scalar - if (fr.Limbs*64)%C != 0 { - nbChunks++ - } - nbChunks *= nbSplits - if nbChunks < config.NbTasks { - nbSplits <<= 1 - nbPoints >>= 1 + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // if we don't utilise all the tasks (CPU in the default case) that we could, let's see if it's worth it to split + if config.NbTasks > 1 && nbChunks < config.NbTasks { + // before spliting, let's see if we endup with more tasks than thread; + cSplit := bestC(nbPoints / 2) + nbChunksPostSplit := int(computeNbChunks(cSplit)) + nbTasksPostSplit := nbChunksPostSplit * 2 + if (nbTasksPostSplit <= config.NbTasks/2) || (nbTasksPostSplit-config.NbTasks/2) <= (config.NbTasks-nbChunks) { + // if postSplit we still have less tasks than available CPU + // or if we have more tasks BUT the difference of CPU usage is in our favor, we split. + config.NbTasks /= 2 + var _p G2Jac + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil } } + _innerMsmG2(p, C, points, scalars, config) + + return p, nil +} + +func _innerMsmG2(p *G2Jac, c uint64, points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G2Jac { // partition the scalars - // note: we do that before the actual chunk processing, as for each c-bit window (starting from LSW) - // if it's larger than 2^{c-1}, we have a carry we need to propagate up to the higher window - var smallValues int - scalars, smallValues = partitionScalars(scalars, C, config.ScalarsMont, config.NbTasks) - - // if we have more than 10% of small values, we split the processing of the first chunk in 2 - // we may want to do that in msmInnerG2Jac , but that would incur a cost of looping through all scalars one more time - splitFirstChunk := (float64(smallValues) / float64(len(scalars))) >= 0.1 - - // we have nbSplits intermediate results that we must sum together. - _p := make([]G2Jac, nbSplits-1) - chDone := make(chan int, nbSplits-1) - for i := 0; i < nbSplits-1; i++ { - start := i * nbPoints - end := start + nbPoints - go func(start, end, i int) { - msmInnerG2Jac(&_p[i], int(C), points[start:end], scalars[start:end], splitFirstChunk) - chDone <- i - }(start, end, i) + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g2JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g2JacExtended, 1) } - msmInnerG2Jac(p, int(C), points[(nbSplits-1)*nbPoints:], scalars[(nbSplits-1)*nbPoints:], splitFirstChunk) - for i := 0; i < nbSplits-1; i++ { - done := <-chDone - p.AddAssign(&_p[done]) + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessorG2(c, chunkStats[j]) + if j == int(nbChunks-1) { + processChunk = getChunkProcessorG2(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan g2JacExtended, 2) + split := n / 2 + go processChunk(uint64(j), chSplit, c, points[:split], digits[j*n:(j*n)+split]) + go processChunk(uint64(j), chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n]) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n]) } - close(chDone) - return p, nil -} -func msmInnerG2Jac(p *G2Jac, c int, points []G2Affine, scalars []fr.Element, splitFirstChunk bool) { + return msmReduceChunkG2Affine(p, int(c), chChunks[:]) +} +// getChunkProcessorG2 decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessorG2(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- g2JacExtended, c uint64, points []G2Affine, digits []uint16) { switch c { + case 3: + return processChunkG2Jacobian[bucketg2JacExtendedC3] case 4: - p.msmC4(points, scalars, splitFirstChunk) - + return processChunkG2Jacobian[bucketg2JacExtendedC4] case 5: - p.msmC5(points, scalars, splitFirstChunk) - + return processChunkG2Jacobian[bucketg2JacExtendedC5] case 8: - p.msmC8(points, scalars, splitFirstChunk) - + return processChunkG2Jacobian[bucketg2JacExtendedC8] + case 11: + const batchSize = 150 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC11] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC11, bucketG2AffineC11, bitSetC11, pG2AffineC11, ppG2AffineC11, qG2AffineC11, cG2AffineC11] case 16: - p.msmC16(points, scalars, splitFirstChunk) - + const batchSize = 640 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC16] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC16, bucketG2AffineC16, bitSetC16, pG2AffineC16, ppG2AffineC16, qG2AffineC16, cG2AffineC16] default: - panic("not implemented") + // panic("will not happen c != previous values is not generated by templates") + return processChunkG2Jacobian[bucketg2JacExtendedC16] } } @@ -743,253 +437,188 @@ func msmReduceChunkG2Affine(p *G2Jac, c int, chChunks []chan g2JacExtended) *G2J return p.unsafeFromJacExtended(&_p) } -func msmProcessChunkG2Affine(chunk uint64, - chRes chan<- g2JacExtended, - buckets []g2JacExtended, - c uint64, - points []G2Affine, - scalars []fr.Element) { - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) - - for i := 0; i < len(buckets); i++ { - buckets[i].setInfinity() - } - - jc := uint64(chunk * c) - s := selector{} - s.index = jc / 64 - s.shift = jc - (s.index * 64) - s.mask = mask << s.shift - s.multiWordSelect = (64%c) != 0 && s.shift > (64-c) && s.index < (fr.Limbs-1) - if s.multiWordSelect { - nbBitsHigh := s.shift - uint64(64-c) - s.maskHigh = (1 << nbBitsHigh) - 1 - s.shiftHigh = (c - nbBitsHigh) - } - - // for each scalars, get the digit corresponding to the chunk we're processing. - for i := 0; i < len(scalars); i++ { - bits := (scalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (scalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { - continue - } - - // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { - // add - buckets[bits-1].addMixed(&points[i]) - } else { - // sub - buckets[bits & ^msbWindow].subMixed(&points[i]) - } - } - - // reduce buckets into total - // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] - - var runningSum, total g2JacExtended - runningSum.setInfinity() - total.setInfinity() - for k := len(buckets) - 1; k >= 0; k-- { - if !buckets[k].ZZ.IsZero() { - runningSum.add(&buckets[k]) - } - total.add(&runningSum) - } - - chRes <- total +// selector stores the index, mask and shifts needed to select bits from a scalar +// it is used during the multiExp algorithm or the batch scalar multiplication +type selector struct { + index uint64 // index in the multi-word scalar to select bits from + mask uint64 // mask (c-bit wide) + shift uint64 // shift needed to get our bits on low positions + multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) + maskHigh uint64 // same than mask, for index+1 + shiftHigh uint64 // same than shift, for index+1 } -func (p *G2Jac) msmC4(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 4 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } +// return number of chunks for a given window size c +// the last chunk may be bigger to accomodate a potential carry from the NAF decomposition +func computeNbChunks(c uint64) uint64 { + return (fr.Bits + c - 1) / c +} - return msmReduceChunkG2Affine(p, c, chChunks[:]) +// return the last window size for a scalar; +// this last window should accomodate a carry (from the NAF decomposition) +// it can be == c if we have 1 available bit +// it can be > c if we have 0 available bit +// it can be < c if we have 2+ available bits +func lastC(c uint64) uint64 { + nbAvailableBits := (computeNbChunks(c) * c) - fr.Bits + return c + 1 - nbAvailableBits } -func (p *G2Jac) msmC5(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 5 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) +type chunkStat struct { + // relative weight of work compared to other chunks. 100.0 -> nominal weight. + weight float32 - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + // percentage of bucket filled in the window; + ppBucketFilled float32 + nbBucketFilled int +} - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } +// partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits +// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract +// 2^{c} to the current digit, making it negative. +// negative digits can be processed in a later step as adding -G into the bucket instead of G +// (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) +func partitionScalars(scalars []fr.Element, c uint64, nbTasks int) ([]uint16, []chunkStat) { + // number of c-bit radixes in a scalar + nbChunks := computeNbChunks(c) - // c doesn't divide 384, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) + digits := make([]uint16, len(scalars)*int(nbChunks)) - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + mask := uint64((1 << c) - 1) // low c bits are 1 + max := int(1<<(c-1)) - 1 // max value (inclusive) we want for our digits + cDivides64 := (64 % c) == 0 // if c doesn't divide 64, we may need to select over multiple words - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) + // compute offset and word selector / shift to select the right bits of our windows + selectors := make([]selector, nbChunks) + for chunk := uint64(0); chunk < nbChunks; chunk++ { + jc := uint64(chunk * c) + d := selector{} + d.index = jc / 64 + d.shift = jc - (d.index * 64) + d.mask = mask << d.shift + d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs-1) + if d.multiWordSelect { + nbBitsHigh := d.shift - uint64(64-c) + d.maskHigh = (1 << nbBitsHigh) - 1 + d.shiftHigh = (c - nbBitsHigh) + } + selectors[chunk] = d } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } + parallel.Execute(len(scalars), func(start, end int) { + for i := start; i < end; i++ { + if scalars[i].IsZero() { + // everything is 0, no need to process this scalar + continue + } + scalar := scalars[i].Bits() - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} + var carry int -func (p *G2Jac) msmC8(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 8 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) + // for each chunk in the scalar, compute the current digit, and an eventual carry + for chunk := uint64(0); chunk < nbChunks-1; chunk++ { + s := selectors[chunk] - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + // init with carry if any + digit := carry + carry = 0 - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } + // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract + // 2^{c} to the current digit, making it negative. + if digit > max { + digit -= (1 << c) + carry = 1 + } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } + // if digit is zero, no impact on result + if digit == 0 { + continue + } - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} + var bits uint16 + if digit > 0 { + bits = uint16(digit) << 1 + } else { + bits = (uint16(-digit-1) << 1) + 1 + } + digits[int(chunk)*len(scalars)+i] = bits + } -func (p *G2Jac) msmC16(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 16 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) + // for the last chunk, we don't want to borrow from a next window + // (but may have a larger max value) + chunk := nbChunks - 1 + s := selectors[chunk] + // init with carry if any + digit := carry + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } + digits[int(chunk)*len(scalars)+i] = uint16(digit) << 1 + } - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + }, nbTasks) - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) + // aggregate chunk stats + chunkStats := make([]chunkStat, nbChunks) + if c <= 9 { + // no need to compute stats for small window sizes + return digits, chunkStats } + parallel.Execute(len(chunkStats), func(start, end int) { + // for each chunk compute the statistics + for chunkID := start; chunkID < end; chunkID++ { + // indicates if a bucket is hit. + var b bitSetC16 - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + // digits for the chunk + chunkDigits := digits[chunkID*len(scalars) : (chunkID+1)*len(scalars)] - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) + totalOps := 0 + nz := 0 // non zero buckets count + for _, digit := range chunkDigits { + if digit == 0 { + continue + } + totalOps++ + bucketID := digit >> 1 + if digit&1 == 0 { + bucketID -= 1 + } + if !b[bucketID] { + nz++ + b[bucketID] = true + } + } + chunkStats[chunkID].weight = float32(totalOps) // count number of ops for now, we will compute the weight after + chunkStats[chunkID].ppBucketFilled = (float32(nz) * 100.0) / float32(int(1<<(c-1))) + chunkStats[chunkID].nbBucketFilled = nz + } + }, nbTasks) + + totalOps := float32(0.0) + for _, stat := range chunkStats { + totalOps += stat.weight } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + target := totalOps / float32(nbChunks) + if target != 0.0 { + // if target == 0, it means all the scalars are 0 everywhere, there is no work to be done. + for i := 0; i < len(chunkStats); i++ { + chunkStats[i].weight = (chunkStats[i].weight * 100.0) / target + } } - return msmReduceChunkG2Affine(p, c, chChunks[:]) + return digits, chunkStats } diff --git a/ecc/bw6-756/multiexp_affine.go b/ecc/bw6-756/multiexp_affine.go new file mode 100644 index 000000000..086c2e9f8 --- /dev/null +++ b/ecc/bw6-756/multiexp_affine.go @@ -0,0 +1,549 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package bw6756 + +import ( + "github.com/consensys/gnark-crypto/ecc/bw6-756/fp" +) + +type batchOpG1Affine struct { + bucketID uint16 + point G1Affine +} + +// processChunkG1BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/ConsenSys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunkG1BatchAffine[BJE ibg1JacExtended, B ibG1Affine, BS bitSet, TP pG1Affine, TPP ppG1Affine, TQ qOpsG1Affine, TC cG1Affine]( + chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16) { + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // g1JacExtended coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediatly + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in G1Affine used with the batch affine additions + // 1 in g1JacExtended used in case the queue of conflicting points + var buckets B + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + bucketsJE[i].setInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize } + + executeAndReset := func() { + batchAddG1Affine[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOpG1Affine) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.setInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *G1Affine, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.setInfinity() + } + return + } + if isAdd { + BK.setInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func() { + for i := 0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func() { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit >> 1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID -= 1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue)-1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total g1JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].ZZ.IsZero() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketG1AffineC11 [1024]G1Affine +type bucketG1AffineC16 [32768]G1Affine + +// buckets: array of G1Affine points of size 1 << (c-1) +type ibG1Affine interface { + bucketG1AffineC11 | + bucketG1AffineC16 +} + +// array of coordinates fp.Element +type cG1Affine interface { + cG1AffineC11 | + cG1AffineC16 +} + +// buckets: array of G1Affine points (for the batch addition) +type pG1Affine interface { + pG1AffineC11 | + pG1AffineC16 +} + +// buckets: array of *G1Affine points (for the batch addition) +type ppG1Affine interface { + ppG1AffineC11 | + ppG1AffineC16 +} + +// buckets: array of G1Affine queue operations (for the batch addition) +type qOpsG1Affine interface { + qG1AffineC11 | + qG1AffineC16 +} + +// batch size 150 when c = 11 +type cG1AffineC11 [150]fp.Element +type pG1AffineC11 [150]G1Affine +type ppG1AffineC11 [150]*G1Affine +type qG1AffineC11 [150]batchOpG1Affine + +// batch size 640 when c = 16 +type cG1AffineC16 [640]fp.Element +type pG1AffineC16 [640]G1Affine +type ppG1AffineC16 [640]*G1Affine +type qG1AffineC16 [640]batchOpG1Affine + +type batchOpG2Affine struct { + bucketID uint16 + point G2Affine +} + +// processChunkG2BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/ConsenSys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunkG2BatchAffine[BJE ibg2JacExtended, B ibG2Affine, BS bitSet, TP pG2Affine, TPP ppG2Affine, TQ qOpsG2Affine, TC cG2Affine]( + chunk uint64, + chRes chan<- g2JacExtended, + c uint64, + points []G2Affine, + digits []uint16) { + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // g2JacExtended coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediatly + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in G2Affine used with the batch affine additions + // 1 in g2JacExtended used in case the queue of conflicting points + var buckets B + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + bucketsJE[i].setInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize } + + executeAndReset := func() { + batchAddG2Affine[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOpG2Affine) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.setInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *G2Affine, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.setInfinity() + } + return + } + if isAdd { + BK.setInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func() { + for i := 0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func() { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit >> 1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID -= 1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue)-1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total g2JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].ZZ.IsZero() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketG2AffineC11 [1024]G2Affine +type bucketG2AffineC16 [32768]G2Affine + +// buckets: array of G2Affine points of size 1 << (c-1) +type ibG2Affine interface { + bucketG2AffineC11 | + bucketG2AffineC16 +} + +// array of coordinates fp.Element +type cG2Affine interface { + cG2AffineC11 | + cG2AffineC16 +} + +// buckets: array of G2Affine points (for the batch addition) +type pG2Affine interface { + pG2AffineC11 | + pG2AffineC16 +} + +// buckets: array of *G2Affine points (for the batch addition) +type ppG2Affine interface { + ppG2AffineC11 | + ppG2AffineC16 +} + +// buckets: array of G2Affine queue operations (for the batch addition) +type qOpsG2Affine interface { + qG2AffineC11 | + qG2AffineC16 +} + +// batch size 150 when c = 11 +type cG2AffineC11 [150]fp.Element +type pG2AffineC11 [150]G2Affine +type ppG2AffineC11 [150]*G2Affine +type qG2AffineC11 [150]batchOpG2Affine + +// batch size 640 when c = 16 +type cG2AffineC16 [640]fp.Element +type pG2AffineC16 [640]G2Affine +type ppG2AffineC16 [640]*G2Affine +type qG2AffineC16 [640]batchOpG2Affine + +type bitSetC3 [4]bool +type bitSetC4 [8]bool +type bitSetC5 [16]bool +type bitSetC8 [128]bool +type bitSetC11 [1024]bool +type bitSetC16 [32768]bool + +type bitSet interface { + bitSetC3 | + bitSetC4 | + bitSetC5 | + bitSetC8 | + bitSetC11 | + bitSetC16 +} diff --git a/ecc/bw6-756/multiexp_jacobian.go b/ecc/bw6-756/multiexp_jacobian.go new file mode 100644 index 000000000..cd1504413 --- /dev/null +++ b/ecc/bw6-756/multiexp_jacobian.go @@ -0,0 +1,139 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package bw6756 + +func processChunkG1Jacobian[B ibg1JacExtended](chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16) { + + var buckets B + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to substract + if digit&1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit >> 1)].subMixed(&points[i]) + } + } + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total g1JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].ZZ.IsZero() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketg1JacExtendedC3 [4]g1JacExtended +type bucketg1JacExtendedC4 [8]g1JacExtended +type bucketg1JacExtendedC5 [16]g1JacExtended +type bucketg1JacExtendedC8 [128]g1JacExtended +type bucketg1JacExtendedC11 [1024]g1JacExtended +type bucketg1JacExtendedC16 [32768]g1JacExtended + +type ibg1JacExtended interface { + bucketg1JacExtendedC3 | + bucketg1JacExtendedC4 | + bucketg1JacExtendedC5 | + bucketg1JacExtendedC8 | + bucketg1JacExtendedC11 | + bucketg1JacExtendedC16 +} + +func processChunkG2Jacobian[B ibg2JacExtended](chunk uint64, + chRes chan<- g2JacExtended, + c uint64, + points []G2Affine, + digits []uint16) { + + var buckets B + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to substract + if digit&1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit >> 1)].subMixed(&points[i]) + } + } + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total g2JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].ZZ.IsZero() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketg2JacExtendedC3 [4]g2JacExtended +type bucketg2JacExtendedC4 [8]g2JacExtended +type bucketg2JacExtendedC5 [16]g2JacExtended +type bucketg2JacExtendedC8 [128]g2JacExtended +type bucketg2JacExtendedC11 [1024]g2JacExtended +type bucketg2JacExtendedC16 [32768]g2JacExtended + +type ibg2JacExtended interface { + bucketg2JacExtendedC3 | + bucketg2JacExtendedC4 | + bucketg2JacExtendedC5 | + bucketg2JacExtendedC8 | + bucketg2JacExtendedC11 | + bucketg2JacExtendedC16 +} diff --git a/ecc/bw6-756/multiexp_test.go b/ecc/bw6-756/multiexp_test.go index 4a6f1b161..8c8dec367 100644 --- a/ecc/bw6-756/multiexp_test.go +++ b/ecc/bw6-756/multiexp_test.go @@ -20,9 +20,11 @@ import ( "fmt" "math/big" "math/bits" + "math/rand" "runtime" "sync" "testing" + "time" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-756/fr" @@ -34,9 +36,9 @@ func TestMultiExpG1(t *testing.T) { parameters := gopter.DefaultTestParameters() if testing.Short() { - parameters.MinSuccessfulTests = 2 + parameters.MinSuccessfulTests = 3 } else { - parameters.MinSuccessfulTests = nbFuzzShort + parameters.MinSuccessfulTests = nbFuzzShort * 2 } properties := gopter.NewProperties(parameters) @@ -55,6 +57,13 @@ func TestMultiExpG1(t *testing.T) { g.AddAssign(&g1Gen) } + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + // final scalar to use in double and add method (without mixer factor) // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) var scalar big.Int @@ -78,13 +87,10 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - scalars16, _ := partitionScalars(sampleScalars[:], 16, false, runtime.NumCPU()) - r16.msmC16(samplePoints[:], scalars16, true) - + r16.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) return r16.Equal(&splitted1) && r16.Equal(&splitted2) @@ -93,7 +99,7 @@ func TestMultiExpG1(t *testing.T) { )) // cRange is generated from template and contains the available parameters for the multiexp window size - cRange := []uint64{4, 5, 8, 16} + cRange := []uint64{3, 4, 5, 8, 11, 16} if testing.Short() { // test only "odd" and "even" (ie windows size divide word size vs not) cRange = []uint64{5, 16} @@ -114,21 +120,72 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - results := make([]G1Jac, len(cRange)+1) + results := make([]G1Jac, len(cRange)) for i, c := range cRange { - scalars, _ := partitionScalars(sampleScalars[:], c, false, runtime.NumCPU()) - msmInnerG1Jac(&results[i], int(c), samplePoints[:], scalars, false) - if c == 16 { - // split the first chunk - msmInnerG1Jac(&results[len(results)-1], 16, samplePoints[:], scalars, true) - } + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) } for i := 1; i < len(results); i++ { if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1], cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]G1Affine + + var expected G1Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) + expected.ScalarMultiplication(&g1Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].setInfinity() + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) return false } } @@ -151,8 +208,7 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= 30; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) samplePoints[i-1].FromJacobian(&g) g.AddAssign(&g1Gen) } @@ -175,6 +231,87 @@ func TestMultiExpG1(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestCrossMultiExpG1(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]G1Affine + var g G1Jac + g.Set(&g1Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g1Gen) + } + + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i := 10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + cRange := []uint64{3, 4, 5, 8, 11, 16} + if testing.Short() { + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange = []uint64{5, 16} + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r G1Jac + _innerMsmG1Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got G1Affine + expected.FromJacobian(&r) + + for i := 0; i < len(results); i++ { + got.FromJacobian(&results[i]) + if !expected.Equal(&got) { + t.Fatalf("cross msm failed with c=%d", cRange[i]) + } + } + +} + +// _innerMsmG1Reference always do ext jacobian with c == 16 +func _innerMsmG1Reference(p *G1Jac, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { + // partition the scalars + digits, _ := partitionScalars(scalars, 16, config.NbTasks) + + nbChunks := computeNbChunks(16) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := processChunkG1Jacobian[bucketg1JacExtendedC16] + go processChunk(uint64(j), chChunks[j], 16, points, digits[j*n:(j+1)*n]) + } + + return msmReduceChunkG1Affine(p, int(16), chChunks[:]) +} + func BenchmarkMultiExpG1(b *testing.B) { const ( @@ -183,11 +320,33 @@ func BenchmarkMultiExpG1(b *testing.B) { ) var ( - samplePoints [nbSamples]G1Affine - sampleScalars [nbSamples]fr.Element + samplePoints [nbSamples]G1Affine + sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element ) fillBenchScalars(sampleScalars[:]) + copy(sampleScalarsSmallValues[:], sampleScalars[:]) + copy(sampleScalarsRedundant[:], sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i := 0; i < len(sampleScalarsSmallValues); i++ { + if i%5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i := 0; i < len(sampleScalarsRedundant); i += 100 { + for j := i + 1; j < i+100 && j < len(sampleScalarsRedundant); j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + fillBenchBasesG1(samplePoints[:]) var testPoint G1Affine @@ -201,6 +360,20 @@ func BenchmarkMultiExpG1(b *testing.B) { testPoint.MultiExp(samplePoints[:using], sampleScalars[:using], ecc.MultiExpConfig{}) } }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using], ecc.MultiExpConfig{}) + } + }) } } @@ -279,9 +452,9 @@ func TestMultiExpG2(t *testing.T) { parameters := gopter.DefaultTestParameters() if testing.Short() { - parameters.MinSuccessfulTests = 2 + parameters.MinSuccessfulTests = 3 } else { - parameters.MinSuccessfulTests = nbFuzzShort + parameters.MinSuccessfulTests = nbFuzzShort * 2 } properties := gopter.NewProperties(parameters) @@ -300,6 +473,13 @@ func TestMultiExpG2(t *testing.T) { g.AddAssign(&g2Gen) } + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + // final scalar to use in double and add method (without mixer factor) // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) var scalar big.Int @@ -323,13 +503,10 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - scalars16, _ := partitionScalars(sampleScalars[:], 16, false, runtime.NumCPU()) - r16.msmC16(samplePoints[:], scalars16, true) - + r16.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) return r16.Equal(&splitted1) && r16.Equal(&splitted2) @@ -357,21 +534,72 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - results := make([]G2Jac, len(cRange)+1) + results := make([]G2Jac, len(cRange)) for i, c := range cRange { - scalars, _ := partitionScalars(sampleScalars[:], c, false, runtime.NumCPU()) - msmInnerG2Jac(&results[i], int(c), samplePoints[:], scalars, false) - if c == 16 { - // split the first chunk - msmInnerG2Jac(&results[len(results)-1], 16, samplePoints[:], scalars, true) - } + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) } for i := 1; i < len(results); i++ { if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1], cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G2] Multi exponentation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]G2Affine + + var expected G2Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) + expected.ScalarMultiplication(&g2Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].setInfinity() + } + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G2] Multi exponentation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) return false } } @@ -394,8 +622,7 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= 30; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) samplePoints[i-1].FromJacobian(&g) g.AddAssign(&g2Gen) } @@ -418,6 +645,85 @@ func TestMultiExpG2(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestCrossMultiExpG2(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]G2Affine + var g G2Jac + g.Set(&g2Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g2Gen) + } + + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i := 10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + // for g2, CI suffers with large c size since it needs to allocate a lot of memory for the buckets. + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange := []uint64{5, 16} + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r G2Jac + _innerMsmG2Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got G2Affine + expected.FromJacobian(&r) + + for i := 0; i < len(results); i++ { + got.FromJacobian(&results[i]) + if !expected.Equal(&got) { + t.Fatalf("cross msm failed with c=%d", cRange[i]) + } + } + +} + +// _innerMsmG2Reference always do ext jacobian with c == 16 +func _innerMsmG2Reference(p *G2Jac, points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G2Jac { + // partition the scalars + digits, _ := partitionScalars(scalars, 16, config.NbTasks) + + nbChunks := computeNbChunks(16) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g2JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g2JacExtended, 1) + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := processChunkG2Jacobian[bucketg2JacExtendedC16] + go processChunk(uint64(j), chChunks[j], 16, points, digits[j*n:(j+1)*n]) + } + + return msmReduceChunkG2Affine(p, int(16), chChunks[:]) +} + func BenchmarkMultiExpG2(b *testing.B) { const ( @@ -426,11 +732,33 @@ func BenchmarkMultiExpG2(b *testing.B) { ) var ( - samplePoints [nbSamples]G2Affine - sampleScalars [nbSamples]fr.Element + samplePoints [nbSamples]G2Affine + sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element ) fillBenchScalars(sampleScalars[:]) + copy(sampleScalarsSmallValues[:], sampleScalars[:]) + copy(sampleScalarsRedundant[:], sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i := 0; i < len(sampleScalarsSmallValues); i++ { + if i%5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i := 0; i < len(sampleScalarsRedundant); i += 100 { + for j := i + 1; j < i+100 && j < len(sampleScalarsRedundant); j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + fillBenchBasesG2(samplePoints[:]) var testPoint G2Affine @@ -444,6 +772,20 @@ func BenchmarkMultiExpG2(b *testing.B) { testPoint.MultiExp(samplePoints[:using], sampleScalars[:using], ecc.MultiExpConfig{}) } }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using], ecc.MultiExpConfig{}) + } + }) } } @@ -520,11 +862,7 @@ func fillBenchBasesG2(samplePoints []G2Affine) { func fillBenchScalars(sampleScalars []fr.Element) { // ensure every words of the scalars are filled - var mixer fr.Element - mixer.SetString("7716837800905789770901243404444209691916730933998574719964609384059111546487") - for i := 1; i <= len(sampleScalars); i++ { - sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + for i := 0; i < len(sampleScalars); i++ { + sampleScalars[i].SetRandom() } } diff --git a/ecc/bw6-756/twistededwards/eddsa/doc.go b/ecc/bw6-756/twistededwards/eddsa/doc.go index 65fdfe7af..2e2d4c888 100644 --- a/ecc/bw6-756/twistededwards/eddsa/doc.go +++ b/ecc/bw6-756/twistededwards/eddsa/doc.go @@ -16,7 +16,7 @@ // Package eddsa provides EdDSA signature scheme on bw6-756's twisted edwards curve. // -// See also +// # See also // // https://en.wikipedia.org/wiki/EdDSA package eddsa diff --git a/ecc/bw6-756/twistededwards/eddsa/eddsa_test.go b/ecc/bw6-756/twistededwards/eddsa/eddsa_test.go index 8cc3ed1f4..82388cfdd 100644 --- a/ecc/bw6-756/twistededwards/eddsa/eddsa_test.go +++ b/ecc/bw6-756/twistededwards/eddsa/eddsa_test.go @@ -37,8 +37,10 @@ func Example() { privateKey, _ := GenerateKey(crand.Reader) publicKey := privateKey.PublicKey - // note that the message is on 4 bytes - msg := []byte{0xde, 0xad, 0xf0, 0x0d} + // generate a message (the size must be a multiple of the size of Fr) + var _msg fr.Element + _msg.SetRandom() + msg := _msg.Marshal() // sign the message signature, _ := privateKey.Sign(msg, hFunc) diff --git a/ecc/bw6-756/twistededwards/eddsa/marshal.go b/ecc/bw6-756/twistededwards/eddsa/marshal.go index c68129087..485e9b710 100644 --- a/ecc/bw6-756/twistededwards/eddsa/marshal.go +++ b/ecc/bw6-756/twistededwards/eddsa/marshal.go @@ -94,11 +94,11 @@ func (privKey *PrivateKey) SetBytes(buf []byte) (int, error) { // Bytes returns the binary representation of sig // as a byte array of size 3*sizeFr x||y||s where -// * x, y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x, y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) func (sig *Signature) Bytes() []byte { var res [sizeSignature]byte sigRBin := sig.R.Bytes() @@ -109,11 +109,12 @@ func (sig *Signature) Bytes() []byte { // SetBytes sets sig from a buffer in binary. // buf is read interpreted as x||y||s where -// * x,y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x,y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) +// // It returns the number of bytes read from buf. func (sig *Signature) SetBytes(buf []byte) (int, error) { n := 0 diff --git a/ecc/bw6-756/twistededwards/point.go b/ecc/bw6-756/twistededwards/point.go index fab8897d9..a0454e8b9 100644 --- a/ecc/bw6-756/twistededwards/point.go +++ b/ecc/bw6-756/twistededwards/point.go @@ -49,7 +49,7 @@ const ( mUnmask = 0x7f // size in byte of a compressed point (point.Y --> fr.Element) - sizePointCompressed = fr.Limbs * 8 + sizePointCompressed = fr.Bytes ) // Bytes returns the compressed point as a byte array diff --git a/ecc/bw6-761/bw6-761.go b/ecc/bw6-761/bw6-761.go index a115d92dc..6de70d925 100644 --- a/ecc/bw6-761/bw6-761.go +++ b/ecc/bw6-761/bw6-761.go @@ -1,24 +1,30 @@ // Package bw6761 efficient elliptic curve, pairing and hash to curve implementation for bw6-761. // // bw6-761: A Brezing--Weng curve (2-chain with bls12-377) -// embedding degree k=6 -// seed x₀=9586122913090633729 -// 𝔽p: p=6891450384315732539396789682275657542479668912536150109513790160209623422243491736087683183289411687640864567753786613451161759120554247759349511699125301598951605099378508850372543631423596795951899700429969112842764913119068299 -// 𝔽r: r=258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 -// (E/𝔽p): Y²=X³-1 -// (Eₜ/𝔽p): Y² = X³+4 (M-type twist) -// r ∣ #E(Fp) and r ∣ #Eₜ(𝔽p) +// +// embedding degree k=6 +// seed x₀=9586122913090633729 +// 𝔽p: p=6891450384315732539396789682275657542479668912536150109513790160209623422243491736087683183289411687640864567753786613451161759120554247759349511699125301598951605099378508850372543631423596795951899700429969112842764913119068299 +// 𝔽r: r=258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 +// (E/𝔽p): Y²=X³-1 +// (Eₜ/𝔽p): Y² = X³+4 (M-type twist) +// r ∣ #E(Fp) and r ∣ #Eₜ(𝔽p) +// // Extension fields tower: -// 𝔽p³[u] = 𝔽p/u³+4 -// 𝔽p⁶[v] = 𝔽p²/v²-u +// +// 𝔽p³[u] = 𝔽p/u³+4 +// 𝔽p⁶[v] = 𝔽p²/v²-u +// // optimal Ate loops: -// x₀+1, x₀²-x₀-1 +// +// x₀+1, x₀²-x₀-1 +// // Security: estimated 126-bit level following [https://eprint.iacr.org/2019/885.pdf] // (r is 377 bits and p⁶ is 4566 bits) // // https://eprint.iacr.org/2020/351.pdf // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package bw6761 diff --git a/ecc/bw6-761/fp/doc.go b/ecc/bw6-761/fp/doc.go index a7340e5dc..1e84636b7 100644 --- a/ecc/bw6-761/fp/doc.go +++ b/ecc/bw6-761/fp/doc.go @@ -21,30 +21,33 @@ // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [12]uint64 // -// Usage +// type Element [12]uint64 +// +// # Usage // // Example API signature: -// // Mul z = x * y (mod q) -// func (z *Element) Mul(x, y *Element) *Element +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) // // Modulus q = // -// q[base10] = 6891450384315732539396789682275657542479668912536150109513790160209623422243491736087683183289411687640864567753786613451161759120554247759349511699125301598951605099378508850372543631423596795951899700429969112842764913119068299 -// q[base16] = 0x122e824fb83ce0ad187c94004faff3eb926186a81d14688528275ef8087be41707ba638e584e91903cebaff25b423048689c8ed12f9fd9071dcd3dc73ebff2e98a116c25667a8f8160cf8aeeaf0a437e6913e6870000082f49d00000000008b +// q[base10] = 6891450384315732539396789682275657542479668912536150109513790160209623422243491736087683183289411687640864567753786613451161759120554247759349511699125301598951605099378508850372543631423596795951899700429969112842764913119068299 +// q[base16] = 0x122e824fb83ce0ad187c94004faff3eb926186a81d14688528275ef8087be41707ba638e584e91903cebaff25b423048689c8ed12f9fd9071dcd3dc73ebff2e98a116c25667a8f8160cf8aeeaf0a437e6913e6870000082f49d00000000008b // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package fp diff --git a/ecc/bw6-761/fp/element.go b/ecc/bw6-761/fp/element.go index a410b94f8..ad3601b0b 100644 --- a/ecc/bw6-761/fp/element.go +++ b/ecc/bw6-761/fp/element.go @@ -20,13 +20,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "github.com/consensys/gnark-crypto/field" "io" "math/big" "math/bits" "reflect" "strconv" "strings" - "sync" ) // Element represents a field element stored on 12 words (uint64) @@ -35,18 +35,18 @@ import ( // // Modulus q = // -// q[base10] = 6891450384315732539396789682275657542479668912536150109513790160209623422243491736087683183289411687640864567753786613451161759120554247759349511699125301598951605099378508850372543631423596795951899700429969112842764913119068299 -// q[base16] = 0x122e824fb83ce0ad187c94004faff3eb926186a81d14688528275ef8087be41707ba638e584e91903cebaff25b423048689c8ed12f9fd9071dcd3dc73ebff2e98a116c25667a8f8160cf8aeeaf0a437e6913e6870000082f49d00000000008b +// q[base10] = 6891450384315732539396789682275657542479668912536150109513790160209623422243491736087683183289411687640864567753786613451161759120554247759349511699125301598951605099378508850372543631423596795951899700429969112842764913119068299 +// q[base16] = 0x122e824fb83ce0ad187c94004faff3eb926186a81d14688528275ef8087be41707ba638e584e91903cebaff25b423048689c8ed12f9fd9071dcd3dc73ebff2e98a116c25667a8f8160cf8aeeaf0a437e6913e6870000082f49d00000000008b // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [12]uint64 const ( - Limbs = 12 // number of 64 bits words needed to represent a Element - Bits = 761 // number of bits needed to represent a Element - Bytes = Limbs * 8 // number of bytes needed to represent a Element + Limbs = 12 // number of 64 bits words needed to represent a Element + Bits = 761 // number of bits needed to represent a Element + Bytes = 96 // number of bytes needed to represent a Element ) // Field modulus q @@ -84,8 +84,8 @@ var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int // -// q[base10] = 6891450384315732539396789682275657542479668912536150109513790160209623422243491736087683183289411687640864567753786613451161759120554247759349511699125301598951605099378508850372543631423596795951899700429969112842764913119068299 -// q[base16] = 0x122e824fb83ce0ad187c94004faff3eb926186a81d14688528275ef8087be41707ba638e584e91903cebaff25b423048689c8ed12f9fd9071dcd3dc73ebff2e98a116c25667a8f8160cf8aeeaf0a437e6913e6870000082f49d00000000008b +// q[base10] = 6891450384315732539396789682275657542479668912536150109513790160209623422243491736087683183289411687640864567753786613451161759120554247759349511699125301598951605099378508850372543631423596795951899700429969112842764913119068299 +// q[base16] = 0x122e824fb83ce0ad187c94004faff3eb926186a81d14688528275ef8087be41707ba638e584e91903cebaff25b423048689c8ed12f9fd9071dcd3dc73ebff2e98a116c25667a8f8160cf8aeeaf0a437e6913e6870000082f49d00000000008b func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } @@ -94,12 +94,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 744663313386281181 -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("122e824fb83ce0ad187c94004faff3eb926186a81d14688528275ef8087be41707ba638e584e91903cebaff25b423048689c8ed12f9fd9071dcd3dc73ebff2e98a116c25667a8f8160cf8aeeaf0a437e6913e6870000082f49d00000000008b", 16) } @@ -107,8 +101,9 @@ func init() { // NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v Element -// v.SetUint64(...) +// +// var v Element +// v.SetUint64(...) func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) @@ -119,7 +114,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -157,14 +152,15 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported // supported types: -// Element -// *Element -// uint64 -// int -// string (see SetString for valid formats) -// *big.Int -// big.Int -// []byte +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { if i1 == nil { return nil, errors.New("can't set fp.Element with ") @@ -290,15 +286,13 @@ func (z *Element) IsOne() bool { // IsUint64 reports whether z can be represented as an uint64. func (z *Element) IsUint64() bool { zz := *z - zz.FromMont() + zz.fromMont() return zz.FitsOnOneWord() } // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -310,15 +304,12 @@ func (z *Element) FitsOnOneWord() bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[11] > _x[11] { return 1 } else if _z[11] < _x[11] { @@ -389,8 +380,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 8813122258298994758, 0) @@ -513,67 +503,9 @@ func (z *Element) Halve() { } -// Mul z = x * y (mod q) -// -// x and y must be strictly inferior to q -func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be strictly inferior to q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } @@ -595,7 +527,7 @@ func (z *Element) Add(x, y *Element) *Element { z[10], carry = bits.Add64(x[10], y[10], carry) z[11], _ = bits.Add64(x[11], y[11], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -631,7 +563,7 @@ func (z *Element) Double(x *Element) *Element { z[10], carry = bits.Add64(x[10], x[10], carry) z[11], _ = bits.Add64(x[11], x[11], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -724,361 +656,531 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // see Mul for algorithm documentation - var t [12]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd1(v, y[4], c[1]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd1(v, y[5], c[1]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd1(v, y[6], c[1]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd1(v, y[7], c[1]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd1(v, y[8], c[1]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd1(v, y[9], c[1]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd1(v, y[10], c[1]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd1(v, y[11], c[1]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 4 - v := x[4] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 5 - v := x[5] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 6 - v := x[6] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 7 - v := x[7] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 8 - v := x[8] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 9 - v := x[9] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 10 - v := x[10] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], t[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], t[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], t[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], t[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], t[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], t[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - t[11], t[10] = madd3(m, q11, c[0], c[2], c[1]) - } - { - // round 11 - v := x[11] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], z[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], z[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - c[2], z[4] = madd2(m, q5, c[2], c[0]) - c[1], c[0] = madd2(v, y[6], c[1], t[6]) - c[2], z[5] = madd2(m, q6, c[2], c[0]) - c[1], c[0] = madd2(v, y[7], c[1], t[7]) - c[2], z[6] = madd2(m, q7, c[2], c[0]) - c[1], c[0] = madd2(v, y[8], c[1], t[8]) - c[2], z[7] = madd2(m, q8, c[2], c[0]) - c[1], c[0] = madd2(v, y[9], c[1], t[9]) - c[2], z[8] = madd2(m, q9, c[2], c[0]) - c[1], c[0] = madd2(v, y[10], c[1], t[10]) - c[2], z[9] = madd2(m, q10, c[2], c[0]) - c[1], c[0] = madd2(v, y[11], c[1], t[11]) - z[11], z[10] = madd3(m, q11, c[0], c[2], c[1]) - } - - // if z >= q → z -= q + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [13]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + C, t[4] = madd1(y[0], x[4], C) + C, t[5] = madd1(y[0], x[5], C) + C, t[6] = madd1(y[0], x[6], C) + C, t[7] = madd1(y[0], x[7], C) + C, t[8] = madd1(y[0], x[8], C) + C, t[9] = madd1(y[0], x[9], C) + C, t[10] = madd1(y[0], x[10], C) + C, t[11] = madd1(y[0], x[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + C, t[4] = madd2(y[1], x[4], t[4], C) + C, t[5] = madd2(y[1], x[5], t[5], C) + C, t[6] = madd2(y[1], x[6], t[6], C) + C, t[7] = madd2(y[1], x[7], t[7], C) + C, t[8] = madd2(y[1], x[8], t[8], C) + C, t[9] = madd2(y[1], x[9], t[9], C) + C, t[10] = madd2(y[1], x[10], t[10], C) + C, t[11] = madd2(y[1], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + C, t[4] = madd2(y[2], x[4], t[4], C) + C, t[5] = madd2(y[2], x[5], t[5], C) + C, t[6] = madd2(y[2], x[6], t[6], C) + C, t[7] = madd2(y[2], x[7], t[7], C) + C, t[8] = madd2(y[2], x[8], t[8], C) + C, t[9] = madd2(y[2], x[9], t[9], C) + C, t[10] = madd2(y[2], x[10], t[10], C) + C, t[11] = madd2(y[2], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + C, t[4] = madd2(y[3], x[4], t[4], C) + C, t[5] = madd2(y[3], x[5], t[5], C) + C, t[6] = madd2(y[3], x[6], t[6], C) + C, t[7] = madd2(y[3], x[7], t[7], C) + C, t[8] = madd2(y[3], x[8], t[8], C) + C, t[9] = madd2(y[3], x[9], t[9], C) + C, t[10] = madd2(y[3], x[10], t[10], C) + C, t[11] = madd2(y[3], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[4], x[0], t[0]) + C, t[1] = madd2(y[4], x[1], t[1], C) + C, t[2] = madd2(y[4], x[2], t[2], C) + C, t[3] = madd2(y[4], x[3], t[3], C) + C, t[4] = madd2(y[4], x[4], t[4], C) + C, t[5] = madd2(y[4], x[5], t[5], C) + C, t[6] = madd2(y[4], x[6], t[6], C) + C, t[7] = madd2(y[4], x[7], t[7], C) + C, t[8] = madd2(y[4], x[8], t[8], C) + C, t[9] = madd2(y[4], x[9], t[9], C) + C, t[10] = madd2(y[4], x[10], t[10], C) + C, t[11] = madd2(y[4], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[5], x[0], t[0]) + C, t[1] = madd2(y[5], x[1], t[1], C) + C, t[2] = madd2(y[5], x[2], t[2], C) + C, t[3] = madd2(y[5], x[3], t[3], C) + C, t[4] = madd2(y[5], x[4], t[4], C) + C, t[5] = madd2(y[5], x[5], t[5], C) + C, t[6] = madd2(y[5], x[6], t[6], C) + C, t[7] = madd2(y[5], x[7], t[7], C) + C, t[8] = madd2(y[5], x[8], t[8], C) + C, t[9] = madd2(y[5], x[9], t[9], C) + C, t[10] = madd2(y[5], x[10], t[10], C) + C, t[11] = madd2(y[5], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[6], x[0], t[0]) + C, t[1] = madd2(y[6], x[1], t[1], C) + C, t[2] = madd2(y[6], x[2], t[2], C) + C, t[3] = madd2(y[6], x[3], t[3], C) + C, t[4] = madd2(y[6], x[4], t[4], C) + C, t[5] = madd2(y[6], x[5], t[5], C) + C, t[6] = madd2(y[6], x[6], t[6], C) + C, t[7] = madd2(y[6], x[7], t[7], C) + C, t[8] = madd2(y[6], x[8], t[8], C) + C, t[9] = madd2(y[6], x[9], t[9], C) + C, t[10] = madd2(y[6], x[10], t[10], C) + C, t[11] = madd2(y[6], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[7], x[0], t[0]) + C, t[1] = madd2(y[7], x[1], t[1], C) + C, t[2] = madd2(y[7], x[2], t[2], C) + C, t[3] = madd2(y[7], x[3], t[3], C) + C, t[4] = madd2(y[7], x[4], t[4], C) + C, t[5] = madd2(y[7], x[5], t[5], C) + C, t[6] = madd2(y[7], x[6], t[6], C) + C, t[7] = madd2(y[7], x[7], t[7], C) + C, t[8] = madd2(y[7], x[8], t[8], C) + C, t[9] = madd2(y[7], x[9], t[9], C) + C, t[10] = madd2(y[7], x[10], t[10], C) + C, t[11] = madd2(y[7], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[8], x[0], t[0]) + C, t[1] = madd2(y[8], x[1], t[1], C) + C, t[2] = madd2(y[8], x[2], t[2], C) + C, t[3] = madd2(y[8], x[3], t[3], C) + C, t[4] = madd2(y[8], x[4], t[4], C) + C, t[5] = madd2(y[8], x[5], t[5], C) + C, t[6] = madd2(y[8], x[6], t[6], C) + C, t[7] = madd2(y[8], x[7], t[7], C) + C, t[8] = madd2(y[8], x[8], t[8], C) + C, t[9] = madd2(y[8], x[9], t[9], C) + C, t[10] = madd2(y[8], x[10], t[10], C) + C, t[11] = madd2(y[8], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[9], x[0], t[0]) + C, t[1] = madd2(y[9], x[1], t[1], C) + C, t[2] = madd2(y[9], x[2], t[2], C) + C, t[3] = madd2(y[9], x[3], t[3], C) + C, t[4] = madd2(y[9], x[4], t[4], C) + C, t[5] = madd2(y[9], x[5], t[5], C) + C, t[6] = madd2(y[9], x[6], t[6], C) + C, t[7] = madd2(y[9], x[7], t[7], C) + C, t[8] = madd2(y[9], x[8], t[8], C) + C, t[9] = madd2(y[9], x[9], t[9], C) + C, t[10] = madd2(y[9], x[10], t[10], C) + C, t[11] = madd2(y[9], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[10], x[0], t[0]) + C, t[1] = madd2(y[10], x[1], t[1], C) + C, t[2] = madd2(y[10], x[2], t[2], C) + C, t[3] = madd2(y[10], x[3], t[3], C) + C, t[4] = madd2(y[10], x[4], t[4], C) + C, t[5] = madd2(y[10], x[5], t[5], C) + C, t[6] = madd2(y[10], x[6], t[6], C) + C, t[7] = madd2(y[10], x[7], t[7], C) + C, t[8] = madd2(y[10], x[8], t[8], C) + C, t[9] = madd2(y[10], x[9], t[9], C) + C, t[10] = madd2(y[10], x[10], t[10], C) + C, t[11] = madd2(y[10], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[11], x[0], t[0]) + C, t[1] = madd2(y[11], x[1], t[1], C) + C, t[2] = madd2(y[11], x[2], t[2], C) + C, t[3] = madd2(y[11], x[3], t[3], C) + C, t[4] = madd2(y[11], x[4], t[4], C) + C, t[5] = madd2(y[11], x[5], t[5], C) + C, t[6] = madd2(y[11], x[6], t[6], C) + C, t[7] = madd2(y[11], x[7], t[7], C) + C, t[8] = madd2(y[11], x[8], t[8], C) + C, t[9] = madd2(y[11], x[9], t[9], C) + C, t[10] = madd2(y[11], x[10], t[10], C) + C, t[11] = madd2(y[11], x[11], t[11], C) + + t[12], D = bits.Add64(t[12], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + C, t[5] = madd2(m, q6, t[6], C) + C, t[6] = madd2(m, q7, t[7], C) + C, t[7] = madd2(m, q8, t[8], C) + C, t[8] = madd2(m, q9, t[9], C) + C, t[9] = madd2(m, q10, t[10], C) + C, t[10] = madd2(m, q11, t[11], C) + + t[11], C = bits.Add64(t[12], C, 0) + t[12], _ = bits.Add64(0, D, C) + + if t[12] != 0 { + // we need to reduce, we have a result on 13 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], b = bits.Sub64(t[3], q3, b) + z[4], b = bits.Sub64(t[4], q4, b) + z[5], b = bits.Sub64(t[5], q5, b) + z[6], b = bits.Sub64(t[6], q6, b) + z[7], b = bits.Sub64(t[7], q7, b) + z[8], b = bits.Sub64(t[8], q8, b) + z[9], b = bits.Sub64(t[9], q9, b) + z[10], b = bits.Sub64(t[10], q10, b) + z[11], _ = bits.Sub64(t[11], q11, b) + return + } + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + z[4] = t[4] + z[5] = t[5] + z[6] = t[6] + z[7] = t[7] + z[8] = t[8] + z[9] = t[9] + z[10] = t[10] + z[11] = t[11] + + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -1094,7 +1196,6 @@ func _mulGeneric(z, x, y *Element) { z[10], b = bits.Sub64(z[10], q10, b) z[11], _ = bits.Sub64(z[11], q11, b) } - } func _fromMontGeneric(z *Element) { @@ -1306,7 +1407,7 @@ func _fromMontGeneric(z *Element) { z[11] = C } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -1326,7 +1427,7 @@ func _fromMontGeneric(z *Element) { func _reduceGeneric(z *Element) { - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -1422,6 +1523,35 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + // Exp z = xᵏ (mod q) func (z *Element) Exp(x Element, k *big.Int) *Element { if k.IsUint64() && k.Uint64() == 0 { @@ -1436,8 +1566,8 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } @@ -1471,23 +1601,37 @@ var rSquare = Element{ 48736111365249031, } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *Element) ToMont() *Element { +func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *Element) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[88:96], z[0]) + binary.BigEndian.PutUint64(b[80:88], z[1]) + binary.BigEndian.PutUint64(b[72:80], z[2]) + binary.BigEndian.PutUint64(b[64:72], z[3]) + binary.BigEndian.PutUint64(b[56:64], z[4]) + binary.BigEndian.PutUint64(b[48:56], z[5]) + binary.BigEndian.PutUint64(b[40:48], z[6]) + binary.BigEndian.PutUint64(b[32:40], z[7]) + binary.BigEndian.PutUint64(b[24:32], z[8]) + binary.BigEndian.PutUint64(b[16:24], z[9]) + binary.BigEndian.PutUint64(b[8:16], z[10]) + binary.BigEndian.PutUint64(b[0:8], z[11]) + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -1506,63 +1650,49 @@ func (z *Element) Text(base int) string { if base == 10 { var zzNeg Element zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } zz := *z - zz.FromMont() + zz.fromMont() if zz.FitsOnOneWord() { return strconv.FormatUint(zz[0], base) } - vv := bigIntPool.Get().(*big.Int) - r := zz.ToBigInt(vv).Text(base) - bigIntPool.Put(vv) + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) return r } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[88:96], z[0]) - binary.BigEndian.PutUint64(b[80:88], z[1]) - binary.BigEndian.PutUint64(b[72:80], z[2]) - binary.BigEndian.PutUint64(b[64:72], z[3]) - binary.BigEndian.PutUint64(b[56:64], z[4]) - binary.BigEndian.PutUint64(b[48:56], z[5]) - binary.BigEndian.PutUint64(b[40:48], z[6]) - binary.BigEndian.PutUint64(b[32:40], z[7]) - binary.BigEndian.PutUint64(b[24:32], z[8]) - binary.BigEndian.PutUint64(b[16:24], z[9]) - binary.BigEndian.PutUint64(b[8:16], z[10]) - binary.BigEndian.PutUint64(b[0:8], z[11]) - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[88:96], _z[0]) - binary.BigEndian.PutUint64(res[80:88], _z[1]) - binary.BigEndian.PutUint64(res[72:80], _z[2]) - binary.BigEndian.PutUint64(res[64:72], _z[3]) - binary.BigEndian.PutUint64(res[56:64], _z[4]) - binary.BigEndian.PutUint64(res[48:56], _z[5]) - binary.BigEndian.PutUint64(res[40:48], _z[6]) - binary.BigEndian.PutUint64(res[32:40], _z[7]) - binary.BigEndian.PutUint64(res[24:32], _z[8]) - binary.BigEndian.PutUint64(res[16:24], _z[9]) - binary.BigEndian.PutUint64(res[8:16], _z[10]) - binary.BigEndian.PutUint64(res[0:8], _z[11]) +// Bits provides access to z by returning its value as a little-endian [12]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [12]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -1575,19 +1705,44 @@ func (z *Element) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } +// SetBytesCanonical interprets e as the bytes of a big-endian 96-byte integer. +// If e is not a 96-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fp.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -1605,17 +1760,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -1637,20 +1791,20 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z // // The number prefix determines the actual base: A prefix of -// ''0b'' or ''0B'' selects base 2, ''0'', ''0o'' or ''0O'' selects base 8, -// and ''0x'' or ''0X'' selects base 16. Otherwise, the selected base is 10 +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For base 16, lower and upper case letters are considered the same: // The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. // -// An underscore character ''_'' may appear between a base +// An underscore character ”_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as a panic if there @@ -1659,7 +1813,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -1668,7 +1822,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -1708,7 +1862,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -1717,10 +1871,111 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 96-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[88:96]) + z[1] = binary.BigEndian.Uint64((*b)[80:88]) + z[2] = binary.BigEndian.Uint64((*b)[72:80]) + z[3] = binary.BigEndian.Uint64((*b)[64:72]) + z[4] = binary.BigEndian.Uint64((*b)[56:64]) + z[5] = binary.BigEndian.Uint64((*b)[48:56]) + z[6] = binary.BigEndian.Uint64((*b)[40:48]) + z[7] = binary.BigEndian.Uint64((*b)[32:40]) + z[8] = binary.BigEndian.Uint64((*b)[24:32]) + z[9] = binary.BigEndian.Uint64((*b)[16:24]) + z[10] = binary.BigEndian.Uint64((*b)[8:16]) + z[11] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[88:96], e[0]) + binary.BigEndian.PutUint64((*b)[80:88], e[1]) + binary.BigEndian.PutUint64((*b)[72:80], e[2]) + binary.BigEndian.PutUint64((*b)[64:72], e[3]) + binary.BigEndian.PutUint64((*b)[56:64], e[4]) + binary.BigEndian.PutUint64((*b)[48:56], e[5]) + binary.BigEndian.PutUint64((*b)[40:48], e[6]) + binary.BigEndian.PutUint64((*b)[32:40], e[7]) + binary.BigEndian.PutUint64((*b)[24:32], e[8]) + binary.BigEndian.PutUint64((*b)[16:24], e[9]) + binary.BigEndian.PutUint64((*b)[8:16], e[10]) + binary.BigEndian.PutUint64((*b)[0:8], e[11]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + z[4] = binary.LittleEndian.Uint64((*b)[32:40]) + z[5] = binary.LittleEndian.Uint64((*b)[40:48]) + z[6] = binary.LittleEndian.Uint64((*b)[48:56]) + z[7] = binary.LittleEndian.Uint64((*b)[56:64]) + z[8] = binary.LittleEndian.Uint64((*b)[64:72]) + z[9] = binary.LittleEndian.Uint64((*b)[72:80]) + z[10] = binary.LittleEndian.Uint64((*b)[80:88]) + z[11] = binary.LittleEndian.Uint64((*b)[88:96]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) + binary.LittleEndian.PutUint64((*b)[32:40], e[4]) + binary.LittleEndian.PutUint64((*b)[40:48], e[5]) + binary.LittleEndian.PutUint64((*b)[48:56], e[6]) + binary.LittleEndian.PutUint64((*b)[56:64], e[7]) + binary.LittleEndian.PutUint64((*b)[64:72], e[8]) + binary.LittleEndian.PutUint64((*b)[72:80], e[9]) + binary.LittleEndian.PutUint64((*b)[80:88], e[10]) + binary.LittleEndian.PutUint64((*b)[88:96], e[11]) +} + +func (littleEndian) String() string { return "LittleEndian" } + // Legendre returns the Legendre symbol of z (either +1, -1, or 0.) func (z *Element) Legendre() int { var l Element @@ -1936,7 +2191,7 @@ func (z *Element) Inverse(x *Element) *Element { // we would multiply by pSq up to 13times; // on x86, the assembly routine outperforms generic code for mul by word // on arm64, we may loose up to ~5% for 6 limbs - mul(&v, &v, &a) + v.Mul(&v, &a) } u.Set(x) // for correctness check @@ -1959,17 +2214,28 @@ func (z *Element) Inverse(x *Element) *Element { // correctness check v.Mul(&u, z) if !v.IsOne() && !u.IsZero() { - return z.inverseExp(&u) + return z.inverseExp(u) } return z } // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) -func (z *Element) inverseExp(x *Element) *Element { - qMinusTwo := Modulus() - qMinusTwo.Sub(qMinusTwo, big.NewInt(2)) - return z.Exp(*x, qMinusTwo) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z } // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits @@ -2252,7 +2518,7 @@ func (z *Element) montReduceSigned(x *Element, xHi uint64) { z[11], z[10] = madd2(m, q11, t[i+11], C) } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) diff --git a/ecc/bw6-761/fp/element_mul_adx_amd64.s b/ecc/bw6-761/fp/element_mul_adx_amd64.s deleted file mode 100644 index ad5011087..000000000 --- a/ecc/bw6-761/fp/element_mul_adx_amd64.s +++ /dev/null @@ -1,2738 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0xf49d00000000008b -DATA q<>+8(SB)/8, $0xe6913e6870000082 -DATA q<>+16(SB)/8, $0x160cf8aeeaf0a437 -DATA q<>+24(SB)/8, $0x98a116c25667a8f8 -DATA q<>+32(SB)/8, $0x71dcd3dc73ebff2e -DATA q<>+40(SB)/8, $0x8689c8ed12f9fd90 -DATA q<>+48(SB)/8, $0x03cebaff25b42304 -DATA q<>+56(SB)/8, $0x707ba638e584e919 -DATA q<>+64(SB)/8, $0x528275ef8087be41 -DATA q<>+72(SB)/8, $0xb926186a81d14688 -DATA q<>+80(SB)/8, $0xd187c94004faff3e -DATA q<>+88(SB)/8, $0x0122e824fb83ce0a -GLOBL q<>(SB), (RODATA+NOPTR), $96 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x0a5593568fa798dd -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, ra6, ra7, ra8, ra9, ra10, ra11, rb0, rb1, rb2, rb3, rb4, rb5, rb6, rb7, rb8, rb9, rb10, rb11) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - MOVQ ra6, rb6; \ - SBBQ q<>+48(SB), ra6; \ - MOVQ ra7, rb7; \ - SBBQ q<>+56(SB), ra7; \ - MOVQ ra8, rb8; \ - SBBQ q<>+64(SB), ra8; \ - MOVQ ra9, rb9; \ - SBBQ q<>+72(SB), ra9; \ - MOVQ ra10, rb10; \ - SBBQ q<>+80(SB), ra10; \ - MOVQ ra11, rb11; \ - SBBQ q<>+88(SB), ra11; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - CMOVQCS rb6, ra6; \ - CMOVQCS rb7, ra7; \ - CMOVQCS rb8, ra8; \ - CMOVQCS rb9, ra9; \ - CMOVQCS rb10, ra10; \ - CMOVQCS rb11, ra11; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), $96-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - NO_LOCAL_POINTERS - MOVQ x+8(FP), AX - - // x[0] -> s0-8(SP) - // x[1] -> s1-16(SP) - // x[2] -> s2-24(SP) - // x[3] -> s3-32(SP) - // x[4] -> s4-40(SP) - // x[5] -> s5-48(SP) - // x[6] -> s6-56(SP) - // x[7] -> s7-64(SP) - // x[8] -> s8-72(SP) - // x[9] -> s9-80(SP) - // x[10] -> s10-88(SP) - // x[11] -> s11-96(SP) - MOVQ 0(AX), R14 - MOVQ 8(AX), R15 - MOVQ 16(AX), CX - MOVQ 24(AX), BX - MOVQ 32(AX), SI - MOVQ 40(AX), DI - MOVQ 48(AX), R8 - MOVQ 56(AX), R9 - MOVQ 64(AX), R10 - MOVQ 72(AX), R11 - MOVQ 80(AX), R12 - MOVQ 88(AX), R13 - MOVQ R14, s0-8(SP) - MOVQ R15, s1-16(SP) - MOVQ CX, s2-24(SP) - MOVQ BX, s3-32(SP) - MOVQ SI, s4-40(SP) - MOVQ DI, s5-48(SP) - MOVQ R8, s6-56(SP) - MOVQ R9, s7-64(SP) - MOVQ R10, s8-72(SP) - MOVQ R11, s9-80(SP) - MOVQ R12, s10-88(SP) - MOVQ R13, s11-96(SP) - - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // t[5] -> DI - // t[6] -> R8 - // t[7] -> R9 - // t[8] -> R10 - // t[9] -> R11 - // t[10] -> R12 - // t[11] -> R13 - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 0(AX), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ s0-8(SP), R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ s1-16(SP), AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ s2-24(SP), AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ s3-32(SP), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ s4-40(SP), AX, DI - ADOXQ AX, SI - - // (A,t[5]) := x[5]*y[0] + A - MULXQ s5-48(SP), AX, R8 - ADOXQ AX, DI - - // (A,t[6]) := x[6]*y[0] + A - MULXQ s6-56(SP), AX, R9 - ADOXQ AX, R8 - - // (A,t[7]) := x[7]*y[0] + A - MULXQ s7-64(SP), AX, R10 - ADOXQ AX, R9 - - // (A,t[8]) := x[8]*y[0] + A - MULXQ s8-72(SP), AX, R11 - ADOXQ AX, R10 - - // (A,t[9]) := x[9]*y[0] + A - MULXQ s9-80(SP), AX, R12 - ADOXQ AX, R11 - - // (A,t[10]) := x[10]*y[0] + A - MULXQ s10-88(SP), AX, R13 - ADOXQ AX, R12 - - // (A,t[11]) := x[11]*y[0] + A - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 8(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[1] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[1] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[1] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[1] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[1] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[1] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 16(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[2] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[2] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[2] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[2] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[2] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[2] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 24(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[3] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[3] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[3] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[3] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[3] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[3] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 32(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[4] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[4] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[4] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[4] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[4] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[4] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 40(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[5] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[5] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[5] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[5] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[5] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[5] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 48(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[6] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[6] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[6] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[6] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[6] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[6] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[6] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[6] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[6] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[6] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[6] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[6] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 56(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[7] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[7] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[7] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[7] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[7] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[7] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[7] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[7] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[7] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[7] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[7] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[7] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 64(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[8] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[8] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[8] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[8] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[8] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[8] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[8] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[8] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[8] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[8] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[8] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[8] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 72(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[9] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[9] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[9] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[9] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[9] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[9] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[9] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[9] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[9] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[9] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[9] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[9] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 80(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[10] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[10] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[10] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[10] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[10] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[10] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[10] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[10] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[10] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[10] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[10] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[10] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // clear the flags - XORQ AX, AX - MOVQ y+16(FP), AX - MOVQ 88(AX), DX - - // (A,t[0]) := t[0] + x[0]*y[11] + A - MULXQ s0-8(SP), AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[11] + A - ADCXQ BP, R15 - MULXQ s1-16(SP), AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[11] + A - ADCXQ BP, CX - MULXQ s2-24(SP), AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[11] + A - ADCXQ BP, BX - MULXQ s3-32(SP), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[11] + A - ADCXQ BP, SI - MULXQ s4-40(SP), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[11] + A - ADCXQ BP, DI - MULXQ s5-48(SP), AX, BP - ADOXQ AX, DI - - // (A,t[6]) := t[6] + x[6]*y[11] + A - ADCXQ BP, R8 - MULXQ s6-56(SP), AX, BP - ADOXQ AX, R8 - - // (A,t[7]) := t[7] + x[7]*y[11] + A - ADCXQ BP, R9 - MULXQ s7-64(SP), AX, BP - ADOXQ AX, R9 - - // (A,t[8]) := t[8] + x[8]*y[11] + A - ADCXQ BP, R10 - MULXQ s8-72(SP), AX, BP - ADOXQ AX, R10 - - // (A,t[9]) := t[9] + x[9]*y[11] + A - ADCXQ BP, R11 - MULXQ s9-80(SP), AX, BP - ADOXQ AX, R11 - - // (A,t[10]) := t[10] + x[10]*y[11] + A - ADCXQ BP, R12 - MULXQ s10-88(SP), AX, BP - ADOXQ AX, R12 - - // (A,t[11]) := t[11] + x[11]*y[11] + A - ADCXQ BP, R13 - MULXQ s11-96(SP), AX, BP - ADOXQ AX, R13 - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - PUSHQ BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - POPQ BP - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - - // t[11] = C + A - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ BP, R13 - - // reduce element(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) using temp registers (s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP)) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP)) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - MOVQ R8, 48(AX) - MOVQ R9, 56(AX) - MOVQ R10, 64(AX) - MOVQ R11, 72(AX) - MOVQ R12, 80(AX) - MOVQ R13, 88(AX) - RET - -TEXT ·fromMont(SB), $96-8 - NO_LOCAL_POINTERS - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - MOVQ 40(DX), DI - MOVQ 48(DX), R8 - MOVQ 56(DX), R9 - MOVQ 64(DX), R10 - MOVQ 72(DX), R11 - MOVQ 80(DX), R12 - MOVQ 88(DX), R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // (C,t[5]) := t[6] + m*q[6] + C - ADCXQ R8, DI - MULXQ q<>+48(SB), AX, R8 - ADOXQ AX, DI - - // (C,t[6]) := t[7] + m*q[7] + C - ADCXQ R9, R8 - MULXQ q<>+56(SB), AX, R9 - ADOXQ AX, R8 - - // (C,t[7]) := t[8] + m*q[8] + C - ADCXQ R10, R9 - MULXQ q<>+64(SB), AX, R10 - ADOXQ AX, R9 - - // (C,t[8]) := t[9] + m*q[9] + C - ADCXQ R11, R10 - MULXQ q<>+72(SB), AX, R11 - ADOXQ AX, R10 - - // (C,t[9]) := t[10] + m*q[10] + C - ADCXQ R12, R11 - MULXQ q<>+80(SB), AX, R12 - ADOXQ AX, R11 - - // (C,t[10]) := t[11] + m*q[11] + C - ADCXQ R13, R12 - MULXQ q<>+88(SB), AX, R13 - ADOXQ AX, R12 - MOVQ $0, AX - ADCXQ AX, R13 - ADOXQ AX, R13 - - // reduce element(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) using temp registers (s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP)) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13,s0-8(SP),s1-16(SP),s2-24(SP),s3-32(SP),s4-40(SP),s5-48(SP),s6-56(SP),s7-64(SP),s8-72(SP),s9-80(SP),s10-88(SP),s11-96(SP)) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - MOVQ R8, 48(AX) - MOVQ R9, 56(AX) - MOVQ R10, 64(AX) - MOVQ R11, 72(AX) - MOVQ R12, 80(AX) - MOVQ R13, 88(AX) - RET diff --git a/ecc/bw6-761/fp/element_mul_amd64.s b/ecc/bw6-761/fp/element_mul_amd64.s index 74517b40b..478922653 100644 --- a/ecc/bw6-761/fp/element_mul_amd64.s +++ b/ecc/bw6-761/fp/element_mul_amd64.s @@ -1,4 +1,4 @@ -// +build !amd64_adx +// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-761/fp/element_ops_amd64.go b/ecc/bw6-761/fp/element_ops_amd64.go index a3c830471..83bba45ae 100644 --- a/ecc/bw6-761/fp/element_ops_amd64.go +++ b/ecc/bw6-761/fp/element_ops_amd64.go @@ -1,3 +1,6 @@ +//go:build !purego +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,7 +38,70 @@ func fromMont(res *Element) func reduce(res *Element) // Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) +// +// a = a + b (mod q) +// b = a - b (mod q) +// //go:noescape func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bw6-761/fp/element_ops_amd64.s b/ecc/bw6-761/fp/element_ops_amd64.s index 71af0aacb..c0f7ed239 100644 --- a/ecc/bw6-761/fp/element_ops_amd64.s +++ b/ecc/bw6-761/fp/element_ops_amd64.s @@ -1,3 +1,5 @@ +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bw6-761/fp/element_ops_noasm.go b/ecc/bw6-761/fp/element_ops_noasm.go deleted file mode 100644 index ea474db82..000000000 --- a/ecc/bw6-761/fp/element_ops_noasm.go +++ /dev/null @@ -1,69 +0,0 @@ -//go:build !amd64 -// +build !amd64 - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fp - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 4345973640412121648, - 16340807117537158706, - 14673764841507373218, - 5587754667198343811, - 12846753860245084942, - 4041391838244625385, - 8324122986343791677, - 8773809490091176420, - 5465994123296109449, - 6649773564661156048, - 9147430723089113754, - 54281803719730243, - } - x.Mul(x, &y) -} - -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} diff --git a/ecc/bw6-761/fp/element_ops_purego.go b/ecc/bw6-761/fp/element_ops_purego.go new file mode 100644 index 000000000..3c1ffa245 --- /dev/null +++ b/ecc/bw6-761/fp/element_ops_purego.go @@ -0,0 +1,2227 @@ +//go:build !amd64 || purego +// +build !amd64 purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 4345973640412121648, + 16340807117537158706, + 14673764841507373218, + 5587754667198343811, + 12846753860245084942, + 4041391838244625385, + 8324122986343791677, + 8773809490091176420, + 5465994123296109449, + 6649773564661156048, + 9147430723089113754, + 54281803719730243, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + var t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11 uint64 + var u0, u1, u2, u3, u4, u5, u6, u7, u8, u9, u10, u11 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + u4, t4 = bits.Mul64(v, y[4]) + u5, t5 = bits.Mul64(v, y[5]) + u6, t6 = bits.Mul64(v, y[6]) + u7, t7 = bits.Mul64(v, y[7]) + u8, t8 = bits.Mul64(v, y[8]) + u9, t9 = bits.Mul64(v, y[9]) + u10, t10 = bits.Mul64(v, y[10]) + u11, t11 = bits.Mul64(v, y[11]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[6] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[7] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[8] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[9] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[10] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[11] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, y[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, y[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, y[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, y[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, y[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, y[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + z[6] = t6 + z[7] = t7 + z[8] = t8 + z[9] = t9 + z[10] = t10 + z[11] = t11 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], b = bits.Sub64(z[5], q5, b) + z[6], b = bits.Sub64(z[6], q6, b) + z[7], b = bits.Sub64(z[7], q7, b) + z[8], b = bits.Sub64(z[8], q8, b) + z[9], b = bits.Sub64(z[9], q9, b) + z[10], b = bits.Sub64(z[10], q10, b) + z[11], _ = bits.Sub64(z[11], q11, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11 uint64 + var u0, u1, u2, u3, u4, u5, u6, u7, u8, u9, u10, u11 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + u4, t4 = bits.Mul64(v, x[4]) + u5, t5 = bits.Mul64(v, x[5]) + u6, t6 = bits.Mul64(v, x[6]) + u7, t7 = bits.Mul64(v, x[7]) + u8, t8 = bits.Mul64(v, x[8]) + u9, t9 = bits.Mul64(v, x[9]) + u10, t10 = bits.Mul64(v, x[10]) + u11, t11 = bits.Mul64(v, x[11]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[6] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[7] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[8] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[9] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[10] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[11] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + u6, c1 = bits.Mul64(v, x[6]) + t6, c0 = bits.Add64(c1, t6, c0) + u7, c1 = bits.Mul64(v, x[7]) + t7, c0 = bits.Add64(c1, t7, c0) + u8, c1 = bits.Mul64(v, x[8]) + t8, c0 = bits.Add64(c1, t8, c0) + u9, c1 = bits.Mul64(v, x[9]) + t9, c0 = bits.Add64(c1, t9, c0) + u10, c1 = bits.Mul64(v, x[10]) + t10, c0 = bits.Add64(c1, t10, c0) + u11, c1 = bits.Mul64(v, x[11]) + t11, c0 = bits.Add64(c1, t11, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + t6, c0 = bits.Add64(u5, t6, c0) + t7, c0 = bits.Add64(u6, t7, c0) + t8, c0 = bits.Add64(u7, t8, c0) + t9, c0 = bits.Add64(u8, t9, c0) + t10, c0 = bits.Add64(u9, t10, c0) + t11, c0 = bits.Add64(u10, t11, c0) + c2, _ = bits.Add64(u11, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + t4, c0 = bits.Add64(t5, c1, c0) + u6, c1 = bits.Mul64(m, q6) + t5, c0 = bits.Add64(t6, c1, c0) + u7, c1 = bits.Mul64(m, q7) + t6, c0 = bits.Add64(t7, c1, c0) + u8, c1 = bits.Mul64(m, q8) + t7, c0 = bits.Add64(t8, c1, c0) + u9, c1 = bits.Mul64(m, q9) + t8, c0 = bits.Add64(t9, c1, c0) + u10, c1 = bits.Mul64(m, q10) + t9, c0 = bits.Add64(t10, c1, c0) + u11, c1 = bits.Mul64(m, q11) + + t10, c0 = bits.Add64(0, c1, c0) + u11, _ = bits.Add64(u11, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + t5, c0 = bits.Add64(u5, t5, c0) + t6, c0 = bits.Add64(u6, t6, c0) + t7, c0 = bits.Add64(u7, t7, c0) + t8, c0 = bits.Add64(u8, t8, c0) + t9, c0 = bits.Add64(u9, t9, c0) + t10, c0 = bits.Add64(u10, t10, c0) + c2, _ = bits.Add64(c2, 0, c0) + t10, c0 = bits.Add64(t11, t10, 0) + t11, _ = bits.Add64(u11, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + z[6] = t6 + z[7] = t7 + z[8] = t8 + z[9] = t9 + z[10] = t10 + z[11] = t11 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], b = bits.Sub64(z[5], q5, b) + z[6], b = bits.Sub64(z[6], q6, b) + z[7], b = bits.Sub64(z[7], q7, b) + z[8], b = bits.Sub64(z[8], q8, b) + z[9], b = bits.Sub64(z[9], q9, b) + z[10], b = bits.Sub64(z[10], q10, b) + z[11], _ = bits.Sub64(z[11], q11, b) + } + return z +} diff --git a/ecc/bw6-761/fp/element_test.go b/ecc/bw6-761/fp/element_test.go index f91e419a0..fc8e03d74 100644 --- a/ecc/bw6-761/fp/element_test.go +++ b/ecc/bw6-761/fp/element_test.go @@ -23,7 +23,7 @@ import ( "math/big" "math/bits" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field" mrand "math/rand" "testing" @@ -182,17 +182,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -654,7 +647,7 @@ func TestElementBitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPairElement) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -769,7 +762,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -782,13 +775,13 @@ func TestElementAdd(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -817,17 +810,17 @@ func TestElementAdd(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -878,7 +871,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -891,13 +884,13 @@ func TestElementSub(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -926,17 +919,17 @@ func TestElementSub(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -987,7 +980,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1000,7 +993,7 @@ func TestElementMul(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -1014,7 +1007,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1054,11 +1047,11 @@ func TestElementMul(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -1071,7 +1064,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1123,7 +1116,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1136,14 +1129,14 @@ func TestElementDiv(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1172,18 +1165,18 @@ func TestElementDiv(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1234,7 +1227,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1247,13 +1240,13 @@ func TestElementExp(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1282,17 +1275,17 @@ func TestElementExp(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1337,7 +1330,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1358,14 +1351,14 @@ func TestElementSquare(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1409,7 +1402,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1430,14 +1423,14 @@ func TestElementInverse(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1481,7 +1474,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1502,14 +1495,14 @@ func TestElementSqrt(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1553,7 +1546,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1574,14 +1567,14 @@ func TestElementDouble(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1625,7 +1618,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1646,14 +1639,14 @@ func TestElementNeg(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -2041,7 +2034,7 @@ func TestElementNegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -2174,17 +2167,17 @@ func TestElementFromMont(t *testing.T) { func(a testPairElement) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPairElement) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -2279,7 +2272,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } @@ -2365,7 +2358,7 @@ func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { return field.BigIntMatchUint64Slice(&aIntMod, slice) } -//TODO: Phase out in favor of property based testing +// TODO: Phase out in favor of property based testing func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { if err := z.matchVeryBigInt(aHi, aInt); err != nil { @@ -2415,7 +2408,7 @@ func TestElementInversionCorrectionFactorFormula(t *testing.T) { inversionCorrectionFactorWord10, inversionCorrectionFactorWord11, } - inversionCorrectionFactor.ToBigInt(&refFactorInt) + inversionCorrectionFactor.toBigInt(&refFactorInt) if refFactorInt.Cmp(factorInt) != 0 { t.Error("mismatch") @@ -2457,7 +2450,7 @@ func TestElementInversionCorrectionFactor(t *testing.T) { if !oneInv.Equal(&one) { var i big.Int - oneInv.ToBigIntRegular(&i) // no montgomery + oneInv.BigInt(&i) // no montgomery i.ModInverse(&i, Modulus()) var fac Element fac.setBigInt(&i) // back to montgomery @@ -2618,7 +2611,7 @@ func TestElement0Inverse(t *testing.T) { } } -//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen func TestUpdateFactorSubtraction(t *testing.T) { for i := 0; i < 1000; i++ { @@ -2831,11 +2824,11 @@ func randomizeUpdateFactors() (int64, int64) { func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { var p1 big.Int - x.ToBigInt(&p1) + x.toBigInt(&p1) p1.Mul(&p1, big.NewInt(xC)) var p2 big.Int - y.ToBigInt(&p2) + y.toBigInt(&p2) p2.Mul(&p2, big.NewInt(yC)) p1.Add(&p1, &p2) @@ -2872,7 +2865,7 @@ func montReduce(res *big.Int, x *big.Int) { } func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { - z.ToBigInt(i) + z.toBigInt(i) var upperWord big.Int upperWord.SetUint64(xHi) upperWord.Lsh(&upperWord, Limbs*64) @@ -2890,7 +2883,7 @@ func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { var xInt big.Int - x.ToBigInt(&xInt) + x.toBigInt(&xInt) xInt.Mul(&xInt, big.NewInt(c)) @@ -2901,7 +2894,7 @@ func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, result func approximateRef(x *Element) uint64 { var asInt big.Int - x.ToBigInt(&asInt) + x.toBigInt(&asInt) n := x.BitLen() if n <= 64 { diff --git a/ecc/bw6-761/fr/doc.go b/ecc/bw6-761/fr/doc.go index a3bf83e31..e3d46536b 100644 --- a/ecc/bw6-761/fr/doc.go +++ b/ecc/bw6-761/fr/doc.go @@ -21,30 +21,33 @@ // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [6]uint64 // -// Usage +// type Element [6]uint64 +// +// # Usage // // Example API signature: -// // Mul z = x * y (mod q) -// func (z *Element) Mul(x, y *Element) *Element +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) // // Modulus q = // -// q[base10] = 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 -// q[base16] = 0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001 +// q[base10] = 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 +// q[base16] = 0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package fr diff --git a/ecc/bw6-761/fr/element.go b/ecc/bw6-761/fr/element.go index dd4c184df..bf8aa530e 100644 --- a/ecc/bw6-761/fr/element.go +++ b/ecc/bw6-761/fr/element.go @@ -20,13 +20,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "github.com/consensys/gnark-crypto/field" "io" "math/big" "math/bits" "reflect" "strconv" "strings" - "sync" ) // Element represents a field element stored on 6 words (uint64) @@ -35,18 +35,18 @@ import ( // // Modulus q = // -// q[base10] = 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 -// q[base16] = 0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001 +// q[base10] = 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 +// q[base16] = 0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [6]uint64 const ( - Limbs = 6 // number of 64 bits words needed to represent a Element - Bits = 377 // number of bits needed to represent a Element - Bytes = Limbs * 8 // number of bytes needed to represent a Element + Limbs = 6 // number of 64 bits words needed to represent a Element + Bits = 377 // number of bits needed to represent a Element + Bytes = 48 // number of bytes needed to represent a Element ) // Field modulus q @@ -72,8 +72,8 @@ var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int // -// q[base10] = 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 -// q[base16] = 0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001 +// q[base10] = 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 +// q[base16] = 0x1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001 func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } @@ -82,12 +82,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 9586122913090633727 -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("1ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001", 16) } @@ -95,8 +89,9 @@ func init() { // NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v Element -// v.SetUint64(...) +// +// var v Element +// v.SetUint64(...) func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) @@ -107,7 +102,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -139,14 +134,15 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported // supported types: -// Element -// *Element -// uint64 -// int -// string (see SetString for valid formats) -// *big.Int -// big.Int -// []byte +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { if i1 == nil { return nil, errors.New("can't set fr.Element with ") @@ -260,15 +256,13 @@ func (z *Element) IsOne() bool { // IsUint64 reports whether z can be represented as an uint64. func (z *Element) IsUint64() bool { zz := *z - zz.FromMont() + zz.fromMont() return zz.FitsOnOneWord() } // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -280,15 +274,12 @@ func (z *Element) FitsOnOneWord() bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[5] > _x[5] { return 1 } else if _z[5] < _x[5] { @@ -329,8 +320,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 4793061456545316865, 0) @@ -429,67 +419,9 @@ func (z *Element) Halve() { } -// Mul z = x * y (mod q) -// -// x and y must be strictly inferior to q -func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). - mul(z, x, y) - return z -} - -// Square z = x * x (mod q) -// -// x must be strictly inferior to q -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - mul(z, x, x) - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } @@ -505,7 +437,7 @@ func (z *Element) Add(x, y *Element) *Element { z[4], carry = bits.Add64(x[4], y[4], carry) z[5], _ = bits.Add64(x[5], y[5], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -529,7 +461,7 @@ func (z *Element) Double(x *Element) *Element { z[4], carry = bits.Add64(x[4], x[4], carry) z[5], _ = bits.Add64(x[5], x[5], carry) - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -592,115 +524,219 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // see Mul for algorithm documentation - var t [6]uint64 - var c [3]uint64 - { - // round 0 - v := x[0] - c[1], c[0] = bits.Mul64(v, y[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd1(v, y[1], c[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd1(v, y[2], c[1]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd1(v, y[3], c[1]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd1(v, y[4], c[1]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd1(v, y[5], c[1]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 1 - v := x[1] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 2 - v := x[2] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 3 - v := x[3] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 4 - v := x[4] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], t[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], t[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], t[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], t[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - t[5], t[4] = madd3(m, q5, c[0], c[2], c[1]) - } - { - // round 5 - v := x[5] - c[1], c[0] = madd1(v, y[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - c[1], c[0] = madd2(v, y[1], c[1], t[1]) - c[2], z[0] = madd2(m, q1, c[2], c[0]) - c[1], c[0] = madd2(v, y[2], c[1], t[2]) - c[2], z[1] = madd2(m, q2, c[2], c[0]) - c[1], c[0] = madd2(v, y[3], c[1], t[3]) - c[2], z[2] = madd2(m, q3, c[2], c[0]) - c[1], c[0] = madd2(v, y[4], c[1], t[4]) - c[2], z[3] = madd2(m, q4, c[2], c[0]) - c[1], c[0] = madd2(v, y[5], c[1], t[5]) - z[5], z[4] = madd3(m, q5, c[0], c[2], c[1]) - } - - // if z >= q → z -= q + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [7]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + C, t[4] = madd1(y[0], x[4], C) + C, t[5] = madd1(y[0], x[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + C, t[4] = madd2(y[1], x[4], t[4], C) + C, t[5] = madd2(y[1], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + C, t[4] = madd2(y[2], x[4], t[4], C) + C, t[5] = madd2(y[2], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + C, t[4] = madd2(y[3], x[4], t[4], C) + C, t[5] = madd2(y[3], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[4], x[0], t[0]) + C, t[1] = madd2(y[4], x[1], t[1], C) + C, t[2] = madd2(y[4], x[2], t[2], C) + C, t[3] = madd2(y[4], x[3], t[3], C) + C, t[4] = madd2(y[4], x[4], t[4], C) + C, t[5] = madd2(y[4], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[5], x[0], t[0]) + C, t[1] = madd2(y[5], x[1], t[1], C) + C, t[2] = madd2(y[5], x[2], t[2], C) + C, t[3] = madd2(y[5], x[3], t[3], C) + C, t[4] = madd2(y[5], x[4], t[4], C) + C, t[5] = madd2(y[5], x[5], t[5], C) + + t[6], D = bits.Add64(t[6], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + C, t[3] = madd2(m, q4, t[4], C) + C, t[4] = madd2(m, q5, t[5], C) + + t[5], C = bits.Add64(t[6], C, 0) + t[6], _ = bits.Add64(0, D, C) + + if t[6] != 0 { + // we need to reduce, we have a result on 7 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], b = bits.Sub64(t[3], q3, b) + z[4], b = bits.Sub64(t[4], q4, b) + z[5], _ = bits.Sub64(t[5], q5, b) + return + } + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + z[4] = t[4] + z[5] = t[5] + + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -710,7 +746,6 @@ func _mulGeneric(z, x, y *Element) { z[4], b = bits.Sub64(z[4], q4, b) z[5], _ = bits.Sub64(z[5], q5, b) } - } func _fromMontGeneric(z *Element) { @@ -784,7 +819,7 @@ func _fromMontGeneric(z *Element) { z[5] = C } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -798,7 +833,7 @@ func _fromMontGeneric(z *Element) { func _reduceGeneric(z *Element) { - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) @@ -870,6 +905,35 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + // Exp z = xᵏ (mod q) func (z *Element) Exp(x Element, k *big.Int) *Element { if k.IsUint64() && k.Uint64() == 0 { @@ -884,8 +948,8 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } @@ -913,23 +977,31 @@ var rSquare = Element{ 30958721782860680, } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *Element) ToMont() *Element { +func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *Element) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[40:48], z[0]) + binary.BigEndian.PutUint64(b[32:40], z[1]) + binary.BigEndian.PutUint64(b[24:32], z[2]) + binary.BigEndian.PutUint64(b[16:24], z[3]) + binary.BigEndian.PutUint64(b[8:16], z[4]) + binary.BigEndian.PutUint64(b[0:8], z[5]) + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -948,51 +1020,49 @@ func (z *Element) Text(base int) string { if base == 10 { var zzNeg Element zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } zz := *z - zz.FromMont() + zz.fromMont() if zz.FitsOnOneWord() { return strconv.FormatUint(zz[0], base) } - vv := bigIntPool.Get().(*big.Int) - r := zz.ToBigInt(vv).Text(base) - bigIntPool.Put(vv) + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) return r } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[40:48], z[0]) - binary.BigEndian.PutUint64(b[32:40], z[1]) - binary.BigEndian.PutUint64(b[24:32], z[2]) - binary.BigEndian.PutUint64(b[16:24], z[3]) - binary.BigEndian.PutUint64(b[8:16], z[4]) - binary.BigEndian.PutUint64(b[0:8], z[5]) - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[40:48], _z[0]) - binary.BigEndian.PutUint64(res[32:40], _z[1]) - binary.BigEndian.PutUint64(res[24:32], _z[2]) - binary.BigEndian.PutUint64(res[16:24], _z[3]) - binary.BigEndian.PutUint64(res[8:16], _z[4]) - binary.BigEndian.PutUint64(res[0:8], _z[5]) +// Bits provides access to z by returning its value as a little-endian [6]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [6]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -1005,19 +1075,44 @@ func (z *Element) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } +// SetBytesCanonical interprets e as the bytes of a big-endian 48-byte integer. +// If e is not a 48-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fr.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -1035,17 +1130,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -1067,20 +1161,20 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z // // The number prefix determines the actual base: A prefix of -// ''0b'' or ''0B'' selects base 2, ''0'', ''0o'' or ''0O'' selects base 8, -// and ''0x'' or ''0X'' selects base 16. Otherwise, the selected base is 10 +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For base 16, lower and upper case letters are considered the same: // The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. // -// An underscore character ''_'' may appear between a base +// An underscore character ”_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as a panic if there @@ -1089,7 +1183,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -1098,7 +1192,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -1138,7 +1232,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -1147,10 +1241,87 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 48-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[40:48]) + z[1] = binary.BigEndian.Uint64((*b)[32:40]) + z[2] = binary.BigEndian.Uint64((*b)[24:32]) + z[3] = binary.BigEndian.Uint64((*b)[16:24]) + z[4] = binary.BigEndian.Uint64((*b)[8:16]) + z[5] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[40:48], e[0]) + binary.BigEndian.PutUint64((*b)[32:40], e[1]) + binary.BigEndian.PutUint64((*b)[24:32], e[2]) + binary.BigEndian.PutUint64((*b)[16:24], e[3]) + binary.BigEndian.PutUint64((*b)[8:16], e[4]) + binary.BigEndian.PutUint64((*b)[0:8], e[5]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + z[4] = binary.LittleEndian.Uint64((*b)[32:40]) + z[5] = binary.LittleEndian.Uint64((*b)[40:48]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) + binary.LittleEndian.PutUint64((*b)[32:40], e[4]) + binary.LittleEndian.PutUint64((*b)[40:48], e[5]) +} + +func (littleEndian) String() string { return "LittleEndian" } + // Legendre returns the Legendre symbol of z (either +1, -1, or 0.) func (z *Element) Legendre() int { var l Element @@ -1183,7 +1354,7 @@ func (z *Element) Sqrt(x *Element) *Element { // y = x^((s+1)/2)) = w * x y.Mul(x, &w) - // b = x^s = w * w * x = y * x + // b = xˢ = w * w * x = y * x b.Mul(&w, &y) // g = nonResidue ^ s @@ -1198,7 +1369,7 @@ func (z *Element) Sqrt(x *Element) *Element { r := uint64(46) // compute legendre symbol - // t = x^((q-1)/2) = r-1 squaring of x^s + // t = x^((q-1)/2) = r-1 squaring of xˢ t = b for i := uint64(0); i < r-1; i++ { t.Square(&t) @@ -1396,7 +1567,7 @@ func (z *Element) Inverse(x *Element) *Element { // we would multiply by pSq up to 13times; // on x86, the assembly routine outperforms generic code for mul by word // on arm64, we may loose up to ~5% for 6 limbs - mul(&v, &v, &a) + v.Mul(&v, &a) } u.Set(x) // for correctness check @@ -1413,17 +1584,28 @@ func (z *Element) Inverse(x *Element) *Element { // correctness check v.Mul(&u, z) if !v.IsOne() && !u.IsZero() { - return z.inverseExp(&u) + return z.inverseExp(u) } return z } // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) -func (z *Element) inverseExp(x *Element) *Element { - qMinusTwo := Modulus() - qMinusTwo.Sub(qMinusTwo, big.NewInt(2)) - return z.Exp(*x, qMinusTwo) +func (z *Element) inverseExp(x Element) *Element { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z } // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits @@ -1556,7 +1738,7 @@ func (z *Element) montReduceSigned(x *Element, xHi uint64) { z[5], z[4] = madd2(m, q5, t[i+5], C) } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { var b uint64 z[0], b = bits.Sub64(z[0], q0, 0) diff --git a/ecc/bw6-761/fr/element_mul_adx_amd64.s b/ecc/bw6-761/fr/element_mul_adx_amd64.s deleted file mode 100644 index e2afd074d..000000000 --- a/ecc/bw6-761/fr/element_mul_adx_amd64.s +++ /dev/null @@ -1,835 +0,0 @@ -// +build amd64_adx - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "textflag.h" -#include "funcdata.h" - -// modulus q -DATA q<>+0(SB)/8, $0x8508c00000000001 -DATA q<>+8(SB)/8, $0x170b5d4430000000 -DATA q<>+16(SB)/8, $0x1ef3622fba094800 -DATA q<>+24(SB)/8, $0x1a22d9f300f5138f -DATA q<>+32(SB)/8, $0xc63b05c06ca1493b -DATA q<>+40(SB)/8, $0x01ae3a4617c510ea -GLOBL q<>(SB), (RODATA+NOPTR), $48 - -// qInv0 q'[0] -DATA qInv0<>(SB)/8, $0x8508bfffffffffff -GLOBL qInv0<>(SB), (RODATA+NOPTR), $8 - -#define REDUCE(ra0, ra1, ra2, ra3, ra4, ra5, rb0, rb1, rb2, rb3, rb4, rb5) \ - MOVQ ra0, rb0; \ - SUBQ q<>(SB), ra0; \ - MOVQ ra1, rb1; \ - SBBQ q<>+8(SB), ra1; \ - MOVQ ra2, rb2; \ - SBBQ q<>+16(SB), ra2; \ - MOVQ ra3, rb3; \ - SBBQ q<>+24(SB), ra3; \ - MOVQ ra4, rb4; \ - SBBQ q<>+32(SB), ra4; \ - MOVQ ra5, rb5; \ - SBBQ q<>+40(SB), ra5; \ - CMOVQCS rb0, ra0; \ - CMOVQCS rb1, ra1; \ - CMOVQCS rb2, ra2; \ - CMOVQCS rb3, ra3; \ - CMOVQCS rb4, ra4; \ - CMOVQCS rb5, ra5; \ - -// mul(res, x, y *Element) -TEXT ·mul(SB), NOSPLIT, $0-24 - - // the algorithm is described in the Element.Mul declaration (.go) - // however, to benefit from the ADCX and ADOX carry chains - // we split the inner loops in 2: - // for i=0 to N-1 - // for j=0 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C + A - - MOVQ x+8(FP), R8 - - // x[0] -> R10 - // x[1] -> R11 - // x[2] -> R12 - MOVQ 0(R8), R10 - MOVQ 8(R8), R11 - MOVQ 16(R8), R12 - MOVQ y+16(FP), R13 - - // A -> BP - // t[0] -> R14 - // t[1] -> R15 - // t[2] -> CX - // t[3] -> BX - // t[4] -> SI - // t[5] -> DI - // clear the flags - XORQ AX, AX - MOVQ 0(R13), DX - - // (A,t[0]) := x[0]*y[0] + A - MULXQ R10, R14, R15 - - // (A,t[1]) := x[1]*y[0] + A - MULXQ R11, AX, CX - ADOXQ AX, R15 - - // (A,t[2]) := x[2]*y[0] + A - MULXQ R12, AX, BX - ADOXQ AX, CX - - // (A,t[3]) := x[3]*y[0] + A - MULXQ 24(R8), AX, SI - ADOXQ AX, BX - - // (A,t[4]) := x[4]*y[0] + A - MULXQ 32(R8), AX, DI - ADOXQ AX, SI - - // (A,t[5]) := x[5]*y[0] + A - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 8(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[1] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[1] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[1] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[1] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[1] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[1] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 16(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[2] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[2] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[2] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[2] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[2] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[2] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 24(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[3] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[3] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[3] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[3] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[3] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[3] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 32(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[4] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[4] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[4] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[4] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[4] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[4] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // clear the flags - XORQ AX, AX - MOVQ 40(R13), DX - - // (A,t[0]) := t[0] + x[0]*y[5] + A - MULXQ R10, AX, BP - ADOXQ AX, R14 - - // (A,t[1]) := t[1] + x[1]*y[5] + A - ADCXQ BP, R15 - MULXQ R11, AX, BP - ADOXQ AX, R15 - - // (A,t[2]) := t[2] + x[2]*y[5] + A - ADCXQ BP, CX - MULXQ R12, AX, BP - ADOXQ AX, CX - - // (A,t[3]) := t[3] + x[3]*y[5] + A - ADCXQ BP, BX - MULXQ 24(R8), AX, BP - ADOXQ AX, BX - - // (A,t[4]) := t[4] + x[4]*y[5] + A - ADCXQ BP, SI - MULXQ 32(R8), AX, BP - ADOXQ AX, SI - - // (A,t[5]) := t[5] + x[5]*y[5] + A - ADCXQ BP, DI - MULXQ 40(R8), AX, BP - ADOXQ AX, DI - - // A += carries from ADCXQ and ADOXQ - MOVQ $0, AX - ADCXQ AX, BP - ADOXQ AX, BP - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - - // clear the flags - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, R9 - ADCXQ R14, AX - MOVQ R9, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - - // t[5] = C + A - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ BP, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R9,R8,R13,R10,R11,R12) - REDUCE(R14,R15,CX,BX,SI,DI,R9,R8,R13,R10,R11,R12) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET - -TEXT ·fromMont(SB), NOSPLIT, $0-8 - - // the algorithm is described here - // https://hackmd.io/@gnark/modular_multiplication - // when y = 1 we have: - // for i=0 to N-1 - // t[i] = x[i] - // for i=0 to N-1 - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // t[N-1] = C - MOVQ res+0(FP), DX - MOVQ 0(DX), R14 - MOVQ 8(DX), R15 - MOVQ 16(DX), CX - MOVQ 24(DX), BX - MOVQ 32(DX), SI - MOVQ 40(DX), DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - XORQ DX, DX - - // m := t[0]*q'[0] mod W - MOVQ qInv0<>(SB), DX - IMULQ R14, DX - XORQ AX, AX - - // C,_ := t[0] + m*q[0] - MULXQ q<>+0(SB), AX, BP - ADCXQ R14, AX - MOVQ BP, R14 - - // (C,t[0]) := t[1] + m*q[1] + C - ADCXQ R15, R14 - MULXQ q<>+8(SB), AX, R15 - ADOXQ AX, R14 - - // (C,t[1]) := t[2] + m*q[2] + C - ADCXQ CX, R15 - MULXQ q<>+16(SB), AX, CX - ADOXQ AX, R15 - - // (C,t[2]) := t[3] + m*q[3] + C - ADCXQ BX, CX - MULXQ q<>+24(SB), AX, BX - ADOXQ AX, CX - - // (C,t[3]) := t[4] + m*q[4] + C - ADCXQ SI, BX - MULXQ q<>+32(SB), AX, SI - ADOXQ AX, BX - - // (C,t[4]) := t[5] + m*q[5] + C - ADCXQ DI, SI - MULXQ q<>+40(SB), AX, DI - ADOXQ AX, SI - MOVQ $0, AX - ADCXQ AX, DI - ADOXQ AX, DI - - // reduce element(R14,R15,CX,BX,SI,DI) using temp registers (R8,R9,R10,R11,R12,R13) - REDUCE(R14,R15,CX,BX,SI,DI,R8,R9,R10,R11,R12,R13) - - MOVQ res+0(FP), AX - MOVQ R14, 0(AX) - MOVQ R15, 8(AX) - MOVQ CX, 16(AX) - MOVQ BX, 24(AX) - MOVQ SI, 32(AX) - MOVQ DI, 40(AX) - RET diff --git a/ecc/bw6-761/fr/element_mul_amd64.s b/ecc/bw6-761/fr/element_mul_amd64.s index b32bb9e20..3e7650e5a 100644 --- a/ecc/bw6-761/fr/element_mul_amd64.s +++ b/ecc/bw6-761/fr/element_mul_amd64.s @@ -1,4 +1,4 @@ -// +build !amd64_adx +// +build !purego // Copyright 2020 ConsenSys Software Inc. // diff --git a/ecc/bw6-761/fr/element_ops_amd64.go b/ecc/bw6-761/fr/element_ops_amd64.go index b37a251b6..e40a9caed 100644 --- a/ecc/bw6-761/fr/element_ops_amd64.go +++ b/ecc/bw6-761/fr/element_ops_amd64.go @@ -1,3 +1,6 @@ +//go:build !purego +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -35,7 +38,70 @@ func fromMont(res *Element) func reduce(res *Element) // Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) +// +// a = a + b (mod q) +// b = a - b (mod q) +// //go:noescape func Butterfly(a, b *Element) + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for doc. + mul(z, x, x) + return z +} diff --git a/ecc/bw6-761/fr/element_ops_amd64.s b/ecc/bw6-761/fr/element_ops_amd64.s index 5c31cbc7a..7242622a4 100644 --- a/ecc/bw6-761/fr/element_ops_amd64.s +++ b/ecc/bw6-761/fr/element_ops_amd64.s @@ -1,3 +1,5 @@ +// +build !purego + // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/ecc/bw6-761/fr/element_ops_noasm.go b/ecc/bw6-761/fr/element_ops_noasm.go deleted file mode 100644 index 44897ff4e..000000000 --- a/ecc/bw6-761/fr/element_ops_noasm.go +++ /dev/null @@ -1,63 +0,0 @@ -//go:build !amd64 -// +build !amd64 - -// Copyright 2020 ConsenSys Software Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by consensys/gnark-crypto DO NOT EDIT - -package fr - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 1176283927673829444, - 14130787773971430395, - 11354866436980285261, - 15740727779991009548, - 14951814113394531041, - 33013799364667434, - } - x.Mul(x, &y) -} - -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} diff --git a/ecc/bw6-761/fr/element_ops_purego.go b/ecc/bw6-761/fr/element_ops_purego.go new file mode 100644 index 000000000..bd2d33293 --- /dev/null +++ b/ecc/bw6-761/fr/element_ops_purego.go @@ -0,0 +1,745 @@ +//go:build !amd64 || purego +// +build !amd64 purego + +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 1176283927673829444, + 14130787773971430395, + 11354866436980285261, + 15740727779991009548, + 14951814113394531041, + 33013799364667434, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + // + // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: + // (also described in https://eprint.iacr.org/2022/1400.pdf annex) + // + // for i=0 to N-1 + // (A,t[0]) := t[0] + x[0]*y[i] + // m := t[0]*q'[0] mod W + // C,_ := t[0] + m*q[0] + // for j=1 to N-1 + // (A,t[j]) := t[j] + x[j]*y[i] + A + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // t[N-1] = C + A + // + // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit + // of the modulus is zero (and not all of the remaining bits are set). + + var t0, t1, t2, t3, t4, t5 uint64 + var u0, u1, u2, u3, u4, u5 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, y[0]) + u1, t1 = bits.Mul64(v, y[1]) + u2, t2 = bits.Mul64(v, y[2]) + u3, t3 = bits.Mul64(v, y[3]) + u4, t4 = bits.Mul64(v, y[4]) + u5, t5 = bits.Mul64(v, y[5]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, y[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, y[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, y[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, y[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, y[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, y[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t0, t1, t2, t3, t4, t5 uint64 + var u0, u1, u2, u3, u4, u5 uint64 + { + var c0, c1, c2 uint64 + v := x[0] + u0, t0 = bits.Mul64(v, x[0]) + u1, t1 = bits.Mul64(v, x[1]) + u2, t2 = bits.Mul64(v, x[2]) + u3, t3 = bits.Mul64(v, x[3]) + u4, t4 = bits.Mul64(v, x[4]) + u5, t5 = bits.Mul64(v, x[5]) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, 0, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[1] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[2] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[3] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[4] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + { + var c0, c1, c2 uint64 + v := x[5] + u0, c1 = bits.Mul64(v, x[0]) + t0, c0 = bits.Add64(c1, t0, 0) + u1, c1 = bits.Mul64(v, x[1]) + t1, c0 = bits.Add64(c1, t1, c0) + u2, c1 = bits.Mul64(v, x[2]) + t2, c0 = bits.Add64(c1, t2, c0) + u3, c1 = bits.Mul64(v, x[3]) + t3, c0 = bits.Add64(c1, t3, c0) + u4, c1 = bits.Mul64(v, x[4]) + t4, c0 = bits.Add64(c1, t4, c0) + u5, c1 = bits.Mul64(v, x[5]) + t5, c0 = bits.Add64(c1, t5, c0) + + c2, _ = bits.Add64(0, 0, c0) + t1, c0 = bits.Add64(u0, t1, 0) + t2, c0 = bits.Add64(u1, t2, c0) + t3, c0 = bits.Add64(u2, t3, c0) + t4, c0 = bits.Add64(u3, t4, c0) + t5, c0 = bits.Add64(u4, t5, c0) + c2, _ = bits.Add64(u5, c2, c0) + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + _, c0 = bits.Add64(t0, c1, 0) + u1, c1 = bits.Mul64(m, q1) + t0, c0 = bits.Add64(t1, c1, c0) + u2, c1 = bits.Mul64(m, q2) + t1, c0 = bits.Add64(t2, c1, c0) + u3, c1 = bits.Mul64(m, q3) + t2, c0 = bits.Add64(t3, c1, c0) + u4, c1 = bits.Mul64(m, q4) + t3, c0 = bits.Add64(t4, c1, c0) + u5, c1 = bits.Mul64(m, q5) + + t4, c0 = bits.Add64(0, c1, c0) + u5, _ = bits.Add64(u5, 0, c0) + t0, c0 = bits.Add64(u0, t0, 0) + t1, c0 = bits.Add64(u1, t1, c0) + t2, c0 = bits.Add64(u2, t2, c0) + t3, c0 = bits.Add64(u3, t3, c0) + t4, c0 = bits.Add64(u4, t4, c0) + c2, _ = bits.Add64(c2, 0, c0) + t4, c0 = bits.Add64(t5, t4, 0) + t5, _ = bits.Add64(u5, c2, c0) + + } + z[0] = t0 + z[1] = t1 + z[2] = t2 + z[3] = t3 + z[4] = t4 + z[5] = t5 + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], b = bits.Sub64(z[3], q3, b) + z[4], b = bits.Sub64(z[4], q4, b) + z[5], _ = bits.Sub64(z[5], q5, b) + } + return z +} diff --git a/ecc/bw6-761/fr/element_test.go b/ecc/bw6-761/fr/element_test.go index 5f0d749fe..21c95400f 100644 --- a/ecc/bw6-761/fr/element_test.go +++ b/ecc/bw6-761/fr/element_test.go @@ -23,7 +23,7 @@ import ( "math/big" "math/bits" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field" mrand "math/rand" "testing" @@ -182,17 +182,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -642,7 +635,7 @@ func TestElementBitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPairElement) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -757,7 +750,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -770,13 +763,13 @@ func TestElementAdd(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -805,17 +798,17 @@ func TestElementAdd(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -866,7 +859,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -879,13 +872,13 @@ func TestElementSub(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -914,17 +907,17 @@ func TestElementSub(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -975,7 +968,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -988,7 +981,7 @@ func TestElementMul(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -1002,7 +995,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1042,11 +1035,11 @@ func TestElementMul(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -1059,7 +1052,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1111,7 +1104,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1124,14 +1117,14 @@ func TestElementDiv(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1160,18 +1153,18 @@ func TestElementDiv(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1222,7 +1215,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1235,13 +1228,13 @@ func TestElementExp(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1270,17 +1263,17 @@ func TestElementExp(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1325,7 +1318,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1346,14 +1339,14 @@ func TestElementSquare(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1397,7 +1390,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1418,14 +1411,14 @@ func TestElementInverse(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1469,7 +1462,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1490,14 +1483,14 @@ func TestElementSqrt(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1541,7 +1534,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1562,14 +1555,14 @@ func TestElementDouble(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1613,7 +1606,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1634,14 +1627,14 @@ func TestElementNeg(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -2029,7 +2022,7 @@ func TestElementNegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -2162,17 +2155,17 @@ func TestElementFromMont(t *testing.T) { func(a testPairElement) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPairElement) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -2255,7 +2248,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } @@ -2323,7 +2316,7 @@ func (z *Element) matchVeryBigInt(aHi uint64, aInt *big.Int) error { return field.BigIntMatchUint64Slice(&aIntMod, slice) } -//TODO: Phase out in favor of property based testing +// TODO: Phase out in favor of property based testing func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { if err := z.matchVeryBigInt(aHi, aInt); err != nil { @@ -2367,7 +2360,7 @@ func TestElementInversionCorrectionFactorFormula(t *testing.T) { inversionCorrectionFactorWord4, inversionCorrectionFactorWord5, } - inversionCorrectionFactor.ToBigInt(&refFactorInt) + inversionCorrectionFactor.toBigInt(&refFactorInt) if refFactorInt.Cmp(factorInt) != 0 { t.Error("mismatch") @@ -2409,7 +2402,7 @@ func TestElementInversionCorrectionFactor(t *testing.T) { if !oneInv.Equal(&one) { var i big.Int - oneInv.ToBigIntRegular(&i) // no montgomery + oneInv.BigInt(&i) // no montgomery i.ModInverse(&i, Modulus()) var fac Element fac.setBigInt(&i) // back to montgomery @@ -2558,7 +2551,7 @@ func TestElement0Inverse(t *testing.T) { } } -//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +// TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen func TestUpdateFactorSubtraction(t *testing.T) { for i := 0; i < 1000; i++ { @@ -2771,11 +2764,11 @@ func randomizeUpdateFactors() (int64, int64) { func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { var p1 big.Int - x.ToBigInt(&p1) + x.toBigInt(&p1) p1.Mul(&p1, big.NewInt(xC)) var p2 big.Int - y.ToBigInt(&p2) + y.toBigInt(&p2) p2.Mul(&p2, big.NewInt(yC)) p1.Add(&p1, &p2) @@ -2812,7 +2805,7 @@ func montReduce(res *big.Int, x *big.Int) { } func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { - z.ToBigInt(i) + z.toBigInt(i) var upperWord big.Int upperWord.SetUint64(xHi) upperWord.Lsh(&upperWord, Limbs*64) @@ -2830,7 +2823,7 @@ func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { var xInt big.Int - x.ToBigInt(&xInt) + x.toBigInt(&xInt) xInt.Mul(&xInt, big.NewInt(c)) @@ -2841,7 +2834,7 @@ func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, result func approximateRef(x *Element) uint64 { var asInt big.Int - x.ToBigInt(&asInt) + x.toBigInt(&asInt) n := x.BitLen() if n <= 64 { diff --git a/ecc/bw6-761/fr/fri/fri.go b/ecc/bw6-761/fr/fri/fri.go index aea48b57b..590ecde99 100644 --- a/ecc/bw6-761/fr/fri/fri.go +++ b/ecc/bw6-761/fr/fri/fri.go @@ -121,7 +121,6 @@ type Round struct { // a function is d-close to a low degree polynomial. // // It is composed of a series of Interactions, emulated with Fiat Shamir, -// type ProofOfProximity struct { // ID unique ID attached to the proof of proximity. It's needed for diff --git a/ecc/bw6-761/fr/gkr/gkr.go b/ecc/bw6-761/fr/gkr/gkr.go new file mode 100644 index 000000000..2dd387df0 --- /dev/null +++ b/ecc/bw6-761/fr/gkr/gkr.go @@ -0,0 +1,774 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/sumcheck" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// Gate must be a low-degree polynomial +type Gate interface { + Evaluate(...fr.Element) fr.Element + Degree() int +} + +type Wire struct { + Gate Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]fr.Element + claimedEvaluations []fr.Element + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a fr.Element) fr.Element { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]fr.Element) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation fr.Element + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]fr.Element, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]fr.Element // x in the paper + claimedEvaluations []fr.Element // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff fr.Element) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + newEq.Eq(c.evaluationPoints[k]) + + eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// computeValAndStep returns val : i ↦ m(1, i...) and step : i ↦ m(1, i...) - m(0, i...) +func computeValAndStep(m polynomial.MultiLin, p *polynomial.Pool) (val polynomial.MultiLin, step polynomial.MultiLin) { + val = p.Clone(m[len(m)/2:]) + step = p.Clone(m[:len(m)/2]) + + valAsPoly, stepAsPoly := polynomial.Polynomial(val), polynomial.Polynomial(step) + + stepAsPoly.Sub(valAsPoly, stepAsPoly) + return +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = g_{j-1}(r_{j-1}). By convention, g_0 is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() (gJ polynomial.Polynomial) { + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + EVal, EStep := computeValAndStep(c.eq, c.manager.memPool) + + puVal := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO: Make a two-dimensional array struct, and index it i-first rather than inputI first: would result in scanning memory access in the "d" loop and obviate the gateInput variable + puStep := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO, ctd: the greater degGJ, the more this would matter + + for i, puI := range c.inputPreprocessors { + puVal[i], puStep[i] = computeValAndStep(puI, c.manager.memPool) + } + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + gJ = make([]fr.Element, degGJ) + + parallel := len(EVal) >= 1024 //TODO: Experiment with threshold + + var gateInput [][]fr.Element + + if parallel { + gateInput = [][]fr.Element{c.manager.memPool.Make(len(c.inputPreprocessors)), + c.manager.memPool.Make(len(c.inputPreprocessors))} + } else { + gateInput = [][]fr.Element{c.manager.memPool.Make(len(c.inputPreprocessors))} + } + + var wg sync.WaitGroup + + for d := 0; d < degGJ; d++ { + + notLastIteration := d+1 < degGJ + + sumOverI := func(res *fr.Element, gateInput []fr.Element, start, end int) { + for i := start; i < end; i++ { + + for inputI := range puVal { + gateInput[inputI].Set(&puVal[inputI][i]) + if notLastIteration { + puVal[inputI][i].Add(&puVal[inputI][i], &puStep[inputI][i]) + } + } + + // gJAtDI = gJ(d, i...) + gJAtDI := c.wire.Gate.Evaluate(gateInput...) + gJAtDI.Mul(&gJAtDI, &EVal[i]) + + res.Add(res, &gJAtDI) + + if notLastIteration { + EVal[i].Add(&EVal[i], &EStep[i]) + } + } + wg.Done() + } + + if parallel { + var firstHalf, secondHalf fr.Element + wg.Add(2) + go sumOverI(&secondHalf, gateInput[1], len(EVal)/2, len(EVal)) + go sumOverI(&firstHalf, gateInput[0], 0, len(EVal)/2) + wg.Wait() + gJ[d].Add(&firstHalf, &secondHalf) + } else { + wg.Add(1) // formalities + sumOverI(&gJ[d], gateInput[0], 0, len(EVal)) + } + } + + c.manager.memPool.Dump(gateInput...) + c.manager.memPool.Dump(EVal, EStep) + + for inputI := range puVal { + c.manager.memPool.Dump(puVal[inputI], puStep[inputI]) + } + + return +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element fr.Element) polynomial.Polynomial { + c.eq.Fold(element) + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []fr.Element) interface{} { + + //defer the proof, return list of claims + evaluations := make([]fr.Element, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, pool *polynomial.Pool) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = pool + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]fr.Element, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []fr.Element, evaluation fr.Element) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< b { + return a + } + return b +} + +func ChallengeNames(sorted []*Wire, logNbInstances int, prefix string) []string { + + // Pre-compute the size TODO: Consider not doing this and just grow the list by appending + size := logNbInstances // first challenge + + for _, w := range sorted { + if w.noProof() { // no proof, no challenge + continue + } + if w.NbClaims() > 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]fr.Element, error) { + res := make([]fr.Element, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + + claims := newClaimsManager(c, assignment, o.pool) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []fr.Element{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]fr.Element) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + + claims := newClaimsManager(c, assignment, o.pool) + + var firstChallenge []fr.Element + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]fr.Element) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return fmt.Errorf("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return fmt.Errorf("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +type IdentityGate struct{} + +func (IdentityGate) Evaluate(input ...fr.Element) fr.Element { + return input[0] +} + +func (IdentityGate) Degree() int { + return 1 +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = IdentityGate{} + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// topologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func topologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := topologicalSort(c) + + numEvaluations := 0 + + for _, w := range sortedWires { + if !w.IsInput() { + if numEvaluations == 0 { + numEvaluations = len(a[w.Inputs[0]]) + } + evals := make([]fr.Element, numEvaluations) + ins := make([]fr.Element, len(w.Inputs)) + for k := 0; k < numEvaluations; k++ { + for inI, in := range w.Inputs { + ins[inI] = a[in][k] + } + evals[k] = w.Gate.Evaluate(ins...) + } + a[w] = evals + } + } + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} diff --git a/ecc/bw6-761/fr/gkr/gkr_test.go b/ecc/bw6-761/fr/gkr/gkr_test.go new file mode 100644 index 000000000..a61b77854 --- /dev/null +++ b/ecc/bw6-761/fr/gkr/gkr_test.go @@ -0,0 +1,722 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/mimc" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/sumcheck" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "os" + "path/filepath" + "reflect" + "strconv" + "testing" +) + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []fr.Element{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []fr.Element{four, three}, []fr.Element{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []fr.Element{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []fr.Element{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []fr.Element{one, one}, []fr.Element{one, two}) +} +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []fr.Element{one, one}, []fr.Element{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]fr.Element) { + return func(t *testing.T, inputAssignments ...[]fr.Element) { + testMimc(t, numRounds, inputAssignments...) + } +} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + }} + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []fr.Element{two, three}} + pool := polynomial.NewPool(256, 1<<11) + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, &pool) + manager.add(wire, []fr.Element{three}, five) + manager.add(wire, []fr.Element{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six fr.Element + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]fr.Element)) { + fullAssignments := make([][]fr.Element, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandom(fullAssignments[i]) + } + + inputAssignments := make([][]fr.Element, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]fr.Element) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]fr.Element) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/ConsenSys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]fr.Element) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandom(slice []fr.Element) { + for i := range slice { + slice[i].SetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, testCase.transcriptSetting()) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, testCase.transcriptSetting()) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(&test_vector_utils.MapHash{Map: testCase.Hash}, []byte{1})) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "../../../../internal/generator/gkr/test_vectors" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func BenchmarkGkrMimc(b *testing.B) { + const N = 1 << 19 + fmt.Println("creating circuit structure") + c := mimcCircuit(91) + + in0 := make([]fr.Element, N) + in1 := make([]fr.Element, N) + setRandom(in0) + setRandom(in1) + + fmt.Println("evaluating circuit") + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + + //b.ResetTimer() + fmt.Println("constructing proof") + Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := topologicalSort(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := topologicalSort(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +type WireInfo struct { + Gate string `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]Circuit) + +func getCircuit(path string) (Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit Circuit) { + circuit = make(Circuit, len(c)) + for i := range c { + circuit[i].Gate = gates[c[i].Gate] + circuit[i].Inputs = make([]*Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +var gates map[string]Gate + +func init() { + gates = make(map[string]Gate) + gates["identity"] = IdentityGate{} + gates["mul"] = mulGate{} + gates["mimc"] = mimcCipherGate{} //TODO: Add ark + gates["select-input-3"] = _select(2) +} + +type mimcCipherGate struct { + ark fr.Element +} + +func (m mimcCipherGate) Evaluate(input ...fr.Element) (res fr.Element) { + var sum fr.Element + + sum. + Add(&input[0], &input[1]). + Add(&sum, &m.ark) + + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +func (m mimcCipherGate) Degree() int { + return 7 +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (Proof, error) { + proof := make(Proof, len(printable)) + for i := range printable { + finalEvalProof := []fr.Element(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]fr.Element, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := test_vector_utils.SetElement(&finalEvalProof[k], finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit Circuit + Hash *test_vector_utils.ElementMap + Proof Proof + FullAssignment WireAssignment + InOutAssignment WireAssignment +} + +type TestCaseInfo struct { + Hash string `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash *test_vector_utils.ElementMap + if _hash, err = test_vector_utils.ElementMapFromFile(filepath.Join(dir, info.Hash)); err != nil { + return nil, err + } + var proof Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(WireAssignment) + inOutAssignment := make(WireAssignment) + + sorted := topologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []fr.Element + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + for _, w := range sorted { + if w.IsOutput() { + + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func (c *TestCase) transcriptSetting(initialChallenge ...[]byte) fiatshamir.Settings { + return fiatshamir.WithHash(&test_vector_utils.MapHash{Map: c.Hash}, initialChallenge...) +} + +type mulGate struct{} + +func (g mulGate) Evaluate(element ...fr.Element) (result fr.Element) { + result.Mul(&element[0], &element[1]) + return +} + +func (g mulGate) Degree() int { + return 2 +} + +type _select int + +func (g _select) Evaluate(in ...fr.Element) fr.Element { + return in[g] +} + +func (g _select) Degree() int { + return 1 +} diff --git a/ecc/bw6-761/fr/kzg/kzg.go b/ecc/bw6-761/fr/kzg/kzg.go index 58c3c23fb..5650b88b3 100644 --- a/ecc/bw6-761/fr/kzg/kzg.go +++ b/ecc/bw6-761/fr/kzg/kzg.go @@ -84,9 +84,6 @@ func NewSRS(size uint64, bAlpha *big.Int) (*SRS, error) { for i := 1; i < len(alphas); i++ { alphas[i].Mul(&alphas[i-1], &alpha) } - for i := 0; i < len(alphas); i++ { - alphas[i].FromMont() - } g1s := bw6761.BatchScalarMultiplicationG1(&gen1Aff, alphas) copy(srs.G1[1:], g1s) @@ -125,7 +122,7 @@ func Commit(p []fr.Element, srs *SRS, nbTasks ...int) (Digest, error) { var res bw6761.G1Affine - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} if len(nbTasks) > 0 { config.NbTasks = nbTasks[0] } @@ -395,7 +392,7 @@ func BatchVerifyMultiPoints(digests []Digest, proofs []OpeningProof, points []fr for i := 0; i < len(randomNumbers); i++ { quotients[i].Set(&proofs[i].H) } - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} _, err := foldedQuotients.MultiExp(quotients, randomNumbers, config) if err != nil { return nil @@ -478,7 +475,7 @@ func fold(di []Digest, fai []fr.Element, ci []fr.Element) (Digest, fr.Element, e // fold the digests ∑ᵢ[cᵢ]([fᵢ(α)]G₁) var foldedDigests Digest - _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{ScalarsMont: true}) + _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{}) if err != nil { return foldedDigests, foldedEvaluations, err } diff --git a/ecc/bw6-761/fr/mimc/decompose.go b/ecc/bw6-761/fr/mimc/decompose.go new file mode 100644 index 000000000..06761e28f --- /dev/null +++ b/ecc/bw6-761/fr/mimc/decompose.go @@ -0,0 +1,46 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" +) + +// Decompose interpret rawBytes as a bigInt x in big endian, +// and returns the digits of x (from LSB to MSB) when x is written +// in basis modulo. +func Decompose(rawBytes []byte) []fr.Element { + + rawBigInt := big.NewInt(0).SetBytes(rawBytes) + modulo := fr.Modulus() + + // maximum number of chunks that a function + maxNbChunks := len(rawBytes) / fr.Bytes + + res := make([]fr.Element, 0, maxNbChunks) + var tmp fr.Element + t := new(big.Int) + for rawBigInt.Sign() != 0 { + rawBigInt.DivMod(rawBigInt, modulo, t) + tmp.SetBigInt(t) + res = append(res, tmp) + } + + return res +} diff --git a/ecc/bw6-761/fr/mimc/decompose_test.go b/ecc/bw6-761/fr/mimc/decompose_test.go new file mode 100644 index 000000000..031367811 --- /dev/null +++ b/ecc/bw6-761/fr/mimc/decompose_test.go @@ -0,0 +1,53 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package mimc + +import ( + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" +) + +func TestDecompose(t *testing.T) { + + // create 10 random digits in basis r + nbDigits := 10 + a := make([]fr.Element, nbDigits) + for i := 0; i < nbDigits; i++ { + a[i].SetRandom() + } + + // create a big int whose digits in basis r are a + m := fr.Modulus() + var b, tmp big.Int + for i := nbDigits - 1; i >= 0; i-- { + b.Mul(&b, m) + a[i].ToBigIntRegular(&tmp) + b.Add(&b, &tmp) + } + + // query the decomposition and compare to a + bb := b.Bytes() + d := Decompose(bb) + for i := 0; i < nbDigits; i++ { + if !d[i].Equal(&a[i]) { + t.Fatal("error decomposition") + } + } + +} diff --git a/ecc/bw6-761/fr/mimc/mimc.go b/ecc/bw6-761/fr/mimc/mimc.go index 5c8bc52e7..3b69ba02b 100644 --- a/ecc/bw6-761/fr/mimc/mimc.go +++ b/ecc/bw6-761/fr/mimc/mimc.go @@ -17,6 +17,7 @@ package mimc import ( + "errors" "hash" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" @@ -91,44 +92,44 @@ func (d *digest) BlockSize() int { } // Write (via the embedded io.Writer interface) adds more data to the running hash. -// It never returns an error. +// +// Each []byte block of size BlockSize represents a big endian fr.Element. +// +// If len(p) is not a multiple of BlockSize and any of the []byte in p represent an integer +// larger than fr.Modulus, this function returns an error. +// +// To hash arbitrary data ([]byte not representing canonical field elements) use Decompose +// function in this package. func (d *digest) Write(p []byte) (n int, err error) { n = len(p) + if n%BlockSize != 0 { + return 0, errors.New("invalid input length: must represent a list of field elements, expects a []byte of len m*BlockSize") + } + + // ensure each block represents a field element in canonical reduced form + for i := 0; i < n; i += BlockSize { + if _, err = fr.BigEndian.Element((*[BlockSize]byte)(p[i : i+BlockSize])); err != nil { + return 0, err + } + } + d.data = append(d.data, p...) return } -// Hash hash using Miyaguchi–Preneel: +// Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form func (d *digest) checksum() fr.Element { + // Write guarantees len(data) % BlockSize == 0 - var buffer [BlockSize]byte - var x fr.Element - - // if data size is not multiple of BlockSizes we padd: - // .. || 0xaf8 -> .. || 0x0000...0af8 - if len(d.data)%BlockSize != 0 { - q := len(d.data) / BlockSize - r := len(d.data) % BlockSize - sliceq := make([]byte, q*BlockSize) - copy(sliceq, d.data) - slicer := make([]byte, r) - copy(slicer, d.data[q*BlockSize:]) - sliceremainder := make([]byte, BlockSize-r) - d.data = append(sliceq, sliceremainder...) - d.data = append(d.data, slicer...) - } - + // TODO @ThomasPiellard shouldn't Sum() returns an error if there is no data? if len(d.data) == 0 { - d.data = make([]byte, 32) + d.data = make([]byte, BlockSize) } - nbChunks := len(d.data) / BlockSize - - for i := 0; i < nbChunks; i++ { - copy(buffer[:], d.data[i*BlockSize:(i+1)*BlockSize]) - x.SetBytes(buffer[:]) + for i := 0; i < len(d.data); i += BlockSize { + x, _ := fr.BigEndian.Element((*[BlockSize]byte)(d.data[i : i+BlockSize])) r := d.encrypt(x) d.h.Add(&r, &d.h).Add(&d.h, &x) } diff --git a/ecc/bw6-761/fr/pedersen/pedersen.go b/ecc/bw6-761/fr/pedersen/pedersen.go new file mode 100644 index 000000000..4af6559c9 --- /dev/null +++ b/ecc/bw6-761/fr/pedersen/pedersen.go @@ -0,0 +1,113 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package pedersen + +import ( + "crypto/rand" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/bw6-761" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "math/big" +) + +// Key for proof and verification +type Key struct { + g bw6761.G2Affine // TODO @tabaie: does this really have to be randomized? + gRootSigmaNeg bw6761.G2Affine //gRootSigmaNeg = g^{-1/σ} + basis []bw6761.G1Affine + basisExpSigma []bw6761.G1Affine +} + +func randomOnG2() (bw6761.G2Affine, error) { // TODO: Add to G2.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return bw6761.G2Affine{}, err + } + return bw6761.HashToG2(gBytes, []byte("random on g2")) +} + +func Setup(basis []bw6761.G1Affine) (Key, error) { + var ( + k Key + err error + ) + + if k.g, err = randomOnG2(); err != nil { + return k, err + } + + var modMinusOne big.Int + modMinusOne.Sub(fr.Modulus(), big.NewInt(1)) + var sigma *big.Int + if sigma, err = rand.Int(rand.Reader, &modMinusOne); err != nil { + return k, err + } + sigma.Add(sigma, big.NewInt(1)) + + var sigmaInvNeg big.Int + sigmaInvNeg.ModInverse(sigma, fr.Modulus()) + sigmaInvNeg.Sub(fr.Modulus(), &sigmaInvNeg) + k.gRootSigmaNeg.ScalarMultiplication(&k.g, &sigmaInvNeg) + + k.basisExpSigma = make([]bw6761.G1Affine, len(basis)) + for i := range basis { + k.basisExpSigma[i].ScalarMultiplication(&basis[i], sigma) + } + + k.basis = basis + return k, err +} + +func (k *Key) Commit(values []fr.Element) (commitment bw6761.G1Affine, knowledgeProof bw6761.G1Affine, err error) { + + if len(values) != len(k.basis) { + err = fmt.Errorf("unexpected number of values") + return + } + + // TODO @gbotrel this will spawn more than one task, see + // https://github.com/ConsenSys/gnark-crypto/issues/269 + config := ecc.MultiExpConfig{ + NbTasks: 1, // TODO Experiment + } + + if _, err = commitment.MultiExp(k.basis, values, config); err != nil { + return + } + + _, err = knowledgeProof.MultiExp(k.basisExpSigma, values, config) + + return +} + +// VerifyKnowledgeProof checks if the proof of knowledge is valid +func (k *Key) VerifyKnowledgeProof(commitment bw6761.G1Affine, knowledgeProof bw6761.G1Affine) error { + + if !commitment.IsInSubGroup() || !knowledgeProof.IsInSubGroup() { + return fmt.Errorf("subgroup check failed") + } + + product, err := bw6761.Pair([]bw6761.G1Affine{commitment, knowledgeProof}, []bw6761.G2Affine{k.g, k.gRootSigmaNeg}) + if err != nil { + return err + } + if product.IsOne() { + return nil + } + return fmt.Errorf("proof rejected") +} diff --git a/ecc/bw6-761/fr/pedersen/pedersen_test.go b/ecc/bw6-761/fr/pedersen/pedersen_test.go new file mode 100644 index 000000000..1747d3a0b --- /dev/null +++ b/ecc/bw6-761/fr/pedersen/pedersen_test.go @@ -0,0 +1,90 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package pedersen + +import ( + "github.com/consensys/gnark-crypto/ecc/bw6-761" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/stretchr/testify/assert" + "math/rand" + "testing" +) + +func interfaceSliceToFrSlice(t *testing.T, values ...interface{}) []fr.Element { + res := make([]fr.Element, len(values)) + for i, v := range values { + _, err := res[i].SetInterface(v) + assert.NoError(t, err) + } + return res +} + +func randomFrSlice(t *testing.T, size int) []interface{} { + res := make([]interface{}, size) + var err error + for i := range res { + var v fr.Element + res[i], err = v.SetRandom() + assert.NoError(t, err) + } + return res +} + +func randomOnG1() (bw6761.G1Affine, error) { // TODO: Add to G1.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return bw6761.G1Affine{}, err + } + return bw6761.HashToG1(gBytes, []byte("random on g2")) +} + +func testCommit(t *testing.T, values ...interface{}) { + + basis := make([]bw6761.G1Affine, len(values)) + for i := range basis { + var err error + basis[i], err = randomOnG1() + assert.NoError(t, err) + } + + var ( + key Key + err error + commitment, pok bw6761.G1Affine + ) + + key, err = Setup(basis) + assert.NoError(t, err) + commitment, pok, err = key.Commit(interfaceSliceToFrSlice(t, values...)) + assert.NoError(t, err) + assert.NoError(t, key.VerifyKnowledgeProof(commitment, pok)) + + pok.Neg(&pok) + assert.NotNil(t, key.VerifyKnowledgeProof(commitment, pok)) +} + +func TestCommitToOne(t *testing.T) { + testCommit(t, 1) +} + +func TestCommitSingle(t *testing.T) { + testCommit(t, randomFrSlice(t, 1)...) +} + +func TestCommitFiveElements(t *testing.T) { + testCommit(t, randomFrSlice(t, 5)...) +} diff --git a/ecc/bw6-761/fr/plookup/vector.go b/ecc/bw6-761/fr/plookup/vector.go index 0fffc1d1f..ba3203090 100644 --- a/ecc/bw6-761/fr/plookup/vector.go +++ b/ecc/bw6-761/fr/plookup/vector.go @@ -125,7 +125,8 @@ func evaluateAccumulationPolynomial(lf, lt, lh1, lh2 []fr.Element, beta, gamma f // evaluateNumBitReversed computes the evaluation (shifted, bit reversed) of h where // h = (x-1)*z*(1+\beta)*(\gamma+f)*(\gamma(1+\beta) + t+ \beta*t(gX)) - -// (x-1)*z(gX)*(\gamma(1+\beta) + h_{1} + \beta*h_{1}(gX))*(\gamma(1+\beta) + h_{2} + \beta*h_{2}(gX) ) +// +// (x-1)*z(gX)*(\gamma(1+\beta) + h_{1} + \beta*h_{1}(gX))*(\gamma(1+\beta) + h_{2} + \beta*h_{2}(gX) ) // // * cz, ch1, ch2, ct, cf are the polynomials z, h1, h2, t, f in canonical basis // * _lz, _lh1, _lh2, _lt, _lf are the polynomials z, h1, h2, t, f in shifted Lagrange basis (domainBig) @@ -370,7 +371,6 @@ func computeQuotientCanonical(alpha fr.Element, lh, lh0, lhn, lh1h2 []fr.Element // before generating a lookup proof), the commitment needs to be done on the // table sorted. Otherwise the commitment in proof.t will not be the same as // the public commitment: it will contain the same values, but permuted. -// func ProveLookupVector(srs *kzg.SRS, f, t Table) (ProofLookupVector, error) { // res diff --git a/ecc/bw6-761/fr/polynomial/multilin.go b/ecc/bw6-761/fr/polynomial/multilin.go index c20d53b13..393f2c750 100644 --- a/ecc/bw6-761/fr/polynomial/multilin.go +++ b/ecc/bw6-761/fr/polynomial/multilin.go @@ -18,6 +18,7 @@ package polynomial import ( "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "math/bits" ) // MultiLin tracks the values of a (dense i.e. not sparse) multilinear polynomial @@ -46,44 +47,62 @@ func (m *MultiLin) Fold(r fr.Element) { *m = (*m)[:mid] } +func (m MultiLin) Sum() fr.Element { + s := m[0] + for i := 1; i < len(m); i++ { + s.Add(&s, &m[i]) + } + return s +} + +func _clone(m MultiLin, p *Pool) MultiLin { + if p == nil { + return m.Clone() + } else { + return p.Clone(m) + } +} + +func _dump(m MultiLin, p *Pool) { + if p != nil { + p.Dump(m) + } +} + // Evaluate extrapolate the value of the multilinear polynomial corresponding to m // on the given coordinates -func (m MultiLin) Evaluate(coordinates []fr.Element) fr.Element { +func (m MultiLin) Evaluate(coordinates []fr.Element, p *Pool) fr.Element { // Folding is a mutating operation - bkCopy := m.Clone() + bkCopy := _clone(m, p) // Evaluate step by step through repeated folding (i.e. evaluation at the first remaining variable) for _, r := range coordinates { bkCopy.Fold(r) } - return bkCopy[0] + result := bkCopy[0] + + _dump(bkCopy, p) + return result } -// Clone creates a deep copy of a book-keeping table. +// Clone creates a deep copy of a bookkeeping table. // Both multilinear interpolation and sumcheck require folding an underlying // array, but folding changes the array. To do both one requires a deep copy -// of the book-keeping table. +// of the bookkeeping table. func (m MultiLin) Clone() MultiLin { - tableDeepCopy := Make(len(m)) - copy(tableDeepCopy, m) - return tableDeepCopy + res := make(MultiLin, len(m)) + copy(res, m) + return res } // Add two bookKeepingTables func (m *MultiLin) Add(left, right MultiLin) { size := len(left) // Check that left and right have the same size - if len(right) != size { - panic("Left and right do not have the right size") + if len(right) != size || len(*m) != size { + panic("left, right and destination must have the right size") } - // Reallocate the table if necessary - if cap(*m) < size { - *m = make([]fr.Element, size) - } - - // Resize the destination table - *m = (*m)[:size] // Add elementwise for i := 0; i < size; i++ { @@ -93,15 +112,17 @@ func (m *MultiLin) Add(left, right MultiLin) { // EvalEq computes Eq(q₁, ... , qₙ, h₁, ... , hₙ) = Π₁ⁿ Eq(qᵢ, hᵢ) // where Eq(x,y) = xy + (1-x)(1-y) = 1 - x - y + xy + xy interpolates -// _________________ -// | | | -// | 0 | 1 | -// |_______|_______| -// y | | | -// | 1 | 0 | -// |_______|_______| // -// x +// _________________ +// | | | +// | 0 | 1 | +// |_______|_______| +// y | | | +// | 1 | 0 | +// |_______|_______| +// +// x +// // In other words the polynomial evaluated here is the multilinear extrapolation of // one that evaluates to q' == h' for vectors q', h' of binary values func EvalEq(q, h []fr.Element) fr.Element { @@ -128,10 +149,7 @@ func (m *MultiLin) Eq(q []fr.Element) { n := len(q) if len(*m) != 1< 0 { - i.Sub(fr.Modulus(), &i) - i.Neg(&i) +func (p Polynomial) SetZero() { + for i := 0; i < len(p); i++ { + p[i].SetZero() } - return i } func (p Polynomial) Text(base int) string { @@ -171,12 +175,13 @@ func (p Polynomial) Text(base int) string { continue } - i := signedBigInt(&p[d]) + pD := p[d] + pDText := pD.Text(base) initialLen := builder.Len() - if i.Sign() < 1 { - i.Neg(&i) + if pDText[0] == '-' { + pDText = pDText[1:] if first { builder.WriteString("-") } else { @@ -188,13 +193,8 @@ func (p Polynomial) Text(base int) string { first = false - asInt64 := int64(0) - if i.IsInt64() { - asInt64 = i.Int64() - } - - if asInt64 != 1 || d == 0 { - builder.WriteString(i.Text(base)) + if !pD.IsOne() || d == 0 { + builder.WriteString(pDText) } if builder.Len()-initialLen > 10 { diff --git a/ecc/bw6-761/fr/polynomial/polynomial_test.go b/ecc/bw6-761/fr/polynomial/polynomial_test.go index 8a00a1daa..c293e9a26 100644 --- a/ecc/bw6-761/fr/polynomial/polynomial_test.go +++ b/ecc/bw6-761/fr/polynomial/polynomial_test.go @@ -17,10 +17,10 @@ package polynomial import ( + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/stretchr/testify/assert" "math/big" "testing" - - "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" ) func TestPolynomialEval(t *testing.T) { @@ -206,3 +206,13 @@ func TestPolynomialAdd(t *testing.T) { t.Fatal("side effect, _f2 should not have been modified") } } + +func TestPolynomialText(t *testing.T) { + var one, negTwo fr.Element + one.SetOne() + negTwo.SetInt64(-2) + + p := Polynomial{one, negTwo, one} + + assert.Equal(t, "X² - 2X + 1", p.Text(10)) +} diff --git a/ecc/bw6-761/fr/polynomial/pool.go b/ecc/bw6-761/fr/polynomial/pool.go index 1f57a87ce..8630c85b9 100644 --- a/ecc/bw6-761/fr/polynomial/pool.go +++ b/ecc/bw6-761/fr/polynomial/pool.go @@ -17,114 +17,187 @@ package polynomial import ( + "encoding/json" "fmt" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "reflect" + "runtime" + "sort" "sync" "unsafe" ) // Memory management for polynomials -// Copied verbatim from gkr repo +// WARNING: This is not thread safe TODO: Make sure that is not a problem +// TODO: There is a lot of "unsafe" memory management here and needs to be vetted thoroughly -// Sets a maximum for the array size we keep in pool -const maxNForLargePool int = 1 << 24 -const maxNForSmallPool int = 256 +type sizedPool struct { + maxN int + pool sync.Pool + stats poolStats +} + +type inUseData struct { + allocatedFor []uintptr + pool *sizedPool +} -// Aliases because it is annoying to use arrays in all the places -type largeArr = [maxNForLargePool]fr.Element -type smallArr = [maxNForSmallPool]fr.Element +type Pool struct { + //lock sync.Mutex + inUse map[*fr.Element]inUseData + subPools []sizedPool +} -var rC = sync.Map{} +func (p *sizedPool) get(n int) *fr.Element { + p.stats.maake(n) + return p.pool.Get().(*fr.Element) +} -var ( - largePool = sync.Pool{ - New: func() interface{} { - var res largeArr - return &res - }, - } - smallPool = sync.Pool{ - New: func() interface{} { - var res smallArr - return &res - }, +func (p *sizedPool) put(ptr *fr.Element) { + p.stats.dump() + p.pool.Put(ptr) +} + +func NewPool(maxN ...int) (pool Pool) { + + sort.Ints(maxN) + pool = Pool{ + inUse: make(map[*fr.Element]inUseData), + subPools: make([]sizedPool, len(maxN)), } -) -// ClearPool Clears the pool completely, shields against memory leaks -// Eg: if we forgot to dump a polynomial at some point, this will ensure the value get dumped eventually -// Returns how many polynomials were cleared that way -func ClearPool() int { - res := 0 - rC.Range(func(k, _ interface{}) bool { - switch ptr := k.(type) { - case *largeArr: - largePool.Put(ptr) - case *smallArr: - smallPool.Put(ptr) - default: - panic(fmt.Sprintf("tried to clear %v", reflect.TypeOf(ptr))) + for i := range pool.subPools { + subPool := &pool.subPools[i] + subPool.maxN = maxN[i] + subPool.pool = sync.Pool{ + New: func() interface{} { + subPool.stats.Allocated++ + return getDataPointer(make([]fr.Element, 0, subPool.maxN)) + }, } - res++ - return true - }) - return res + } + return } -// CountPool Returns the number of elements in the pool without mutating it -func CountPool() int { - res := 0 - rC.Range(func(_, _ interface{}) bool { - res++ - return true - }) - return res +func (p *Pool) findCorrespondingPool(n int) *sizedPool { + poolI := 0 + for poolI < len(p.subPools) && n > p.subPools[poolI].maxN { + poolI++ + } + return &p.subPools[poolI] // out of bounds error here would mean that n is too large } -// Make tries to find a reusable polynomial or allocates a new one -func Make(n int) []fr.Element { - if n > maxNForLargePool { - panic(fmt.Sprintf("been provided with size of %v but the maximum is %v", n, maxNForLargePool)) +func (p *Pool) Make(n int) []fr.Element { + pool := p.findCorrespondingPool(n) + ptr := pool.get(n) + p.addInUse(ptr, pool) + return unsafe.Slice(ptr, n) +} + +// Dump dumps a set of polynomials into the pool +func (p *Pool) Dump(slices ...[]fr.Element) { + for _, slice := range slices { + ptr := getDataPointer(slice) + if metadata, ok := p.inUse[ptr]; ok { + delete(p.inUse, ptr) + metadata.pool.put(ptr) + } else { + panic("attempting to dump a slice not created by the pool") + } } +} + +func (p *Pool) addInUse(ptr *fr.Element, pool *sizedPool) { + pcs := make([]uintptr, 2) + n := runtime.Callers(3, pcs) - if n <= maxNForSmallPool { - ptr := smallPool.Get().(*smallArr) - rC.Store(ptr, struct{}{}) // registers the pointer being used - return (*ptr)[:n] + if prevPcs, ok := p.inUse[ptr]; ok { // TODO: remove if unnecessary for security + panic(fmt.Errorf("re-allocated non-dumped slice, previously allocated at %v", runtime.CallersFrames(prevPcs.allocatedFor))) } + p.inUse[ptr] = inUseData{ + allocatedFor: pcs[:n], + pool: pool, + } +} - ptr := largePool.Get().(*largeArr) - rC.Store(ptr, struct{}{}) // remember we allocated the pointer is being used - return (*ptr)[:n] +func printFrame(frame runtime.Frame) { + fmt.Printf("\t%s line %d, function %s\n", frame.File, frame.Line, frame.Function) } -// Dump dumps a set of polynomials into the pool -// Returns the number of deallocated polys -func Dump(arrs ...[]fr.Element) int { - cnt := 0 - for _, arr := range arrs { - ptr := ptr(arr) - pool := &smallPool - if len(arr) > maxNForSmallPool { - pool = &largePool - } - // If the rC did not register, then - // either the array was allocated somewhere else which can be ignored - // otherwise a double put which MUST be ignored - if _, ok := rC.Load(ptr); ok { - pool.Put(ptr) - // And deregisters the ptr - rC.Delete(ptr) - cnt++ +func (p *Pool) printInUse() { + fmt.Println("slices never dumped allocated at:") + for _, pcs := range p.inUse { + fmt.Println("-------------------------") + + var frame runtime.Frame + frames := runtime.CallersFrames(pcs.allocatedFor) + more := true + for more { + frame, more = frames.Next() + printFrame(frame) } } - return cnt } -func ptr(m []fr.Element) unsafe.Pointer { - if cap(m) != maxNForSmallPool && cap(m) != maxNForLargePool { - panic(fmt.Sprintf("can't cast to large or small array, the put array's is %v it should have capacity %v or %v", cap(m), maxNForLargePool, maxNForSmallPool)) +type poolStats struct { + Used int + Allocated int + ReuseRate float64 + InUse int + GreatestNUsed int + SmallestNUsed int +} + +type poolsStats struct { + SubPools []poolStats + InUse int +} + +func (s *poolStats) maake(n int) { + s.Used++ + s.InUse++ + if n > s.GreatestNUsed { + s.GreatestNUsed = n + } + if s.SmallestNUsed == 0 || s.SmallestNUsed > n { + s.SmallestNUsed = n + } +} + +func (s *poolStats) dump() { + s.InUse-- +} + +func (s *poolStats) finalize() { + s.ReuseRate = float64(s.Used) / float64(s.Allocated) +} + +func getDataPointer(slice []fr.Element) *fr.Element { + header := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) + return (*fr.Element)(unsafe.Pointer(header.Data)) +} + +func (p *Pool) PrintPoolStats() { + InUse := 0 + subStats := make([]poolStats, len(p.subPools)) + for i := range p.subPools { + subPool := &p.subPools[i] + subPool.stats.finalize() + subStats[i] = subPool.stats + InUse += subPool.stats.InUse } - return unsafe.Pointer(&m[0]) + + poolsStats := poolsStats{ + SubPools: subStats, + InUse: InUse, + } + serialized, _ := json.MarshalIndent(poolsStats, "", " ") + fmt.Println(string(serialized)) + p.printInUse() +} + +func (p *Pool) Clone(slice []fr.Element) []fr.Element { + res := p.Make(len(slice)) + copy(res, slice) + return res } diff --git a/ecc/bw6-761/fr/sumcheck/sumcheck.go b/ecc/bw6-761/fr/sumcheck/sumcheck.go new file mode 100644 index 000000000..7c0aa092e --- /dev/null +++ b/ecc/bw6-761/fr/sumcheck/sumcheck.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a fr.Element) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(fr.Element) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []fr.Element) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a fr.Element) fr.Element // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = &transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []fr.Element, remainingChallengeNames *[]string) (fr.Element, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return fr.Element{}, err + } + } + var res fr.Element + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff fr.Element + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]fr.Element, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff fr.Element + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []fr.Element{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]fr.Element, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return fmt.Errorf("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/ecc/bw6-761/fr/sumcheck/sumcheck_test.go b/ecc/bw6-761/fr/sumcheck/sumcheck_test.go new file mode 100644 index 000000000..17152970e --- /dev/null +++ b/ecc/bw6-761/fr/sumcheck/sumcheck_test.go @@ -0,0 +1,161 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []fr.Element) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []fr.Element{sum} +} + +func (c singleMultilinClaim) Combine(fr.Element) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r fr.Element) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum fr.Element +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []fr.Element, combinationCoeff fr.Element, purportedValue fr.Element, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs fr.Element) fr.Element { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/ecc/bw6-761/fr/test_vector_utils/test_vector_utils.go b/ecc/bw6-761/fr/test_vector_utils/test_vector_utils.go new file mode 100644 index 000000000..00b75ce22 --- /dev/null +++ b/ecc/bw6-761/fr/test_vector_utils/test_vector_utils.go @@ -0,0 +1,429 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package test_vector_utils + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" + "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/polynomial" + "hash" + + "os" + "path/filepath" + "reflect" + "sort" + "strconv" + "strings" +) + +type ElementTriplet struct { + key1 fr.Element + key2 fr.Element + key2Present bool + value fr.Element + used bool +} + +func (t *ElementTriplet) CmpKey(o *ElementTriplet) int { + if cmp1 := t.key1.Cmp(&o.key1); cmp1 != 0 { + return cmp1 + } + + if t.key2Present { + if o.key2Present { + return t.key2.Cmp(&o.key2) + } + return 1 + } else { + if o.key2Present { + return -1 + } + return 0 + } +} + +var MapCache = make(map[string]*ElementMap) + +func ElementMapFromFile(path string) (*ElementMap, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if h, ok := MapCache[path]; ok { + return h, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var asMap map[string]interface{} + if err = json.Unmarshal(bytes, &asMap); err != nil { + return nil, err + } + + var h ElementMap + if h, err = CreateElementMap(asMap); err == nil { + MapCache[path] = &h + } + + return &h, err + + } else { + return nil, err + } +} + +func CreateElementMap(rawMap map[string]interface{}) (ElementMap, error) { + res := make(ElementMap, 0, len(rawMap)) + + for k, v := range rawMap { + var entry ElementTriplet + if _, err := SetElement(&entry.value, v); err != nil { + return nil, err + } + + key := strings.Split(k, ",") + switch len(key) { + case 1: + entry.key2Present = false + case 2: + entry.key2Present = true + if _, err := SetElement(&entry.key2, key[1]); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("cannot parse %T as one or two field elements", v) + } + if _, err := SetElement(&entry.key1, key[0]); err != nil { + return nil, err + } + + res = append(res, &entry) + } + + res.sort() + return res, nil +} + +type ElementMap []*ElementTriplet + +type MapHash struct { + Map *ElementMap + state fr.Element + stateValid bool +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func (m *MapHash) Write(p []byte) (n int, err error) { + var x fr.Element + for i := 0; i < len(p); i += fr.Bytes { + x.SetBytes(p[i:min(len(p), i+fr.Bytes)]) + if err = m.write(x); err != nil { + return + } + } + n = len(p) + return +} + +func (m *MapHash) Sum(b []byte) []byte { + mP := *m + if _, err := mP.Write(b); err != nil { + panic(err) + } + bytes := mP.state.Bytes() + return bytes[:] +} + +func (m *MapHash) Reset() { + m.stateValid = false +} + +func (m *MapHash) Size() int { + return fr.Bytes +} + +func (m *MapHash) BlockSize() int { + return fr.Bytes +} + +func (m *MapHash) write(x fr.Element) error { + X := &x + Y := &m.state + if !m.stateValid { + Y = nil + } + var err error + if m.state, err = m.Map.FindPair(X, Y); err == nil { + m.stateValid = true + } + return err +} + +func (t *ElementTriplet) writeKey(sb *strings.Builder) { + sb.WriteRune('"') + sb.WriteString(t.key1.String()) + if t.key2Present { + sb.WriteRune(',') + sb.WriteString(t.key2.String()) + } + sb.WriteRune('"') +} +func (m *ElementMap) UnusedEntries() []interface{} { + unused := make([]interface{}, 0) + for _, v := range *m { + if !v.used { + var vInterface interface{} + if v.key2Present { + vInterface = []interface{}{ElementToInterface(&v.key1), ElementToInterface(&v.key2)} + } else { + vInterface = ElementToInterface(&v.key1) + } + unused = append(unused, vInterface) + } + } + return unused +} + +func (m *ElementMap) sort() { + sort.Slice(*m, func(i, j int) bool { + return (*m)[i].CmpKey((*m)[j]) <= 0 + }) +} + +func (m *ElementMap) find(toFind *ElementTriplet) (fr.Element, error) { + i := sort.Search(len(*m), func(i int) bool { return (*m)[i].CmpKey(toFind) >= 0 }) + + if i < len(*m) && (*m)[i].CmpKey(toFind) == 0 { + (*m)[i].used = true + return (*m)[i].value, nil + } + var sb strings.Builder + sb.WriteString("no value available for input ") + toFind.writeKey(&sb) + return fr.Element{}, fmt.Errorf(sb.String()) +} + +func (m *ElementMap) FindPair(x *fr.Element, y *fr.Element) (fr.Element, error) { + + toFind := ElementTriplet{ + key1: *x, + key2Present: y != nil, + } + + if y != nil { + toFind.key2 = *y + } + + return m.find(&toFind) +} + +func ToElement(i int64) *fr.Element { + var res fr.Element + res.SetInt64(i) + return &res +} + +type MessageCounter struct { + startState uint64 + state uint64 + step uint64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/fr.Bytes + 1 + m.state += uint64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/fr.Bytes + 1 + resI := m.state + uint64(inputBlockSize)*m.step + var res fr.Element + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return fr.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return fr.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: uint64(startState), state: uint64(startState), step: uint64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []fr.Element + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return fr.Bytes +} + +func (h *ListHash) BlockSize() int { + return fr.Bytes +} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} + +func SliceToElementSlice[T any](slice []T) ([]fr.Element, error) { + elementSlice := make([]fr.Element, len(slice)) + for i, v := range slice { + if _, err := SetElement(&elementSlice[i], v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []fr.Element, b []fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]fr.Element, b [][]fr.Element) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *fr.Element) interface{} { + text := x.Text(10) + if len(text) < 10 && !strings.Contains(text, "/") { + if i, err := strconv.Atoi(text); err != nil { + panic(err.Error()) + } else { + return i + } + } + return text +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(fr.Element) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/ecc/bw6-761/g1.go b/ecc/bw6-761/g1.go index 0263285ab..7457b4fdb 100644 --- a/ecc/bw6-761/g1.go +++ b/ecc/bw6-761/g1.go @@ -17,13 +17,12 @@ package bw6761 import ( - "math/big" - "runtime" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-761/fp" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" + "runtime" ) // G1Affine point in affine coordinates @@ -36,7 +35,7 @@ type G1Jac struct { X, Y, Z fp.Element } -// g1JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +// g1JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) type g1JacExtended struct { X, Y, ZZ, ZZZ fp.Element } @@ -55,6 +54,13 @@ func (p *G1Affine) Set(a *G1Affine) *G1Affine { return p } +// setInfinity sets p to O +func (p *G1Affine) setInfinity() *G1Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + // ScalarMultiplication computes and returns p = a ⋅ s func (p *G1Affine) ScalarMultiplication(a *G1Affine, s *big.Int) *G1Affine { var _p G1Jac @@ -378,6 +384,7 @@ func (p *G1Jac) IsOnCurve() bool { } // IsInSubGroup returns true if p is on the r-torsion, false otherwise. + // Z[r,0]+Z[-lambdaG1Affine, 1] is the kernel // of (u,v)->u+lambdaG1Affinev mod r. Expressing r, lambdaG1Affine as // polynomials in x, a short vector of this Zmodule is @@ -481,8 +488,8 @@ func (p *G1Jac) mulGLV(a *G1Jac, s *big.Int) *G1Jac { // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // we don't target constant-timeness so we check first if we increase the bounds or not maxBit := k1.BitLen() @@ -521,6 +528,7 @@ func (p *G1Affine) ClearCofactor(a *G1Affine) *G1Affine { // ClearCofactor maps a point in E(Fp) to E(Fp)[r] func (p *G1Jac) ClearCofactor(a *G1Jac) *G1Jac { + // https://eprint.iacr.org/2020/351.pdf var points [4]G1Jac points[0].Set(a) @@ -623,15 +631,15 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { return p } - var A, B, X1ZZ2, X2ZZ1, Y1ZZZ2, Y2ZZZ1 fp.Element + var A, B, U1, U2, S1, S2 fp.Element // p2: q, p1: p - X2ZZ1.Mul(&q.X, &p.ZZ) - X1ZZ2.Mul(&p.X, &q.ZZ) - A.Sub(&X2ZZ1, &X1ZZ2) - Y2ZZZ1.Mul(&q.Y, &p.ZZZ) - Y1ZZZ2.Mul(&p.Y, &q.ZZZ) - B.Sub(&Y2ZZZ1, &Y1ZZZ2) + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) if A.IsZero() { if B.IsZero() { @@ -643,11 +651,7 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { return p } - var U1, U2, S1, S2, P, R, PP, PPP, Q, V fp.Element - U1.Mul(&p.X, &q.ZZ) - U2.Mul(&q.X, &p.ZZ) - S1.Mul(&p.Y, &q.ZZZ) - S2.Mul(&q.Y, &p.ZZZ) + var P, R, PP, PPP, Q, V fp.Element P.Sub(&U2, &U1) R.Sub(&S2, &S1) PP.Square(&P) @@ -672,6 +676,8 @@ func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { // double point in Jacobian extended coordinates // http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well func (p *g1JacExtended) double(q *g1JacExtended) *g1JacExtended { var U, V, W, S, XX, M fp.Element @@ -989,95 +995,72 @@ func BatchJacobianToAffineG1(points []G1Jac) []G1Affine { // and return resulting points in affine coordinates // uses a simple windowed-NAF like exponentiation algorithm func BatchScalarMultiplicationG1(base *G1Affine, scalars []fr.Element) []G1Affine { - // approximate cost in group ops is // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) nbPoints := uint64(len(scalars)) min := ^uint64(0) bestC := 0 - for c := 2; c < 18; c++ { - cost := uint64(1 << (c - 1)) - nbChunks := uint64(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - cost += nbPoints * ((fr.Limbs * 64) + nbChunks) + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c - 1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add if cost < min { min = cost bestC = c } } c := uint64(bestC) // window size - nbChunks := int(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c } - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // precompute all powers of base for our window // note here that if performance is critical, we can implement as in the msmX methods // this allocation to be on the stack - baseTable := make([]G1Jac, (1 << (c - 1))) - baseTable[0].Set(&g1Infinity) - baseTable[0].AddMixed(base) + baseTable := make([]G1Jac, (1 << (maxC - 1))) + baseTable[0].FromAffine(base) for i := 1; i < len(baseTable); i++ { baseTable[i] = baseTable[i-1] baseTable[i].AddMixed(base) } - - pScalars, _ := partitionScalars(scalars, c, false, runtime.NumCPU()) - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := 0; chunk < nbChunks; chunk++ { - jc := uint64(uint64(chunk) * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = (64%c) != 0 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } // convert our base exp table into affine to use AddMixed baseTableAff := BatchJacobianToAffineG1(baseTable) toReturn := make([]G1Jac, len(scalars)) + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + // for each digit, take value in the base table, double it c time, voilà. - parallel.Execute(len(pScalars), func(start, end int) { + parallel.Execute(len(scalars), func(start, end int) { var p G1Jac for i := start; i < end; i++ { p.Set(&g1Infinity) for chunk := nbChunks - 1; chunk >= 0; chunk-- { - s := selectors[chunk] if chunk != nbChunks-1 { for j := uint64(0); j < c; j++ { p.DoubleAssign() } } + offset := chunk * len(scalars) + digit := digits[i+offset] - bits := (pScalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (pScalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { + if digit == 0 { continue } // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { + if digit&1 == 0 { // add - p.AddMixed(&baseTableAff[bits-1]) + p.AddMixed(&baseTableAff[(digit>>1)-1]) } else { // sub - t := baseTableAff[bits & ^msbWindow] + t := baseTableAff[digit>>1] t.Neg(&t) p.AddMixed(&t) } @@ -1091,3 +1074,54 @@ func BatchScalarMultiplicationG1(base *G1Affine, scalars []fr.Element) []G1Affin toReturnAff := BatchJacobianToAffineG1(toReturn) return toReturnAff } + +// batch add affine coordinates +// using batch inversion +// special cases (doubling, infinity) must be filtered out before this call +func batchAddG1Affine[TP pG1Affine, TPP ppG1Affine, TC cG1Affine](R *TPP, P *TP, batchSize int) { + var lambda, lambdain TC + + // add part + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // invert denominator using montgomery batch invert technique + { + var accumulator fp.Element + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var d fp.Element + var rr G1Affine + + // add part + for j := 0; j < batchSize; j++ { + // computa lambda + d.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &d) + + // compute X, Y + rr.X.Square(&lambda[j]) + rr.X.Sub(&rr.X, &(*R)[j].X) + rr.X.Sub(&rr.X, &(*P)[j].X) + d.Sub(&(*R)[j].X, &rr.X) + rr.Y.Mul(&lambda[j], &d) + rr.Y.Sub(&rr.Y, &(*R)[j].Y) + (*R)[j].Set(&rr) + } +} diff --git a/ecc/bw6-761/g1_test.go b/ecc/bw6-761/g1_test.go index 6ace718ac..5a3edef3d 100644 --- a/ecc/bw6-761/g1_test.go +++ b/ecc/bw6-761/g1_test.go @@ -19,6 +19,7 @@ package bw6761 import ( "fmt" "math/big" + "math/rand" "testing" "github.com/consensys/gnark-crypto/ecc/bw6-761/fp" @@ -338,7 +339,7 @@ func TestG1AffineOps(t *testing.T) { r := fr.Modulus() var g G1Jac - g.mulGLV(&g1Gen, r) + g.ScalarMultiplication(&g1Gen, r) var scalar, blindedScalar, rminusone big.Int var op1, op2, op3, gneg G1Jac @@ -458,8 +459,7 @@ func TestG1AffineBatchScalarMultiplication(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } result := BatchScalarMultiplicationG1(&g1GenAff, sampleScalars[:]) @@ -472,7 +472,7 @@ func TestG1AffineBatchScalarMultiplication(t *testing.T) { var expectedJac G1Jac var expected G1Affine var b big.Int - expectedJac.mulGLV(&g1Gen, sampleScalars[i].ToBigInt(&b)) + expectedJac.ScalarMultiplication(&g1Gen, sampleScalars[i].ToBigIntRegular(&b)) expected.FromJacobian(&expectedJac) if !result[i].Equal(&expected) { return false @@ -499,6 +499,33 @@ func BenchmarkG1JacIsInSubGroup(b *testing.B) { } +func BenchmarkBatchAddG1Affine(b *testing.B) { + + var P, R pG1AffineC16 + var RR ppG1AffineC16 + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBasesG1(P[:]) + fillBenchBasesG1(R[:]) + + for i := 0; i < len(ridx); i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAddG1Affine[pG1AffineC16, ppG1AffineC16, cG1AffineC16](&RR, &P, len(P)) + } +} + func BenchmarkG1AffineBatchScalarMultiplication(b *testing.B) { // ensure every words of the scalars are filled var mixer fr.Element @@ -511,8 +538,7 @@ func BenchmarkG1AffineBatchScalarMultiplication(b *testing.B) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } for i := 5; i <= pow; i++ { diff --git a/ecc/bw6-761/g2.go b/ecc/bw6-761/g2.go index a56f3dfeb..9687efef3 100644 --- a/ecc/bw6-761/g2.go +++ b/ecc/bw6-761/g2.go @@ -17,13 +17,12 @@ package bw6761 import ( - "math/big" - "runtime" - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-761/fp" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" "github.com/consensys/gnark-crypto/internal/parallel" + "math/big" + "runtime" ) // G2Affine point in affine coordinates @@ -36,7 +35,7 @@ type G2Jac struct { X, Y, Z fp.Element } -// g2JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +// g2JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) type g2JacExtended struct { X, Y, ZZ, ZZZ fp.Element } @@ -50,6 +49,13 @@ func (p *G2Affine) Set(a *G2Affine) *G2Affine { return p } +// setInfinity sets p to O +func (p *G2Affine) setInfinity() *G2Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + // ScalarMultiplication computes and returns p = a ⋅ s func (p *G2Affine) ScalarMultiplication(a *G2Affine, s *big.Int) *G2Affine { var _p G2Jac @@ -365,6 +371,7 @@ func (p *G2Jac) IsOnCurve() bool { } // IsInSubGroup returns true if p is on the r-torsion, false otherwise. + // Z[r,0]+Z[-lambdaG2Affine, 1] is the kernel // of (u,v)->u+lambdaG2Affinev mod r. Expressing r, lambdaG2Affine as // polynomials in x, a short vector of this Zmodule is @@ -468,8 +475,8 @@ func (p *G2Jac) mulGLV(a *G2Jac, s *big.Int) *G2Jac { // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // we don't target constant-timeness so we check first if we increase the bounds or not maxBit := k1.BitLen() @@ -508,7 +515,6 @@ func (p *G2Affine) ClearCofactor(a *G2Affine) *G2Affine { // ClearCofactor maps a point in curve to r-torsion func (p *G2Jac) ClearCofactor(a *G2Jac) *G2Jac { - var points [4]G2Jac points[0].Set(a) points[1].ScalarMultiplication(a, &xGen) @@ -610,15 +616,15 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { return p } - var A, B, X1ZZ2, X2ZZ1, Y1ZZZ2, Y2ZZZ1 fp.Element + var A, B, U1, U2, S1, S2 fp.Element // p2: q, p1: p - X2ZZ1.Mul(&q.X, &p.ZZ) - X1ZZ2.Mul(&p.X, &q.ZZ) - A.Sub(&X2ZZ1, &X1ZZ2) - Y2ZZZ1.Mul(&q.Y, &p.ZZZ) - Y1ZZZ2.Mul(&p.Y, &q.ZZZ) - B.Sub(&Y2ZZZ1, &Y1ZZZ2) + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) if A.IsZero() { if B.IsZero() { @@ -630,11 +636,7 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { return p } - var U1, U2, S1, S2, P, R, PP, PPP, Q, V fp.Element - U1.Mul(&p.X, &q.ZZ) - U2.Mul(&q.X, &p.ZZ) - S1.Mul(&p.Y, &q.ZZZ) - S2.Mul(&q.Y, &p.ZZZ) + var P, R, PP, PPP, Q, V fp.Element P.Sub(&U2, &U1) R.Sub(&S2, &S1) PP.Square(&P) @@ -659,6 +661,8 @@ func (p *g2JacExtended) add(q *g2JacExtended) *g2JacExtended { // double point in Jacobian extended coordinates // http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well func (p *g2JacExtended) double(q *g2JacExtended) *g2JacExtended { var U, V, W, S, XX, M fp.Element @@ -852,93 +856,70 @@ func (p *g2JacExtended) doubleMixed(q *G2Affine) *g2JacExtended { // and return resulting points in affine coordinates // uses a simple windowed-NAF like exponentiation algorithm func BatchScalarMultiplicationG2(base *G2Affine, scalars []fr.Element) []G2Affine { - // approximate cost in group ops is // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) nbPoints := uint64(len(scalars)) min := ^uint64(0) bestC := 0 - for c := 2; c < 18; c++ { - cost := uint64(1 << (c - 1)) - nbChunks := uint64(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - cost += nbPoints * ((fr.Limbs * 64) + nbChunks) + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c - 1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add if cost < min { min = cost bestC = c } } c := uint64(bestC) // window size - nbChunks := int(fr.Limbs * 64 / c) - if (fr.Limbs*64)%c != 0 { - nbChunks++ + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c } - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // precompute all powers of base for our window // note here that if performance is critical, we can implement as in the msmX methods // this allocation to be on the stack - baseTable := make([]G2Jac, (1 << (c - 1))) - baseTable[0].Set(&g2Infinity) - baseTable[0].AddMixed(base) + baseTable := make([]G2Jac, (1 << (maxC - 1))) + baseTable[0].FromAffine(base) for i := 1; i < len(baseTable); i++ { baseTable[i] = baseTable[i-1] baseTable[i].AddMixed(base) } - - pScalars, _ := partitionScalars(scalars, c, false, runtime.NumCPU()) - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := 0; chunk < nbChunks; chunk++ { - jc := uint64(uint64(chunk) * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = (64%c) != 0 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } toReturn := make([]G2Affine, len(scalars)) + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + // for each digit, take value in the base table, double it c time, voilà. - parallel.Execute(len(pScalars), func(start, end int) { + parallel.Execute(len(scalars), func(start, end int) { var p G2Jac for i := start; i < end; i++ { p.Set(&g2Infinity) for chunk := nbChunks - 1; chunk >= 0; chunk-- { - s := selectors[chunk] if chunk != nbChunks-1 { for j := uint64(0); j < c; j++ { p.DoubleAssign() } } + offset := chunk * len(scalars) + digit := digits[i+offset] - bits := (pScalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (pScalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { + if digit == 0 { continue } // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { + if digit&1 == 0 { // add - p.AddAssign(&baseTable[bits-1]) + p.AddAssign(&baseTable[(digit>>1)-1]) } else { // sub - t := baseTable[bits & ^msbWindow] + t := baseTable[digit>>1] t.Neg(&t) p.AddAssign(&t) } @@ -951,3 +932,54 @@ func BatchScalarMultiplicationG2(base *G2Affine, scalars []fr.Element) []G2Affin }) return toReturn } + +// batch add affine coordinates +// using batch inversion +// special cases (doubling, infinity) must be filtered out before this call +func batchAddG2Affine[TP pG2Affine, TPP ppG2Affine, TC cG2Affine](R *TPP, P *TP, batchSize int) { + var lambda, lambdain TC + + // add part + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // invert denominator using montgomery batch invert technique + { + var accumulator fp.Element + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var d fp.Element + var rr G2Affine + + // add part + for j := 0; j < batchSize; j++ { + // computa lambda + d.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &d) + + // compute X, Y + rr.X.Square(&lambda[j]) + rr.X.Sub(&rr.X, &(*R)[j].X) + rr.X.Sub(&rr.X, &(*P)[j].X) + d.Sub(&(*R)[j].X, &rr.X) + rr.Y.Mul(&lambda[j], &d) + rr.Y.Sub(&rr.Y, &(*R)[j].Y) + (*R)[j].Set(&rr) + } +} diff --git a/ecc/bw6-761/g2_test.go b/ecc/bw6-761/g2_test.go index 9630dbf17..8c3f321e6 100644 --- a/ecc/bw6-761/g2_test.go +++ b/ecc/bw6-761/g2_test.go @@ -19,6 +19,7 @@ package bw6761 import ( "fmt" "math/big" + "math/rand" "testing" "github.com/consensys/gnark-crypto/ecc/bw6-761/fp" @@ -325,7 +326,7 @@ func TestG2AffineOps(t *testing.T) { r := fr.Modulus() var g G2Jac - g.mulGLV(&g2Gen, r) + g.ScalarMultiplication(&g2Gen, r) var scalar, blindedScalar, rminusone big.Int var op1, op2, op3, gneg G2Jac @@ -445,8 +446,7 @@ func TestG2AffineBatchScalarMultiplication(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } result := BatchScalarMultiplicationG2(&g2GenAff, sampleScalars[:]) @@ -459,7 +459,7 @@ func TestG2AffineBatchScalarMultiplication(t *testing.T) { var expectedJac G2Jac var expected G2Affine var b big.Int - expectedJac.mulGLV(&g2Gen, sampleScalars[i].ToBigInt(&b)) + expectedJac.ScalarMultiplication(&g2Gen, sampleScalars[i].ToBigIntRegular(&b)) expected.FromJacobian(&expectedJac) if !result[i].Equal(&expected) { return false @@ -486,6 +486,33 @@ func BenchmarkG2JacIsInSubGroup(b *testing.B) { } +func BenchmarkBatchAddG2Affine(b *testing.B) { + + var P, R pG2AffineC16 + var RR ppG2AffineC16 + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBasesG2(P[:]) + fillBenchBasesG2(R[:]) + + for i := 0; i < len(ridx); i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAddG2Affine[pG2AffineC16, ppG2AffineC16, cG2AffineC16](&RR, &P, len(P)) + } +} + func BenchmarkG2AffineBatchScalarMultiplication(b *testing.B) { // ensure every words of the scalars are filled var mixer fr.Element @@ -498,8 +525,7 @@ func BenchmarkG2AffineBatchScalarMultiplication(b *testing.B) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } for i := 5; i <= pow; i++ { diff --git a/ecc/bw6-761/hash_to_g1.go b/ecc/bw6-761/hash_to_g1.go index 716a516b9..a26f10d83 100644 --- a/ecc/bw6-761/hash_to_g1.go +++ b/ecc/bw6-761/hash_to_g1.go @@ -17,7 +17,6 @@ package bw6761 import ( - "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-761/fp" "math/big" @@ -219,35 +218,14 @@ func g1EvalPolynomial(z *fp.Element, monic bool, coefficients []fp.Element, x *f z.Set(&dst) } -// hashToFp hashes msg to count prime field elements. -// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 -func hashToFp(msg, dst []byte, count int) ([]fp.Element, error) { - // 128 bits of security - // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 - const Bytes = 1 + (fp.Bits-1)/8 - const L = 16 + Bytes - - lenInBytes := count * L - pseudoRandomBytes, err := ecc.ExpandMsgXmd(msg, dst, lenInBytes) - if err != nil { - return nil, err - } - - res := make([]fp.Element, count) - for i := 0; i < count; i++ { - res[i].SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) - } - return res, nil -} - // g1Sgn0 is an algebraic substitute for the notion of sign in ordered fields // Namely, every non-zero quadratic residue in a finite field of characteristic =/= 2 has exactly two square roots, one of each sign // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#name-the-sgn0-function // The sign of an element is not obviously related to that of its Montgomery form func g1Sgn0(z *fp.Element) uint64 { - nonMont := *z - nonMont.FromMont() + nonMont := z.Bits() + // m == 1 return nonMont[0] % 2 @@ -265,11 +243,11 @@ func MapToG1(u fp.Element) G1Affine { // EncodeToG1 hashes a message to a point on the G1 curve using the SSWU map. // It is faster than HashToG1, but the result is not uniformly distributed. Unsuitable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func EncodeToG1(msg, dst []byte) (G1Affine, error) { var res G1Affine - u, err := hashToFp(msg, dst, 1) + u, err := fp.Hash(msg, dst, 1) if err != nil { return res, err } @@ -285,9 +263,9 @@ func EncodeToG1(msg, dst []byte) (G1Affine, error) { // HashToG1 hashes a message to a point on the G1 curve using the SSWU map. // Slower than EncodeToG1, but usable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func HashToG1(msg, dst []byte) (G1Affine, error) { - u, err := hashToFp(msg, dst, 2*1) + u, err := fp.Hash(msg, dst, 2*1) if err != nil { return G1Affine{}, err } diff --git a/ecc/bw6-761/hash_to_g1_test.go b/ecc/bw6-761/hash_to_g1_test.go index e8f3ecbed..a268f8b0e 100644 --- a/ecc/bw6-761/hash_to_g1_test.go +++ b/ecc/bw6-761/hash_to_g1_test.go @@ -62,7 +62,7 @@ func TestG1SqrtRatio(t *testing.T) { func TestHashToFpG1(t *testing.T) { for _, c := range encodeToG1Vector.cases { - elems, err := hashToFp([]byte(c.msg), encodeToG1Vector.dst, 1) + elems, err := fp.Hash([]byte(c.msg), encodeToG1Vector.dst, 1) if err != nil { t.Error(err) } @@ -70,7 +70,7 @@ func TestHashToFpG1(t *testing.T) { } for _, c := range hashToG1Vector.cases { - elems, err := hashToFp([]byte(c.msg), hashToG1Vector.dst, 2*1) + elems, err := fp.Hash([]byte(c.msg), hashToG1Vector.dst, 2*1) if err != nil { t.Error(err) } @@ -220,7 +220,7 @@ func BenchmarkHashToG1(b *testing.B) { } } -//TODO: Crude. Do something clever in Jacobian +// TODO: Crude. Do something clever in Jacobian func isOnE1Prime(p G1Affine) bool { var A, B fp.Element @@ -247,7 +247,7 @@ func isOnE1Prime(p G1Affine) bool { return LHS.Equal(&RHS) } -//Only works on simple extensions (two-story towers) +// Only works on simple extensions (two-story towers) func g1CoordSetString(z *fp.Element, s string) { z.SetString(s) } diff --git a/ecc/bw6-761/hash_to_g2.go b/ecc/bw6-761/hash_to_g2.go index bacb36d66..f0bc4678e 100644 --- a/ecc/bw6-761/hash_to_g2.go +++ b/ecc/bw6-761/hash_to_g2.go @@ -400,8 +400,8 @@ func g2EvalPolynomial(z *fp.Element, monic bool, coefficients []fp.Element, x *f // The sign of an element is not obviously related to that of its Montgomery form func g2Sgn0(z *fp.Element) uint64 { - nonMont := *z - nonMont.FromMont() + nonMont := z.Bits() + // m == 1 return nonMont[0] % 2 @@ -419,11 +419,11 @@ func MapToG2(u fp.Element) G2Affine { // EncodeToG2 hashes a message to a point on the G2 curve using the SSWU map. // It is faster than HashToG2, but the result is not uniformly distributed. Unsuitable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func EncodeToG2(msg, dst []byte) (G2Affine, error) { var res G2Affine - u, err := hashToFp(msg, dst, 1) + u, err := fp.Hash(msg, dst, 1) if err != nil { return res, err } @@ -439,9 +439,9 @@ func EncodeToG2(msg, dst []byte) (G2Affine, error) { // HashToG2 hashes a message to a point on the G2 curve using the SSWU map. // Slower than EncodeToG2, but usable as a random oracle. // dst stands for "domain separation tag", a string unique to the construction using the hash function -//https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func HashToG2(msg, dst []byte) (G2Affine, error) { - u, err := hashToFp(msg, dst, 2*1) + u, err := fp.Hash(msg, dst, 2*1) if err != nil { return G2Affine{}, err } diff --git a/ecc/bw6-761/hash_to_g2_test.go b/ecc/bw6-761/hash_to_g2_test.go index 05e5008d4..246016b4c 100644 --- a/ecc/bw6-761/hash_to_g2_test.go +++ b/ecc/bw6-761/hash_to_g2_test.go @@ -62,7 +62,7 @@ func TestG2SqrtRatio(t *testing.T) { func TestHashToFpG2(t *testing.T) { for _, c := range encodeToG2Vector.cases { - elems, err := hashToFp([]byte(c.msg), encodeToG2Vector.dst, 1) + elems, err := fp.Hash([]byte(c.msg), encodeToG2Vector.dst, 1) if err != nil { t.Error(err) } @@ -70,7 +70,7 @@ func TestHashToFpG2(t *testing.T) { } for _, c := range hashToG2Vector.cases { - elems, err := hashToFp([]byte(c.msg), hashToG2Vector.dst, 2*1) + elems, err := fp.Hash([]byte(c.msg), hashToG2Vector.dst, 2*1) if err != nil { t.Error(err) } @@ -220,7 +220,7 @@ func BenchmarkHashToG2(b *testing.B) { } } -//TODO: Crude. Do something clever in Jacobian +// TODO: Crude. Do something clever in Jacobian func isOnE2Prime(p G2Affine) bool { var A, B fp.Element @@ -247,7 +247,7 @@ func isOnE2Prime(p G2Affine) bool { return LHS.Equal(&RHS) } -//Only works on simple extensions (two-story towers) +// Only works on simple extensions (two-story towers) func g2CoordSetString(z *fp.Element, s string) { z.SetString(s) } diff --git a/ecc/bw6-761/internal/fptower/e3.go b/ecc/bw6-761/internal/fptower/e3.go index c8e1b9a76..5f2a5d537 100644 --- a/ecc/bw6-761/internal/fptower/e3.go +++ b/ecc/bw6-761/internal/fptower/e3.go @@ -83,6 +83,10 @@ func (z *E3) IsZero() bool { return z.A0.IsZero() && z.A1.IsZero() && z.A2.IsZero() } +func (z *E3) IsOne() bool { + return z.A0.IsOne() && z.A1.IsZero() && z.A2.IsZero() +} + // Neg negates the E3 number func (z *E3) Neg(x *E3) *E3 { z.A0.Neg(&x.A0) @@ -91,22 +95,6 @@ func (z *E3) Neg(x *E3) *E3 { return z } -// ToMont converts to Mont form -func (z *E3) ToMont() *E3 { - z.A0.ToMont() - z.A1.ToMont() - z.A2.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E3) FromMont() *E3 { - z.A0.FromMont() - z.A1.FromMont() - z.A2.FromMont() - return z -} - // Add adds two elements of E3 func (z *E3) Add(x, y *E3) *E3 { z.A0.Add(&x.A0, &y.A0) diff --git a/ecc/bw6-761/internal/fptower/e6.go b/ecc/bw6-761/internal/fptower/e6.go index a23797234..12167e26c 100644 --- a/ecc/bw6-761/internal/fptower/e6.go +++ b/ecc/bw6-761/internal/fptower/e6.go @@ -67,20 +67,6 @@ func (z *E6) SetOne() *E6 { return z } -// ToMont converts to Mont form -func (z *E6) ToMont() *E6 { - z.B0.ToMont() - z.B1.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E6) FromMont() *E6 { - z.B0.FromMont() - z.B1.FromMont() - return z -} - // Add set z=x+y in E6 and return z func (z *E6) Add(x, y *E6) *E6 { z.B0.Add(&x.B0, &y.B0) @@ -118,6 +104,10 @@ func (z *E6) IsZero() bool { return z.B0.IsZero() && z.B1.IsZero() } +func (z *E6) IsOne() bool { + return z.B0.IsOne() && z.B1.IsZero() +} + // Mul set z=x*y in E6 and return z func (z *E6) Mul(x, y *E6) *E6 { var a, b, c E3 @@ -225,9 +215,12 @@ func (z *E6) CyclotomicSquareCompressed(x *E6) *E6 { // DecompressKarabina Karabina's cyclotomic square result // if g3 != 0 -// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// +// g4 = (E * g5^2 + 3 * g1^2 - 2 * g2)/4g3 +// // if g3 == 0 -// g4 = 2g1g5/g2 +// +// g4 = 2g1g5/g2 // // if g3=g2=0 then g4=g5=g1=0 and g0=1 (x=1) // Theorem 3.1 is well-defined for all x in Gϕₙ\{1} @@ -252,7 +245,7 @@ func (z *E6) DecompressKarabina(x *E6) *E6 { t[1].Sub(&t[0], &x.B0.A2). Double(&t[1]). Add(&t[1], &t[0]) - // t0 = E * g5^2 + t1 + // t0 = E * g5^2 + t1 t[2].Square(&x.B1.A2) t[0].MulByNonResidue(&t[2]). Add(&t[0], &t[1]) @@ -525,8 +518,8 @@ func (z *E6) ExpGLV(x E6, k *big.Int) *E6 { table[14].Mul(&table[11], &table[2]) // bounds on the lattice base vectors guarantee that s1, s2 are len(r)/2 bits long max - s1.SetBigInt(&s[0]).FromMont() - s2.SetBigInt(&s[1]).FromMont() + s1 = s1.SetBigInt(&s[0]).Bits() + s2 = s2.SetBigInt(&s[1]).Bits() // loop starts from len(s1)/2 due to the bounds for i := len(s1) / 2; i >= 0; i-- { diff --git a/ecc/bw6-761/marshal.go b/ecc/bw6-761/marshal.go index d70849b92..4def5de67 100644 --- a/ecc/bw6-761/marshal.go +++ b/ecc/bw6-761/marshal.go @@ -100,7 +100,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fr.Bytes]) + err = t.SetBytesCanonical(buf[:fr.Bytes]) return case *fp.Element: read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) @@ -108,7 +108,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fp.Bytes]) + err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: var sliceLen uint32 @@ -126,7 +126,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fr.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { + return + } } return case *[]fp.Element: @@ -145,7 +147,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fp.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return + } } return case *G1Affine: @@ -221,7 +225,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -276,7 +284,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -643,9 +655,6 @@ func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - msbMask := mCompressedSmallest // compressed, we need to know if Y is lexicographically bigger than -Y // if p.Y ">" -p.Y @@ -654,20 +663,7 @@ func (p *G1Affine) Bytes() (res [SizeOfG1AffineCompressed]byte) { } // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[88:96], tmp[0]) - binary.BigEndian.PutUint64(res[80:88], tmp[1]) - binary.BigEndian.PutUint64(res[72:80], tmp[2]) - binary.BigEndian.PutUint64(res[64:72], tmp[3]) - binary.BigEndian.PutUint64(res[56:64], tmp[4]) - binary.BigEndian.PutUint64(res[48:56], tmp[5]) - binary.BigEndian.PutUint64(res[40:48], tmp[6]) - binary.BigEndian.PutUint64(res[32:40], tmp[7]) - binary.BigEndian.PutUint64(res[24:32], tmp[8]) - binary.BigEndian.PutUint64(res[16:24], tmp[9]) - binary.BigEndian.PutUint64(res[8:16], tmp[10]) - binary.BigEndian.PutUint64(res[0:8], tmp[11]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= msbMask @@ -686,41 +682,12 @@ func (p *G1Affine) RawBytes() (res [SizeOfG1AffineUncompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - // not compressed // we store the Y coordinate - tmp = p.Y - tmp.FromMont() - binary.BigEndian.PutUint64(res[184:192], tmp[0]) - binary.BigEndian.PutUint64(res[176:184], tmp[1]) - binary.BigEndian.PutUint64(res[168:176], tmp[2]) - binary.BigEndian.PutUint64(res[160:168], tmp[3]) - binary.BigEndian.PutUint64(res[152:160], tmp[4]) - binary.BigEndian.PutUint64(res[144:152], tmp[5]) - binary.BigEndian.PutUint64(res[136:144], tmp[6]) - binary.BigEndian.PutUint64(res[128:136], tmp[7]) - binary.BigEndian.PutUint64(res[120:128], tmp[8]) - binary.BigEndian.PutUint64(res[112:120], tmp[9]) - binary.BigEndian.PutUint64(res[104:112], tmp[10]) - binary.BigEndian.PutUint64(res[96:104], tmp[11]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[96:96+fp.Bytes]), p.Y) // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[88:96], tmp[0]) - binary.BigEndian.PutUint64(res[80:88], tmp[1]) - binary.BigEndian.PutUint64(res[72:80], tmp[2]) - binary.BigEndian.PutUint64(res[64:72], tmp[3]) - binary.BigEndian.PutUint64(res[56:64], tmp[4]) - binary.BigEndian.PutUint64(res[48:56], tmp[5]) - binary.BigEndian.PutUint64(res[40:48], tmp[6]) - binary.BigEndian.PutUint64(res[32:40], tmp[7]) - binary.BigEndian.PutUint64(res[24:32], tmp[8]) - binary.BigEndian.PutUint64(res[16:24], tmp[9]) - binary.BigEndian.PutUint64(res[8:16], tmp[10]) - binary.BigEndian.PutUint64(res[0:8], tmp[11]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= mUncompressed @@ -771,8 +738,12 @@ func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { // uncompressed point if mData == mUncompressed { // read X and Y coordinates - p.X.SetBytes(buf[:fp.Bytes]) - p.Y.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.Y.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } // subgroup check if subGroupCheck && !p.IsInSubGroup() { @@ -792,7 +763,9 @@ func (p *G1Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return 0, err + } var YSquared, Y fp.Element @@ -866,7 +839,7 @@ func (p *G1Affine) unsafeComputeY(subGroupCheck bool) error { // assumes buf[:8] mask is set to compressed // returns true if point is infinity and need no further processing // it sets X coordinate and uses Y for scratch space to store decompression metadata -func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { +func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { // read the most significant byte mData := buf[0] & mMask @@ -875,7 +848,7 @@ func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { p.X.SetZero() p.Y.SetZero() isInfinity = true - return + return isInfinity, nil } // we need to copy the input buffer (to keep this method thread safe) @@ -884,12 +857,14 @@ func (p *G1Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return false, err + } // store mData in p.Y[0] p.Y[0] = uint64(mData) // recomputing Y will be done asynchronously - return + return isInfinity, nil } // SizeOfG2AffineCompressed represents the size in bytes that a G2Affine need in binary form, compressed @@ -927,9 +902,6 @@ func (p *G2Affine) Bytes() (res [SizeOfG2AffineCompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - msbMask := mCompressedSmallest // compressed, we need to know if Y is lexicographically bigger than -Y // if p.Y ">" -p.Y @@ -938,20 +910,7 @@ func (p *G2Affine) Bytes() (res [SizeOfG2AffineCompressed]byte) { } // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[88:96], tmp[0]) - binary.BigEndian.PutUint64(res[80:88], tmp[1]) - binary.BigEndian.PutUint64(res[72:80], tmp[2]) - binary.BigEndian.PutUint64(res[64:72], tmp[3]) - binary.BigEndian.PutUint64(res[56:64], tmp[4]) - binary.BigEndian.PutUint64(res[48:56], tmp[5]) - binary.BigEndian.PutUint64(res[40:48], tmp[6]) - binary.BigEndian.PutUint64(res[32:40], tmp[7]) - binary.BigEndian.PutUint64(res[24:32], tmp[8]) - binary.BigEndian.PutUint64(res[16:24], tmp[9]) - binary.BigEndian.PutUint64(res[8:16], tmp[10]) - binary.BigEndian.PutUint64(res[0:8], tmp[11]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= msbMask @@ -970,41 +929,12 @@ func (p *G2Affine) RawBytes() (res [SizeOfG2AffineUncompressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element - // not compressed // we store the Y coordinate - tmp = p.Y - tmp.FromMont() - binary.BigEndian.PutUint64(res[184:192], tmp[0]) - binary.BigEndian.PutUint64(res[176:184], tmp[1]) - binary.BigEndian.PutUint64(res[168:176], tmp[2]) - binary.BigEndian.PutUint64(res[160:168], tmp[3]) - binary.BigEndian.PutUint64(res[152:160], tmp[4]) - binary.BigEndian.PutUint64(res[144:152], tmp[5]) - binary.BigEndian.PutUint64(res[136:144], tmp[6]) - binary.BigEndian.PutUint64(res[128:136], tmp[7]) - binary.BigEndian.PutUint64(res[120:128], tmp[8]) - binary.BigEndian.PutUint64(res[112:120], tmp[9]) - binary.BigEndian.PutUint64(res[104:112], tmp[10]) - binary.BigEndian.PutUint64(res[96:104], tmp[11]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[96:96+fp.Bytes]), p.Y) // we store X and mask the most significant word with our metadata mask - tmp = p.X - tmp.FromMont() - binary.BigEndian.PutUint64(res[88:96], tmp[0]) - binary.BigEndian.PutUint64(res[80:88], tmp[1]) - binary.BigEndian.PutUint64(res[72:80], tmp[2]) - binary.BigEndian.PutUint64(res[64:72], tmp[3]) - binary.BigEndian.PutUint64(res[56:64], tmp[4]) - binary.BigEndian.PutUint64(res[48:56], tmp[5]) - binary.BigEndian.PutUint64(res[40:48], tmp[6]) - binary.BigEndian.PutUint64(res[32:40], tmp[7]) - binary.BigEndian.PutUint64(res[24:32], tmp[8]) - binary.BigEndian.PutUint64(res[16:24], tmp[9]) - binary.BigEndian.PutUint64(res[8:16], tmp[10]) - binary.BigEndian.PutUint64(res[0:8], tmp[11]) + fp.BigEndian.PutElement((*[fp.Bytes]byte)(res[0:0+fp.Bytes]), p.X) res[0] |= mUncompressed @@ -1055,8 +985,12 @@ func (p *G2Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { // uncompressed point if mData == mUncompressed { // read X and Y coordinates - p.X.SetBytes(buf[:fp.Bytes]) - p.Y.SetBytes(buf[fp.Bytes : fp.Bytes*2]) + if err := p.X.SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.Y.SetBytesCanonical(buf[fp.Bytes : fp.Bytes*2]); err != nil { + return 0, err + } // subgroup check if subGroupCheck && !p.IsInSubGroup() { @@ -1076,7 +1010,9 @@ func (p *G2Affine) setBytes(buf []byte, subGroupCheck bool) (int, error) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return 0, err + } var YSquared, Y fp.Element @@ -1150,7 +1086,7 @@ func (p *G2Affine) unsafeComputeY(subGroupCheck bool) error { // assumes buf[:8] mask is set to compressed // returns true if point is infinity and need no further processing // it sets X coordinate and uses Y for scratch space to store decompression metadata -func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { +func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { // read the most significant byte mData := buf[0] & mMask @@ -1159,7 +1095,7 @@ func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { p.X.SetZero() p.Y.SetZero() isInfinity = true - return + return isInfinity, nil } // we need to copy the input buffer (to keep this method thread safe) @@ -1168,10 +1104,12 @@ func (p *G2Affine) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { bufX[0] &= ^mMask // read X coordinate - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return false, err + } // store mData in p.Y[0] p.Y[0] = uint64(mData) // recomputing Y will be done asynchronously - return + return isInfinity, nil } diff --git a/ecc/bw6-761/multiexp.go b/ecc/bw6-761/multiexp.go index e9cef54ee..3f39f3c6e 100644 --- a/ecc/bw6-761/multiexp.go +++ b/ecc/bw6-761/multiexp.go @@ -25,143 +25,6 @@ import ( "runtime" ) -// selector stores the index, mask and shifts needed to select bits from a scalar -// it is used during the multiExp algorithm or the batch scalar multiplication -type selector struct { - index uint64 // index in the multi-word scalar to select bits from - mask uint64 // mask (c-bit wide) - shift uint64 // shift needed to get our bits on low positions - - multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) - maskHigh uint64 // same than mask, for index+1 - shiftHigh uint64 // same than shift, for index+1 -} - -// partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits -// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract -// 2^{c} to the current digit, making it negative. -// negative digits can be processed in a later step as adding -G into the bucket instead of G -// (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) -// scalarsMont indicates wheter the provided scalars are in montgomery form -// returns smallValues, which represent the number of scalars which meets the following condition -// 0 < scalar < 2^c (in other words, scalars where only the c-least significant bits are non zero) -func partitionScalars(scalars []fr.Element, c uint64, scalarsMont bool, nbTasks int) ([]fr.Element, int) { - toReturn := make([]fr.Element, len(scalars)) - - // number of c-bit radixes in a scalar - nbChunks := fr.Limbs * 64 / c - if (fr.Limbs*64)%c != 0 { - nbChunks++ - } - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) // msb of the c-bit window - max := int(1 << (c - 1)) // max value we want for our digits - cDivides64 := (64 % c) == 0 // if c doesn't divide 64, we may need to select over multiple words - - // compute offset and word selector / shift to select the right bits of our windows - selectors := make([]selector, nbChunks) - for chunk := uint64(0); chunk < nbChunks; chunk++ { - jc := uint64(chunk * c) - d := selector{} - d.index = jc / 64 - d.shift = jc - (d.index * 64) - d.mask = mask << d.shift - d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs-1) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } - - // for each chunk, we could track the number of non-zeros points we will need to process - // this way, if a chunk has more work to do than others, we can spawn off more go routines - // (at the cost of more buckets allocated) - // a simplified approach is to track the small values where only the first word is set - // if this number represent a significant number of points, then we will split first chunk - // processing in the msm in 2, to ensure all go routines finish at ~same time - // /!\ nbTasks is enough as parallel.Execute is not going to spawn more than nbTasks go routine - // if it does, though, this will deadlocK. - chSmallValues := make(chan int, nbTasks) - - parallel.Execute(len(scalars), func(start, end int) { - smallValues := 0 - for i := start; i < end; i++ { - var carry int - - scalar := scalars[i] - if scalarsMont { - scalar.FromMont() - } - if scalar.FitsOnOneWord() { - // everything is 0, no need to process this scalar - if scalar[0] == 0 { - continue - } - // low c-bits are 1 in mask - if scalar[0]&mask == scalar[0] { - smallValues++ - } - } - - // for each chunk in the scalar, compute the current digit, and an eventual carry - for chunk := uint64(0); chunk < nbChunks; chunk++ { - s := selectors[chunk] - - // init with carry if any - digit := carry - carry = 0 - - // digit = value of the c-bit window - digit += int((scalar[s.index] & s.mask) >> s.shift) - - if s.multiWordSelect { - // we are selecting bits over 2 words - digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh - } - - // if digit is zero, no impact on result - if digit == 0 { - continue - } - - // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract - // 2^{c} to the current digit, making it negative. - if digit >= max { - digit -= (1 << c) - carry = 1 - } - - var bits uint64 - if digit >= 0 { - bits = uint64(digit) - } else { - bits = uint64(-digit-1) | msbWindow - } - - toReturn[i][s.index] |= (bits << s.shift) - if s.multiWordSelect { - toReturn[i][s.index+1] |= (bits >> s.shiftHigh) - } - - } - } - - chSmallValues <- smallValues - - }, nbTasks) - - // aggregate small values - close(chSmallValues) - smallValues := 0 - for o := range chSmallValues { - smallValues += o - } - return toReturn, smallValues -} - // MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf // // This call return an error if len(scalars) != len(points) or if provided config is invalid. @@ -221,7 +84,7 @@ func (p *G1Jac) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.Mul // we split recursively until nbChunks(c) >= nbTasks, bestC := func(nbPoints int) uint64 { // implemented msmC methods (the c we use must be in this slice) - implementedCs := []uint64{4, 5, 8, 16} + implementedCs := []uint64{4, 5, 8, 10, 16} var C uint64 // approximate cost (in group operations) // cost = bits/c * (nbPoints + 2^{c}) @@ -229,85 +92,128 @@ func (p *G1Jac) MultiExp(points []G1Affine, scalars []fr.Element, config ecc.Mul // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results min := math.MaxFloat64 for _, c := range implementedCs { - cc := fr.Limbs * 64 * (nbPoints + (1 << (c))) + cc := (fr.Bits + 1) * (nbPoints + (1 << c)) cost := float64(cc) / float64(c) if cost < min { min = cost C = c } } - // empirical, needs to be tuned. - // if C > 16 && nbPoints < 1 << 23 { - // C = 16 - // } return C } - var C uint64 - nbSplits := 1 - nbChunks := 0 - for nbChunks < config.NbTasks { - C = bestC(nbPoints) - nbChunks = int(fr.Limbs * 64 / C) // number of c-bit radixes in a scalar - if (fr.Limbs*64)%C != 0 { - nbChunks++ - } - nbChunks *= nbSplits - if nbChunks < config.NbTasks { - nbSplits <<= 1 - nbPoints >>= 1 + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // if we don't utilise all the tasks (CPU in the default case) that we could, let's see if it's worth it to split + if config.NbTasks > 1 && nbChunks < config.NbTasks { + // before spliting, let's see if we endup with more tasks than thread; + cSplit := bestC(nbPoints / 2) + nbChunksPostSplit := int(computeNbChunks(cSplit)) + nbTasksPostSplit := nbChunksPostSplit * 2 + if (nbTasksPostSplit <= config.NbTasks/2) || (nbTasksPostSplit-config.NbTasks/2) <= (config.NbTasks-nbChunks) { + // if postSplit we still have less tasks than available CPU + // or if we have more tasks BUT the difference of CPU usage is in our favor, we split. + config.NbTasks /= 2 + var _p G1Jac + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil } } + _innerMsmG1(p, C, points, scalars, config) + + return p, nil +} + +func _innerMsmG1(p *G1Jac, c uint64, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { // partition the scalars - // note: we do that before the actual chunk processing, as for each c-bit window (starting from LSW) - // if it's larger than 2^{c-1}, we have a carry we need to propagate up to the higher window - var smallValues int - scalars, smallValues = partitionScalars(scalars, C, config.ScalarsMont, config.NbTasks) - - // if we have more than 10% of small values, we split the processing of the first chunk in 2 - // we may want to do that in msmInnerG1Jac , but that would incur a cost of looping through all scalars one more time - splitFirstChunk := (float64(smallValues) / float64(len(scalars))) >= 0.1 - - // we have nbSplits intermediate results that we must sum together. - _p := make([]G1Jac, nbSplits-1) - chDone := make(chan int, nbSplits-1) - for i := 0; i < nbSplits-1; i++ { - start := i * nbPoints - end := start + nbPoints - go func(start, end, i int) { - msmInnerG1Jac(&_p[i], int(C), points[start:end], scalars[start:end], splitFirstChunk) - chDone <- i - }(start, end, i) + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) } - msmInnerG1Jac(p, int(C), points[(nbSplits-1)*nbPoints:], scalars[(nbSplits-1)*nbPoints:], splitFirstChunk) - for i := 0; i < nbSplits-1; i++ { - done := <-chDone - p.AddAssign(&_p[done]) + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessorG1(c, chunkStats[j]) + if j == int(nbChunks-1) { + processChunk = getChunkProcessorG1(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan g1JacExtended, 2) + split := n / 2 + go processChunk(uint64(j), chSplit, c, points[:split], digits[j*n:(j*n)+split]) + go processChunk(uint64(j), chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n]) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n]) } - close(chDone) - return p, nil -} -func msmInnerG1Jac(p *G1Jac, c int, points []G1Affine, scalars []fr.Element, splitFirstChunk bool) { + return msmReduceChunkG1Affine(p, int(c), chChunks[:]) +} +// getChunkProcessorG1 decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessorG1(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- g1JacExtended, c uint64, points []G1Affine, digits []uint16) { switch c { + case 2: + return processChunkG1Jacobian[bucketg1JacExtendedC2] + case 3: + return processChunkG1Jacobian[bucketg1JacExtendedC3] case 4: - p.msmC4(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC4] case 5: - p.msmC5(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC5] case 8: - p.msmC8(points, scalars, splitFirstChunk) - + return processChunkG1Jacobian[bucketg1JacExtendedC8] + case 10: + const batchSize = 80 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC10] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC10, bucketG1AffineC10, bitSetC10, pG1AffineC10, ppG1AffineC10, qG1AffineC10, cG1AffineC10] case 16: - p.msmC16(points, scalars, splitFirstChunk) - + const batchSize = 640 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG1Jacobian[bucketg1JacExtendedC16] + } + return processChunkG1BatchAffine[bucketg1JacExtendedC16, bucketG1AffineC16, bitSetC16, pG1AffineC16, ppG1AffineC16, qG1AffineC16, cG1AffineC16] default: - panic("not implemented") + // panic("will not happen c != previous values is not generated by templates") + return processChunkG1Jacobian[bucketg1JacExtendedC16] } } @@ -327,257 +233,6 @@ func msmReduceChunkG1Affine(p *G1Jac, c int, chChunks []chan g1JacExtended) *G1J return p.unsafeFromJacExtended(&_p) } -func msmProcessChunkG1Affine(chunk uint64, - chRes chan<- g1JacExtended, - buckets []g1JacExtended, - c uint64, - points []G1Affine, - scalars []fr.Element) { - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) - - for i := 0; i < len(buckets); i++ { - buckets[i].setInfinity() - } - - jc := uint64(chunk * c) - s := selector{} - s.index = jc / 64 - s.shift = jc - (s.index * 64) - s.mask = mask << s.shift - s.multiWordSelect = (64%c) != 0 && s.shift > (64-c) && s.index < (fr.Limbs-1) - if s.multiWordSelect { - nbBitsHigh := s.shift - uint64(64-c) - s.maskHigh = (1 << nbBitsHigh) - 1 - s.shiftHigh = (c - nbBitsHigh) - } - - // for each scalars, get the digit corresponding to the chunk we're processing. - for i := 0; i < len(scalars); i++ { - bits := (scalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (scalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { - continue - } - - // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { - // add - buckets[bits-1].addMixed(&points[i]) - } else { - // sub - buckets[bits & ^msbWindow].subMixed(&points[i]) - } - } - - // reduce buckets into total - // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] - - var runningSum, total g1JacExtended - runningSum.setInfinity() - total.setInfinity() - for k := len(buckets) - 1; k >= 0; k-- { - if !buckets[k].ZZ.IsZero() { - runningSum.add(&buckets[k]) - } - total.add(&runningSum) - } - - chRes <- total - -} - -func (p *G1Jac) msmC4(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 4 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC5(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 5 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - // c doesn't divide 384, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G1Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g1JacExtended - msmProcessChunkG1Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC8(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 8 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - -func (p *G1Jac) msmC16(points []G1Affine, scalars []fr.Element, splitFirstChunk bool) *G1Jac { - const ( - c = 16 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g1JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g1JacExtended, 1) - } - - processChunk := func(j int, points []G1Affine, scalars []fr.Element, chChunk chan g1JacExtended) { - var buckets [1 << (c - 1)]g1JacExtended - msmProcessChunkG1Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g1JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - return msmReduceChunkG1Affine(p, c, chChunks[:]) -} - // MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf // // This call return an error if len(scalars) != len(points) or if provided config is invalid. @@ -637,7 +292,7 @@ func (p *G2Jac) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.Mul // we split recursively until nbChunks(c) >= nbTasks, bestC := func(nbPoints int) uint64 { // implemented msmC methods (the c we use must be in this slice) - implementedCs := []uint64{4, 5, 8, 16} + implementedCs := []uint64{4, 5, 8, 10, 16} var C uint64 // approximate cost (in group operations) // cost = bits/c * (nbPoints + 2^{c}) @@ -645,85 +300,128 @@ func (p *G2Jac) MultiExp(points []G2Affine, scalars []fr.Element, config ecc.Mul // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results min := math.MaxFloat64 for _, c := range implementedCs { - cc := fr.Limbs * 64 * (nbPoints + (1 << (c))) + cc := (fr.Bits + 1) * (nbPoints + (1 << c)) cost := float64(cc) / float64(c) if cost < min { min = cost C = c } } - // empirical, needs to be tuned. - // if C > 16 && nbPoints < 1 << 23 { - // C = 16 - // } return C } - var C uint64 - nbSplits := 1 - nbChunks := 0 - for nbChunks < config.NbTasks { - C = bestC(nbPoints) - nbChunks = int(fr.Limbs * 64 / C) // number of c-bit radixes in a scalar - if (fr.Limbs*64)%C != 0 { - nbChunks++ - } - nbChunks *= nbSplits - if nbChunks < config.NbTasks { - nbSplits <<= 1 - nbPoints >>= 1 + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // if we don't utilise all the tasks (CPU in the default case) that we could, let's see if it's worth it to split + if config.NbTasks > 1 && nbChunks < config.NbTasks { + // before spliting, let's see if we endup with more tasks than thread; + cSplit := bestC(nbPoints / 2) + nbChunksPostSplit := int(computeNbChunks(cSplit)) + nbTasksPostSplit := nbChunksPostSplit * 2 + if (nbTasksPostSplit <= config.NbTasks/2) || (nbTasksPostSplit-config.NbTasks/2) <= (config.NbTasks-nbChunks) { + // if postSplit we still have less tasks than available CPU + // or if we have more tasks BUT the difference of CPU usage is in our favor, we split. + config.NbTasks /= 2 + var _p G2Jac + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil } } + _innerMsmG2(p, C, points, scalars, config) + + return p, nil +} + +func _innerMsmG2(p *G2Jac, c uint64, points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G2Jac { // partition the scalars - // note: we do that before the actual chunk processing, as for each c-bit window (starting from LSW) - // if it's larger than 2^{c-1}, we have a carry we need to propagate up to the higher window - var smallValues int - scalars, smallValues = partitionScalars(scalars, C, config.ScalarsMont, config.NbTasks) - - // if we have more than 10% of small values, we split the processing of the first chunk in 2 - // we may want to do that in msmInnerG2Jac , but that would incur a cost of looping through all scalars one more time - splitFirstChunk := (float64(smallValues) / float64(len(scalars))) >= 0.1 - - // we have nbSplits intermediate results that we must sum together. - _p := make([]G2Jac, nbSplits-1) - chDone := make(chan int, nbSplits-1) - for i := 0; i < nbSplits-1; i++ { - start := i * nbPoints - end := start + nbPoints - go func(start, end, i int) { - msmInnerG2Jac(&_p[i], int(C), points[start:end], scalars[start:end], splitFirstChunk) - chDone <- i - }(start, end, i) + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g2JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g2JacExtended, 1) } - msmInnerG2Jac(p, int(C), points[(nbSplits-1)*nbPoints:], scalars[(nbSplits-1)*nbPoints:], splitFirstChunk) - for i := 0; i < nbSplits-1; i++ { - done := <-chDone - p.AddAssign(&_p[done]) + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessorG2(c, chunkStats[j]) + if j == int(nbChunks-1) { + processChunk = getChunkProcessorG2(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan g2JacExtended, 2) + split := n / 2 + go processChunk(uint64(j), chSplit, c, points[:split], digits[j*n:(j*n)+split]) + go processChunk(uint64(j), chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n]) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n]) } - close(chDone) - return p, nil -} -func msmInnerG2Jac(p *G2Jac, c int, points []G2Affine, scalars []fr.Element, splitFirstChunk bool) { + return msmReduceChunkG2Affine(p, int(c), chChunks[:]) +} +// getChunkProcessorG2 decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessorG2(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- g2JacExtended, c uint64, points []G2Affine, digits []uint16) { switch c { + case 2: + return processChunkG2Jacobian[bucketg2JacExtendedC2] + case 3: + return processChunkG2Jacobian[bucketg2JacExtendedC3] case 4: - p.msmC4(points, scalars, splitFirstChunk) - + return processChunkG2Jacobian[bucketg2JacExtendedC4] case 5: - p.msmC5(points, scalars, splitFirstChunk) - + return processChunkG2Jacobian[bucketg2JacExtendedC5] case 8: - p.msmC8(points, scalars, splitFirstChunk) - + return processChunkG2Jacobian[bucketg2JacExtendedC8] + case 10: + const batchSize = 80 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC10] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC10, bucketG2AffineC10, bitSetC10, pG2AffineC10, ppG2AffineC10, qG2AffineC10, cG2AffineC10] case 16: - p.msmC16(points, scalars, splitFirstChunk) - + const batchSize = 640 + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunkG2Jacobian[bucketg2JacExtendedC16] + } + return processChunkG2BatchAffine[bucketg2JacExtendedC16, bucketG2AffineC16, bitSetC16, pG2AffineC16, ppG2AffineC16, qG2AffineC16, cG2AffineC16] default: - panic("not implemented") + // panic("will not happen c != previous values is not generated by templates") + return processChunkG2Jacobian[bucketg2JacExtendedC16] } } @@ -743,253 +441,188 @@ func msmReduceChunkG2Affine(p *G2Jac, c int, chChunks []chan g2JacExtended) *G2J return p.unsafeFromJacExtended(&_p) } -func msmProcessChunkG2Affine(chunk uint64, - chRes chan<- g2JacExtended, - buckets []g2JacExtended, - c uint64, - points []G2Affine, - scalars []fr.Element) { - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c - 1)) - - for i := 0; i < len(buckets); i++ { - buckets[i].setInfinity() - } - - jc := uint64(chunk * c) - s := selector{} - s.index = jc / 64 - s.shift = jc - (s.index * 64) - s.mask = mask << s.shift - s.multiWordSelect = (64%c) != 0 && s.shift > (64-c) && s.index < (fr.Limbs-1) - if s.multiWordSelect { - nbBitsHigh := s.shift - uint64(64-c) - s.maskHigh = (1 << nbBitsHigh) - 1 - s.shiftHigh = (c - nbBitsHigh) - } - - // for each scalars, get the digit corresponding to the chunk we're processing. - for i := 0; i < len(scalars); i++ { - bits := (scalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (scalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { - continue - } - - // if msbWindow bit is set, we need to substract - if bits&msbWindow == 0 { - // add - buckets[bits-1].addMixed(&points[i]) - } else { - // sub - buckets[bits & ^msbWindow].subMixed(&points[i]) - } - } - - // reduce buckets into total - // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] - - var runningSum, total g2JacExtended - runningSum.setInfinity() - total.setInfinity() - for k := len(buckets) - 1; k >= 0; k-- { - if !buckets[k].ZZ.IsZero() { - runningSum.add(&buckets[k]) - } - total.add(&runningSum) - } - - chRes <- total +// selector stores the index, mask and shifts needed to select bits from a scalar +// it is used during the multiExp algorithm or the batch scalar multiplication +type selector struct { + index uint64 // index in the multi-word scalar to select bits from + mask uint64 // mask (c-bit wide) + shift uint64 // shift needed to get our bits on low positions + multiWordSelect bool // set to true if we need to select bits from 2 words (case where c doesn't divide 64) + maskHigh uint64 // same than mask, for index+1 + shiftHigh uint64 // same than shift, for index+1 } -func (p *G2Jac) msmC4(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 4 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } - - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } +// return number of chunks for a given window size c +// the last chunk may be bigger to accomodate a potential carry from the NAF decomposition +func computeNbChunks(c uint64) uint64 { + return (fr.Bits + c - 1) / c +} - return msmReduceChunkG2Affine(p, c, chChunks[:]) +// return the last window size for a scalar; +// this last window should accomodate a carry (from the NAF decomposition) +// it can be == c if we have 1 available bit +// it can be > c if we have 0 available bit +// it can be < c if we have 2+ available bits +func lastC(c uint64) uint64 { + nbAvailableBits := (computeNbChunks(c) * c) - fr.Bits + return c + 1 - nbAvailableBits } -func (p *G2Jac) msmC5(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 5 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) +type chunkStat struct { + // relative weight of work compared to other chunks. 100.0 -> nominal weight. + weight float32 - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + // percentage of bucket filled in the window; + ppBucketFilled float32 + nbBucketFilled int +} - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks + 1]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } +// partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits +// if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract +// 2^{c} to the current digit, making it negative. +// negative digits can be processed in a later step as adding -G into the bucket instead of G +// (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) +func partitionScalars(scalars []fr.Element, c uint64, nbTasks int) ([]uint16, []chunkStat) { + // number of c-bit radixes in a scalar + nbChunks := computeNbChunks(c) - // c doesn't divide 384, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []G2Affine, scalars []fr.Element) { - var buckets [1 << (lastC - 1)]g2JacExtended - msmProcessChunkG2Affine(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) + digits := make([]uint16, len(scalars)*int(nbChunks)) - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + mask := uint64((1 << c) - 1) // low c bits are 1 + max := int(1<<(c-1)) - 1 // max value (inclusive) we want for our digits + cDivides64 := (64 % c) == 0 // if c doesn't divide 64, we may need to select over multiple words - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) + // compute offset and word selector / shift to select the right bits of our windows + selectors := make([]selector, nbChunks) + for chunk := uint64(0); chunk < nbChunks; chunk++ { + jc := uint64(chunk * c) + d := selector{} + d.index = jc / 64 + d.shift = jc - (d.index * 64) + d.mask = mask << d.shift + d.multiWordSelect = !cDivides64 && d.shift > (64-c) && d.index < (fr.Limbs-1) + if d.multiWordSelect { + nbBitsHigh := d.shift - uint64(64-c) + d.maskHigh = (1 << nbBitsHigh) - 1 + d.shiftHigh = (c - nbBitsHigh) + } + selectors[chunk] = d } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } + parallel.Execute(len(scalars), func(start, end int) { + for i := start; i < end; i++ { + if scalars[i].IsZero() { + // everything is 0, no need to process this scalar + continue + } + scalar := scalars[i].Bits() - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} + var carry int -func (p *G2Jac) msmC8(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 8 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) + // for each chunk in the scalar, compute the current digit, and an eventual carry + for chunk := uint64(0); chunk < nbChunks-1; chunk++ { + s := selectors[chunk] - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + // init with carry if any + digit := carry + carry = 0 - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) - } + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } + // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract + // 2^{c} to the current digit, making it negative. + if digit > max { + digit -= (1 << c) + carry = 1 + } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } + // if digit is zero, no impact on result + if digit == 0 { + continue + } - return msmReduceChunkG2Affine(p, c, chChunks[:]) -} + var bits uint16 + if digit > 0 { + bits = uint16(digit) << 1 + } else { + bits = (uint16(-digit-1) << 1) + 1 + } + digits[int(chunk)*len(scalars)+i] = bits + } -func (p *G2Jac) msmC16(points []G2Affine, scalars []fr.Element, splitFirstChunk bool) *G2Jac { - const ( - c = 16 // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) + // for the last chunk, we don't want to borrow from a next window + // (but may have a larger max value) + chunk := nbChunks - 1 + s := selectors[chunk] + // init with carry if any + digit := carry + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1]&s.maskHigh) << s.shiftHigh + } + digits[int(chunk)*len(scalars)+i] = uint16(digit) << 1 + } - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance + }, nbTasks) - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks]chan g2JacExtended - for i := 0; i < len(chChunks); i++ { - chChunks[i] = make(chan g2JacExtended, 1) + // aggregate chunk stats + chunkStats := make([]chunkStat, nbChunks) + if c <= 9 { + // no need to compute stats for small window sizes + return digits, chunkStats } + parallel.Execute(len(chunkStats), func(start, end int) { + // for each chunk compute the statistics + for chunkID := start; chunkID < end; chunkID++ { + // indicates if a bucket is hit. + var b bitSetC16 - processChunk := func(j int, points []G2Affine, scalars []fr.Element, chChunk chan g2JacExtended) { - var buckets [1 << (c - 1)]g2JacExtended - msmProcessChunkG2Affine(uint64(j), chChunk, buckets[:], c, points, scalars) - } + // digits for the chunk + chunkDigits := digits[chunkID*len(scalars) : (chunkID+1)*len(scalars)] - for j := int(nbChunks - 1); j > 0; j-- { - go processChunk(j, points, scalars, chChunks[j]) + totalOps := 0 + nz := 0 // non zero buckets count + for _, digit := range chunkDigits { + if digit == 0 { + continue + } + totalOps++ + bucketID := digit >> 1 + if digit&1 == 0 { + bucketID -= 1 + } + if !b[bucketID] { + nz++ + b[bucketID] = true + } + } + chunkStats[chunkID].weight = float32(totalOps) // count number of ops for now, we will compute the weight after + chunkStats[chunkID].ppBucketFilled = (float32(nz) * 100.0) / float32(int(1<<(c-1))) + chunkStats[chunkID].nbBucketFilled = nz + } + }, nbTasks) + + totalOps := float32(0.0) + for _, stat := range chunkStats { + totalOps += stat.weight } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan g2JacExtended, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() + target := totalOps / float32(nbChunks) + if target != 0.0 { + // if target == 0, it means all the scalars are 0 everywhere, there is no work to be done. + for i := 0; i < len(chunkStats); i++ { + chunkStats[i].weight = (chunkStats[i].weight * 100.0) / target + } } - return msmReduceChunkG2Affine(p, c, chChunks[:]) + return digits, chunkStats } diff --git a/ecc/bw6-761/multiexp_affine.go b/ecc/bw6-761/multiexp_affine.go new file mode 100644 index 000000000..5f423838c --- /dev/null +++ b/ecc/bw6-761/multiexp_affine.go @@ -0,0 +1,551 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package bw6761 + +import ( + "github.com/consensys/gnark-crypto/ecc/bw6-761/fp" +) + +type batchOpG1Affine struct { + bucketID uint16 + point G1Affine +} + +// processChunkG1BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/ConsenSys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunkG1BatchAffine[BJE ibg1JacExtended, B ibG1Affine, BS bitSet, TP pG1Affine, TPP ppG1Affine, TQ qOpsG1Affine, TC cG1Affine]( + chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16) { + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // g1JacExtended coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediatly + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in G1Affine used with the batch affine additions + // 1 in g1JacExtended used in case the queue of conflicting points + var buckets B + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + bucketsJE[i].setInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize } + + executeAndReset := func() { + batchAddG1Affine[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOpG1Affine) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.setInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *G1Affine, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.setInfinity() + } + return + } + if isAdd { + BK.setInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func() { + for i := 0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func() { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit >> 1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID -= 1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue)-1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total g1JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].ZZ.IsZero() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketG1AffineC10 [512]G1Affine +type bucketG1AffineC16 [32768]G1Affine + +// buckets: array of G1Affine points of size 1 << (c-1) +type ibG1Affine interface { + bucketG1AffineC10 | + bucketG1AffineC16 +} + +// array of coordinates fp.Element +type cG1Affine interface { + cG1AffineC10 | + cG1AffineC16 +} + +// buckets: array of G1Affine points (for the batch addition) +type pG1Affine interface { + pG1AffineC10 | + pG1AffineC16 +} + +// buckets: array of *G1Affine points (for the batch addition) +type ppG1Affine interface { + ppG1AffineC10 | + ppG1AffineC16 +} + +// buckets: array of G1Affine queue operations (for the batch addition) +type qOpsG1Affine interface { + qG1AffineC10 | + qG1AffineC16 +} + +// batch size 80 when c = 10 +type cG1AffineC10 [80]fp.Element +type pG1AffineC10 [80]G1Affine +type ppG1AffineC10 [80]*G1Affine +type qG1AffineC10 [80]batchOpG1Affine + +// batch size 640 when c = 16 +type cG1AffineC16 [640]fp.Element +type pG1AffineC16 [640]G1Affine +type ppG1AffineC16 [640]*G1Affine +type qG1AffineC16 [640]batchOpG1Affine + +type batchOpG2Affine struct { + bucketID uint16 + point G2Affine +} + +// processChunkG2BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/ConsenSys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunkG2BatchAffine[BJE ibg2JacExtended, B ibG2Affine, BS bitSet, TP pG2Affine, TPP ppG2Affine, TQ qOpsG2Affine, TC cG2Affine]( + chunk uint64, + chRes chan<- g2JacExtended, + c uint64, + points []G2Affine, + digits []uint16) { + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // g2JacExtended coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediatly + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in G2Affine used with the batch affine additions + // 1 in g2JacExtended used in case the queue of conflicting points + var buckets B + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + bucketsJE[i].setInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize } + + executeAndReset := func() { + batchAddG2Affine[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOpG2Affine) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.setInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *G2Affine, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.setInfinity() + } + return + } + if isAdd { + BK.setInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func() { + for i := 0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func() { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit >> 1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID -= 1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue)-1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total g2JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].ZZ.IsZero() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketG2AffineC10 [512]G2Affine +type bucketG2AffineC16 [32768]G2Affine + +// buckets: array of G2Affine points of size 1 << (c-1) +type ibG2Affine interface { + bucketG2AffineC10 | + bucketG2AffineC16 +} + +// array of coordinates fp.Element +type cG2Affine interface { + cG2AffineC10 | + cG2AffineC16 +} + +// buckets: array of G2Affine points (for the batch addition) +type pG2Affine interface { + pG2AffineC10 | + pG2AffineC16 +} + +// buckets: array of *G2Affine points (for the batch addition) +type ppG2Affine interface { + ppG2AffineC10 | + ppG2AffineC16 +} + +// buckets: array of G2Affine queue operations (for the batch addition) +type qOpsG2Affine interface { + qG2AffineC10 | + qG2AffineC16 +} + +// batch size 80 when c = 10 +type cG2AffineC10 [80]fp.Element +type pG2AffineC10 [80]G2Affine +type ppG2AffineC10 [80]*G2Affine +type qG2AffineC10 [80]batchOpG2Affine + +// batch size 640 when c = 16 +type cG2AffineC16 [640]fp.Element +type pG2AffineC16 [640]G2Affine +type ppG2AffineC16 [640]*G2Affine +type qG2AffineC16 [640]batchOpG2Affine + +type bitSetC2 [2]bool +type bitSetC3 [4]bool +type bitSetC4 [8]bool +type bitSetC5 [16]bool +type bitSetC8 [128]bool +type bitSetC10 [512]bool +type bitSetC16 [32768]bool + +type bitSet interface { + bitSetC2 | + bitSetC3 | + bitSetC4 | + bitSetC5 | + bitSetC8 | + bitSetC10 | + bitSetC16 +} diff --git a/ecc/bw6-761/multiexp_jacobian.go b/ecc/bw6-761/multiexp_jacobian.go new file mode 100644 index 000000000..ca6b1610d --- /dev/null +++ b/ecc/bw6-761/multiexp_jacobian.go @@ -0,0 +1,143 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package bw6761 + +func processChunkG1Jacobian[B ibg1JacExtended](chunk uint64, + chRes chan<- g1JacExtended, + c uint64, + points []G1Affine, + digits []uint16) { + + var buckets B + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to substract + if digit&1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit >> 1)].subMixed(&points[i]) + } + } + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total g1JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].ZZ.IsZero() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketg1JacExtendedC2 [2]g1JacExtended +type bucketg1JacExtendedC3 [4]g1JacExtended +type bucketg1JacExtendedC4 [8]g1JacExtended +type bucketg1JacExtendedC5 [16]g1JacExtended +type bucketg1JacExtendedC8 [128]g1JacExtended +type bucketg1JacExtendedC10 [512]g1JacExtended +type bucketg1JacExtendedC16 [32768]g1JacExtended + +type ibg1JacExtended interface { + bucketg1JacExtendedC2 | + bucketg1JacExtendedC3 | + bucketg1JacExtendedC4 | + bucketg1JacExtendedC5 | + bucketg1JacExtendedC8 | + bucketg1JacExtendedC10 | + bucketg1JacExtendedC16 +} + +func processChunkG2Jacobian[B ibg2JacExtended](chunk uint64, + chRes chan<- g2JacExtended, + c uint64, + points []G2Affine, + digits []uint16) { + + var buckets B + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to substract + if digit&1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit >> 1)].subMixed(&points[i]) + } + } + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total g2JacExtended + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].ZZ.IsZero() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +type bucketg2JacExtendedC2 [2]g2JacExtended +type bucketg2JacExtendedC3 [4]g2JacExtended +type bucketg2JacExtendedC4 [8]g2JacExtended +type bucketg2JacExtendedC5 [16]g2JacExtended +type bucketg2JacExtendedC8 [128]g2JacExtended +type bucketg2JacExtendedC10 [512]g2JacExtended +type bucketg2JacExtendedC16 [32768]g2JacExtended + +type ibg2JacExtended interface { + bucketg2JacExtendedC2 | + bucketg2JacExtendedC3 | + bucketg2JacExtendedC4 | + bucketg2JacExtendedC5 | + bucketg2JacExtendedC8 | + bucketg2JacExtendedC10 | + bucketg2JacExtendedC16 +} diff --git a/ecc/bw6-761/multiexp_test.go b/ecc/bw6-761/multiexp_test.go index 33a90ba3a..112cee7c8 100644 --- a/ecc/bw6-761/multiexp_test.go +++ b/ecc/bw6-761/multiexp_test.go @@ -20,9 +20,11 @@ import ( "fmt" "math/big" "math/bits" + "math/rand" "runtime" "sync" "testing" + "time" "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bw6-761/fr" @@ -34,9 +36,9 @@ func TestMultiExpG1(t *testing.T) { parameters := gopter.DefaultTestParameters() if testing.Short() { - parameters.MinSuccessfulTests = 2 + parameters.MinSuccessfulTests = 3 } else { - parameters.MinSuccessfulTests = nbFuzzShort + parameters.MinSuccessfulTests = nbFuzzShort * 2 } properties := gopter.NewProperties(parameters) @@ -55,6 +57,13 @@ func TestMultiExpG1(t *testing.T) { g.AddAssign(&g1Gen) } + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + // final scalar to use in double and add method (without mixer factor) // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) var scalar big.Int @@ -78,13 +87,10 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - scalars16, _ := partitionScalars(sampleScalars[:], 16, false, runtime.NumCPU()) - r16.msmC16(samplePoints[:], scalars16, true) - + r16.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) return r16.Equal(&splitted1) && r16.Equal(&splitted2) @@ -93,7 +99,7 @@ func TestMultiExpG1(t *testing.T) { )) // cRange is generated from template and contains the available parameters for the multiexp window size - cRange := []uint64{4, 5, 8, 16} + cRange := []uint64{2, 3, 4, 5, 8, 10, 16} if testing.Short() { // test only "odd" and "even" (ie windows size divide word size vs not) cRange = []uint64{5, 16} @@ -114,21 +120,72 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - results := make([]G1Jac, len(cRange)+1) + results := make([]G1Jac, len(cRange)) for i, c := range cRange { - scalars, _ := partitionScalars(sampleScalars[:], c, false, runtime.NumCPU()) - msmInnerG1Jac(&results[i], int(c), samplePoints[:], scalars, false) - if c == 16 { - // split the first chunk - msmInnerG1Jac(&results[len(results)-1], 16, samplePoints[:], scalars, true) - } + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) } for i := 1; i < len(results); i++ { if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1], cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]G1Affine + + var expected G1Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) + expected.ScalarMultiplication(&g1Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].setInfinity() + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G1] Multi exponentation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) return false } } @@ -151,8 +208,7 @@ func TestMultiExpG1(t *testing.T) { for i := 1; i <= 30; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) samplePoints[i-1].FromJacobian(&g) g.AddAssign(&g1Gen) } @@ -175,6 +231,87 @@ func TestMultiExpG1(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestCrossMultiExpG1(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]G1Affine + var g G1Jac + g.Set(&g1Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g1Gen) + } + + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i := 10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + cRange := []uint64{2, 3, 4, 5, 8, 10, 16} + if testing.Short() { + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange = []uint64{5, 16} + } + + results := make([]G1Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG1(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r G1Jac + _innerMsmG1Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got G1Affine + expected.FromJacobian(&r) + + for i := 0; i < len(results); i++ { + got.FromJacobian(&results[i]) + if !expected.Equal(&got) { + t.Fatalf("cross msm failed with c=%d", cRange[i]) + } + } + +} + +// _innerMsmG1Reference always do ext jacobian with c == 16 +func _innerMsmG1Reference(p *G1Jac, points []G1Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G1Jac { + // partition the scalars + digits, _ := partitionScalars(scalars, 16, config.NbTasks) + + nbChunks := computeNbChunks(16) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g1JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g1JacExtended, 1) + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := processChunkG1Jacobian[bucketg1JacExtendedC16] + go processChunk(uint64(j), chChunks[j], 16, points, digits[j*n:(j+1)*n]) + } + + return msmReduceChunkG1Affine(p, int(16), chChunks[:]) +} + func BenchmarkMultiExpG1(b *testing.B) { const ( @@ -183,11 +320,33 @@ func BenchmarkMultiExpG1(b *testing.B) { ) var ( - samplePoints [nbSamples]G1Affine - sampleScalars [nbSamples]fr.Element + samplePoints [nbSamples]G1Affine + sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element ) fillBenchScalars(sampleScalars[:]) + copy(sampleScalarsSmallValues[:], sampleScalars[:]) + copy(sampleScalarsRedundant[:], sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i := 0; i < len(sampleScalarsSmallValues); i++ { + if i%5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i := 0; i < len(sampleScalarsRedundant); i += 100 { + for j := i + 1; j < i+100 && j < len(sampleScalarsRedundant); j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + fillBenchBasesG1(samplePoints[:]) var testPoint G1Affine @@ -201,6 +360,20 @@ func BenchmarkMultiExpG1(b *testing.B) { testPoint.MultiExp(samplePoints[:using], sampleScalars[:using], ecc.MultiExpConfig{}) } }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using], ecc.MultiExpConfig{}) + } + }) } } @@ -279,9 +452,9 @@ func TestMultiExpG2(t *testing.T) { parameters := gopter.DefaultTestParameters() if testing.Short() { - parameters.MinSuccessfulTests = 2 + parameters.MinSuccessfulTests = 3 } else { - parameters.MinSuccessfulTests = nbFuzzShort + parameters.MinSuccessfulTests = nbFuzzShort * 2 } properties := gopter.NewProperties(parameters) @@ -300,6 +473,13 @@ func TestMultiExpG2(t *testing.T) { g.AddAssign(&g2Gen) } + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + // final scalar to use in double and add method (without mixer factor) // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) var scalar big.Int @@ -323,13 +503,10 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - scalars16, _ := partitionScalars(sampleScalars[:], 16, false, runtime.NumCPU()) - r16.msmC16(samplePoints[:], scalars16, true) - + r16.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) return r16.Equal(&splitted1) && r16.Equal(&splitted2) @@ -357,21 +534,72 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - results := make([]G2Jac, len(cRange)+1) + results := make([]G2Jac, len(cRange)) for i, c := range cRange { - scalars, _ := partitionScalars(sampleScalars[:], c, false, runtime.NumCPU()) - msmInnerG2Jac(&results[i], int(c), samplePoints[:], scalars, false) - if c == 16 { - // split the first chunk - msmInnerG2Jac(&results[len(results)-1], 16, samplePoints[:], scalars, true) - } + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) } for i := 1; i < len(results); i++ { if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1], cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G2] Multi exponentation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]G2Affine + + var expected G2Jac + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) + expected.ScalarMultiplication(&g2Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].setInfinity() + } + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[G2] Multi exponentation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) return false } } @@ -394,8 +622,7 @@ func TestMultiExpG2(t *testing.T) { for i := 1; i <= 30; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) samplePoints[i-1].FromJacobian(&g) g.AddAssign(&g2Gen) } @@ -418,6 +645,85 @@ func TestMultiExpG2(t *testing.T) { properties.TestingRun(t, gopter.ConsoleReporter(false)) } +func TestCrossMultiExpG2(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]G2Affine + var g G2Jac + g.Set(&g2Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&g2Gen) + } + + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i := 10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + // for g2, CI suffers with large c size since it needs to allocate a lot of memory for the buckets. + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange := []uint64{5, 16} + + results := make([]G2Jac, len(cRange)) + for i, c := range cRange { + _innerMsmG2(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r G2Jac + _innerMsmG2Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got G2Affine + expected.FromJacobian(&r) + + for i := 0; i < len(results); i++ { + got.FromJacobian(&results[i]) + if !expected.Equal(&got) { + t.Fatalf("cross msm failed with c=%d", cRange[i]) + } + } + +} + +// _innerMsmG2Reference always do ext jacobian with c == 16 +func _innerMsmG2Reference(p *G2Jac, points []G2Affine, scalars []fr.Element, config ecc.MultiExpConfig) *G2Jac { + // partition the scalars + digits, _ := partitionScalars(scalars, 16, config.NbTasks) + + nbChunks := computeNbChunks(16) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan g2JacExtended, nbChunks) + for i := 0; i < len(chChunks); i++ { + chChunks[i] = make(chan g2JacExtended, 1) + } + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := processChunkG2Jacobian[bucketg2JacExtendedC16] + go processChunk(uint64(j), chChunks[j], 16, points, digits[j*n:(j+1)*n]) + } + + return msmReduceChunkG2Affine(p, int(16), chChunks[:]) +} + func BenchmarkMultiExpG2(b *testing.B) { const ( @@ -426,11 +732,33 @@ func BenchmarkMultiExpG2(b *testing.B) { ) var ( - samplePoints [nbSamples]G2Affine - sampleScalars [nbSamples]fr.Element + samplePoints [nbSamples]G2Affine + sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element ) fillBenchScalars(sampleScalars[:]) + copy(sampleScalarsSmallValues[:], sampleScalars[:]) + copy(sampleScalarsRedundant[:], sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i := 0; i < len(sampleScalarsSmallValues); i++ { + if i%5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i := 0; i < len(sampleScalarsRedundant); i += 100 { + for j := i + 1; j < i+100 && j < len(sampleScalarsRedundant); j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + fillBenchBasesG2(samplePoints[:]) var testPoint G2Affine @@ -444,6 +772,20 @@ func BenchmarkMultiExpG2(b *testing.B) { testPoint.MultiExp(samplePoints[:using], sampleScalars[:using], ecc.MultiExpConfig{}) } }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using], ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using], ecc.MultiExpConfig{}) + } + }) } } @@ -520,11 +862,7 @@ func fillBenchBasesG2(samplePoints []G2Affine) { func fillBenchScalars(sampleScalars []fr.Element) { // ensure every words of the scalars are filled - var mixer fr.Element - mixer.SetString("7716837800905789770901243404444209691916730933998574719964609384059111546487") - for i := 1; i <= len(sampleScalars); i++ { - sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + for i := 0; i < len(sampleScalars); i++ { + sampleScalars[i].SetRandom() } } diff --git a/ecc/bw6-761/twistededwards/eddsa/doc.go b/ecc/bw6-761/twistededwards/eddsa/doc.go index fa989472f..fe9aab1d4 100644 --- a/ecc/bw6-761/twistededwards/eddsa/doc.go +++ b/ecc/bw6-761/twistededwards/eddsa/doc.go @@ -16,7 +16,7 @@ // Package eddsa provides EdDSA signature scheme on bw6-761's twisted edwards curve. // -// See also +// # See also // // https://en.wikipedia.org/wiki/EdDSA package eddsa diff --git a/ecc/bw6-761/twistededwards/eddsa/eddsa_test.go b/ecc/bw6-761/twistededwards/eddsa/eddsa_test.go index 6918f7c44..f8929c461 100644 --- a/ecc/bw6-761/twistededwards/eddsa/eddsa_test.go +++ b/ecc/bw6-761/twistededwards/eddsa/eddsa_test.go @@ -37,8 +37,10 @@ func Example() { privateKey, _ := GenerateKey(crand.Reader) publicKey := privateKey.PublicKey - // note that the message is on 4 bytes - msg := []byte{0xde, 0xad, 0xf0, 0x0d} + // generate a message (the size must be a multiple of the size of Fr) + var _msg fr.Element + _msg.SetRandom() + msg := _msg.Marshal() // sign the message signature, _ := privateKey.Sign(msg, hFunc) diff --git a/ecc/bw6-761/twistededwards/eddsa/marshal.go b/ecc/bw6-761/twistededwards/eddsa/marshal.go index c68129087..485e9b710 100644 --- a/ecc/bw6-761/twistededwards/eddsa/marshal.go +++ b/ecc/bw6-761/twistededwards/eddsa/marshal.go @@ -94,11 +94,11 @@ func (privKey *PrivateKey) SetBytes(buf []byte) (int, error) { // Bytes returns the binary representation of sig // as a byte array of size 3*sizeFr x||y||s where -// * x, y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x, y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) func (sig *Signature) Bytes() []byte { var res [sizeSignature]byte sigRBin := sig.R.Bytes() @@ -109,11 +109,12 @@ func (sig *Signature) Bytes() []byte { // SetBytes sets sig from a buffer in binary. // buf is read interpreted as x||y||s where -// * x,y are the coordinates of a point on the twisted -// Edwards represented in big endian -// * s=r+h(r,a,m) mod l, the Hasse bound guarantess that -// s is smaller than sizeFr (in particular it is supposed -// s is NOT blinded) +// - x,y are the coordinates of a point on the twisted +// Edwards represented in big endian +// - s=r+h(r,a,m) mod l, the Hasse bound guarantess that +// s is smaller than sizeFr (in particular it is supposed +// s is NOT blinded) +// // It returns the number of bytes read from buf. func (sig *Signature) SetBytes(buf []byte) (int, error) { n := 0 diff --git a/ecc/bw6-761/twistededwards/point.go b/ecc/bw6-761/twistededwards/point.go index aa332d2aa..5a494f828 100644 --- a/ecc/bw6-761/twistededwards/point.go +++ b/ecc/bw6-761/twistededwards/point.go @@ -49,7 +49,7 @@ const ( mUnmask = 0x7f // size in byte of a compressed point (point.Y --> fr.Element) - sizePointCompressed = fr.Limbs * 8 + sizePointCompressed = fr.Bytes ) // Bytes returns the compressed point as a byte array diff --git a/ecc/ecc.go b/ecc/ecc.go index 34f0c0736..6e87e14ef 100644 --- a/ecc/ecc.go +++ b/ecc/ecc.go @@ -18,12 +18,12 @@ limitations under the License. // // Also // -// * Multi exponentiation -// * FFT -// * Polynomial commitment schemes -// * MiMC -// * twisted edwards "companion curves" -// * EdDSA (on the "companion" twisted edwards curves) +// - Multi exponentiation +// - FFT +// - Polynomial commitment schemes +// - MiMC +// - twisted edwards "companion curves" +// - EdDSA (on the "companion" twisted edwards curves) package ecc import ( @@ -48,11 +48,12 @@ const ( BW6_761 BW6_633 BW6_756 + SECP256K1 ) // Implemented return the list of curves fully implemented in gnark-crypto func Implemented() []ID { - return []ID{BN254, BLS12_377, BLS12_381, BW6_761, BLS24_315, BW6_633, BLS12_378, BW6_756, BLS24_317} + return []ID{BN254, BLS12_377, BLS12_381, BW6_761, BLS24_315, BW6_633, BLS12_378, BW6_756, BLS24_317, SECP256K1} } func (id ID) String() string { @@ -94,6 +95,8 @@ func (id ID) config() *config.Curve { return &config.BLS24_317 case BW6_756: return &config.BW6_756 + case SECP256K1: + return &config.SECP256K1 default: panic("unimplemented ecc ID") } @@ -109,6 +112,5 @@ func modulus(c *config.Curve, scalarField bool) *big.Int { // MultiExpConfig enables to set optional configuration attribute to a call to MultiExp type MultiExpConfig struct { - NbTasks int // go routines to be used in the multiexp. can be larger than num cpus. - ScalarsMont bool // indicates if the scalars are in montgommery form. Default to false. + NbTasks int // go routines to be used in the multiexp. can be larger than num cpus. } diff --git a/ecc/secp256k1/fp/arith.go b/ecc/secp256k1/fp/arith.go new file mode 100644 index 000000000..66fa66748 --- /dev/null +++ b/ecc/secp256k1/fp/arith.go @@ -0,0 +1,60 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import ( + "math/bits" +) + +// madd0 hi = a*b + c (discards lo bits) +func madd0(a, b, c uint64) (hi uint64) { + var carry, lo uint64 + hi, lo = bits.Mul64(a, b) + _, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +// madd1 hi, lo = a*b + c +func madd1(a, b, c uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +// madd2 hi, lo = a*b + c + d +func madd2(a, b, c, d uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, e, carry) + return +} diff --git a/ecc/secp256k1/fp/doc.go b/ecc/secp256k1/fp/doc.go new file mode 100644 index 000000000..6e11299df --- /dev/null +++ b/ecc/secp256k1/fp/doc.go @@ -0,0 +1,53 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// Package fp contains field arithmetic operations for modulus = 0xffffff...fffc2f. +// +// The API is similar to math/big (big.Int), but the operations are significantly faster (up to 20x for the modular multiplication on amd64, see also https://hackmd.io/@gnark/modular_multiplication) +// +// The modulus is hardcoded in all the operations. +// +// Field elements are represented as an array, and assumed to be in Montgomery form in all methods: +// +// type Element [4]uint64 +// +// # Usage +// +// Example API signature: +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element +// +// and can be used like so: +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) +// +// Modulus q = +// +// q[base10] = 115792089237316195423570985008687907853269984665640564039457584007908834671663 +// q[base16] = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f +// +// # Warning +// +// This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. +package fp diff --git a/ecc/secp256k1/fp/element.go b/ecc/secp256k1/fp/element.go new file mode 100644 index 000000000..2a0e5f31a --- /dev/null +++ b/ecc/secp256k1/fp/element.go @@ -0,0 +1,1225 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import ( + "crypto/rand" + "encoding/binary" + "errors" + "github.com/consensys/gnark-crypto/field" + "io" + "math/big" + "math/bits" + "reflect" + "strconv" + "strings" +) + +// Element represents a field element stored on 4 words (uint64) +// +// Element are assumed to be in Montgomery form in all methods. +// +// Modulus q = +// +// q[base10] = 115792089237316195423570985008687907853269984665640564039457584007908834671663 +// q[base16] = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f +// +// # Warning +// +// This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. +type Element [4]uint64 + +const ( + Limbs = 4 // number of 64 bits words needed to represent a Element + Bits = 256 // number of bits needed to represent a Element + Bytes = 32 // number of bytes needed to represent a Element +) + +// Field modulus q +const ( + q0 uint64 = 18446744069414583343 + q1 uint64 = 18446744073709551615 + q2 uint64 = 18446744073709551615 + q3 uint64 = 18446744073709551615 +) + +var qElement = Element{ + q0, + q1, + q2, + q3, +} + +var _modulus big.Int // q stored as big.Int + +// Modulus returns q as a big.Int +// +// q[base10] = 115792089237316195423570985008687907853269984665640564039457584007908834671663 +// q[base16] = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f +func Modulus() *big.Int { + return new(big.Int).Set(&_modulus) +} + +// q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +// used for Montgomery reduction +const qInvNeg uint64 = 15580212934572586289 + +func init() { + _modulus.SetString("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f", 16) +} + +// NewElement returns a new Element from a uint64 value +// +// it is equivalent to +// +// var v Element +// v.SetUint64(...) +func NewElement(v uint64) Element { + z := Element{v} + z.Mul(&z, &rSquare) + return z +} + +// SetUint64 sets z to v and returns z +func (z *Element) SetUint64(v uint64) *Element { + // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form + *z = Element{v} + return z.Mul(z, &rSquare) // z.toMont() +} + +// SetInt64 sets z to v and returns z +func (z *Element) SetInt64(v int64) *Element { + + // absolute value of v + m := v >> 63 + z.SetUint64(uint64((v ^ m) - m)) + + if m != 0 { + // v is negative + z.Neg(z) + } + + return z +} + +// Set z = x and returns z +func (z *Element) Set(x *Element) *Element { + z[0] = x[0] + z[1] = x[1] + z[2] = x[2] + z[3] = x[3] + return z +} + +// SetInterface converts provided interface into Element +// returns an error if provided type is not supported +// supported types: +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte +func (z *Element) SetInterface(i1 interface{}) (*Element, error) { + if i1 == nil { + return nil, errors.New("can't set fp.Element with ") + } + + switch c1 := i1.(type) { + case Element: + return z.Set(&c1), nil + case *Element: + if c1 == nil { + return nil, errors.New("can't set fp.Element with ") + } + return z.Set(c1), nil + case uint8: + return z.SetUint64(uint64(c1)), nil + case uint16: + return z.SetUint64(uint64(c1)), nil + case uint32: + return z.SetUint64(uint64(c1)), nil + case uint: + return z.SetUint64(uint64(c1)), nil + case uint64: + return z.SetUint64(c1), nil + case int8: + return z.SetInt64(int64(c1)), nil + case int16: + return z.SetInt64(int64(c1)), nil + case int32: + return z.SetInt64(int64(c1)), nil + case int64: + return z.SetInt64(c1), nil + case int: + return z.SetInt64(int64(c1)), nil + case string: + return z.SetString(c1) + case *big.Int: + if c1 == nil { + return nil, errors.New("can't set fp.Element with ") + } + return z.SetBigInt(c1), nil + case big.Int: + return z.SetBigInt(&c1), nil + case []byte: + return z.SetBytes(c1), nil + default: + return nil, errors.New("can't set fp.Element from type " + reflect.TypeOf(i1).String()) + } +} + +// SetZero z = 0 +func (z *Element) SetZero() *Element { + z[0] = 0 + z[1] = 0 + z[2] = 0 + z[3] = 0 + return z +} + +// SetOne z = 1 (in Montgomery form) +func (z *Element) SetOne() *Element { + z[0] = 4294968273 + z[1] = 0 + z[2] = 0 + z[3] = 0 + return z +} + +// Div z = x*y⁻¹ (mod q) +func (z *Element) Div(x, y *Element) *Element { + var yInv Element + yInv.Inverse(y) + z.Mul(x, &yInv) + return z +} + +// Bit returns the i'th bit, with lsb == bit 0. +// +// It is the responsibility of the caller to convert from Montgomery to Regular form if needed. +func (z *Element) Bit(i uint64) uint64 { + j := i / 64 + if j >= 4 { + return 0 + } + return uint64(z[j] >> (i % 64) & 1) +} + +// Equal returns z == x; constant-time +func (z *Element) Equal(x *Element) bool { + return z.NotEqual(x) == 0 +} + +// NotEqual returns 0 if and only if z == x; constant-time +func (z *Element) NotEqual(x *Element) uint64 { + return (z[3] ^ x[3]) | (z[2] ^ x[2]) | (z[1] ^ x[1]) | (z[0] ^ x[0]) +} + +// IsZero returns z == 0 +func (z *Element) IsZero() bool { + return (z[3] | z[2] | z[1] | z[0]) == 0 +} + +// IsOne returns z == 1 +func (z *Element) IsOne() bool { + return (z[3] ^ 0 | z[2] ^ 0 | z[1] ^ 0 | z[0] ^ 4294968273) == 0 +} + +// IsUint64 reports whether z can be represented as an uint64. +func (z *Element) IsUint64() bool { + zz := *z + zz.fromMont() + return zz.FitsOnOneWord() +} + +// Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. +func (z *Element) Uint64() uint64 { + return z.Bits()[0] +} + +// FitsOnOneWord reports whether z words (except the least significant word) are 0 +// +// It is the responsibility of the caller to convert from Montgomery to Regular form if needed. +func (z *Element) FitsOnOneWord() bool { + return (z[3] | z[2] | z[1]) == 0 +} + +// Cmp compares (lexicographic order) z and x and returns: +// +// -1 if z < x +// 0 if z == x +// +1 if z > x +func (z *Element) Cmp(x *Element) int { + _z := z.Bits() + _x := x.Bits() + if _z[3] > _x[3] { + return 1 + } else if _z[3] < _x[3] { + return -1 + } + if _z[2] > _x[2] { + return 1 + } else if _z[2] < _x[2] { + return -1 + } + if _z[1] > _x[1] { + return 1 + } else if _z[1] < _x[1] { + return -1 + } + if _z[0] > _x[0] { + return 1 + } else if _z[0] < _x[0] { + return -1 + } + return 0 +} + +// LexicographicallyLargest returns true if this element is strictly lexicographically +// larger than its negation, false otherwise +func (z *Element) LexicographicallyLargest() bool { + // adapted from github.com/zkcrypto/bls12_381 + // we check if the element is larger than (q-1) / 2 + // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 + + _z := z.Bits() + + var b uint64 + _, b = bits.Sub64(_z[0], 18446744071562067480, 0) + _, b = bits.Sub64(_z[1], 18446744073709551615, b) + _, b = bits.Sub64(_z[2], 18446744073709551615, b) + _, b = bits.Sub64(_z[3], 9223372036854775807, b) + + return b == 0 +} + +// SetRandom sets z to a uniform random value in [0, q). +// +// This might error only if reading from crypto/rand.Reader errors, +// in which case, value of z is undefined. +func (z *Element) SetRandom() (*Element, error) { + // this code is generated for all modulus + // and derived from go/src/crypto/rand/util.go + + // l is number of limbs * 8; the number of bytes needed to reconstruct 4 uint64 + const l = 32 + + // bitLen is the maximum bit length needed to encode a value < q. + const bitLen = 256 + + // k is the maximum byte length needed to encode a value < q. + const k = (bitLen + 7) / 8 + + // b is the number of bits in the most significant byte of q-1. + b := uint(bitLen % 8) + if b == 0 { + b = 8 + } + + var bytes [l]byte + + for { + // note that bytes[k:l] is always 0 + if _, err := io.ReadFull(rand.Reader, bytes[:k]); err != nil { + return nil, err + } + + // Clear unused bits in in the most signicant byte to increase probability + // that the candidate is < q. + bytes[k-1] &= uint8(int(1<> 1 + z[0] = z[0]>>1 | z[1]<<63 + z[1] = z[1]>>1 | z[2]<<63 + z[2] = z[2]>>1 | z[3]<<63 + z[3] >>= 1 + + if carry != 0 { + // when we added q, the result was larger than our available limbs + // when we shift right, we need to set the highest bit + z[3] |= (1 << 63) + } + +} + +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// sets and returns z = z * 1 +func (z *Element) fromMont() *Element { + fromMont(z) + return z +} + +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], carry = bits.Add64(x[3], y[3], carry) + // if we overflowed the last addition, z >= q + // if z >= q, z = z - q + if carry != 0 { + var b uint64 + // we overflowed, so z >= q + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + return z + } + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], carry = bits.Add64(x[3], x[3], carry) + // if we overflowed the last addition, z >= q + // if z >= q, z = z - q + if carry != 0 { + var b uint64 + // we overflowed, so z >= q + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + return z + } + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + z[1], c = bits.Add64(z[1], q1, c) + z[2], c = bits.Add64(z[2], q2, c) + z[3], _ = bits.Add64(z[3], q3, c) + } + return z +} + +// Neg z = q - x +func (z *Element) Neg(x *Element) *Element { + if x.IsZero() { + z.SetZero() + return z + } + var borrow uint64 + z[0], borrow = bits.Sub64(q0, x[0], 0) + z[1], borrow = bits.Sub64(q1, x[1], borrow) + z[2], borrow = bits.Sub64(q2, x[2], borrow) + z[3], _ = bits.Sub64(q3, x[3], borrow) + return z +} + +// Select is a constant-time conditional move. +// If c=0, z = x0. Else z = x1 +func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { + cC := uint64((int64(c) | -int64(c)) >> 63) // "canonicized" into: 0 if c=0, -1 otherwise + z[0] = x0[0] ^ cC&(x0[0]^x1[0]) + z[1] = x0[1] ^ cC&(x0[1]^x1[1]) + z[2] = x0[2] ^ cC&(x0[2]^x1[2]) + z[3] = x0[3] ^ cC&(x0[3]^x1[3]) + return z +} + +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. +func _mulGeneric(z, x, y *Element) { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [5]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + + if t[4] != 0 { + // we need to reduce, we have a result on 5 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], _ = bits.Sub64(t[3], q3, b) + return + } + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } +} + +func _fromMontGeneric(z *Element) { + // the following lines implement z = z * 1 + // with a modified CIOS montgomery multiplication + // see Mul for algorithm documentation + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } +} + +func _reduceGeneric(z *Element) { + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } +} + +// BatchInvert returns a new slice with every element inverted. +// Uses Montgomery batch inversion trick +func BatchInvert(a []Element) []Element { + res := make([]Element, len(a)) + if len(a) == 0 { + return res + } + + zeroes := make([]bool, len(a)) + accumulator := One() + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + zeroes[i] = true + continue + } + res[i] = accumulator + accumulator.Mul(&accumulator, &a[i]) + } + + accumulator.Inverse(&accumulator) + + for i := len(a) - 1; i >= 0; i-- { + if zeroes[i] { + continue + } + res[i].Mul(&res[i], &accumulator) + accumulator.Mul(&accumulator, &a[i]) + } + + return res +} + +func _butterflyGeneric(a, b *Element) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// BitLen returns the minimum number of bits needed to represent z +// returns 0 if z == 0 +func (z *Element) BitLen() int { + if z[3] != 0 { + return 192 + bits.Len64(z[3]) + } + if z[2] != 0 { + return 128 + bits.Len64(z[2]) + } + if z[1] != 0 { + return 64 + bits.Len64(z[1]) + } + return bits.Len64(z[0]) +} + +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + +// Exp z = xᵏ (mod q) +func (z *Element) Exp(x Element, k *big.Int) *Element { + if k.IsUint64() && k.Uint64() == 0 { + return z.SetOne() + } + + e := k + if k.Sign() == -1 { + // negative k, we invert + // if k < 0: xᵏ (mod q) == (x⁻¹)ᵏ (mod q) + x.Inverse(&x) + + // we negate k in a temp big.Int since + // Int.Bit(_) of k and -k is different + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) + e.Neg(k) + } + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z +} + +// rSquare where r is the Montgommery constant +// see section 2.3.2 of Tolga Acar's thesis +// https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf +var rSquare = Element{ + 8392367050913, + 1, + 0, + 0, +} + +// toMont converts z to Montgomery form +// sets and returns z = z * r² +func (z *Element) toMont() *Element { + return z.Mul(z, &rSquare) +} + +// String returns the decimal representation of z as generated by +// z.Text(10). +func (z *Element) String() string { + return z.Text(10) +} + +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[24:32], z[0]) + binary.BigEndian.PutUint64(b[16:24], z[1]) + binary.BigEndian.PutUint64(b[8:16], z[2]) + binary.BigEndian.PutUint64(b[0:8], z[3]) + + return res.SetBytes(b[:]) +} + +// Text returns the string representation of z in the given base. +// Base must be between 2 and 36, inclusive. The result uses the +// lower-case letters 'a' to 'z' for digit values 10 to 35. +// No prefix (such as "0x") is added to the string. If z is a nil +// pointer it returns "". +// If base == 10 and -z fits in a uint16 prefix "-" is added to the string. +func (z *Element) Text(base int) string { + if base < 2 || base > 36 { + panic("invalid base") + } + if z == nil { + return "" + } + + const maxUint16 = 65535 + if base == 10 { + var zzNeg Element + zzNeg.Neg(z) + zzNeg.fromMont() + if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { + return "-" + strconv.FormatUint(zzNeg[0], base) + } + } + zz := *z + zz.fromMont() + if zz.FitsOnOneWord() { + return strconv.FormatUint(zz[0], base) + } + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) + return r +} + +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) +} + +// ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead +func (z Element) ToBigIntRegular(res *big.Int) *big.Int { + z.fromMont() + return z.toBigInt(res) +} + +// Bits provides access to z by returning its value as a little-endian [4]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [4]uint64 { + _z := *z + fromMont(&_z) + return _z +} + +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) + return +} + +// Marshal returns the value of z as a big-endian byte slice +func (z *Element) Marshal() []byte { + b := z.Bytes() + return b[:] +} + +// SetBytes interprets e as the bytes of a big-endian unsigned integer, +// sets z to that value, and returns z. +func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. + // get a big int from our pool + vv := field.BigIntPool.Get() + vv.SetBytes(e) + + // set big int + z.SetBigInt(vv) + + // put temporary object back in pool + field.BigIntPool.Put(vv) + + return z +} + +// SetBytesCanonical interprets e as the bytes of a big-endian 32-byte integer. +// If e is not a 32-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fp.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + +// SetBigInt sets z to v and returns z +func (z *Element) SetBigInt(v *big.Int) *Element { + z.SetZero() + + var zero big.Int + + // fast path + c := v.Cmp(&_modulus) + if c == 0 { + // v == 0 + return z + } else if c != 1 && v.Cmp(&zero) != -1 { + // 0 < v < q + return z.setBigInt(v) + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + // copy input + modular reduction + vv.Mod(v, &_modulus) + + // set big int byte value + z.setBigInt(vv) + + // release object into pool + field.BigIntPool.Put(vv) + return z +} + +// setBigInt assumes 0 ⩽ v < q +func (z *Element) setBigInt(v *big.Int) *Element { + vBits := v.Bits() + + if bits.UintSize == 64 { + for i := 0; i < len(vBits); i++ { + z[i] = uint64(vBits[i]) + } + } else { + for i := 0; i < len(vBits); i++ { + if i%2 == 0 { + z[i/2] = uint64(vBits[i]) + } else { + z[i/2] |= uint64(vBits[i]) << 32 + } + } + } + + return z.toMont() +} + +// SetString creates a big.Int with number and calls SetBigInt on z +// +// The number prefix determines the actual base: A prefix of +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 +// and no prefix is accepted. +// +// For base 16, lower and upper case letters are considered the same: +// The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. +// +// An underscore character ”_” may appear between a base +// prefix and an adjacent digit, and between successive digits; such +// underscores do not change the value of the number. +// Incorrect placement of underscores is reported as a panic if there +// are no other errors. +// +// If the number is invalid this method leaves z unchanged and returns nil, error. +func (z *Element) SetString(number string) (*Element, error) { + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + if _, ok := vv.SetString(number, 0); !ok { + return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) + } + + z.SetBigInt(vv) + + // release object into pool + field.BigIntPool.Put(vv) + + return z, nil +} + +// MarshalJSON returns json encoding of z (z.Text(10)) +// If z == nil, returns null +func (z *Element) MarshalJSON() ([]byte, error) { + if z == nil { + return []byte("null"), nil + } + const maxSafeBound = 15 // we encode it as number if it's small + s := z.Text(10) + if len(s) <= maxSafeBound { + return []byte(s), nil + } + var sbb strings.Builder + sbb.WriteByte('"') + sbb.WriteString(s) + sbb.WriteByte('"') + return []byte(sbb.String()), nil +} + +// UnmarshalJSON accepts numbers and strings as input +// See Element.SetString for valid prefixes (0x, 0b, ...) +func (z *Element) UnmarshalJSON(data []byte) error { + s := string(data) + if len(s) > Bits*3 { + return errors.New("value too large (max = Element.Bits * 3)") + } + + // we accept numbers and strings, remove leading and trailing quotes if any + if len(s) > 0 && s[0] == '"' { + s = s[1:] + } + if len(s) > 0 && s[len(s)-1] == '"' { + s = s[:len(s)-1] + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + if _, ok := vv.SetString(s, 0); !ok { + return errors.New("can't parse into a big.Int: " + s) + } + + z.SetBigInt(vv) + + // release object into pool + field.BigIntPool.Put(vv) + return nil +} + +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 32-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[24:32]) + z[1] = binary.BigEndian.Uint64((*b)[16:24]) + z[2] = binary.BigEndian.Uint64((*b)[8:16]) + z[3] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[24:32], e[0]) + binary.BigEndian.PutUint64((*b)[16:24], e[1]) + binary.BigEndian.PutUint64((*b)[8:16], e[2]) + binary.BigEndian.PutUint64((*b)[0:8], e[3]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fp.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) +} + +func (littleEndian) String() string { return "LittleEndian" } + +// Legendre returns the Legendre symbol of z (either +1, -1, or 0.) +func (z *Element) Legendre() int { + var l Element + // z^((q-1)/2) + l.expByLegendreExp(*z) + + if l.IsZero() { + return 0 + } + + // if l == 1 + if l.IsOne() { + return 1 + } + return -1 +} + +// Sqrt z = √x (mod q) +// if the square root doesn't exist (x is not a square mod q) +// Sqrt leaves z unchanged and returns nil +func (z *Element) Sqrt(x *Element) *Element { + // q ≡ 3 (mod 4) + // using z ≡ ± x^((p+1)/4) (mod q) + var y, square Element + y.expBySqrtExp(*x) + // as we didn't compute the legendre symbol, ensure we found y such that y * y = x + square.Square(&y) + if square.Equal(x) { + return z.Set(&y) + } + return nil +} + +// Inverse z = x⁻¹ (mod q) +// +// note: allocates a big.Int (math/big) +func (z *Element) Inverse(x *Element) *Element { + var _xNonMont big.Int + x.BigInt(&_xNonMont) + _xNonMont.ModInverse(&_xNonMont, Modulus()) + z.SetBigInt(&_xNonMont) + return z +} diff --git a/ecc/secp256k1/fp/element_exp.go b/ecc/secp256k1/fp/element_exp.go new file mode 100644 index 000000000..7b2f8cb83 --- /dev/null +++ b/ecc/secp256k1/fp/element_exp.go @@ -0,0 +1,324 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +// expBySqrtExp is equivalent to z.Exp(x, 3fffffffffffffffffffffffffffffffffffffffffffffffffffffffbfffff0c) +// +// uses github.com/mmcloughlin/addchain v0.4.0 to generate a shorter addition chain +func (z *Element) expBySqrtExp(x Element) *Element { + // addition chain: + // + // _10 = 2*1 + // _11 = 1 + _10 + // _1100 = _11 << 2 + // _1111 = _11 + _1100 + // _11110 = 2*_1111 + // _11111 = 1 + _11110 + // _1111100 = _11111 << 2 + // _1111111 = _11 + _1111100 + // x11 = _1111111 << 4 + _1111 + // x22 = x11 << 11 + x11 + // x27 = x22 << 5 + _11111 + // x54 = x27 << 27 + x27 + // x108 = x54 << 54 + x54 + // x216 = x108 << 108 + x108 + // x223 = x216 << 7 + _1111111 + // return ((x223 << 23 + x22) << 6 + _11) << 2 + // + // Operations: 253 squares 13 multiplies + + // Allocate Temporaries. + var ( + t0 = new(Element) + t1 = new(Element) + t2 = new(Element) + t3 = new(Element) + ) + + // var t0,t1,t2,t3 Element + // Step 1: z = x^0x2 + z.Square(&x) + + // Step 2: z = x^0x3 + z.Mul(&x, z) + + // Step 4: t0 = x^0xc + t0.Square(z) + for s := 1; s < 2; s++ { + t0.Square(t0) + } + + // Step 5: t0 = x^0xf + t0.Mul(z, t0) + + // Step 6: t1 = x^0x1e + t1.Square(t0) + + // Step 7: t2 = x^0x1f + t2.Mul(&x, t1) + + // Step 9: t1 = x^0x7c + t1.Square(t2) + for s := 1; s < 2; s++ { + t1.Square(t1) + } + + // Step 10: t1 = x^0x7f + t1.Mul(z, t1) + + // Step 14: t3 = x^0x7f0 + t3.Square(t1) + for s := 1; s < 4; s++ { + t3.Square(t3) + } + + // Step 15: t0 = x^0x7ff + t0.Mul(t0, t3) + + // Step 26: t3 = x^0x3ff800 + t3.Square(t0) + for s := 1; s < 11; s++ { + t3.Square(t3) + } + + // Step 27: t0 = x^0x3fffff + t0.Mul(t0, t3) + + // Step 32: t3 = x^0x7ffffe0 + t3.Square(t0) + for s := 1; s < 5; s++ { + t3.Square(t3) + } + + // Step 33: t2 = x^0x7ffffff + t2.Mul(t2, t3) + + // Step 60: t3 = x^0x3ffffff8000000 + t3.Square(t2) + for s := 1; s < 27; s++ { + t3.Square(t3) + } + + // Step 61: t2 = x^0x3fffffffffffff + t2.Mul(t2, t3) + + // Step 115: t3 = x^0xfffffffffffffc0000000000000 + t3.Square(t2) + for s := 1; s < 54; s++ { + t3.Square(t3) + } + + // Step 116: t2 = x^0xfffffffffffffffffffffffffff + t2.Mul(t2, t3) + + // Step 224: t3 = x^0xfffffffffffffffffffffffffff000000000000000000000000000 + t3.Square(t2) + for s := 1; s < 108; s++ { + t3.Square(t3) + } + + // Step 225: t2 = x^0xffffffffffffffffffffffffffffffffffffffffffffffffffffff + t2.Mul(t2, t3) + + // Step 232: t2 = x^0x7fffffffffffffffffffffffffffffffffffffffffffffffffffff80 + for s := 0; s < 7; s++ { + t2.Square(t2) + } + + // Step 233: t1 = x^0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffff + t1.Mul(t1, t2) + + // Step 256: t1 = x^0x3fffffffffffffffffffffffffffffffffffffffffffffffffffffff800000 + for s := 0; s < 23; s++ { + t1.Square(t1) + } + + // Step 257: t0 = x^0x3fffffffffffffffffffffffffffffffffffffffffffffffffffffffbfffff + t0.Mul(t0, t1) + + // Step 263: t0 = x^0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc0 + for s := 0; s < 6; s++ { + t0.Square(t0) + } + + // Step 264: z = x^0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc3 + z.Mul(z, t0) + + // Step 266: z = x^0x3fffffffffffffffffffffffffffffffffffffffffffffffffffffffbfffff0c + for s := 0; s < 2; s++ { + z.Square(z) + } + + return z +} + +// expByLegendreExp is equivalent to z.Exp(x, 7fffffffffffffffffffffffffffffffffffffffffffffffffffffff7ffffe17) +// +// uses github.com/mmcloughlin/addchain v0.4.0 to generate a shorter addition chain +func (z *Element) expByLegendreExp(x Element) *Element { + // addition chain: + // + // _10 = 2*1 + // _100 = 2*_10 + // _110 = _10 + _100 + // _111 = 1 + _110 + // _1110 = 2*_111 + // _10101 = _111 + _1110 + // _10111 = _10 + _10101 + // _101110 = 2*_10111 + // _10111000 = _101110 << 2 + // _11100110 = _101110 + _10111000 + // _11111101 = _10111 + _11100110 + // x11 = _11111101 << 3 + _10111 + // x22 = x11 << 11 + x11 + // i29 = 2*x22 + // i31 = i29 << 2 + // i54 = i31 << 22 + i31 + // i122 = (i54 << 20 + i29) << 46 + i54 + // x223 = i122 << 110 + i122 + _111 + // return (x223 << 23 + x22) << 9 + _10111 + // + // Operations: 253 squares 15 multiplies + + // Allocate Temporaries. + var ( + t0 = new(Element) + t1 = new(Element) + t2 = new(Element) + t3 = new(Element) + t4 = new(Element) + ) + + // var t0,t1,t2,t3,t4 Element + // Step 1: z = x^0x2 + z.Square(&x) + + // Step 2: t0 = x^0x4 + t0.Square(z) + + // Step 3: t0 = x^0x6 + t0.Mul(z, t0) + + // Step 4: t1 = x^0x7 + t1.Mul(&x, t0) + + // Step 5: t0 = x^0xe + t0.Square(t1) + + // Step 6: t0 = x^0x15 + t0.Mul(t1, t0) + + // Step 7: z = x^0x17 + z.Mul(z, t0) + + // Step 8: t0 = x^0x2e + t0.Square(z) + + // Step 10: t2 = x^0xb8 + t2.Square(t0) + for s := 1; s < 2; s++ { + t2.Square(t2) + } + + // Step 11: t0 = x^0xe6 + t0.Mul(t0, t2) + + // Step 12: t0 = x^0xfd + t0.Mul(z, t0) + + // Step 15: t0 = x^0x7e8 + for s := 0; s < 3; s++ { + t0.Square(t0) + } + + // Step 16: t0 = x^0x7ff + t0.Mul(z, t0) + + // Step 27: t2 = x^0x3ff800 + t2.Square(t0) + for s := 1; s < 11; s++ { + t2.Square(t2) + } + + // Step 28: t0 = x^0x3fffff + t0.Mul(t0, t2) + + // Step 29: t3 = x^0x7ffffe + t3.Square(t0) + + // Step 31: t2 = x^0x1fffff8 + t2.Square(t3) + for s := 1; s < 2; s++ { + t2.Square(t2) + } + + // Step 53: t4 = x^0x7ffffe000000 + t4.Square(t2) + for s := 1; s < 22; s++ { + t4.Square(t4) + } + + // Step 54: t2 = x^0x7ffffffffff8 + t2.Mul(t2, t4) + + // Step 74: t4 = x^0x7ffffffffff800000 + t4.Square(t2) + for s := 1; s < 20; s++ { + t4.Square(t4) + } + + // Step 75: t3 = x^0x7fffffffffffffffe + t3.Mul(t3, t4) + + // Step 121: t3 = x^0x1ffffffffffffffff800000000000 + for s := 0; s < 46; s++ { + t3.Square(t3) + } + + // Step 122: t2 = x^0x1fffffffffffffffffffffffffff8 + t2.Mul(t2, t3) + + // Step 232: t3 = x^0x7ffffffffffffffffffffffffffe0000000000000000000000000000 + t3.Square(t2) + for s := 1; s < 110; s++ { + t3.Square(t3) + } + + // Step 233: t2 = x^0x7ffffffffffffffffffffffffffffffffffffffffffffffffffffff8 + t2.Mul(t2, t3) + + // Step 234: t1 = x^0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffff + t1.Mul(t1, t2) + + // Step 257: t1 = x^0x3fffffffffffffffffffffffffffffffffffffffffffffffffffffff800000 + for s := 0; s < 23; s++ { + t1.Square(t1) + } + + // Step 258: t0 = x^0x3fffffffffffffffffffffffffffffffffffffffffffffffffffffffbfffff + t0.Mul(t0, t1) + + // Step 267: t0 = x^0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffff7ffffe00 + for s := 0; s < 9; s++ { + t0.Square(t0) + } + + // Step 268: z = x^0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffff7ffffe17 + z.Mul(z, t0) + + return z +} diff --git a/ecc/secp256k1/fp/element_ops_purego.go b/ecc/secp256k1/fp/element_ops_purego.go new file mode 100644 index 000000000..31298ba1f --- /dev/null +++ b/ecc/secp256k1/fp/element_ops_purego.go @@ -0,0 +1,330 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 55834587549, + 0, + 0, + 0, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [5]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + + if t[4] != 0 { + // we need to reduce, we have a result on 5 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], _ = bits.Sub64(t[3], q3, b) + return z + } + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t [5]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(x[0], x[0]) + C, t[1] = madd1(x[0], x[1], C) + C, t[2] = madd1(x[0], x[2], C) + C, t[3] = madd1(x[0], x[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(x[1], x[0], t[0]) + C, t[1] = madd2(x[1], x[1], t[1], C) + C, t[2] = madd2(x[1], x[2], t[2], C) + C, t[3] = madd2(x[1], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(x[2], x[0], t[0]) + C, t[1] = madd2(x[2], x[1], t[1], C) + C, t[2] = madd2(x[2], x[2], t[2], C) + C, t[3] = madd2(x[2], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(x[3], x[0], t[0]) + C, t[1] = madd2(x[3], x[1], t[1], C) + C, t[2] = madd2(x[3], x[2], t[2], C) + C, t[3] = madd2(x[3], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + + if t[4] != 0 { + // we need to reduce, we have a result on 5 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], _ = bits.Sub64(t[3], q3, b) + return z + } + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} diff --git a/ecc/secp256k1/fp/element_test.go b/ecc/secp256k1/fp/element_test.go new file mode 100644 index 000000000..5d6291c24 --- /dev/null +++ b/ecc/secp256k1/fp/element_test.go @@ -0,0 +1,2288 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fp + +import ( + "crypto/rand" + "encoding/json" + "fmt" + "math/big" + "math/bits" + + "testing" + + "github.com/leanovate/gopter" + ggen "github.com/leanovate/gopter/gen" + "github.com/leanovate/gopter/prop" + + "github.com/stretchr/testify/require" +) + +// ------------------------------------------------------------------------------------------------- +// benchmarks +// most benchmarks are rudimentary and should sample a large number of random inputs +// or be run multiple times to ensure it didn't measure the fastest path of the function + +var benchResElement Element + +func BenchmarkElementSelect(b *testing.B) { + var x, y Element + x.SetRandom() + y.SetRandom() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Select(i%3, &x, &y) + } +} + +func BenchmarkElementSetRandom(b *testing.B) { + var x Element + x.SetRandom() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = x.SetRandom() + } +} + +func BenchmarkElementSetBytes(b *testing.B) { + var x Element + x.SetRandom() + bb := x.Bytes() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchResElement.SetBytes(bb[:]) + } + +} + +func BenchmarkElementMulByConstants(b *testing.B) { + b.Run("mulBy3", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy3(&benchResElement) + } + }) + b.Run("mulBy5", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy5(&benchResElement) + } + }) + b.Run("mulBy13", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy13(&benchResElement) + } + }) +} + +func BenchmarkElementInverse(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchResElement.Inverse(&x) + } + +} + +func BenchmarkElementButterfly(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + Butterfly(&x, &benchResElement) + } +} + +func BenchmarkElementExp(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b1, _ := rand.Int(rand.Reader, Modulus()) + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Exp(x, b1) + } +} + +func BenchmarkElementDouble(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Double(&benchResElement) + } +} + +func BenchmarkElementAdd(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Add(&x, &benchResElement) + } +} + +func BenchmarkElementSub(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Sub(&x, &benchResElement) + } +} + +func BenchmarkElementNeg(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Neg(&benchResElement) + } +} + +func BenchmarkElementDiv(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Div(&x, &benchResElement) + } +} + +func BenchmarkElementFromMont(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.fromMont() + } +} + +func BenchmarkElementSquare(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Square(&benchResElement) + } +} + +func BenchmarkElementSqrt(b *testing.B) { + var a Element + a.SetUint64(4) + a.Neg(&a) + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Sqrt(&a) + } +} + +func BenchmarkElementMul(b *testing.B) { + x := Element{ + 8392367050913, + 1, + 0, + 0, + } + benchResElement.SetOne() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Mul(&benchResElement, &x) + } +} + +func BenchmarkElementCmp(b *testing.B) { + x := Element{ + 8392367050913, + 1, + 0, + 0, + } + benchResElement = x + benchResElement[0] = 0 + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Cmp(&x) + } +} + +func TestElementCmp(t *testing.T) { + var x, y Element + + if x.Cmp(&y) != 0 { + t.Fatal("x == y") + } + + one := One() + y.Sub(&y, &one) + + if x.Cmp(&y) != -1 { + t.Fatal("x < y") + } + if y.Cmp(&x) != 1 { + t.Fatal("x < y") + } + + x = y + if x.Cmp(&y) != 0 { + t.Fatal("x == y") + } + + x.Sub(&x, &one) + if x.Cmp(&y) != -1 { + t.Fatal("x < y") + } + if y.Cmp(&x) != 1 { + t.Fatal("x < y") + } +} +func TestElementIsRandom(t *testing.T) { + for i := 0; i < 50; i++ { + var x, y Element + x.SetRandom() + y.SetRandom() + if x.Equal(&y) { + t.Fatal("2 random numbers are unlikely to be equal") + } + } +} + +func TestElementIsUint64(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + properties.Property("reduce should output a result smaller than modulus", prop.ForAll( + func(v uint64) bool { + var e Element + e.SetUint64(v) + + if !e.IsUint64() { + return false + } + + return e.Uint64() == v + }, + ggen.UInt64(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementNegZero(t *testing.T) { + var a, b Element + b.SetZero() + for a.IsZero() { + a.SetRandom() + } + a.Neg(&b) + if !a.IsZero() { + t.Fatal("neg(0) != 0") + } +} + +// ------------------------------------------------------------------------------------------------- +// Gopter tests +// most of them are generated with a template + +const ( + nbFuzzShort = 200 + nbFuzz = 1000 +) + +// special values to be used in tests +var staticTestValues []Element + +func init() { + staticTestValues = append(staticTestValues, Element{}) // zero + staticTestValues = append(staticTestValues, One()) // one + staticTestValues = append(staticTestValues, rSquare) // r² + var e, one Element + one.SetOne() + e.Sub(&qElement, &one) + staticTestValues = append(staticTestValues, e) // q - 1 + e.Double(&one) + staticTestValues = append(staticTestValues, e) // 2 + + { + a := qElement + a[0]-- + staticTestValues = append(staticTestValues, a) + } + staticTestValues = append(staticTestValues, Element{0}) + staticTestValues = append(staticTestValues, Element{0, 0}) + staticTestValues = append(staticTestValues, Element{1}) + staticTestValues = append(staticTestValues, Element{0, 1}) + staticTestValues = append(staticTestValues, Element{2}) + staticTestValues = append(staticTestValues, Element{0, 2}) + + { + a := qElement + a[3]-- + staticTestValues = append(staticTestValues, a) + } + { + a := qElement + a[3]-- + a[0]++ + staticTestValues = append(staticTestValues, a) + } + + { + a := qElement + a[3] = 0 + staticTestValues = append(staticTestValues, a) + } + +} + +func TestElementReduce(t *testing.T) { + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, s := range testValues { + expected := s + reduce(&s) + _reduceGeneric(&expected) + if !s.Equal(&expected) { + t.Fatal("reduce failed: asm and generic impl don't match") + } + } + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := genFull() + + properties.Property("reduce should output a result smaller than modulus", prop.ForAll( + func(a Element) bool { + b := a + reduce(&a) + _reduceGeneric(&b) + return a.smallerThanModulus() && a.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementEqual(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("x.Equal(&y) iff x == y; likely false for random pairs", prop.ForAll( + func(a testPairElement, b testPairElement) bool { + return a.element.Equal(&b.element) == (a.element == b.element) + }, + genA, + genB, + )) + + properties.Property("x.Equal(&y) if x == y", prop.ForAll( + func(a testPairElement) bool { + b := a.element + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementBytes(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("SetBytes(Bytes()) should stay constant", prop.ForAll( + func(a testPairElement) bool { + var b Element + bytes := a.element.Bytes() + b.SetBytes(bytes[:]) + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementInverseExp(t *testing.T) { + // inverse must be equal to exp^-2 + exp := Modulus() + exp.Sub(exp, new(big.Int).SetUint64(2)) + + invMatchExp := func(a testPairElement) bool { + var b Element + b.Set(&a.element) + a.element.Inverse(&a.element) + b.Exp(b, exp) + + return a.element.Equal(&b) + } + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + properties := gopter.NewProperties(parameters) + genA := gen() + properties.Property("inv == exp^-2", prop.ForAll(invMatchExp, genA)) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + + parameters.MinSuccessfulTests = 1 + properties = gopter.NewProperties(parameters) + properties.Property("inv(0) == 0", prop.ForAll(invMatchExp, ggen.OneConstOf(testPairElement{}))) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func mulByConstant(z *Element, c uint8) { + var y Element + y.SetUint64(uint64(c)) + z.Mul(z, &y) +} + +func TestElementMulByConstants(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + implemented := []uint8{0, 1, 2, 3, 5, 13} + properties.Property("mulByConstant", prop.ForAll( + func(a testPairElement) bool { + for _, c := range implemented { + var constant Element + constant.SetUint64(uint64(c)) + + b := a.element + b.Mul(&b, &constant) + + aa := a.element + mulByConstant(&aa, c) + + if !aa.Equal(&b) { + return false + } + } + + return true + }, + genA, + )) + + properties.Property("MulBy3(x) == Mul(x, 3)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(3) + + b := a.element + b.Mul(&b, &constant) + + MulBy3(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("MulBy5(x) == Mul(x, 5)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(5) + + b := a.element + b.Mul(&b, &constant) + + MulBy5(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("MulBy13(x) == Mul(x, 13)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(13) + + b := a.element + b.Mul(&b, &constant) + + MulBy13(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementLegendre(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("legendre should output same result than big.Int.Jacobi", prop.ForAll( + func(a testPairElement) bool { + return a.element.Legendre() == big.Jacobi(&a.bigint, Modulus()) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementBitLen(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( + func(a testPairElement) bool { + return a.element.fromMont().BitLen() == a.bigint.BitLen() + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementButterflies(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("butterfly0 == a -b; a +b", prop.ForAll( + func(a, b testPairElement) bool { + a0, b0 := a.element, b.element + + _butterflyGeneric(&a.element, &b.element) + Butterfly(&a0, &b0) + + return a.element.Equal(&a0) && b.element.Equal(&b0) + }, + genA, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementLexicographicallyLargest(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("element.Cmp should match LexicographicallyLargest output", prop.ForAll( + func(a testPairElement) bool { + var negA Element + negA.Neg(&a.element) + + cmpResult := a.element.Cmp(&negA) + lResult := a.element.LexicographicallyLargest() + + if lResult && cmpResult == 1 { + return true + } + if !lResult && cmpResult != 1 { + return true + } + return false + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementAdd(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Add: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Add(&a.element, &b.element) + a.element.Add(&a.element, &b.element) + b.element.Add(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Add: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Add(&a.element, &b.element) + + var d, e big.Int + d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, r := range testValues { + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Add(&a.element, &r) + d.Add(&a.bigint, &rb).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Add: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Add(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + for _, b := range testValues { + + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Add(&a, &b) + d.Add(&aBig, &bBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Add failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSub(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Sub: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Sub(&a.element, &b.element) + a.element.Sub(&a.element, &b.element) + b.element.Sub(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Sub: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Sub(&a.element, &b.element) + + var d, e big.Int + d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, r := range testValues { + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Sub(&a.element, &r) + d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Sub: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Sub(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + for _, b := range testValues { + + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Sub(&a, &b) + d.Sub(&aBig, &bBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Sub failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementMul(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Mul: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Mul(&a.element, &b.element) + a.element.Mul(&a.element, &b.element) + b.element.Mul(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Mul: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Mul(&a.element, &b.element) + + var d, e big.Int + d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, r := range testValues { + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Mul(&a.element, &r) + d.Mul(&a.bigint, &rb).Mod(&d, Modulus()) + + // checking generic impl against asm path + var cGeneric Element + _mulGeneric(&cGeneric, &a.element, &r) + if !cGeneric.Equal(&c) { + // need to give context to failing error. + return false + } + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Mul: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Mul(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + properties.Property("Mul: assembly implementation must be consistent with generic one", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + c.Mul(&a.element, &b.element) + _mulGeneric(&d, &a.element, &b.element) + return c.Equal(&d) + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + for _, b := range testValues { + + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Mul(&a, &b) + d.Mul(&aBig, &bBig).Mod(&d, Modulus()) + + // checking asm against generic impl + var cGeneric Element + _mulGeneric(&cGeneric, &a, &b) + if !cGeneric.Equal(&c) { + t.Fatal("Mul failed special test values: asm and generic impl don't match") + } + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Mul failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementDiv(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Div: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Div(&a.element, &b.element) + a.element.Div(&a.element, &b.element) + b.element.Div(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Div: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Div(&a.element, &b.element) + + var d, e big.Int + d.ModInverse(&b.bigint, Modulus()) + d.Mul(&d, &a.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, r := range testValues { + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Div(&a.element, &r) + d.ModInverse(&rb, Modulus()) + d.Mul(&d, &a.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Div: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Div(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + for _, b := range testValues { + + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Div(&a, &b) + d.ModInverse(&bBig, Modulus()) + d.Mul(&d, &aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Div failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementExp(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Exp: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Exp(a.element, &b.bigint) + a.element.Exp(a.element, &b.bigint) + b.element.Exp(d, &b.bigint) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Exp: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Exp(a.element, &b.bigint) + + var d, e big.Int + d.Exp(&a.bigint, &b.bigint, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, r := range testValues { + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Exp(a.element, &rb) + d.Exp(&a.bigint, &rb, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Exp: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Exp(a.element, &b.bigint) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + for _, b := range testValues { + + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Exp(a, &bBig) + d.Exp(&aBig, &bBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Exp failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSquare(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Square: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Square(&a.element) + a.element.Square(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Square: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Square(&a.element) + + var d, e big.Int + d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Square: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Square(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Square(&a) + + var d, e big.Int + d.Mul(&aBig, &aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Square failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementInverse(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Inverse: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Inverse(&a.element) + a.element.Inverse(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Inverse: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Inverse(&a.element) + + var d, e big.Int + d.ModInverse(&a.bigint, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Inverse: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Inverse(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Inverse(&a) + + var d, e big.Int + d.ModInverse(&aBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Inverse failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSqrt(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Sqrt: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + b := a.element + + b.Sqrt(&a.element) + a.element.Sqrt(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Sqrt: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Sqrt(&a.element) + + var d, e big.Int + d.ModSqrt(&a.bigint, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Sqrt: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Sqrt(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Sqrt(&a) + + var d, e big.Int + d.ModSqrt(&aBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Sqrt failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementDouble(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Double: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Double(&a.element) + a.element.Double(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Double: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Double(&a.element) + + var d, e big.Int + d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Double: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Double(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Double(&a) + + var d, e big.Int + d.Lsh(&aBig, 1).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Double failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementNeg(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Neg: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Neg(&a.element) + a.element.Neg(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Neg: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Neg(&a.element) + + var d, e big.Int + d.Neg(&a.bigint).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Neg: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Neg(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Neg(&a) + + var d, e big.Int + d.Neg(&aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Neg failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementFixedExp(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + var ( + _bLegendreExponentElement *big.Int + _bSqrtExponentElement *big.Int + ) + + _bLegendreExponentElement, _ = new(big.Int).SetString("7fffffffffffffffffffffffffffffffffffffffffffffffffffffff7ffffe17", 16) + const sqrtExponentElement = "3fffffffffffffffffffffffffffffffffffffffffffffffffffffffbfffff0c" + _bSqrtExponentElement, _ = new(big.Int).SetString(sqrtExponentElement, 16) + + genA := gen() + + properties.Property(fmt.Sprintf("expBySqrtExp must match Exp(%s)", sqrtExponentElement), prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.expBySqrtExp(c) + d.Exp(d, _bSqrtExponentElement) + return c.Equal(&d) + }, + genA, + )) + + properties.Property("expByLegendreExp must match Exp(7fffffffffffffffffffffffffffffffffffffffffffffffffffffff7ffffe17)", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.expByLegendreExp(c) + d.Exp(d, _bLegendreExponentElement) + return c.Equal(&d) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementHalve(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + var twoInv Element + twoInv.SetUint64(2) + twoInv.Inverse(&twoInv) + + properties.Property("z.Halve must match z / 2", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.Halve() + d.Mul(&d, &twoInv) + return c.Equal(&d) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func combineSelectionArguments(c int64, z int8) int { + if z%3 == 0 { + return 0 + } + return int(c) +} + +func TestElementSelect(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := genFull() + genB := genFull() + genC := ggen.Int64() //the condition + genZ := ggen.Int8() //to make zeros artificially more likely + + properties.Property("Select: must select correctly", prop.ForAll( + func(a, b Element, cond int64, z int8) bool { + condC := combineSelectionArguments(cond, z) + + var c Element + c.Select(condC, &a, &b) + + if condC == 0 { + return c.Equal(&a) + } + return c.Equal(&b) + }, + genA, + genB, + genC, + genZ, + )) + + properties.Property("Select: having the receiver as operand should output the same result", prop.ForAll( + func(a, b Element, cond int64, z int8) bool { + condC := combineSelectionArguments(cond, z) + + var c, d Element + d.Set(&a) + c.Select(condC, &a, &b) + a.Select(condC, &a, &b) + b.Select(condC, &d, &b) + return a.Equal(&b) && a.Equal(&c) && b.Equal(&c) + }, + genA, + genB, + genC, + genZ, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementSetInt64(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("z.SetInt64 must match z.SetString", prop.ForAll( + func(a testPairElement, v int64) bool { + c := a.element + d := a.element + + c.SetInt64(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, ggen.Int64(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementSetInterface(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genInt := ggen.Int + genInt8 := ggen.Int8 + genInt16 := ggen.Int16 + genInt32 := ggen.Int32 + genInt64 := ggen.Int64 + + genUint := ggen.UInt + genUint8 := ggen.UInt8 + genUint16 := ggen.UInt16 + genUint32 := ggen.UInt32 + genUint64 := ggen.UInt64 + + properties.Property("z.SetInterface must match z.SetString with int8", prop.ForAll( + func(a testPairElement, v int8) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt8(), + )) + + properties.Property("z.SetInterface must match z.SetString with int16", prop.ForAll( + func(a testPairElement, v int16) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt16(), + )) + + properties.Property("z.SetInterface must match z.SetString with int32", prop.ForAll( + func(a testPairElement, v int32) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt32(), + )) + + properties.Property("z.SetInterface must match z.SetString with int64", prop.ForAll( + func(a testPairElement, v int64) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt64(), + )) + + properties.Property("z.SetInterface must match z.SetString with int", prop.ForAll( + func(a testPairElement, v int) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint8", prop.ForAll( + func(a testPairElement, v uint8) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint8(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint16", prop.ForAll( + func(a testPairElement, v uint16) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint16(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint32", prop.ForAll( + func(a testPairElement, v uint32) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint32(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint64", prop.ForAll( + func(a testPairElement, v uint64) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint64(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint", prop.ForAll( + func(a testPairElement, v uint) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + + { + assert := require.New(t) + var e Element + r, err := e.SetInterface(nil) + assert.Nil(r) + assert.Error(err) + + var ptE *Element + var ptB *big.Int + + r, err = e.SetInterface(ptE) + assert.Nil(r) + assert.Error(err) + ptE = new(Element).SetOne() + r, err = e.SetInterface(ptE) + assert.NoError(err) + assert.True(r.IsOne()) + + r, err = e.SetInterface(ptB) + assert.Nil(r) + assert.Error(err) + + } +} + +func TestElementNegativeExp(t *testing.T) { + t.Parallel() + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("x⁻ᵏ == 1/xᵏ", prop.ForAll( + func(a, b testPairElement) bool { + + var nb, d, e big.Int + nb.Neg(&b.bigint) + + var c Element + c.Exp(a.element, &nb) + + d.Exp(&a.bigint, &nb, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementNewElement(t *testing.T) { + assert := require.New(t) + + t.Parallel() + + e := NewElement(1) + assert.True(e.IsOne()) + + e = NewElement(0) + assert.True(e.IsZero()) +} + +func TestElementBatchInvert(t *testing.T) { + assert := require.New(t) + + t.Parallel() + + // ensure batchInvert([x]) == invert(x) + for i := int64(-1); i <= 2; i++ { + var e, eInv Element + e.SetInt64(i) + eInv.Inverse(&e) + + a := []Element{e} + aInv := BatchInvert(a) + + assert.True(aInv[0].Equal(&eInv), "batchInvert != invert") + + } + + // test x * x⁻¹ == 1 + tData := [][]int64{ + {-1, 1, 2, 3}, + {0, -1, 1, 2, 3, 0}, + {0, -1, 1, 0, 2, 3, 0}, + {-1, 1, 0, 2, 3}, + {0, 0, 1}, + {1, 0, 0}, + {0, 0, 0}, + } + + for _, t := range tData { + a := make([]Element, len(t)) + for i := 0; i < len(a); i++ { + a[i].SetInt64(t[i]) + } + + aInv := BatchInvert(a) + + assert.True(len(aInv) == len(a)) + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + assert.True(aInv[i].IsZero(), "0⁻¹ != 0") + } else { + assert.True(a[i].Mul(&a[i], &aInv[i]).IsOne(), "x * x⁻¹ != 1") + } + } + } + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("batchInvert --> x * x⁻¹ == 1", prop.ForAll( + func(tp testPairElement, r uint8) bool { + + a := make([]Element, r) + if r != 0 { + a[0] = tp.element + + } + one := One() + for i := 1; i < len(a); i++ { + a[i].Add(&a[i-1], &one) + } + + aInv := BatchInvert(a) + + assert.True(len(aInv) == len(a)) + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + if !aInv[i].IsZero() { + return false + } + } else { + if !a[i].Mul(&a[i], &aInv[i]).IsOne() { + return false + } + } + } + return true + }, + genA, ggen.UInt8(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementFromMont(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Assembly implementation must be consistent with generic one", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.fromMont() + _fromMontGeneric(&d) + return c.Equal(&d) + }, + genA, + )) + + properties.Property("x.fromMont().toMont() == x", prop.ForAll( + func(a testPairElement) bool { + c := a.element + c.fromMont().toMont() + return c.Equal(&a.element) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementJSON(t *testing.T) { + assert := require.New(t) + + type S struct { + A Element + B [3]Element + C *Element + D *Element + } + + // encode to JSON + var s S + s.A.SetString("-1") + s.B[2].SetUint64(42) + s.D = new(Element).SetUint64(8000) + + encoded, err := json.Marshal(&s) + assert.NoError(err) + const expected = "{\"A\":-1,\"B\":[0,0,42],\"C\":null,\"D\":8000}" + assert.Equal(expected, string(encoded)) + + // decode valid + var decoded S + err = json.Unmarshal([]byte(expected), &decoded) + assert.NoError(err) + + assert.Equal(s, decoded, "element -> json -> element round trip failed") + + // decode hex and string values + withHexValues := "{\"A\":\"-1\",\"B\":[0,\"0x00000\",\"0x2A\"],\"C\":null,\"D\":\"8000\"}" + + var decodedS S + err = json.Unmarshal([]byte(withHexValues), &decodedS) + assert.NoError(err) + + assert.Equal(s, decodedS, " json with strings -> element failed") + +} + +type testPairElement struct { + element Element + bigint big.Int +} + +func gen() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var g testPairElement + + g.element = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g.element[3] %= (qElement[3] + 1) + } + + for !g.element.smallerThanModulus() { + g.element = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g.element[3] %= (qElement[3] + 1) + } + } + + g.element.BigInt(&g.bigint) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + + genRandomFq := func() Element { + var g Element + + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + } + + return g + } + a := genRandomFq() + + var carry uint64 + a[0], carry = bits.Add64(a[0], qElement[0], carry) + a[1], carry = bits.Add64(a[1], qElement[1], carry) + a[2], carry = bits.Add64(a[2], qElement[2], carry) + a[3], _ = bits.Add64(a[3], qElement[3], carry) + + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/secp256k1/fr/arith.go b/ecc/secp256k1/fr/arith.go new file mode 100644 index 000000000..83c9fd9ef --- /dev/null +++ b/ecc/secp256k1/fr/arith.go @@ -0,0 +1,60 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import ( + "math/bits" +) + +// madd0 hi = a*b + c (discards lo bits) +func madd0(a, b, c uint64) (hi uint64) { + var carry, lo uint64 + hi, lo = bits.Mul64(a, b) + _, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +// madd1 hi, lo = a*b + c +func madd1(a, b, c uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +// madd2 hi, lo = a*b + c + d +func madd2(a, b, c, d uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, e, carry) + return +} diff --git a/ecc/secp256k1/fr/doc.go b/ecc/secp256k1/fr/doc.go new file mode 100644 index 000000000..3c90f775e --- /dev/null +++ b/ecc/secp256k1/fr/doc.go @@ -0,0 +1,53 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// Package fr contains field arithmetic operations for modulus = 0xffffff...364141. +// +// The API is similar to math/big (big.Int), but the operations are significantly faster (up to 20x for the modular multiplication on amd64, see also https://hackmd.io/@gnark/modular_multiplication) +// +// The modulus is hardcoded in all the operations. +// +// Field elements are represented as an array, and assumed to be in Montgomery form in all methods: +// +// type Element [4]uint64 +// +// # Usage +// +// Example API signature: +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element +// +// and can be used like so: +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) +// +// Modulus q = +// +// q[base10] = 115792089237316195423570985008687907852837564279074904382605163141518161494337 +// q[base16] = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141 +// +// # Warning +// +// This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. +package fr diff --git a/ecc/secp256k1/fr/element.go b/ecc/secp256k1/fr/element.go new file mode 100644 index 000000000..8023eb297 --- /dev/null +++ b/ecc/secp256k1/fr/element.go @@ -0,0 +1,1277 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import ( + "crypto/rand" + "encoding/binary" + "errors" + "github.com/consensys/gnark-crypto/field" + "io" + "math/big" + "math/bits" + "reflect" + "strconv" + "strings" +) + +// Element represents a field element stored on 4 words (uint64) +// +// Element are assumed to be in Montgomery form in all methods. +// +// Modulus q = +// +// q[base10] = 115792089237316195423570985008687907852837564279074904382605163141518161494337 +// q[base16] = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141 +// +// # Warning +// +// This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. +type Element [4]uint64 + +const ( + Limbs = 4 // number of 64 bits words needed to represent a Element + Bits = 256 // number of bits needed to represent a Element + Bytes = 32 // number of bytes needed to represent a Element +) + +// Field modulus q +const ( + q0 uint64 = 13822214165235122497 + q1 uint64 = 13451932020343611451 + q2 uint64 = 18446744073709551614 + q3 uint64 = 18446744073709551615 +) + +var qElement = Element{ + q0, + q1, + q2, + q3, +} + +var _modulus big.Int // q stored as big.Int + +// Modulus returns q as a big.Int +// +// q[base10] = 115792089237316195423570985008687907852837564279074904382605163141518161494337 +// q[base16] = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141 +func Modulus() *big.Int { + return new(big.Int).Set(&_modulus) +} + +// q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +// used for Montgomery reduction +const qInvNeg uint64 = 5408259542528602431 + +func init() { + _modulus.SetString("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", 16) +} + +// NewElement returns a new Element from a uint64 value +// +// it is equivalent to +// +// var v Element +// v.SetUint64(...) +func NewElement(v uint64) Element { + z := Element{v} + z.Mul(&z, &rSquare) + return z +} + +// SetUint64 sets z to v and returns z +func (z *Element) SetUint64(v uint64) *Element { + // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form + *z = Element{v} + return z.Mul(z, &rSquare) // z.toMont() +} + +// SetInt64 sets z to v and returns z +func (z *Element) SetInt64(v int64) *Element { + + // absolute value of v + m := v >> 63 + z.SetUint64(uint64((v ^ m) - m)) + + if m != 0 { + // v is negative + z.Neg(z) + } + + return z +} + +// Set z = x and returns z +func (z *Element) Set(x *Element) *Element { + z[0] = x[0] + z[1] = x[1] + z[2] = x[2] + z[3] = x[3] + return z +} + +// SetInterface converts provided interface into Element +// returns an error if provided type is not supported +// supported types: +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte +func (z *Element) SetInterface(i1 interface{}) (*Element, error) { + if i1 == nil { + return nil, errors.New("can't set fr.Element with ") + } + + switch c1 := i1.(type) { + case Element: + return z.Set(&c1), nil + case *Element: + if c1 == nil { + return nil, errors.New("can't set fr.Element with ") + } + return z.Set(c1), nil + case uint8: + return z.SetUint64(uint64(c1)), nil + case uint16: + return z.SetUint64(uint64(c1)), nil + case uint32: + return z.SetUint64(uint64(c1)), nil + case uint: + return z.SetUint64(uint64(c1)), nil + case uint64: + return z.SetUint64(c1), nil + case int8: + return z.SetInt64(int64(c1)), nil + case int16: + return z.SetInt64(int64(c1)), nil + case int32: + return z.SetInt64(int64(c1)), nil + case int64: + return z.SetInt64(c1), nil + case int: + return z.SetInt64(int64(c1)), nil + case string: + return z.SetString(c1) + case *big.Int: + if c1 == nil { + return nil, errors.New("can't set fr.Element with ") + } + return z.SetBigInt(c1), nil + case big.Int: + return z.SetBigInt(&c1), nil + case []byte: + return z.SetBytes(c1), nil + default: + return nil, errors.New("can't set fr.Element from type " + reflect.TypeOf(i1).String()) + } +} + +// SetZero z = 0 +func (z *Element) SetZero() *Element { + z[0] = 0 + z[1] = 0 + z[2] = 0 + z[3] = 0 + return z +} + +// SetOne z = 1 (in Montgomery form) +func (z *Element) SetOne() *Element { + z[0] = 4624529908474429119 + z[1] = 4994812053365940164 + z[2] = 1 + z[3] = 0 + return z +} + +// Div z = x*y⁻¹ (mod q) +func (z *Element) Div(x, y *Element) *Element { + var yInv Element + yInv.Inverse(y) + z.Mul(x, &yInv) + return z +} + +// Bit returns the i'th bit, with lsb == bit 0. +// +// It is the responsibility of the caller to convert from Montgomery to Regular form if needed. +func (z *Element) Bit(i uint64) uint64 { + j := i / 64 + if j >= 4 { + return 0 + } + return uint64(z[j] >> (i % 64) & 1) +} + +// Equal returns z == x; constant-time +func (z *Element) Equal(x *Element) bool { + return z.NotEqual(x) == 0 +} + +// NotEqual returns 0 if and only if z == x; constant-time +func (z *Element) NotEqual(x *Element) uint64 { + return (z[3] ^ x[3]) | (z[2] ^ x[2]) | (z[1] ^ x[1]) | (z[0] ^ x[0]) +} + +// IsZero returns z == 0 +func (z *Element) IsZero() bool { + return (z[3] | z[2] | z[1] | z[0]) == 0 +} + +// IsOne returns z == 1 +func (z *Element) IsOne() bool { + return (z[3] ^ 0 | z[2] ^ 1 | z[1] ^ 4994812053365940164 | z[0] ^ 4624529908474429119) == 0 +} + +// IsUint64 reports whether z can be represented as an uint64. +func (z *Element) IsUint64() bool { + zz := *z + zz.fromMont() + return zz.FitsOnOneWord() +} + +// Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. +func (z *Element) Uint64() uint64 { + return z.Bits()[0] +} + +// FitsOnOneWord reports whether z words (except the least significant word) are 0 +// +// It is the responsibility of the caller to convert from Montgomery to Regular form if needed. +func (z *Element) FitsOnOneWord() bool { + return (z[3] | z[2] | z[1]) == 0 +} + +// Cmp compares (lexicographic order) z and x and returns: +// +// -1 if z < x +// 0 if z == x +// +1 if z > x +func (z *Element) Cmp(x *Element) int { + _z := z.Bits() + _x := x.Bits() + if _z[3] > _x[3] { + return 1 + } else if _z[3] < _x[3] { + return -1 + } + if _z[2] > _x[2] { + return 1 + } else if _z[2] < _x[2] { + return -1 + } + if _z[1] > _x[1] { + return 1 + } else if _z[1] < _x[1] { + return -1 + } + if _z[0] > _x[0] { + return 1 + } else if _z[0] < _x[0] { + return -1 + } + return 0 +} + +// LexicographicallyLargest returns true if this element is strictly lexicographically +// larger than its negation, false otherwise +func (z *Element) LexicographicallyLargest() bool { + // adapted from github.com/zkcrypto/bls12_381 + // we check if the element is larger than (q-1) / 2 + // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 + + _z := z.Bits() + + var b uint64 + _, b = bits.Sub64(_z[0], 16134479119472337057, 0) + _, b = bits.Sub64(_z[1], 6725966010171805725, b) + _, b = bits.Sub64(_z[2], 18446744073709551615, b) + _, b = bits.Sub64(_z[3], 9223372036854775807, b) + + return b == 0 +} + +// SetRandom sets z to a uniform random value in [0, q). +// +// This might error only if reading from crypto/rand.Reader errors, +// in which case, value of z is undefined. +func (z *Element) SetRandom() (*Element, error) { + // this code is generated for all modulus + // and derived from go/src/crypto/rand/util.go + + // l is number of limbs * 8; the number of bytes needed to reconstruct 4 uint64 + const l = 32 + + // bitLen is the maximum bit length needed to encode a value < q. + const bitLen = 256 + + // k is the maximum byte length needed to encode a value < q. + const k = (bitLen + 7) / 8 + + // b is the number of bits in the most significant byte of q-1. + b := uint(bitLen % 8) + if b == 0 { + b = 8 + } + + var bytes [l]byte + + for { + // note that bytes[k:l] is always 0 + if _, err := io.ReadFull(rand.Reader, bytes[:k]); err != nil { + return nil, err + } + + // Clear unused bits in in the most signicant byte to increase probability + // that the candidate is < q. + bytes[k-1] &= uint8(int(1<> 1 + z[0] = z[0]>>1 | z[1]<<63 + z[1] = z[1]>>1 | z[2]<<63 + z[2] = z[2]>>1 | z[3]<<63 + z[3] >>= 1 + + if carry != 0 { + // when we added q, the result was larger than our available limbs + // when we shift right, we need to set the highest bit + z[3] |= (1 << 63) + } + +} + +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// sets and returns z = z * 1 +func (z *Element) fromMont() *Element { + fromMont(z) + return z +} + +// Add z = x + y (mod q) +func (z *Element) Add(x, y *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], y[0], 0) + z[1], carry = bits.Add64(x[1], y[1], carry) + z[2], carry = bits.Add64(x[2], y[2], carry) + z[3], carry = bits.Add64(x[3], y[3], carry) + // if we overflowed the last addition, z >= q + // if z >= q, z = z - q + if carry != 0 { + var b uint64 + // we overflowed, so z >= q + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + return z + } + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Double z = x + x (mod q), aka Lsh 1 +func (z *Element) Double(x *Element) *Element { + + var carry uint64 + z[0], carry = bits.Add64(x[0], x[0], 0) + z[1], carry = bits.Add64(x[1], x[1], carry) + z[2], carry = bits.Add64(x[2], x[2], carry) + z[3], carry = bits.Add64(x[3], x[3], carry) + // if we overflowed the last addition, z >= q + // if z >= q, z = z - q + if carry != 0 { + var b uint64 + // we overflowed, so z >= q + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + return z + } + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Sub z = x - y (mod q) +func (z *Element) Sub(x, y *Element) *Element { + var b uint64 + z[0], b = bits.Sub64(x[0], y[0], 0) + z[1], b = bits.Sub64(x[1], y[1], b) + z[2], b = bits.Sub64(x[2], y[2], b) + z[3], b = bits.Sub64(x[3], y[3], b) + if b != 0 { + var c uint64 + z[0], c = bits.Add64(z[0], q0, 0) + z[1], c = bits.Add64(z[1], q1, c) + z[2], c = bits.Add64(z[2], q2, c) + z[3], _ = bits.Add64(z[3], q3, c) + } + return z +} + +// Neg z = q - x +func (z *Element) Neg(x *Element) *Element { + if x.IsZero() { + z.SetZero() + return z + } + var borrow uint64 + z[0], borrow = bits.Sub64(q0, x[0], 0) + z[1], borrow = bits.Sub64(q1, x[1], borrow) + z[2], borrow = bits.Sub64(q2, x[2], borrow) + z[3], _ = bits.Sub64(q3, x[3], borrow) + return z +} + +// Select is a constant-time conditional move. +// If c=0, z = x0. Else z = x1 +func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { + cC := uint64((int64(c) | -int64(c)) >> 63) // "canonicized" into: 0 if c=0, -1 otherwise + z[0] = x0[0] ^ cC&(x0[0]^x1[0]) + z[1] = x0[1] ^ cC&(x0[1]^x1[1]) + z[2] = x0[2] ^ cC&(x0[2]^x1[2]) + z[3] = x0[3] ^ cC&(x0[3]^x1[3]) + return z +} + +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. +func _mulGeneric(z, x, y *Element) { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [5]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + + if t[4] != 0 { + // we need to reduce, we have a result on 5 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], _ = bits.Sub64(t[3], q3, b) + return + } + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } +} + +func _fromMontGeneric(z *Element) { + // the following lines implement z = z * 1 + // with a modified CIOS montgomery multiplication + // see Mul for algorithm documentation + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + { + // m = z[0]n'[0] mod W + m := z[0] * qInvNeg + C := madd0(m, q0, z[0]) + C, z[0] = madd2(m, q1, z[1], C) + C, z[1] = madd2(m, q2, z[2], C) + C, z[2] = madd2(m, q3, z[3], C) + z[3] = C + } + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } +} + +func _reduceGeneric(z *Element) { + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } +} + +// BatchInvert returns a new slice with every element inverted. +// Uses Montgomery batch inversion trick +func BatchInvert(a []Element) []Element { + res := make([]Element, len(a)) + if len(a) == 0 { + return res + } + + zeroes := make([]bool, len(a)) + accumulator := One() + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + zeroes[i] = true + continue + } + res[i] = accumulator + accumulator.Mul(&accumulator, &a[i]) + } + + accumulator.Inverse(&accumulator) + + for i := len(a) - 1; i >= 0; i-- { + if zeroes[i] { + continue + } + res[i].Mul(&res[i], &accumulator) + accumulator.Mul(&accumulator, &a[i]) + } + + return res +} + +func _butterflyGeneric(a, b *Element) { + t := *a + a.Add(a, b) + b.Sub(&t, b) +} + +// BitLen returns the minimum number of bits needed to represent z +// returns 0 if z == 0 +func (z *Element) BitLen() int { + if z[3] != 0 { + return 192 + bits.Len64(z[3]) + } + if z[2] != 0 { + return 128 + bits.Len64(z[2]) + } + if z[1] != 0 { + return 64 + bits.Len64(z[1]) + } + return bits.Len64(z[0]) +} + +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + +// Exp z = xᵏ (mod q) +func (z *Element) Exp(x Element, k *big.Int) *Element { + if k.IsUint64() && k.Uint64() == 0 { + return z.SetOne() + } + + e := k + if k.Sign() == -1 { + // negative k, we invert + // if k < 0: xᵏ (mod q) == (x⁻¹)ᵏ (mod q) + x.Inverse(&x) + + // we negate k in a temp big.Int since + // Int.Bit(_) of k and -k is different + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) + e.Neg(k) + } + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z +} + +// rSquare where r is the Montgommery constant +// see section 2.3.2 of Tolga Acar's thesis +// https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf +var rSquare = Element{ + 9902555850136342848, + 8364476168144746616, + 16616019711348246470, + 11342065889886772165, +} + +// toMont converts z to Montgomery form +// sets and returns z = z * r² +func (z *Element) toMont() *Element { + return z.Mul(z, &rSquare) +} + +// String returns the decimal representation of z as generated by +// z.Text(10). +func (z *Element) String() string { + return z.Text(10) +} + +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[24:32], z[0]) + binary.BigEndian.PutUint64(b[16:24], z[1]) + binary.BigEndian.PutUint64(b[8:16], z[2]) + binary.BigEndian.PutUint64(b[0:8], z[3]) + + return res.SetBytes(b[:]) +} + +// Text returns the string representation of z in the given base. +// Base must be between 2 and 36, inclusive. The result uses the +// lower-case letters 'a' to 'z' for digit values 10 to 35. +// No prefix (such as "0x") is added to the string. If z is a nil +// pointer it returns "". +// If base == 10 and -z fits in a uint16 prefix "-" is added to the string. +func (z *Element) Text(base int) string { + if base < 2 || base > 36 { + panic("invalid base") + } + if z == nil { + return "" + } + + const maxUint16 = 65535 + if base == 10 { + var zzNeg Element + zzNeg.Neg(z) + zzNeg.fromMont() + if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { + return "-" + strconv.FormatUint(zzNeg[0], base) + } + } + zz := *z + zz.fromMont() + if zz.FitsOnOneWord() { + return strconv.FormatUint(zz[0], base) + } + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) + return r +} + +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) +} + +// ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead +func (z Element) ToBigIntRegular(res *big.Int) *big.Int { + z.fromMont() + return z.toBigInt(res) +} + +// Bits provides access to z by returning its value as a little-endian [4]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [4]uint64 { + _z := *z + fromMont(&_z) + return _z +} + +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) + return +} + +// Marshal returns the value of z as a big-endian byte slice +func (z *Element) Marshal() []byte { + b := z.Bytes() + return b[:] +} + +// SetBytes interprets e as the bytes of a big-endian unsigned integer, +// sets z to that value, and returns z. +func (z *Element) SetBytes(e []byte) *Element { + if len(e) == Bytes { + // fast path + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } + } + + // slow path. + // get a big int from our pool + vv := field.BigIntPool.Get() + vv.SetBytes(e) + + // set big int + z.SetBigInt(vv) + + // put temporary object back in pool + field.BigIntPool.Put(vv) + + return z +} + +// SetBytesCanonical interprets e as the bytes of a big-endian 32-byte integer. +// If e is not a 32-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid fr.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + +// SetBigInt sets z to v and returns z +func (z *Element) SetBigInt(v *big.Int) *Element { + z.SetZero() + + var zero big.Int + + // fast path + c := v.Cmp(&_modulus) + if c == 0 { + // v == 0 + return z + } else if c != 1 && v.Cmp(&zero) != -1 { + // 0 < v < q + return z.setBigInt(v) + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + // copy input + modular reduction + vv.Mod(v, &_modulus) + + // set big int byte value + z.setBigInt(vv) + + // release object into pool + field.BigIntPool.Put(vv) + return z +} + +// setBigInt assumes 0 ⩽ v < q +func (z *Element) setBigInt(v *big.Int) *Element { + vBits := v.Bits() + + if bits.UintSize == 64 { + for i := 0; i < len(vBits); i++ { + z[i] = uint64(vBits[i]) + } + } else { + for i := 0; i < len(vBits); i++ { + if i%2 == 0 { + z[i/2] = uint64(vBits[i]) + } else { + z[i/2] |= uint64(vBits[i]) << 32 + } + } + } + + return z.toMont() +} + +// SetString creates a big.Int with number and calls SetBigInt on z +// +// The number prefix determines the actual base: A prefix of +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 +// and no prefix is accepted. +// +// For base 16, lower and upper case letters are considered the same: +// The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. +// +// An underscore character ”_” may appear between a base +// prefix and an adjacent digit, and between successive digits; such +// underscores do not change the value of the number. +// Incorrect placement of underscores is reported as a panic if there +// are no other errors. +// +// If the number is invalid this method leaves z unchanged and returns nil, error. +func (z *Element) SetString(number string) (*Element, error) { + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + if _, ok := vv.SetString(number, 0); !ok { + return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) + } + + z.SetBigInt(vv) + + // release object into pool + field.BigIntPool.Put(vv) + + return z, nil +} + +// MarshalJSON returns json encoding of z (z.Text(10)) +// If z == nil, returns null +func (z *Element) MarshalJSON() ([]byte, error) { + if z == nil { + return []byte("null"), nil + } + const maxSafeBound = 15 // we encode it as number if it's small + s := z.Text(10) + if len(s) <= maxSafeBound { + return []byte(s), nil + } + var sbb strings.Builder + sbb.WriteByte('"') + sbb.WriteString(s) + sbb.WriteByte('"') + return []byte(sbb.String()), nil +} + +// UnmarshalJSON accepts numbers and strings as input +// See Element.SetString for valid prefixes (0x, 0b, ...) +func (z *Element) UnmarshalJSON(data []byte) error { + s := string(data) + if len(s) > Bits*3 { + return errors.New("value too large (max = Element.Bits * 3)") + } + + // we accept numbers and strings, remove leading and trailing quotes if any + if len(s) > 0 && s[0] == '"' { + s = s[1:] + } + if len(s) > 0 && s[len(s)-1] == '"' { + s = s[:len(s)-1] + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + if _, ok := vv.SetString(s, 0); !ok { + return errors.New("can't parse into a big.Int: " + s) + } + + z.SetBigInt(vv) + + // release object into pool + field.BigIntPool.Put(vv) + return nil +} + +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 32-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[24:32]) + z[1] = binary.BigEndian.Uint64((*b)[16:24]) + z[2] = binary.BigEndian.Uint64((*b)[8:16]) + z[3] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[24:32], e[0]) + binary.BigEndian.PutUint64((*b)[16:24], e[1]) + binary.BigEndian.PutUint64((*b)[8:16], e[2]) + binary.BigEndian.PutUint64((*b)[0:8], e[3]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + z[1] = binary.LittleEndian.Uint64((*b)[8:16]) + z[2] = binary.LittleEndian.Uint64((*b)[16:24]) + z[3] = binary.LittleEndian.Uint64((*b)[24:32]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid fr.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) + binary.LittleEndian.PutUint64((*b)[8:16], e[1]) + binary.LittleEndian.PutUint64((*b)[16:24], e[2]) + binary.LittleEndian.PutUint64((*b)[24:32], e[3]) +} + +func (littleEndian) String() string { return "LittleEndian" } + +// Legendre returns the Legendre symbol of z (either +1, -1, or 0.) +func (z *Element) Legendre() int { + var l Element + // z^((q-1)/2) + l.expByLegendreExp(*z) + + if l.IsZero() { + return 0 + } + + // if l == 1 + if l.IsOne() { + return 1 + } + return -1 +} + +// Sqrt z = √x (mod q) +// if the square root doesn't exist (x is not a square mod q) +// Sqrt leaves z unchanged and returns nil +func (z *Element) Sqrt(x *Element) *Element { + // q ≡ 1 (mod 4) + // see modSqrtTonelliShanks in math/big/int.go + // using https://www.maa.org/sites/default/files/pdf/upload_library/22/Polya/07468342.di020786.02p0470a.pdf + + var y, b, t, w Element + // w = x^((s-1)/2)) + w.expBySqrtExp(*x) + + // y = x^((s+1)/2)) = w * x + y.Mul(x, &w) + + // b = xˢ = w * w * x = y * x + b.Mul(&w, &y) + + // g = nonResidue ^ s + var g = Element{ + 16727483617216526287, + 14607548025256143850, + 15265302390528700431, + 15433920720005950142, + } + r := uint64(6) + + // compute legendre symbol + // t = x^((q-1)/2) = r-1 squaring of xˢ + t = b + for i := uint64(0); i < r-1; i++ { + t.Square(&t) + } + if t.IsZero() { + return z.SetZero() + } + if !t.IsOne() { + // t != 1, we don't have a square root + return nil + } + for { + var m uint64 + t = b + + // for t != 1 + for !t.IsOne() { + t.Square(&t) + m++ + } + + if m == 0 { + return z.Set(&y) + } + // t = g^(2^(r-m-1)) (mod q) + ge := int(r - m - 1) + t = g + for ge > 0 { + t.Square(&t) + ge-- + } + + g.Square(&t) + y.Mul(&y, &t) + b.Mul(&b, &g) + r = m + } +} + +// Inverse z = x⁻¹ (mod q) +// +// note: allocates a big.Int (math/big) +func (z *Element) Inverse(x *Element) *Element { + var _xNonMont big.Int + x.BigInt(&_xNonMont) + _xNonMont.ModInverse(&_xNonMont, Modulus()) + z.SetBigInt(&_xNonMont) + return z +} diff --git a/ecc/secp256k1/fr/element_exp.go b/ecc/secp256k1/fr/element_exp.go new file mode 100644 index 000000000..a93ff5459 --- /dev/null +++ b/ecc/secp256k1/fr/element_exp.go @@ -0,0 +1,695 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +// expBySqrtExp is equivalent to z.Exp(x, 1fffffffffffffffffffffffffffffffd755db9cd5e9140777fa4bd19a06c82) +// +// uses github.com/mmcloughlin/addchain v0.4.0 to generate a shorter addition chain +func (z *Element) expBySqrtExp(x Element) *Element { + // addition chain: + // + // _10 = 2*1 + // _11 = 1 + _10 + // _101 = _10 + _11 + // _111 = _10 + _101 + // _1001 = _10 + _111 + // _1011 = _10 + _1001 + // _1101 = _10 + _1011 + // _1111 = _10 + _1101 + // _1111000 = _1111 << 3 + // _1111111 = _111 + _1111000 + // _11111110 = 2*_1111111 + // _11111111 = 1 + _11111110 + // i21 = _11111111 << 7 + // x15 = _1111111 + i21 + // i30 = i21 << 8 + // x23 = x15 + i30 + // x31 = i30 << 8 + x23 + // x32 = 2*x31 + 1 + // x64 = x32 << 32 + x32 + // x96 = x64 << 32 + x32 + // x127 = x96 << 31 + x31 + // i154 = ((x127 << 5 + _1011) << 3 + _101) << 4 + // i166 = ((_101 + i154) << 4 + _111) << 5 + _1101 + // i181 = ((i166 << 2 + _11) << 5 + _111) << 6 + // i193 = ((_1101 + i181) << 5 + _1011) << 4 + _1101 + // i214 = ((i193 << 3 + 1) << 6 + _101) << 10 + // i230 = ((_111 + i214) << 4 + _111) << 9 + _11111111 + // i247 = ((i230 << 5 + _1001) << 6 + _1011) << 4 + // i261 = ((_1101 + i247) << 5 + _11) << 6 + _1101 + // i283 = ((i261 << 10 + _1101) << 4 + _1001) << 6 + // return 2*(1 + i283) + // + // Operations: 246 squares 39 multiplies + + // Allocate Temporaries. + var ( + t0 = new(Element) + t1 = new(Element) + t2 = new(Element) + t3 = new(Element) + t4 = new(Element) + t5 = new(Element) + t6 = new(Element) + t7 = new(Element) + t8 = new(Element) + ) + + // var t0,t1,t2,t3,t4,t5,t6,t7,t8 Element + // Step 1: t3 = x^0x2 + t3.Square(&x) + + // Step 2: t1 = x^0x3 + t1.Mul(&x, t3) + + // Step 3: t5 = x^0x5 + t5.Mul(t3, t1) + + // Step 4: t4 = x^0x7 + t4.Mul(t3, t5) + + // Step 5: z = x^0x9 + z.Mul(t3, t4) + + // Step 6: t2 = x^0xb + t2.Mul(t3, z) + + // Step 7: t0 = x^0xd + t0.Mul(t3, t2) + + // Step 8: t3 = x^0xf + t3.Mul(t3, t0) + + // Step 11: t3 = x^0x78 + for s := 0; s < 3; s++ { + t3.Square(t3) + } + + // Step 12: t6 = x^0x7f + t6.Mul(t4, t3) + + // Step 13: t3 = x^0xfe + t3.Square(t6) + + // Step 14: t3 = x^0xff + t3.Mul(&x, t3) + + // Step 21: t7 = x^0x7f80 + t7.Square(t3) + for s := 1; s < 7; s++ { + t7.Square(t7) + } + + // Step 22: t6 = x^0x7fff + t6.Mul(t6, t7) + + // Step 30: t7 = x^0x7f8000 + for s := 0; s < 8; s++ { + t7.Square(t7) + } + + // Step 31: t6 = x^0x7fffff + t6.Mul(t6, t7) + + // Step 39: t7 = x^0x7f800000 + for s := 0; s < 8; s++ { + t7.Square(t7) + } + + // Step 40: t6 = x^0x7fffffff + t6.Mul(t6, t7) + + // Step 41: t7 = x^0xfffffffe + t7.Square(t6) + + // Step 42: t7 = x^0xffffffff + t7.Mul(&x, t7) + + // Step 74: t8 = x^0xffffffff00000000 + t8.Square(t7) + for s := 1; s < 32; s++ { + t8.Square(t8) + } + + // Step 75: t8 = x^0xffffffffffffffff + t8.Mul(t7, t8) + + // Step 107: t8 = x^0xffffffffffffffff00000000 + for s := 0; s < 32; s++ { + t8.Square(t8) + } + + // Step 108: t7 = x^0xffffffffffffffffffffffff + t7.Mul(t7, t8) + + // Step 139: t7 = x^0x7fffffffffffffffffffffff80000000 + for s := 0; s < 31; s++ { + t7.Square(t7) + } + + // Step 140: t6 = x^0x7fffffffffffffffffffffffffffffff + t6.Mul(t6, t7) + + // Step 145: t6 = x^0xfffffffffffffffffffffffffffffffe0 + for s := 0; s < 5; s++ { + t6.Square(t6) + } + + // Step 146: t6 = x^0xfffffffffffffffffffffffffffffffeb + t6.Mul(t2, t6) + + // Step 149: t6 = x^0x7fffffffffffffffffffffffffffffff58 + for s := 0; s < 3; s++ { + t6.Square(t6) + } + + // Step 150: t6 = x^0x7fffffffffffffffffffffffffffffff5d + t6.Mul(t5, t6) + + // Step 154: t6 = x^0x7fffffffffffffffffffffffffffffff5d0 + for s := 0; s < 4; s++ { + t6.Square(t6) + } + + // Step 155: t6 = x^0x7fffffffffffffffffffffffffffffff5d5 + t6.Mul(t5, t6) + + // Step 159: t6 = x^0x7fffffffffffffffffffffffffffffff5d50 + for s := 0; s < 4; s++ { + t6.Square(t6) + } + + // Step 160: t6 = x^0x7fffffffffffffffffffffffffffffff5d57 + t6.Mul(t4, t6) + + // Step 165: t6 = x^0xfffffffffffffffffffffffffffffffebaae0 + for s := 0; s < 5; s++ { + t6.Square(t6) + } + + // Step 166: t6 = x^0xfffffffffffffffffffffffffffffffebaaed + t6.Mul(t0, t6) + + // Step 168: t6 = x^0x3fffffffffffffffffffffffffffffffaeabb4 + for s := 0; s < 2; s++ { + t6.Square(t6) + } + + // Step 169: t6 = x^0x3fffffffffffffffffffffffffffffffaeabb7 + t6.Mul(t1, t6) + + // Step 174: t6 = x^0x7fffffffffffffffffffffffffffffff5d576e0 + for s := 0; s < 5; s++ { + t6.Square(t6) + } + + // Step 175: t6 = x^0x7fffffffffffffffffffffffffffffff5d576e7 + t6.Mul(t4, t6) + + // Step 181: t6 = x^0x1fffffffffffffffffffffffffffffffd755db9c0 + for s := 0; s < 6; s++ { + t6.Square(t6) + } + + // Step 182: t6 = x^0x1fffffffffffffffffffffffffffffffd755db9cd + t6.Mul(t0, t6) + + // Step 187: t6 = x^0x3fffffffffffffffffffffffffffffffaeabb739a0 + for s := 0; s < 5; s++ { + t6.Square(t6) + } + + // Step 188: t6 = x^0x3fffffffffffffffffffffffffffffffaeabb739ab + t6.Mul(t2, t6) + + // Step 192: t6 = x^0x3fffffffffffffffffffffffffffffffaeabb739ab0 + for s := 0; s < 4; s++ { + t6.Square(t6) + } + + // Step 193: t6 = x^0x3fffffffffffffffffffffffffffffffaeabb739abd + t6.Mul(t0, t6) + + // Step 196: t6 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e8 + for s := 0; s < 3; s++ { + t6.Square(t6) + } + + // Step 197: t6 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e9 + t6.Mul(&x, t6) + + // Step 203: t6 = x^0x7fffffffffffffffffffffffffffffff5d576e7357a40 + for s := 0; s < 6; s++ { + t6.Square(t6) + } + + // Step 204: t5 = x^0x7fffffffffffffffffffffffffffffff5d576e7357a45 + t5.Mul(t5, t6) + + // Step 214: t5 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e91400 + for s := 0; s < 10; s++ { + t5.Square(t5) + } + + // Step 215: t5 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e91407 + t5.Mul(t4, t5) + + // Step 219: t5 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e914070 + for s := 0; s < 4; s++ { + t5.Square(t5) + } + + // Step 220: t4 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e914077 + t4.Mul(t4, t5) + + // Step 229: t4 = x^0x3fffffffffffffffffffffffffffffffaeabb739abd2280ee00 + for s := 0; s < 9; s++ { + t4.Square(t4) + } + + // Step 230: t3 = x^0x3fffffffffffffffffffffffffffffffaeabb739abd2280eeff + t3.Mul(t3, t4) + + // Step 235: t3 = x^0x7fffffffffffffffffffffffffffffff5d576e7357a4501ddfe0 + for s := 0; s < 5; s++ { + t3.Square(t3) + } + + // Step 236: t3 = x^0x7fffffffffffffffffffffffffffffff5d576e7357a4501ddfe9 + t3.Mul(z, t3) + + // Step 242: t3 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e9140777fa40 + for s := 0; s < 6; s++ { + t3.Square(t3) + } + + // Step 243: t2 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e9140777fa4b + t2.Mul(t2, t3) + + // Step 247: t2 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e9140777fa4b0 + for s := 0; s < 4; s++ { + t2.Square(t2) + } + + // Step 248: t2 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e9140777fa4bd + t2.Mul(t0, t2) + + // Step 253: t2 = x^0x3fffffffffffffffffffffffffffffffaeabb739abd2280eeff497a0 + for s := 0; s < 5; s++ { + t2.Square(t2) + } + + // Step 254: t1 = x^0x3fffffffffffffffffffffffffffffffaeabb739abd2280eeff497a3 + t1.Mul(t1, t2) + + // Step 260: t1 = x^0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8c0 + for s := 0; s < 6; s++ { + t1.Square(t1) + } + + // Step 261: t1 = x^0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd + t1.Mul(t0, t1) + + // Step 271: t1 = x^0x3fffffffffffffffffffffffffffffffaeabb739abd2280eeff497a33400 + for s := 0; s < 10; s++ { + t1.Square(t1) + } + + // Step 272: t0 = x^0x3fffffffffffffffffffffffffffffffaeabb739abd2280eeff497a3340d + t0.Mul(t0, t1) + + // Step 276: t0 = x^0x3fffffffffffffffffffffffffffffffaeabb739abd2280eeff497a3340d0 + for s := 0; s < 4; s++ { + t0.Square(t0) + } + + // Step 277: z = x^0x3fffffffffffffffffffffffffffffffaeabb739abd2280eeff497a3340d9 + z.Mul(z, t0) + + // Step 283: z = x^0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd03640 + for s := 0; s < 6; s++ { + z.Square(z) + } + + // Step 284: z = x^0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd03641 + z.Mul(&x, z) + + // Step 285: z = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e9140777fa4bd19a06c82 + z.Square(z) + + return z +} + +// expByLegendreExp is equivalent to z.Exp(x, 7fffffffffffffffffffffffffffffff5d576e7357a4501ddfe92f46681b20a0) +// +// uses github.com/mmcloughlin/addchain v0.4.0 to generate a shorter addition chain +func (z *Element) expByLegendreExp(x Element) *Element { + // addition chain: + // + // _10 = 2*1 + // _11 = 1 + _10 + // _101 = _10 + _11 + // _111 = _10 + _101 + // _1001 = _10 + _111 + // _1011 = _10 + _1001 + // _1101 = _10 + _1011 + // _1111 = _10 + _1101 + // _1111000 = _1111 << 3 + // _1111111 = _111 + _1111000 + // _11111110 = 2*_1111111 + // _11111111 = 1 + _11111110 + // i21 = _11111111 << 7 + // x15 = _1111111 + i21 + // i30 = i21 << 8 + // x23 = x15 + i30 + // x31 = i30 << 8 + x23 + // x32 = 2*x31 + 1 + // x64 = x32 << 32 + x32 + // x96 = x64 << 32 + x32 + // x127 = x96 << 31 + x31 + // i154 = ((x127 << 5 + _1011) << 3 + _101) << 4 + // i166 = ((_101 + i154) << 4 + _111) << 5 + _1101 + // i181 = ((i166 << 2 + _11) << 5 + _111) << 6 + // i193 = ((_1101 + i181) << 5 + _1011) << 4 + _1101 + // i214 = ((i193 << 3 + 1) << 6 + _101) << 10 + // i230 = ((_111 + i214) << 4 + _111) << 9 + _11111111 + // i247 = ((i230 << 5 + _1001) << 6 + _1011) << 4 + // i261 = ((_1101 + i247) << 5 + _11) << 6 + _1101 + // i285 = ((i261 << 10 + _1101) << 4 + _1001) << 8 + // return (_101 + i285) << 5 + // + // Operations: 252 squares 39 multiplies + + // Allocate Temporaries. + var ( + t0 = new(Element) + t1 = new(Element) + t2 = new(Element) + t3 = new(Element) + t4 = new(Element) + t5 = new(Element) + t6 = new(Element) + t7 = new(Element) + t8 = new(Element) + ) + + // var t0,t1,t2,t3,t4,t5,t6,t7,t8 Element + // Step 1: t4 = x^0x2 + t4.Square(&x) + + // Step 2: t2 = x^0x3 + t2.Mul(&x, t4) + + // Step 3: z = x^0x5 + z.Mul(t4, t2) + + // Step 4: t5 = x^0x7 + t5.Mul(t4, z) + + // Step 5: t0 = x^0x9 + t0.Mul(t4, t5) + + // Step 6: t3 = x^0xb + t3.Mul(t4, t0) + + // Step 7: t1 = x^0xd + t1.Mul(t4, t3) + + // Step 8: t4 = x^0xf + t4.Mul(t4, t1) + + // Step 11: t4 = x^0x78 + for s := 0; s < 3; s++ { + t4.Square(t4) + } + + // Step 12: t6 = x^0x7f + t6.Mul(t5, t4) + + // Step 13: t4 = x^0xfe + t4.Square(t6) + + // Step 14: t4 = x^0xff + t4.Mul(&x, t4) + + // Step 21: t7 = x^0x7f80 + t7.Square(t4) + for s := 1; s < 7; s++ { + t7.Square(t7) + } + + // Step 22: t6 = x^0x7fff + t6.Mul(t6, t7) + + // Step 30: t7 = x^0x7f8000 + for s := 0; s < 8; s++ { + t7.Square(t7) + } + + // Step 31: t6 = x^0x7fffff + t6.Mul(t6, t7) + + // Step 39: t7 = x^0x7f800000 + for s := 0; s < 8; s++ { + t7.Square(t7) + } + + // Step 40: t6 = x^0x7fffffff + t6.Mul(t6, t7) + + // Step 41: t7 = x^0xfffffffe + t7.Square(t6) + + // Step 42: t7 = x^0xffffffff + t7.Mul(&x, t7) + + // Step 74: t8 = x^0xffffffff00000000 + t8.Square(t7) + for s := 1; s < 32; s++ { + t8.Square(t8) + } + + // Step 75: t8 = x^0xffffffffffffffff + t8.Mul(t7, t8) + + // Step 107: t8 = x^0xffffffffffffffff00000000 + for s := 0; s < 32; s++ { + t8.Square(t8) + } + + // Step 108: t7 = x^0xffffffffffffffffffffffff + t7.Mul(t7, t8) + + // Step 139: t7 = x^0x7fffffffffffffffffffffff80000000 + for s := 0; s < 31; s++ { + t7.Square(t7) + } + + // Step 140: t6 = x^0x7fffffffffffffffffffffffffffffff + t6.Mul(t6, t7) + + // Step 145: t6 = x^0xfffffffffffffffffffffffffffffffe0 + for s := 0; s < 5; s++ { + t6.Square(t6) + } + + // Step 146: t6 = x^0xfffffffffffffffffffffffffffffffeb + t6.Mul(t3, t6) + + // Step 149: t6 = x^0x7fffffffffffffffffffffffffffffff58 + for s := 0; s < 3; s++ { + t6.Square(t6) + } + + // Step 150: t6 = x^0x7fffffffffffffffffffffffffffffff5d + t6.Mul(z, t6) + + // Step 154: t6 = x^0x7fffffffffffffffffffffffffffffff5d0 + for s := 0; s < 4; s++ { + t6.Square(t6) + } + + // Step 155: t6 = x^0x7fffffffffffffffffffffffffffffff5d5 + t6.Mul(z, t6) + + // Step 159: t6 = x^0x7fffffffffffffffffffffffffffffff5d50 + for s := 0; s < 4; s++ { + t6.Square(t6) + } + + // Step 160: t6 = x^0x7fffffffffffffffffffffffffffffff5d57 + t6.Mul(t5, t6) + + // Step 165: t6 = x^0xfffffffffffffffffffffffffffffffebaae0 + for s := 0; s < 5; s++ { + t6.Square(t6) + } + + // Step 166: t6 = x^0xfffffffffffffffffffffffffffffffebaaed + t6.Mul(t1, t6) + + // Step 168: t6 = x^0x3fffffffffffffffffffffffffffffffaeabb4 + for s := 0; s < 2; s++ { + t6.Square(t6) + } + + // Step 169: t6 = x^0x3fffffffffffffffffffffffffffffffaeabb7 + t6.Mul(t2, t6) + + // Step 174: t6 = x^0x7fffffffffffffffffffffffffffffff5d576e0 + for s := 0; s < 5; s++ { + t6.Square(t6) + } + + // Step 175: t6 = x^0x7fffffffffffffffffffffffffffffff5d576e7 + t6.Mul(t5, t6) + + // Step 181: t6 = x^0x1fffffffffffffffffffffffffffffffd755db9c0 + for s := 0; s < 6; s++ { + t6.Square(t6) + } + + // Step 182: t6 = x^0x1fffffffffffffffffffffffffffffffd755db9cd + t6.Mul(t1, t6) + + // Step 187: t6 = x^0x3fffffffffffffffffffffffffffffffaeabb739a0 + for s := 0; s < 5; s++ { + t6.Square(t6) + } + + // Step 188: t6 = x^0x3fffffffffffffffffffffffffffffffaeabb739ab + t6.Mul(t3, t6) + + // Step 192: t6 = x^0x3fffffffffffffffffffffffffffffffaeabb739ab0 + for s := 0; s < 4; s++ { + t6.Square(t6) + } + + // Step 193: t6 = x^0x3fffffffffffffffffffffffffffffffaeabb739abd + t6.Mul(t1, t6) + + // Step 196: t6 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e8 + for s := 0; s < 3; s++ { + t6.Square(t6) + } + + // Step 197: t6 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e9 + t6.Mul(&x, t6) + + // Step 203: t6 = x^0x7fffffffffffffffffffffffffffffff5d576e7357a40 + for s := 0; s < 6; s++ { + t6.Square(t6) + } + + // Step 204: t6 = x^0x7fffffffffffffffffffffffffffffff5d576e7357a45 + t6.Mul(z, t6) + + // Step 214: t6 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e91400 + for s := 0; s < 10; s++ { + t6.Square(t6) + } + + // Step 215: t6 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e91407 + t6.Mul(t5, t6) + + // Step 219: t6 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e914070 + for s := 0; s < 4; s++ { + t6.Square(t6) + } + + // Step 220: t5 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e914077 + t5.Mul(t5, t6) + + // Step 229: t5 = x^0x3fffffffffffffffffffffffffffffffaeabb739abd2280ee00 + for s := 0; s < 9; s++ { + t5.Square(t5) + } + + // Step 230: t4 = x^0x3fffffffffffffffffffffffffffffffaeabb739abd2280eeff + t4.Mul(t4, t5) + + // Step 235: t4 = x^0x7fffffffffffffffffffffffffffffff5d576e7357a4501ddfe0 + for s := 0; s < 5; s++ { + t4.Square(t4) + } + + // Step 236: t4 = x^0x7fffffffffffffffffffffffffffffff5d576e7357a4501ddfe9 + t4.Mul(t0, t4) + + // Step 242: t4 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e9140777fa40 + for s := 0; s < 6; s++ { + t4.Square(t4) + } + + // Step 243: t3 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e9140777fa4b + t3.Mul(t3, t4) + + // Step 247: t3 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e9140777fa4b0 + for s := 0; s < 4; s++ { + t3.Square(t3) + } + + // Step 248: t3 = x^0x1fffffffffffffffffffffffffffffffd755db9cd5e9140777fa4bd + t3.Mul(t1, t3) + + // Step 253: t3 = x^0x3fffffffffffffffffffffffffffffffaeabb739abd2280eeff497a0 + for s := 0; s < 5; s++ { + t3.Square(t3) + } + + // Step 254: t2 = x^0x3fffffffffffffffffffffffffffffffaeabb739abd2280eeff497a3 + t2.Mul(t2, t3) + + // Step 260: t2 = x^0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8c0 + for s := 0; s < 6; s++ { + t2.Square(t2) + } + + // Step 261: t2 = x^0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd + t2.Mul(t1, t2) + + // Step 271: t2 = x^0x3fffffffffffffffffffffffffffffffaeabb739abd2280eeff497a33400 + for s := 0; s < 10; s++ { + t2.Square(t2) + } + + // Step 272: t1 = x^0x3fffffffffffffffffffffffffffffffaeabb739abd2280eeff497a3340d + t1.Mul(t1, t2) + + // Step 276: t1 = x^0x3fffffffffffffffffffffffffffffffaeabb739abd2280eeff497a3340d0 + for s := 0; s < 4; s++ { + t1.Square(t1) + } + + // Step 277: t0 = x^0x3fffffffffffffffffffffffffffffffaeabb739abd2280eeff497a3340d9 + t0.Mul(t0, t1) + + // Step 285: t0 = x^0x3fffffffffffffffffffffffffffffffaeabb739abd2280eeff497a3340d900 + for s := 0; s < 8; s++ { + t0.Square(t0) + } + + // Step 286: z = x^0x3fffffffffffffffffffffffffffffffaeabb739abd2280eeff497a3340d905 + z.Mul(z, t0) + + // Step 291: z = x^0x7fffffffffffffffffffffffffffffff5d576e7357a4501ddfe92f46681b20a0 + for s := 0; s < 5; s++ { + z.Square(z) + } + + return z +} diff --git a/ecc/secp256k1/fr/element_ops_purego.go b/ecc/secp256k1/fr/element_ops_purego.go new file mode 100644 index 000000000..7d3b6e3dd --- /dev/null +++ b/ecc/secp256k1/fr/element_ops_purego.go @@ -0,0 +1,330 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + _x := *x + x.Double(x).Add(x, &_x) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + _x := *x + x.Double(x).Double(x).Add(x, &_x) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y = Element{ + 4778656589038923699, + 9592324472628567287, + 16, + 0, + } + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +func (z *Element) Mul(x, y *Element) *Element { + + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [5]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + C, t[1] = madd1(y[0], x[1], C) + C, t[2] = madd1(y[0], x[2], C) + C, t[3] = madd1(y[0], x[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[1], x[0], t[0]) + C, t[1] = madd2(y[1], x[1], t[1], C) + C, t[2] = madd2(y[1], x[2], t[2], C) + C, t[3] = madd2(y[1], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[2], x[0], t[0]) + C, t[1] = madd2(y[2], x[1], t[1], C) + C, t[2] = madd2(y[2], x[2], t[2], C) + C, t[3] = madd2(y[2], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(y[3], x[0], t[0]) + C, t[1] = madd2(y[3], x[1], t[1], C) + C, t[2] = madd2(y[3], x[2], t[2], C) + C, t[3] = madd2(y[3], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + + if t[4] != 0 { + // we need to reduce, we have a result on 5 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], _ = bits.Sub64(t[3], q3, b) + return z + } + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} + +// Square z = x * x (mod q) +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + var t [5]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(x[0], x[0]) + C, t[1] = madd1(x[0], x[1], C) + C, t[2] = madd1(x[0], x[2], C) + C, t[3] = madd1(x[0], x[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(x[1], x[0], t[0]) + C, t[1] = madd2(x[1], x[1], t[1], C) + C, t[2] = madd2(x[1], x[2], t[2], C) + C, t[3] = madd2(x[1], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(x[2], x[0], t[0]) + C, t[1] = madd2(x[2], x[1], t[1], C) + C, t[2] = madd2(x[2], x[2], t[2], C) + C, t[3] = madd2(x[2], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + // ----------------------------------- + // First loop + + C, t[0] = madd1(x[3], x[0], t[0]) + C, t[1] = madd2(x[3], x[1], t[1], C) + C, t[2] = madd2(x[3], x[2], t[2], C) + C, t[3] = madd2(x[3], x[3], t[3], C) + + t[4], D = bits.Add64(t[4], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + C, t[0] = madd2(m, q1, t[1], C) + C, t[1] = madd2(m, q2, t[2], C) + C, t[2] = madd2(m, q3, t[3], C) + + t[3], C = bits.Add64(t[4], C, 0) + t[4], _ = bits.Add64(0, D, C) + + if t[4] != 0 { + // we need to reduce, we have a result on 5 words + var b uint64 + z[0], b = bits.Sub64(t[0], q0, 0) + z[1], b = bits.Sub64(t[1], q1, b) + z[2], b = bits.Sub64(t[2], q2, b) + z[3], _ = bits.Sub64(t[3], q3, b) + return z + } + + // copy t into z + z[0] = t[0] + z[1] = t[1] + z[2] = t[2] + z[3] = t[3] + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + var b uint64 + z[0], b = bits.Sub64(z[0], q0, 0) + z[1], b = bits.Sub64(z[1], q1, b) + z[2], b = bits.Sub64(z[2], q2, b) + z[3], _ = bits.Sub64(z[3], q3, b) + } + return z +} diff --git a/ecc/secp256k1/fr/element_test.go b/ecc/secp256k1/fr/element_test.go new file mode 100644 index 000000000..00b4192a2 --- /dev/null +++ b/ecc/secp256k1/fr/element_test.go @@ -0,0 +1,2288 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package fr + +import ( + "crypto/rand" + "encoding/json" + "fmt" + "math/big" + "math/bits" + + "testing" + + "github.com/leanovate/gopter" + ggen "github.com/leanovate/gopter/gen" + "github.com/leanovate/gopter/prop" + + "github.com/stretchr/testify/require" +) + +// ------------------------------------------------------------------------------------------------- +// benchmarks +// most benchmarks are rudimentary and should sample a large number of random inputs +// or be run multiple times to ensure it didn't measure the fastest path of the function + +var benchResElement Element + +func BenchmarkElementSelect(b *testing.B) { + var x, y Element + x.SetRandom() + y.SetRandom() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Select(i%3, &x, &y) + } +} + +func BenchmarkElementSetRandom(b *testing.B) { + var x Element + x.SetRandom() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = x.SetRandom() + } +} + +func BenchmarkElementSetBytes(b *testing.B) { + var x Element + x.SetRandom() + bb := x.Bytes() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchResElement.SetBytes(bb[:]) + } + +} + +func BenchmarkElementMulByConstants(b *testing.B) { + b.Run("mulBy3", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy3(&benchResElement) + } + }) + b.Run("mulBy5", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy5(&benchResElement) + } + }) + b.Run("mulBy13", func(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + MulBy13(&benchResElement) + } + }) +} + +func BenchmarkElementInverse(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchResElement.Inverse(&x) + } + +} + +func BenchmarkElementButterfly(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + Butterfly(&x, &benchResElement) + } +} + +func BenchmarkElementExp(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b1, _ := rand.Int(rand.Reader, Modulus()) + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Exp(x, b1) + } +} + +func BenchmarkElementDouble(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Double(&benchResElement) + } +} + +func BenchmarkElementAdd(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Add(&x, &benchResElement) + } +} + +func BenchmarkElementSub(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Sub(&x, &benchResElement) + } +} + +func BenchmarkElementNeg(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Neg(&benchResElement) + } +} + +func BenchmarkElementDiv(b *testing.B) { + var x Element + x.SetRandom() + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Div(&x, &benchResElement) + } +} + +func BenchmarkElementFromMont(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.fromMont() + } +} + +func BenchmarkElementSquare(b *testing.B) { + benchResElement.SetRandom() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Square(&benchResElement) + } +} + +func BenchmarkElementSqrt(b *testing.B) { + var a Element + a.SetUint64(4) + a.Neg(&a) + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Sqrt(&a) + } +} + +func BenchmarkElementMul(b *testing.B) { + x := Element{ + 9902555850136342848, + 8364476168144746616, + 16616019711348246470, + 11342065889886772165, + } + benchResElement.SetOne() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Mul(&benchResElement, &x) + } +} + +func BenchmarkElementCmp(b *testing.B) { + x := Element{ + 9902555850136342848, + 8364476168144746616, + 16616019711348246470, + 11342065889886772165, + } + benchResElement = x + benchResElement[0] = 0 + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchResElement.Cmp(&x) + } +} + +func TestElementCmp(t *testing.T) { + var x, y Element + + if x.Cmp(&y) != 0 { + t.Fatal("x == y") + } + + one := One() + y.Sub(&y, &one) + + if x.Cmp(&y) != -1 { + t.Fatal("x < y") + } + if y.Cmp(&x) != 1 { + t.Fatal("x < y") + } + + x = y + if x.Cmp(&y) != 0 { + t.Fatal("x == y") + } + + x.Sub(&x, &one) + if x.Cmp(&y) != -1 { + t.Fatal("x < y") + } + if y.Cmp(&x) != 1 { + t.Fatal("x < y") + } +} +func TestElementIsRandom(t *testing.T) { + for i := 0; i < 50; i++ { + var x, y Element + x.SetRandom() + y.SetRandom() + if x.Equal(&y) { + t.Fatal("2 random numbers are unlikely to be equal") + } + } +} + +func TestElementIsUint64(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + properties.Property("reduce should output a result smaller than modulus", prop.ForAll( + func(v uint64) bool { + var e Element + e.SetUint64(v) + + if !e.IsUint64() { + return false + } + + return e.Uint64() == v + }, + ggen.UInt64(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementNegZero(t *testing.T) { + var a, b Element + b.SetZero() + for a.IsZero() { + a.SetRandom() + } + a.Neg(&b) + if !a.IsZero() { + t.Fatal("neg(0) != 0") + } +} + +// ------------------------------------------------------------------------------------------------- +// Gopter tests +// most of them are generated with a template + +const ( + nbFuzzShort = 200 + nbFuzz = 1000 +) + +// special values to be used in tests +var staticTestValues []Element + +func init() { + staticTestValues = append(staticTestValues, Element{}) // zero + staticTestValues = append(staticTestValues, One()) // one + staticTestValues = append(staticTestValues, rSquare) // r² + var e, one Element + one.SetOne() + e.Sub(&qElement, &one) + staticTestValues = append(staticTestValues, e) // q - 1 + e.Double(&one) + staticTestValues = append(staticTestValues, e) // 2 + + { + a := qElement + a[0]-- + staticTestValues = append(staticTestValues, a) + } + staticTestValues = append(staticTestValues, Element{0}) + staticTestValues = append(staticTestValues, Element{0, 0}) + staticTestValues = append(staticTestValues, Element{1}) + staticTestValues = append(staticTestValues, Element{0, 1}) + staticTestValues = append(staticTestValues, Element{2}) + staticTestValues = append(staticTestValues, Element{0, 2}) + + { + a := qElement + a[3]-- + staticTestValues = append(staticTestValues, a) + } + { + a := qElement + a[3]-- + a[0]++ + staticTestValues = append(staticTestValues, a) + } + + { + a := qElement + a[3] = 0 + staticTestValues = append(staticTestValues, a) + } + +} + +func TestElementReduce(t *testing.T) { + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, s := range testValues { + expected := s + reduce(&s) + _reduceGeneric(&expected) + if !s.Equal(&expected) { + t.Fatal("reduce failed: asm and generic impl don't match") + } + } + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := genFull() + + properties.Property("reduce should output a result smaller than modulus", prop.ForAll( + func(a Element) bool { + b := a + reduce(&a) + _reduceGeneric(&b) + return a.smallerThanModulus() && a.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementEqual(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("x.Equal(&y) iff x == y; likely false for random pairs", prop.ForAll( + func(a testPairElement, b testPairElement) bool { + return a.element.Equal(&b.element) == (a.element == b.element) + }, + genA, + genB, + )) + + properties.Property("x.Equal(&y) if x == y", prop.ForAll( + func(a testPairElement) bool { + b := a.element + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementBytes(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("SetBytes(Bytes()) should stay constant", prop.ForAll( + func(a testPairElement) bool { + var b Element + bytes := a.element.Bytes() + b.SetBytes(bytes[:]) + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementInverseExp(t *testing.T) { + // inverse must be equal to exp^-2 + exp := Modulus() + exp.Sub(exp, new(big.Int).SetUint64(2)) + + invMatchExp := func(a testPairElement) bool { + var b Element + b.Set(&a.element) + a.element.Inverse(&a.element) + b.Exp(b, exp) + + return a.element.Equal(&b) + } + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + properties := gopter.NewProperties(parameters) + genA := gen() + properties.Property("inv == exp^-2", prop.ForAll(invMatchExp, genA)) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + + parameters.MinSuccessfulTests = 1 + properties = gopter.NewProperties(parameters) + properties.Property("inv(0) == 0", prop.ForAll(invMatchExp, ggen.OneConstOf(testPairElement{}))) + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func mulByConstant(z *Element, c uint8) { + var y Element + y.SetUint64(uint64(c)) + z.Mul(z, &y) +} + +func TestElementMulByConstants(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + implemented := []uint8{0, 1, 2, 3, 5, 13} + properties.Property("mulByConstant", prop.ForAll( + func(a testPairElement) bool { + for _, c := range implemented { + var constant Element + constant.SetUint64(uint64(c)) + + b := a.element + b.Mul(&b, &constant) + + aa := a.element + mulByConstant(&aa, c) + + if !aa.Equal(&b) { + return false + } + } + + return true + }, + genA, + )) + + properties.Property("MulBy3(x) == Mul(x, 3)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(3) + + b := a.element + b.Mul(&b, &constant) + + MulBy3(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("MulBy5(x) == Mul(x, 5)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(5) + + b := a.element + b.Mul(&b, &constant) + + MulBy5(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("MulBy13(x) == Mul(x, 13)", prop.ForAll( + func(a testPairElement) bool { + var constant Element + constant.SetUint64(13) + + b := a.element + b.Mul(&b, &constant) + + MulBy13(&a.element) + + return a.element.Equal(&b) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementLegendre(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("legendre should output same result than big.Int.Jacobi", prop.ForAll( + func(a testPairElement) bool { + return a.element.Legendre() == big.Jacobi(&a.bigint, Modulus()) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementBitLen(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( + func(a testPairElement) bool { + return a.element.fromMont().BitLen() == a.bigint.BitLen() + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementButterflies(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("butterfly0 == a -b; a +b", prop.ForAll( + func(a, b testPairElement) bool { + a0, b0 := a.element, b.element + + _butterflyGeneric(&a.element, &b.element) + Butterfly(&a0, &b0) + + return a.element.Equal(&a0) && b.element.Equal(&b0) + }, + genA, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementLexicographicallyLargest(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("element.Cmp should match LexicographicallyLargest output", prop.ForAll( + func(a testPairElement) bool { + var negA Element + negA.Neg(&a.element) + + cmpResult := a.element.Cmp(&negA) + lResult := a.element.LexicographicallyLargest() + + if lResult && cmpResult == 1 { + return true + } + if !lResult && cmpResult != 1 { + return true + } + return false + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + +} + +func TestElementAdd(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Add: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Add(&a.element, &b.element) + a.element.Add(&a.element, &b.element) + b.element.Add(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Add: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Add(&a.element, &b.element) + + var d, e big.Int + d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, r := range testValues { + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Add(&a.element, &r) + d.Add(&a.bigint, &rb).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Add: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Add(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + for _, b := range testValues { + + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Add(&a, &b) + d.Add(&aBig, &bBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Add failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSub(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Sub: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Sub(&a.element, &b.element) + a.element.Sub(&a.element, &b.element) + b.element.Sub(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Sub: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Sub(&a.element, &b.element) + + var d, e big.Int + d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, r := range testValues { + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Sub(&a.element, &r) + d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Sub: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Sub(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + for _, b := range testValues { + + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Sub(&a, &b) + d.Sub(&aBig, &bBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Sub failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementMul(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Mul: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Mul(&a.element, &b.element) + a.element.Mul(&a.element, &b.element) + b.element.Mul(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Mul: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Mul(&a.element, &b.element) + + var d, e big.Int + d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, r := range testValues { + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Mul(&a.element, &r) + d.Mul(&a.bigint, &rb).Mod(&d, Modulus()) + + // checking generic impl against asm path + var cGeneric Element + _mulGeneric(&cGeneric, &a.element, &r) + if !cGeneric.Equal(&c) { + // need to give context to failing error. + return false + } + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Mul: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Mul(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + properties.Property("Mul: assembly implementation must be consistent with generic one", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + c.Mul(&a.element, &b.element) + _mulGeneric(&d, &a.element, &b.element) + return c.Equal(&d) + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + for _, b := range testValues { + + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Mul(&a, &b) + d.Mul(&aBig, &bBig).Mod(&d, Modulus()) + + // checking asm against generic impl + var cGeneric Element + _mulGeneric(&cGeneric, &a, &b) + if !cGeneric.Equal(&c) { + t.Fatal("Mul failed special test values: asm and generic impl don't match") + } + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Mul failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementDiv(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Div: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Div(&a.element, &b.element) + a.element.Div(&a.element, &b.element) + b.element.Div(&d, &b.element) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Div: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Div(&a.element, &b.element) + + var d, e big.Int + d.ModInverse(&b.bigint, Modulus()) + d.Mul(&d, &a.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, r := range testValues { + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Div(&a.element, &r) + d.ModInverse(&rb, Modulus()) + d.Mul(&d, &a.bigint).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Div: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Div(&a.element, &b.element) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + for _, b := range testValues { + + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Div(&a, &b) + d.ModInverse(&bBig, Modulus()) + d.Mul(&d, &aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Div failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementExp(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genB := gen() + + properties.Property("Exp: having the receiver as operand should output the same result", prop.ForAll( + func(a, b testPairElement) bool { + var c, d Element + d.Set(&a.element) + + c.Exp(a.element, &b.bigint) + a.element.Exp(a.element, &b.bigint) + b.element.Exp(d, &b.bigint) + + return a.element.Equal(&b.element) && a.element.Equal(&c) && b.element.Equal(&c) + }, + genA, + genB, + )) + + properties.Property("Exp: operation result must match big.Int result", prop.ForAll( + func(a, b testPairElement) bool { + { + var c Element + + c.Exp(a.element, &b.bigint) + + var d, e big.Int + d.Exp(&a.bigint, &b.bigint, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + + // fixed elements + // a is random + // r takes special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, r := range testValues { + var d, e, rb big.Int + r.BigInt(&rb) + + var c Element + c.Exp(a.element, &rb) + d.Exp(&a.bigint, &rb, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + return false + } + } + return true + }, + genA, + genB, + )) + + properties.Property("Exp: operation result must be smaller than modulus", prop.ForAll( + func(a, b testPairElement) bool { + var c Element + + c.Exp(a.element, &b.bigint) + + return c.smallerThanModulus() + }, + genA, + genB, + )) + + specialValueTest := func() { + // test special values against special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + for _, b := range testValues { + + var bBig, d, e big.Int + b.BigInt(&bBig) + + var c Element + c.Exp(a, &bBig) + d.Exp(&aBig, &bBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Exp failed special test values") + } + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSquare(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Square: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Square(&a.element) + a.element.Square(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Square: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Square(&a.element) + + var d, e big.Int + d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Square: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Square(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Square(&a) + + var d, e big.Int + d.Mul(&aBig, &aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Square failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementInverse(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Inverse: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Inverse(&a.element) + a.element.Inverse(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Inverse: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Inverse(&a.element) + + var d, e big.Int + d.ModInverse(&a.bigint, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Inverse: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Inverse(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Inverse(&a) + + var d, e big.Int + d.ModInverse(&aBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Inverse failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementSqrt(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Sqrt: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + b := a.element + + b.Sqrt(&a.element) + a.element.Sqrt(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Sqrt: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Sqrt(&a.element) + + var d, e big.Int + d.ModSqrt(&a.bigint, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Sqrt: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Sqrt(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Sqrt(&a) + + var d, e big.Int + d.ModSqrt(&aBig, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Sqrt failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementDouble(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Double: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Double(&a.element) + a.element.Double(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Double: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Double(&a.element) + + var d, e big.Int + d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Double: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Double(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Double(&a) + + var d, e big.Int + d.Lsh(&aBig, 1).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Double failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementNeg(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Neg: having the receiver as operand should output the same result", prop.ForAll( + func(a testPairElement) bool { + + var b Element + + b.Neg(&a.element) + a.element.Neg(&a.element) + return a.element.Equal(&b) + }, + genA, + )) + + properties.Property("Neg: operation result must match big.Int result", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Neg(&a.element) + + var d, e big.Int + d.Neg(&a.bigint).Mod(&d, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, + )) + + properties.Property("Neg: operation result must be smaller than modulus", prop.ForAll( + func(a testPairElement) bool { + var c Element + c.Neg(&a.element) + return c.smallerThanModulus() + }, + genA, + )) + + specialValueTest := func() { + // test special values + testValues := make([]Element, len(staticTestValues)) + copy(testValues, staticTestValues) + + for _, a := range testValues { + var aBig big.Int + a.BigInt(&aBig) + var c Element + c.Neg(&a) + + var d, e big.Int + d.Neg(&aBig).Mod(&d, Modulus()) + + if c.BigInt(&e).Cmp(&d) != 0 { + t.Fatal("Neg failed special test values") + } + } + } + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + specialValueTest() + +} + +func TestElementFixedExp(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + var ( + _bLegendreExponentElement *big.Int + _bSqrtExponentElement *big.Int + ) + + _bLegendreExponentElement, _ = new(big.Int).SetString("7fffffffffffffffffffffffffffffff5d576e7357a4501ddfe92f46681b20a0", 16) + const sqrtExponentElement = "1fffffffffffffffffffffffffffffffd755db9cd5e9140777fa4bd19a06c82" + _bSqrtExponentElement, _ = new(big.Int).SetString(sqrtExponentElement, 16) + + genA := gen() + + properties.Property(fmt.Sprintf("expBySqrtExp must match Exp(%s)", sqrtExponentElement), prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.expBySqrtExp(c) + d.Exp(d, _bSqrtExponentElement) + return c.Equal(&d) + }, + genA, + )) + + properties.Property("expByLegendreExp must match Exp(7fffffffffffffffffffffffffffffff5d576e7357a4501ddfe92f46681b20a0)", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.expByLegendreExp(c) + d.Exp(d, _bLegendreExponentElement) + return c.Equal(&d) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementHalve(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + var twoInv Element + twoInv.SetUint64(2) + twoInv.Inverse(&twoInv) + + properties.Property("z.Halve must match z / 2", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.Halve() + d.Mul(&d, &twoInv) + return c.Equal(&d) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func combineSelectionArguments(c int64, z int8) int { + if z%3 == 0 { + return 0 + } + return int(c) +} + +func TestElementSelect(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := genFull() + genB := genFull() + genC := ggen.Int64() //the condition + genZ := ggen.Int8() //to make zeros artificially more likely + + properties.Property("Select: must select correctly", prop.ForAll( + func(a, b Element, cond int64, z int8) bool { + condC := combineSelectionArguments(cond, z) + + var c Element + c.Select(condC, &a, &b) + + if condC == 0 { + return c.Equal(&a) + } + return c.Equal(&b) + }, + genA, + genB, + genC, + genZ, + )) + + properties.Property("Select: having the receiver as operand should output the same result", prop.ForAll( + func(a, b Element, cond int64, z int8) bool { + condC := combineSelectionArguments(cond, z) + + var c, d Element + d.Set(&a) + c.Select(condC, &a, &b) + a.Select(condC, &a, &b) + b.Select(condC, &d, &b) + return a.Equal(&b) && a.Equal(&c) && b.Equal(&c) + }, + genA, + genB, + genC, + genZ, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementSetInt64(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("z.SetInt64 must match z.SetString", prop.ForAll( + func(a testPairElement, v int64) bool { + c := a.element + d := a.element + + c.SetInt64(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, ggen.Int64(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementSetInterface(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + genInt := ggen.Int + genInt8 := ggen.Int8 + genInt16 := ggen.Int16 + genInt32 := ggen.Int32 + genInt64 := ggen.Int64 + + genUint := ggen.UInt + genUint8 := ggen.UInt8 + genUint16 := ggen.UInt16 + genUint32 := ggen.UInt32 + genUint64 := ggen.UInt64 + + properties.Property("z.SetInterface must match z.SetString with int8", prop.ForAll( + func(a testPairElement, v int8) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt8(), + )) + + properties.Property("z.SetInterface must match z.SetString with int16", prop.ForAll( + func(a testPairElement, v int16) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt16(), + )) + + properties.Property("z.SetInterface must match z.SetString with int32", prop.ForAll( + func(a testPairElement, v int32) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt32(), + )) + + properties.Property("z.SetInterface must match z.SetString with int64", prop.ForAll( + func(a testPairElement, v int64) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt64(), + )) + + properties.Property("z.SetInterface must match z.SetString with int", prop.ForAll( + func(a testPairElement, v int) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genInt(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint8", prop.ForAll( + func(a testPairElement, v uint8) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint8(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint16", prop.ForAll( + func(a testPairElement, v uint16) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint16(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint32", prop.ForAll( + func(a testPairElement, v uint32) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint32(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint64", prop.ForAll( + func(a testPairElement, v uint64) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint64(), + )) + + properties.Property("z.SetInterface must match z.SetString with uint", prop.ForAll( + func(a testPairElement, v uint) bool { + c := a.element + d := a.element + + c.SetInterface(v) + d.SetString(fmt.Sprintf("%v", v)) + + return c.Equal(&d) + }, + genA, genUint(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + + { + assert := require.New(t) + var e Element + r, err := e.SetInterface(nil) + assert.Nil(r) + assert.Error(err) + + var ptE *Element + var ptB *big.Int + + r, err = e.SetInterface(ptE) + assert.Nil(r) + assert.Error(err) + ptE = new(Element).SetOne() + r, err = e.SetInterface(ptE) + assert.NoError(err) + assert.True(r.IsOne()) + + r, err = e.SetInterface(ptB) + assert.Nil(r) + assert.Error(err) + + } +} + +func TestElementNegativeExp(t *testing.T) { + t.Parallel() + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("x⁻ᵏ == 1/xᵏ", prop.ForAll( + func(a, b testPairElement) bool { + + var nb, d, e big.Int + nb.Neg(&b.bigint) + + var c Element + c.Exp(a.element, &nb) + + d.Exp(&a.bigint, &nb, Modulus()) + + return c.BigInt(&e).Cmp(&d) == 0 + }, + genA, genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementNewElement(t *testing.T) { + assert := require.New(t) + + t.Parallel() + + e := NewElement(1) + assert.True(e.IsOne()) + + e = NewElement(0) + assert.True(e.IsZero()) +} + +func TestElementBatchInvert(t *testing.T) { + assert := require.New(t) + + t.Parallel() + + // ensure batchInvert([x]) == invert(x) + for i := int64(-1); i <= 2; i++ { + var e, eInv Element + e.SetInt64(i) + eInv.Inverse(&e) + + a := []Element{e} + aInv := BatchInvert(a) + + assert.True(aInv[0].Equal(&eInv), "batchInvert != invert") + + } + + // test x * x⁻¹ == 1 + tData := [][]int64{ + {-1, 1, 2, 3}, + {0, -1, 1, 2, 3, 0}, + {0, -1, 1, 0, 2, 3, 0}, + {-1, 1, 0, 2, 3}, + {0, 0, 1}, + {1, 0, 0}, + {0, 0, 0}, + } + + for _, t := range tData { + a := make([]Element, len(t)) + for i := 0; i < len(a); i++ { + a[i].SetInt64(t[i]) + } + + aInv := BatchInvert(a) + + assert.True(len(aInv) == len(a)) + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + assert.True(aInv[i].IsZero(), "0⁻¹ != 0") + } else { + assert.True(a[i].Mul(&a[i], &aInv[i]).IsOne(), "x * x⁻¹ != 1") + } + } + } + + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("batchInvert --> x * x⁻¹ == 1", prop.ForAll( + func(tp testPairElement, r uint8) bool { + + a := make([]Element, r) + if r != 0 { + a[0] = tp.element + + } + one := One() + for i := 1; i < len(a); i++ { + a[i].Add(&a[i-1], &one) + } + + aInv := BatchInvert(a) + + assert.True(len(aInv) == len(a)) + + for i := 0; i < len(a); i++ { + if a[i].IsZero() { + if !aInv[i].IsZero() { + return false + } + } else { + if !a[i].Mul(&a[i], &aInv[i]).IsOne() { + return false + } + } + } + return true + }, + genA, ggen.UInt8(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementFromMont(t *testing.T) { + + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + genA := gen() + + properties.Property("Assembly implementation must be consistent with generic one", prop.ForAll( + func(a testPairElement) bool { + c := a.element + d := a.element + c.fromMont() + _fromMontGeneric(&d) + return c.Equal(&d) + }, + genA, + )) + + properties.Property("x.fromMont().toMont() == x", prop.ForAll( + func(a testPairElement) bool { + c := a.element + c.fromMont().toMont() + return c.Equal(&a.element) + }, + genA, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestElementJSON(t *testing.T) { + assert := require.New(t) + + type S struct { + A Element + B [3]Element + C *Element + D *Element + } + + // encode to JSON + var s S + s.A.SetString("-1") + s.B[2].SetUint64(42) + s.D = new(Element).SetUint64(8000) + + encoded, err := json.Marshal(&s) + assert.NoError(err) + const expected = "{\"A\":-1,\"B\":[0,0,42],\"C\":null,\"D\":8000}" + assert.Equal(expected, string(encoded)) + + // decode valid + var decoded S + err = json.Unmarshal([]byte(expected), &decoded) + assert.NoError(err) + + assert.Equal(s, decoded, "element -> json -> element round trip failed") + + // decode hex and string values + withHexValues := "{\"A\":\"-1\",\"B\":[0,\"0x00000\",\"0x2A\"],\"C\":null,\"D\":\"8000\"}" + + var decodedS S + err = json.Unmarshal([]byte(withHexValues), &decodedS) + assert.NoError(err) + + assert.Equal(s, decodedS, " json with strings -> element failed") + +} + +type testPairElement struct { + element Element + bigint big.Int +} + +func gen() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var g testPairElement + + g.element = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g.element[3] %= (qElement[3] + 1) + } + + for !g.element.smallerThanModulus() { + g.element = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g.element[3] %= (qElement[3] + 1) + } + } + + g.element.BigInt(&g.bigint) + genResult := gopter.NewGenResult(g, gopter.NoShrinker) + return genResult + } +} + +func genFull() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + + genRandomFq := func() Element { + var g Element + + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + + for !g.smallerThanModulus() { + g = Element{ + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + genParams.NextUint64(), + } + if qElement[3] != ^uint64(0) { + g[3] %= (qElement[3] + 1) + } + } + + return g + } + a := genRandomFq() + + var carry uint64 + a[0], carry = bits.Add64(a[0], qElement[0], carry) + a[1], carry = bits.Add64(a[1], qElement[1], carry) + a[2], carry = bits.Add64(a[2], qElement[2], carry) + a[3], _ = bits.Add64(a[3], qElement[3], carry) + + genResult := gopter.NewGenResult(a, gopter.NoShrinker) + return genResult + } +} diff --git a/field/goldilocks/element_ops_noasm.go b/ecc/secp256k1/fr/polynomial/doc.go similarity index 54% rename from field/goldilocks/element_ops_noasm.go rename to ecc/secp256k1/fr/polynomial/doc.go index 4d9ba479c..83479b058 100644 --- a/field/goldilocks/element_ops_noasm.go +++ b/ecc/secp256k1/fr/polynomial/doc.go @@ -14,40 +14,5 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -package goldilocks - -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - var y Element - y.SetUint64(3) - x.Mul(x, &y) -} - -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - var y Element - y.SetUint64(5) - x.Mul(x, &y) -} - -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y Element - y.SetUint64(13) - x.Mul(x, &y) -} - -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} - -func fromMont(z *Element) { - _fromMontGeneric(z) -} - -func reduce(z *Element) { - _reduceGeneric(z) -} +// Package polynomial provides polynomial methods and commitment schemes. +package polynomial diff --git a/ecc/secp256k1/fr/polynomial/multilin.go b/ecc/secp256k1/fr/polynomial/multilin.go new file mode 100644 index 000000000..37e10ef3e --- /dev/null +++ b/ecc/secp256k1/fr/polynomial/multilin.go @@ -0,0 +1,271 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package polynomial + +import ( + "github.com/consensys/gnark-crypto/ecc/secp256k1/fr" + "math/bits" +) + +// MultiLin tracks the values of a (dense i.e. not sparse) multilinear polynomial +// The variables are X₁ through Xₙ where n = log(len(.)) +// .[∑ᵢ 2ⁱ⁻¹ bₙ₋ᵢ] = the polynomial evaluated at (b₁, b₂, ..., bₙ) +// It is understood that any hypercube evaluation can be extrapolated to a multilinear polynomial +type MultiLin []fr.Element + +// Fold is partial evaluation function k[X₁, X₂, ..., Xₙ] → k[X₂, ..., Xₙ] by setting X₁=r +func (m *MultiLin) Fold(r fr.Element) { + mid := len(*m) / 2 + + bottom, top := (*m)[:mid], (*m)[mid:] + + // updating bookkeeping table + // knowing that the polynomial f ∈ (k[X₂, ..., Xₙ])[X₁] is linear, we would get f(r) = f(0) + r(f(1) - f(0)) + // the following loop computes the evaluations of f(r) accordingly: + // f(r, b₂, ..., bₙ) = f(0, b₂, ..., bₙ) + r(f(1, b₂, ..., bₙ) - f(0, b₂, ..., bₙ)) + for i := 0; i < mid; i++ { + // table[i] ← table[i] + r (table[i + mid] - table[i]) + top[i].Sub(&top[i], &bottom[i]) + top[i].Mul(&top[i], &r) + bottom[i].Add(&bottom[i], &top[i]) + } + + *m = (*m)[:mid] +} + +func (m MultiLin) Sum() fr.Element { + s := m[0] + for i := 1; i < len(m); i++ { + s.Add(&s, &m[i]) + } + return s +} + +func _clone(m MultiLin, p *Pool) MultiLin { + if p == nil { + return m.Clone() + } else { + return p.Clone(m) + } +} + +func _dump(m MultiLin, p *Pool) { + if p != nil { + p.Dump(m) + } +} + +// Evaluate extrapolate the value of the multilinear polynomial corresponding to m +// on the given coordinates +func (m MultiLin) Evaluate(coordinates []fr.Element, p *Pool) fr.Element { + // Folding is a mutating operation + bkCopy := _clone(m, p) + + // Evaluate step by step through repeated folding (i.e. evaluation at the first remaining variable) + for _, r := range coordinates { + bkCopy.Fold(r) + } + + result := bkCopy[0] + + _dump(bkCopy, p) + return result +} + +// Clone creates a deep copy of a bookkeeping table. +// Both multilinear interpolation and sumcheck require folding an underlying +// array, but folding changes the array. To do both one requires a deep copy +// of the bookkeeping table. +func (m MultiLin) Clone() MultiLin { + res := make(MultiLin, len(m)) + copy(res, m) + return res +} + +// Add two bookKeepingTables +func (m *MultiLin) Add(left, right MultiLin) { + size := len(left) + // Check that left and right have the same size + if len(right) != size || len(*m) != size { + panic("left, right and destination must have the right size") + } + + // Add elementwise + for i := 0; i < size; i++ { + (*m)[i].Add(&left[i], &right[i]) + } +} + +// EvalEq computes Eq(q₁, ... , qₙ, h₁, ... , hₙ) = Π₁ⁿ Eq(qᵢ, hᵢ) +// where Eq(x,y) = xy + (1-x)(1-y) = 1 - x - y + xy + xy interpolates +// +// _________________ +// | | | +// | 0 | 1 | +// |_______|_______| +// y | | | +// | 1 | 0 | +// |_______|_______| +// +// x +// +// In other words the polynomial evaluated here is the multilinear extrapolation of +// one that evaluates to q' == h' for vectors q', h' of binary values +func EvalEq(q, h []fr.Element) fr.Element { + var res, nxt, one, sum fr.Element + one.SetOne() + for i := 0; i < len(q); i++ { + nxt.Mul(&q[i], &h[i]) // nxt <- qᵢ * hᵢ + nxt.Double(&nxt) // nxt <- 2 * qᵢ * hᵢ + nxt.Add(&nxt, &one) // nxt <- 1 + 2 * qᵢ * hᵢ + sum.Add(&q[i], &h[i]) // sum <- qᵢ + hᵢ TODO: Why not subtract one by one from nxt? More parallel? + + if i == 0 { + res.Sub(&nxt, &sum) // nxt <- 1 + 2 * qᵢ * hᵢ - qᵢ - hᵢ + } else { + nxt.Sub(&nxt, &sum) // nxt <- 1 + 2 * qᵢ * hᵢ - qᵢ - hᵢ + res.Mul(&res, &nxt) // res <- res * nxt + } + } + return res +} + +// Eq sets m to the representation of the polynomial Eq(q₁, ..., qₙ, *, ..., *) × m[0] +func (m *MultiLin) Eq(q []fr.Element) { + n := len(q) + + if len(*m) != 1<= 0; i-- { + res.Mul(&res, v) + res.Add(&res, &(*p)[i]) + } + + return res +} + +// Clone returns a copy of the polynomial +func (p *Polynomial) Clone() Polynomial { + _p := make(Polynomial, len(*p)) + copy(_p, *p) + return _p +} + +// Set to another polynomial +func (p *Polynomial) Set(p1 Polynomial) { + if len(*p) != len(p1) { + *p = p1.Clone() + return + } + + for i := 0; i < len(p1); i++ { + (*p)[i].Set(&p1[i]) + } +} + +// AddConstantInPlace adds a constant to the polynomial, modifying p +func (p *Polynomial) AddConstantInPlace(c *fr.Element) { + for i := 0; i < len(*p); i++ { + (*p)[i].Add(&(*p)[i], c) + } +} + +// SubConstantInPlace subs a constant to the polynomial, modifying p +func (p *Polynomial) SubConstantInPlace(c *fr.Element) { + for i := 0; i < len(*p); i++ { + (*p)[i].Sub(&(*p)[i], c) + } +} + +// ScaleInPlace multiplies p by v, modifying p +func (p *Polynomial) ScaleInPlace(c *fr.Element) { + for i := 0; i < len(*p); i++ { + (*p)[i].Mul(&(*p)[i], c) + } +} + +// Scale multiplies p0 by v, storing the result in p +func (p *Polynomial) Scale(c *fr.Element, p0 Polynomial) { + if len(*p) != len(p0) { + *p = make(Polynomial, len(p0)) + } + for i := 0; i < len(p0); i++ { + (*p)[i].Mul(c, &p0[i]) + } +} + +// Add adds p1 to p2 +// This function allocates a new slice unless p == p1 or p == p2 +func (p *Polynomial) Add(p1, p2 Polynomial) *Polynomial { + + bigger := p1 + smaller := p2 + if len(bigger) < len(smaller) { + bigger, smaller = smaller, bigger + } + + if len(*p) == len(bigger) && (&(*p)[0] == &bigger[0]) { + for i := 0; i < len(smaller); i++ { + (*p)[i].Add(&(*p)[i], &smaller[i]) + } + return p + } + + if len(*p) == len(smaller) && (&(*p)[0] == &smaller[0]) { + for i := 0; i < len(smaller); i++ { + (*p)[i].Add(&(*p)[i], &bigger[i]) + } + *p = append(*p, bigger[len(smaller):]...) + return p + } + + res := make(Polynomial, len(bigger)) + copy(res, bigger) + for i := 0; i < len(smaller); i++ { + res[i].Add(&res[i], &smaller[i]) + } + *p = res + return p +} + +// Sub subtracts p2 from p1 +// TODO make interface more consistent with Add +func (p *Polynomial) Sub(p1, p2 Polynomial) *Polynomial { + if len(p1) != len(p2) || len(p2) != len(*p) { + return nil + } + for i := 0; i < len(*p); i++ { + (*p)[i].Sub(&p1[i], &p2[i]) + } + return p +} + +// Equal checks equality between two polynomials +func (p *Polynomial) Equal(p1 Polynomial) bool { + if (*p == nil) != (p1 == nil) { + return false + } + + if len(*p) != len(p1) { + return false + } + + for i := range p1 { + if !(*p)[i].Equal(&p1[i]) { + return false + } + } + + return true +} + +func (p Polynomial) SetZero() { + for i := 0; i < len(p); i++ { + p[i].SetZero() + } +} + +func (p Polynomial) Text(base int) string { + + var builder strings.Builder + + first := true + for d := len(p) - 1; d >= 0; d-- { + if p[d].IsZero() { + continue + } + + pD := p[d] + pDText := pD.Text(base) + + initialLen := builder.Len() + + if pDText[0] == '-' { + pDText = pDText[1:] + if first { + builder.WriteString("-") + } else { + builder.WriteString(" - ") + } + } else if !first { + builder.WriteString(" + ") + } + + first = false + + if !pD.IsOne() || d == 0 { + builder.WriteString(pDText) + } + + if builder.Len()-initialLen > 10 { + builder.WriteString("×") + } + + if d != 0 { + builder.WriteString("X") + } + if d > 1 { + builder.WriteString( + utils.ToSuperscript(strconv.Itoa(d)), + ) + } + + } + + if first { + return "0" + } + + return builder.String() +} diff --git a/ecc/secp256k1/fr/polynomial/polynomial_test.go b/ecc/secp256k1/fr/polynomial/polynomial_test.go new file mode 100644 index 000000000..5f8abe1fc --- /dev/null +++ b/ecc/secp256k1/fr/polynomial/polynomial_test.go @@ -0,0 +1,218 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package polynomial + +import ( + "github.com/consensys/gnark-crypto/ecc/secp256k1/fr" + "github.com/stretchr/testify/assert" + "math/big" + "testing" +) + +func TestPolynomialEval(t *testing.T) { + + // build polynomial + f := make(Polynomial, 20) + for i := 0; i < 20; i++ { + f[i].SetOne() + } + + // random value + var point fr.Element + point.SetRandom() + + // compute manually f(val) + var expectedEval, one, den fr.Element + var expo big.Int + one.SetOne() + expo.SetUint64(20) + expectedEval.Exp(point, &expo). + Sub(&expectedEval, &one) + den.Sub(&point, &one) + expectedEval.Div(&expectedEval, &den) + + // compute purported evaluation + purportedEval := f.Eval(&point) + + // check + if !purportedEval.Equal(&expectedEval) { + t.Fatal("polynomial evaluation failed") + } +} + +func TestPolynomialAddConstantInPlace(t *testing.T) { + + // build polynomial + f := make(Polynomial, 20) + for i := 0; i < 20; i++ { + f[i].SetOne() + } + + // constant to add + var c fr.Element + c.SetRandom() + + // add constant + f.AddConstantInPlace(&c) + + // check + var expectedCoeffs, one fr.Element + one.SetOne() + expectedCoeffs.Add(&one, &c) + for i := 0; i < 20; i++ { + if !f[i].Equal(&expectedCoeffs) { + t.Fatal("AddConstantInPlace failed") + } + } +} + +func TestPolynomialSubConstantInPlace(t *testing.T) { + + // build polynomial + f := make(Polynomial, 20) + for i := 0; i < 20; i++ { + f[i].SetOne() + } + + // constant to sub + var c fr.Element + c.SetRandom() + + // sub constant + f.SubConstantInPlace(&c) + + // check + var expectedCoeffs, one fr.Element + one.SetOne() + expectedCoeffs.Sub(&one, &c) + for i := 0; i < 20; i++ { + if !f[i].Equal(&expectedCoeffs) { + t.Fatal("SubConstantInPlace failed") + } + } +} + +func TestPolynomialScaleInPlace(t *testing.T) { + + // build polynomial + f := make(Polynomial, 20) + for i := 0; i < 20; i++ { + f[i].SetOne() + } + + // constant to scale by + var c fr.Element + c.SetRandom() + + // scale by constant + f.ScaleInPlace(&c) + + // check + for i := 0; i < 20; i++ { + if !f[i].Equal(&c) { + t.Fatal("ScaleInPlace failed") + } + } + +} + +func TestPolynomialAdd(t *testing.T) { + + // build unbalanced polynomials + f1 := make(Polynomial, 20) + f1Backup := make(Polynomial, 20) + for i := 0; i < 20; i++ { + f1[i].SetOne() + f1Backup[i].SetOne() + } + f2 := make(Polynomial, 10) + f2Backup := make(Polynomial, 10) + for i := 0; i < 10; i++ { + f2[i].SetOne() + f2Backup[i].SetOne() + } + + // expected result + var one, two fr.Element + one.SetOne() + two.Double(&one) + expectedSum := make(Polynomial, 20) + for i := 0; i < 10; i++ { + expectedSum[i].Set(&two) + } + for i := 10; i < 20; i++ { + expectedSum[i].Set(&one) + } + + // caller is empty + var g Polynomial + g.Add(f1, f2) + if !g.Equal(expectedSum) { + t.Fatal("add polynomials fails") + } + if !f1.Equal(f1Backup) { + t.Fatal("side effect, f1 should not have been modified") + } + if !f2.Equal(f2Backup) { + t.Fatal("side effect, f2 should not have been modified") + } + + // all operands are distincts + _f1 := f1.Clone() + _f1.Add(f1, f2) + if !_f1.Equal(expectedSum) { + t.Fatal("add polynomials fails") + } + if !f1.Equal(f1Backup) { + t.Fatal("side effect, f1 should not have been modified") + } + if !f2.Equal(f2Backup) { + t.Fatal("side effect, f2 should not have been modified") + } + + // first operand = caller + _f1 = f1.Clone() + _f2 := f2.Clone() + _f1.Add(_f1, _f2) + if !_f1.Equal(expectedSum) { + t.Fatal("add polynomials fails") + } + if !_f2.Equal(f2Backup) { + t.Fatal("side effect, _f2 should not have been modified") + } + + // second operand = caller + _f1 = f1.Clone() + _f2 = f2.Clone() + _f1.Add(_f2, _f1) + if !_f1.Equal(expectedSum) { + t.Fatal("add polynomials fails") + } + if !_f2.Equal(f2Backup) { + t.Fatal("side effect, _f2 should not have been modified") + } +} + +func TestPolynomialText(t *testing.T) { + var one, negTwo fr.Element + one.SetOne() + negTwo.SetInt64(-2) + + p := Polynomial{one, negTwo, one} + + assert.Equal(t, "X² - 2X + 1", p.Text(10)) +} diff --git a/ecc/secp256k1/fr/polynomial/pool.go b/ecc/secp256k1/fr/polynomial/pool.go new file mode 100644 index 000000000..478ff52c8 --- /dev/null +++ b/ecc/secp256k1/fr/polynomial/pool.go @@ -0,0 +1,203 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package polynomial + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/ecc/secp256k1/fr" + "reflect" + "runtime" + "sort" + "sync" + "unsafe" +) + +// Memory management for polynomials +// WARNING: This is not thread safe TODO: Make sure that is not a problem +// TODO: There is a lot of "unsafe" memory management here and needs to be vetted thoroughly + +type sizedPool struct { + maxN int + pool sync.Pool + stats poolStats +} + +type inUseData struct { + allocatedFor []uintptr + pool *sizedPool +} + +type Pool struct { + //lock sync.Mutex + inUse map[*fr.Element]inUseData + subPools []sizedPool +} + +func (p *sizedPool) get(n int) *fr.Element { + p.stats.maake(n) + return p.pool.Get().(*fr.Element) +} + +func (p *sizedPool) put(ptr *fr.Element) { + p.stats.dump() + p.pool.Put(ptr) +} + +func NewPool(maxN ...int) (pool Pool) { + + sort.Ints(maxN) + pool = Pool{ + inUse: make(map[*fr.Element]inUseData), + subPools: make([]sizedPool, len(maxN)), + } + + for i := range pool.subPools { + subPool := &pool.subPools[i] + subPool.maxN = maxN[i] + subPool.pool = sync.Pool{ + New: func() interface{} { + subPool.stats.Allocated++ + return getDataPointer(make([]fr.Element, 0, subPool.maxN)) + }, + } + } + return +} + +func (p *Pool) findCorrespondingPool(n int) *sizedPool { + poolI := 0 + for poolI < len(p.subPools) && n > p.subPools[poolI].maxN { + poolI++ + } + return &p.subPools[poolI] // out of bounds error here would mean that n is too large +} + +func (p *Pool) Make(n int) []fr.Element { + pool := p.findCorrespondingPool(n) + ptr := pool.get(n) + p.addInUse(ptr, pool) + return unsafe.Slice(ptr, n) +} + +// Dump dumps a set of polynomials into the pool +func (p *Pool) Dump(slices ...[]fr.Element) { + for _, slice := range slices { + ptr := getDataPointer(slice) + if metadata, ok := p.inUse[ptr]; ok { + delete(p.inUse, ptr) + metadata.pool.put(ptr) + } else { + panic("attempting to dump a slice not created by the pool") + } + } +} + +func (p *Pool) addInUse(ptr *fr.Element, pool *sizedPool) { + pcs := make([]uintptr, 2) + n := runtime.Callers(3, pcs) + + if prevPcs, ok := p.inUse[ptr]; ok { // TODO: remove if unnecessary for security + panic(fmt.Errorf("re-allocated non-dumped slice, previously allocated at %v", runtime.CallersFrames(prevPcs.allocatedFor))) + } + p.inUse[ptr] = inUseData{ + allocatedFor: pcs[:n], + pool: pool, + } +} + +func printFrame(frame runtime.Frame) { + fmt.Printf("\t%s line %d, function %s\n", frame.File, frame.Line, frame.Function) +} + +func (p *Pool) printInUse() { + fmt.Println("slices never dumped allocated at:") + for _, pcs := range p.inUse { + fmt.Println("-------------------------") + + var frame runtime.Frame + frames := runtime.CallersFrames(pcs.allocatedFor) + more := true + for more { + frame, more = frames.Next() + printFrame(frame) + } + } +} + +type poolStats struct { + Used int + Allocated int + ReuseRate float64 + InUse int + GreatestNUsed int + SmallestNUsed int +} + +type poolsStats struct { + SubPools []poolStats + InUse int +} + +func (s *poolStats) maake(n int) { + s.Used++ + s.InUse++ + if n > s.GreatestNUsed { + s.GreatestNUsed = n + } + if s.SmallestNUsed == 0 || s.SmallestNUsed > n { + s.SmallestNUsed = n + } +} + +func (s *poolStats) dump() { + s.InUse-- +} + +func (s *poolStats) finalize() { + s.ReuseRate = float64(s.Used) / float64(s.Allocated) +} + +func getDataPointer(slice []fr.Element) *fr.Element { + header := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) + return (*fr.Element)(unsafe.Pointer(header.Data)) +} + +func (p *Pool) PrintPoolStats() { + InUse := 0 + subStats := make([]poolStats, len(p.subPools)) + for i := range p.subPools { + subPool := &p.subPools[i] + subPool.stats.finalize() + subStats[i] = subPool.stats + InUse += subPool.stats.InUse + } + + poolsStats := poolsStats{ + SubPools: subStats, + InUse: InUse, + } + serialized, _ := json.MarshalIndent(poolsStats, "", " ") + fmt.Println(string(serialized)) + p.printInUse() +} + +func (p *Pool) Clone(slice []fr.Element) []fr.Element { + res := p.Make(len(slice)) + copy(res, slice) + return res +} diff --git a/ecc/secp256k1/g1.go b/ecc/secp256k1/g1.go new file mode 100644 index 000000000..4b34b63fc --- /dev/null +++ b/ecc/secp256k1/g1.go @@ -0,0 +1,845 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package secp256k1 + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/secp256k1/fp" + "github.com/consensys/gnark-crypto/ecc/secp256k1/fr" + "github.com/consensys/gnark-crypto/internal/parallel" +) + +// G1Affine point in affine coordinates +type G1Affine struct { + X, Y fp.Element +} + +// G1Jac is a point with fp.Element coordinates +type G1Jac struct { + X, Y, Z fp.Element +} + +// g1JacExtended parameterized Jacobian coordinates (x=X/ZZ, y=Y/ZZZ, ZZ³=ZZZ²) +type g1JacExtended struct { + X, Y, ZZ, ZZZ fp.Element +} + +// ------------------------------------------------------------------------------------------------- +// Affine + +// Set sets p to the provided point +func (p *G1Affine) Set(a *G1Affine) *G1Affine { + p.X, p.Y = a.X, a.Y + return p +} + +// setInfinity sets p to O +func (p *G1Affine) setInfinity() *G1Affine { + p.X.SetZero() + p.Y.SetZero() + return p +} + +// ScalarMultiplication computes and returns p = a ⋅ s +func (p *G1Affine) ScalarMultiplication(a *G1Affine, s *big.Int) *G1Affine { + var _p G1Jac + _p.FromAffine(a) + _p.mulGLV(&_p, s) + p.FromJacobian(&_p) + return p +} + +// ScalarMultiplicationAffine computes and returns p = a ⋅ s +// Takes an affine point and returns a Jacobian point (useful for KZG) +func (p *G1Jac) ScalarMultiplicationAffine(a *G1Affine, s *big.Int) *G1Jac { + p.FromAffine(a) + p.mulGLV(p, s) + return p +} + +// Add adds two point in affine coordinates. +// This should rarely be used as it is very inefficient compared to Jacobian +func (p *G1Affine) Add(a, b *G1Affine) *G1Affine { + var p1, p2 G1Jac + p1.FromAffine(a) + p2.FromAffine(b) + p1.AddAssign(&p2) + p.FromJacobian(&p1) + return p +} + +// Sub subs two point in affine coordinates. +// This should rarely be used as it is very inefficient compared to Jacobian +func (p *G1Affine) Sub(a, b *G1Affine) *G1Affine { + var p1, p2 G1Jac + p1.FromAffine(a) + p2.FromAffine(b) + p1.SubAssign(&p2) + p.FromJacobian(&p1) + return p +} + +// Equal tests if two points (in Affine coordinates) are equal +func (p *G1Affine) Equal(a *G1Affine) bool { + return p.X.Equal(&a.X) && p.Y.Equal(&a.Y) +} + +// Neg computes -G +func (p *G1Affine) Neg(a *G1Affine) *G1Affine { + p.X = a.X + p.Y.Neg(&a.Y) + return p +} + +// FromJacobian rescales a point in Jacobian coord in z=1 plane +func (p *G1Affine) FromJacobian(p1 *G1Jac) *G1Affine { + + var a, b fp.Element + + if p1.Z.IsZero() { + p.X.SetZero() + p.Y.SetZero() + return p + } + + a.Inverse(&p1.Z) + b.Square(&a) + p.X.Mul(&p1.X, &b) + p.Y.Mul(&p1.Y, &b).Mul(&p.Y, &a) + + return p +} + +// String returns the string representation of the point or "O" if it is infinity +func (p *G1Affine) String() string { + if p.IsInfinity() { + return "O" + } + return "E([" + p.X.String() + "," + p.Y.String() + "])" +} + +// IsInfinity checks if the point is infinity +// in affine, it's encoded as (0,0) +// (0,0) is never on the curve for j=0 curves +func (p *G1Affine) IsInfinity() bool { + return p.X.IsZero() && p.Y.IsZero() +} + +// IsOnCurve returns true if p in on the curve +func (p *G1Affine) IsOnCurve() bool { + var point G1Jac + point.FromAffine(p) + return point.IsOnCurve() // call this function to handle infinity point +} + +// IsInSubGroup returns true if p is in the correct subgroup, false otherwise +func (p *G1Affine) IsInSubGroup() bool { + var _p G1Jac + _p.FromAffine(p) + return _p.IsInSubGroup() +} + +// ------------------------------------------------------------------------------------------------- +// Jacobian + +// Set sets p to the provided point +func (p *G1Jac) Set(a *G1Jac) *G1Jac { + p.X, p.Y, p.Z = a.X, a.Y, a.Z + return p +} + +// Equal tests if two points (in Jacobian coordinates) are equal +func (p *G1Jac) Equal(a *G1Jac) bool { + + if p.Z.IsZero() && a.Z.IsZero() { + return true + } + _p := G1Affine{} + _p.FromJacobian(p) + + _a := G1Affine{} + _a.FromJacobian(a) + + return _p.X.Equal(&_a.X) && _p.Y.Equal(&_a.Y) +} + +// Neg computes -G +func (p *G1Jac) Neg(a *G1Jac) *G1Jac { + *p = *a + p.Y.Neg(&a.Y) + return p +} + +// SubAssign subtracts two points on the curve +func (p *G1Jac) SubAssign(a *G1Jac) *G1Jac { + var tmp G1Jac + tmp.Set(a) + tmp.Y.Neg(&tmp.Y) + p.AddAssign(&tmp) + return p +} + +// AddAssign point addition in montgomery form +// https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-3.html#addition-add-2007-bl +func (p *G1Jac) AddAssign(a *G1Jac) *G1Jac { + + // p is infinity, return a + if p.Z.IsZero() { + p.Set(a) + return p + } + + // a is infinity, return p + if a.Z.IsZero() { + return p + } + + var Z1Z1, Z2Z2, U1, U2, S1, S2, H, I, J, r, V fp.Element + Z1Z1.Square(&a.Z) + Z2Z2.Square(&p.Z) + U1.Mul(&a.X, &Z2Z2) + U2.Mul(&p.X, &Z1Z1) + S1.Mul(&a.Y, &p.Z). + Mul(&S1, &Z2Z2) + S2.Mul(&p.Y, &a.Z). + Mul(&S2, &Z1Z1) + + // if p == a, we double instead + if U1.Equal(&U2) && S1.Equal(&S2) { + return p.DoubleAssign() + } + + H.Sub(&U2, &U1) + I.Double(&H). + Square(&I) + J.Mul(&H, &I) + r.Sub(&S2, &S1).Double(&r) + V.Mul(&U1, &I) + p.X.Square(&r). + Sub(&p.X, &J). + Sub(&p.X, &V). + Sub(&p.X, &V) + p.Y.Sub(&V, &p.X). + Mul(&p.Y, &r) + S1.Mul(&S1, &J).Double(&S1) + p.Y.Sub(&p.Y, &S1) + p.Z.Add(&p.Z, &a.Z) + p.Z.Square(&p.Z). + Sub(&p.Z, &Z1Z1). + Sub(&p.Z, &Z2Z2). + Mul(&p.Z, &H) + + return p +} + +// AddMixed point addition +// http://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#addition-madd-2007-bl +func (p *G1Jac) AddMixed(a *G1Affine) *G1Jac { + + //if a is infinity return p + if a.IsInfinity() { + return p + } + // p is infinity, return a + if p.Z.IsZero() { + p.X = a.X + p.Y = a.Y + p.Z.SetOne() + return p + } + + var Z1Z1, U2, S2, H, HH, I, J, r, V fp.Element + Z1Z1.Square(&p.Z) + U2.Mul(&a.X, &Z1Z1) + S2.Mul(&a.Y, &p.Z). + Mul(&S2, &Z1Z1) + + // if p == a, we double instead + if U2.Equal(&p.X) && S2.Equal(&p.Y) { + return p.DoubleAssign() + } + + H.Sub(&U2, &p.X) + HH.Square(&H) + I.Double(&HH).Double(&I) + J.Mul(&H, &I) + r.Sub(&S2, &p.Y).Double(&r) + V.Mul(&p.X, &I) + p.X.Square(&r). + Sub(&p.X, &J). + Sub(&p.X, &V). + Sub(&p.X, &V) + J.Mul(&J, &p.Y).Double(&J) + p.Y.Sub(&V, &p.X). + Mul(&p.Y, &r) + p.Y.Sub(&p.Y, &J) + p.Z.Add(&p.Z, &H) + p.Z.Square(&p.Z). + Sub(&p.Z, &Z1Z1). + Sub(&p.Z, &HH) + + return p +} + +// Double doubles a point in Jacobian coordinates +// https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-3.html#doubling-dbl-2007-bl +func (p *G1Jac) Double(q *G1Jac) *G1Jac { + p.Set(q) + p.DoubleAssign() + return p +} + +// DoubleAssign doubles a point in Jacobian coordinates +// https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-3.html#doubling-dbl-2007-bl +func (p *G1Jac) DoubleAssign() *G1Jac { + + var XX, YY, YYYY, ZZ, S, M, T fp.Element + + XX.Square(&p.X) + YY.Square(&p.Y) + YYYY.Square(&YY) + ZZ.Square(&p.Z) + S.Add(&p.X, &YY) + S.Square(&S). + Sub(&S, &XX). + Sub(&S, &YYYY). + Double(&S) + M.Double(&XX).Add(&M, &XX) + p.Z.Add(&p.Z, &p.Y). + Square(&p.Z). + Sub(&p.Z, &YY). + Sub(&p.Z, &ZZ) + T.Square(&M) + p.X = T + T.Double(&S) + p.X.Sub(&p.X, &T) + p.Y.Sub(&S, &p.X). + Mul(&p.Y, &M) + YYYY.Double(&YYYY).Double(&YYYY).Double(&YYYY) + p.Y.Sub(&p.Y, &YYYY) + + return p +} + +// ScalarMultiplication computes and returns p = a ⋅ s +// see https://www.iacr.org/archive/crypto2001/21390189.pdf +func (p *G1Jac) ScalarMultiplication(a *G1Jac, s *big.Int) *G1Jac { + return p.mulGLV(a, s) +} + +// String returns canonical representation of the point in affine coordinates +func (p *G1Jac) String() string { + _p := G1Affine{} + _p.FromJacobian(p) + return _p.String() +} + +// FromAffine sets p = Q, p in Jacobian, Q in affine +func (p *G1Jac) FromAffine(Q *G1Affine) *G1Jac { + if Q.IsInfinity() { + p.Z.SetZero() + p.X.SetOne() + p.Y.SetOne() + return p + } + p.Z.SetOne() + p.X.Set(&Q.X) + p.Y.Set(&Q.Y) + return p +} + +// IsOnCurve returns true if p in on the curve +func (p *G1Jac) IsOnCurve() bool { + var left, right, tmp fp.Element + left.Square(&p.Y) + right.Square(&p.X).Mul(&right, &p.X) + tmp.Square(&p.Z). + Square(&tmp). + Mul(&tmp, &p.Z). + Mul(&tmp, &p.Z). + Mul(&tmp, &bCurveCoeff) + right.Add(&right, &tmp) + return left.Equal(&right) +} + +// IsInSubGroup returns true if p is on the r-torsion, false otherwise. +// secp256k1 curve is of prime order i.e. E(𝔽p) is the full group +// so we just check that the point is on the curve. +func (p *G1Jac) IsInSubGroup() bool { + + return p.IsOnCurve() + +} + +// mulWindowed computes a 2-bits windowed scalar multiplication +func (p *G1Jac) mulWindowed(a *G1Jac, s *big.Int) *G1Jac { + + var res G1Jac + var ops [3]G1Jac + + res.Set(&g1Infinity) + ops[0].Set(a) + ops[1].Double(&ops[0]) + ops[2].Set(&ops[0]).AddAssign(&ops[1]) + + b := s.Bytes() + for i := range b { + w := b[i] + mask := byte(0xc0) + for j := 0; j < 4; j++ { + res.DoubleAssign().DoubleAssign() + c := (w & mask) >> (6 - 2*j) + if c != 0 { + res.AddAssign(&ops[c-1]) + } + mask = mask >> 2 + } + } + p.Set(&res) + + return p + +} + +// ϕ assigns p to ϕ(a) where ϕ: (x,y) → (w x,y), and returns p +// where w is a third root of unity in 𝔽p +func (p *G1Jac) phi(a *G1Jac) *G1Jac { + p.Set(a) + p.X.Mul(&p.X, &thirdRootOneG1) + return p +} + +// mulGLV computes the scalar multiplication using a windowed-GLV method +// see https://www.iacr.org/archive/crypto2001/21390189.pdf +func (p *G1Jac) mulGLV(a *G1Jac, s *big.Int) *G1Jac { + + var table [15]G1Jac + var res G1Jac + var k1, k2 fr.Element + + res.Set(&g1Infinity) + + // table[b3b2b1b0-1] = b3b2 ⋅ ϕ(a) + b1b0*a + table[0].Set(a) + table[3].phi(a) + + // split the scalar, modifies ±a, ϕ(a) accordingly + k := ecc.SplitScalar(s, &glvBasis) + + if k[0].Sign() == -1 { + k[0].Neg(&k[0]) + table[0].Neg(&table[0]) + } + if k[1].Sign() == -1 { + k[1].Neg(&k[1]) + table[3].Neg(&table[3]) + } + + // precompute table (2 bits sliding window) + // table[b3b2b1b0-1] = b3b2 ⋅ ϕ(a) + b1b0 ⋅ a if b3b2b1b0 != 0 + table[1].Double(&table[0]) + table[2].Set(&table[1]).AddAssign(&table[0]) + table[4].Set(&table[3]).AddAssign(&table[0]) + table[5].Set(&table[3]).AddAssign(&table[1]) + table[6].Set(&table[3]).AddAssign(&table[2]) + table[7].Double(&table[3]) + table[8].Set(&table[7]).AddAssign(&table[0]) + table[9].Set(&table[7]).AddAssign(&table[1]) + table[10].Set(&table[7]).AddAssign(&table[2]) + table[11].Set(&table[7]).AddAssign(&table[3]) + table[12].Set(&table[11]).AddAssign(&table[0]) + table[13].Set(&table[11]).AddAssign(&table[1]) + table[14].Set(&table[11]).AddAssign(&table[2]) + + // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max + // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() + + // we don't target constant-timeness so we check first if we increase the bounds or not + maxBit := k1.BitLen() + if k2.BitLen() > maxBit { + maxBit = k2.BitLen() + } + hiWordIndex := (maxBit - 1) / 64 + + // loop starts from len(k1)/2 or len(k1)/2+1 due to the bounds + for i := hiWordIndex; i >= 0; i-- { + mask := uint64(3) << 62 + for j := 0; j < 32; j++ { + res.Double(&res).Double(&res) + b1 := (k1[i] & mask) >> (62 - 2*j) + b2 := (k2[i] & mask) >> (62 - 2*j) + if b1|b2 != 0 { + s := (b2<<2 | b1) + res.AddAssign(&table[s-1]) + } + mask = mask >> 2 + } + } + + p.Set(&res) + return p +} + +// ------------------------------------------------------------------------------------------------- +// Jacobian extended + +// Set sets p to the provided point +func (p *g1JacExtended) Set(a *g1JacExtended) *g1JacExtended { + p.X, p.Y, p.ZZ, p.ZZZ = a.X, a.Y, a.ZZ, a.ZZZ + return p +} + +// setInfinity sets p to O +func (p *g1JacExtended) setInfinity() *g1JacExtended { + p.X.SetOne() + p.Y.SetOne() + p.ZZ = fp.Element{} + p.ZZZ = fp.Element{} + return p +} + +// fromJacExtended sets Q in affine coordinates +func (p *G1Affine) fromJacExtended(Q *g1JacExtended) *G1Affine { + if Q.ZZ.IsZero() { + p.X = fp.Element{} + p.Y = fp.Element{} + return p + } + p.X.Inverse(&Q.ZZ).Mul(&p.X, &Q.X) + p.Y.Inverse(&Q.ZZZ).Mul(&p.Y, &Q.Y) + return p +} + +// fromJacExtended sets Q in Jacobian coordinates +func (p *G1Jac) fromJacExtended(Q *g1JacExtended) *G1Jac { + if Q.ZZ.IsZero() { + p.Set(&g1Infinity) + return p + } + p.X.Mul(&Q.ZZ, &Q.X).Mul(&p.X, &Q.ZZ) + p.Y.Mul(&Q.ZZZ, &Q.Y).Mul(&p.Y, &Q.ZZZ) + p.Z.Set(&Q.ZZZ) + return p +} + +// unsafeFromJacExtended sets p in Jacobian coordinates, but don't check for infinity +func (p *G1Jac) unsafeFromJacExtended(Q *g1JacExtended) *G1Jac { + p.X.Square(&Q.ZZ).Mul(&p.X, &Q.X) + p.Y.Square(&Q.ZZZ).Mul(&p.Y, &Q.Y) + p.Z = Q.ZZZ + return p +} + +// add point in Jacobian extended coordinates +// https://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#addition-add-2008-s +func (p *g1JacExtended) add(q *g1JacExtended) *g1JacExtended { + //if q is infinity return p + if q.ZZ.IsZero() { + return p + } + // p is infinity, return q + if p.ZZ.IsZero() { + p.Set(q) + return p + } + + var A, B, U1, U2, S1, S2 fp.Element + + // p2: q, p1: p + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) + + if A.IsZero() { + if B.IsZero() { + return p.double(q) + + } + p.ZZ = fp.Element{} + p.ZZZ = fp.Element{} + return p + } + + var P, R, PP, PPP, Q, V fp.Element + P.Sub(&U2, &U1) + R.Sub(&S2, &S1) + PP.Square(&P) + PPP.Mul(&P, &PP) + Q.Mul(&U1, &PP) + V.Mul(&S1, &PPP) + + p.X.Square(&R). + Sub(&p.X, &PPP). + Sub(&p.X, &Q). + Sub(&p.X, &Q) + p.Y.Sub(&Q, &p.X). + Mul(&p.Y, &R). + Sub(&p.Y, &V) + p.ZZ.Mul(&p.ZZ, &q.ZZ). + Mul(&p.ZZ, &PP) + p.ZZZ.Mul(&p.ZZZ, &q.ZZZ). + Mul(&p.ZZZ, &PPP) + + return p +} + +// double point in Jacobian extended coordinates +// http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +func (p *g1JacExtended) double(q *g1JacExtended) *g1JacExtended { + var U, V, W, S, XX, M fp.Element + + U.Double(&q.Y) + V.Square(&U) + W.Mul(&U, &V) + S.Mul(&q.X, &V) + XX.Square(&q.X) + M.Double(&XX). + Add(&M, &XX) // -> + a, but a=0 here + U.Mul(&W, &q.Y) + + p.X.Square(&M). + Sub(&p.X, &S). + Sub(&p.X, &S) + p.Y.Sub(&S, &p.X). + Mul(&p.Y, &M). + Sub(&p.Y, &U) + p.ZZ.Mul(&V, &q.ZZ) + p.ZZZ.Mul(&W, &q.ZZZ) + + return p +} + +// subMixed same as addMixed, but will negate a.Y +// http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#addition-madd-2008-s +func (p *g1JacExtended) subMixed(a *G1Affine) *g1JacExtended { + + //if a is infinity return p + if a.IsInfinity() { + return p + } + // p is infinity, return a + if p.ZZ.IsZero() { + p.X = a.X + p.Y.Neg(&a.Y) + p.ZZ.SetOne() + p.ZZZ.SetOne() + return p + } + + var P, R fp.Element + + // p2: a, p1: p + P.Mul(&a.X, &p.ZZ) + P.Sub(&P, &p.X) + + R.Mul(&a.Y, &p.ZZZ) + R.Neg(&R) + R.Sub(&R, &p.Y) + + if P.IsZero() { + if R.IsZero() { + return p.doubleNegMixed(a) + + } + p.ZZ = fp.Element{} + p.ZZZ = fp.Element{} + return p + } + + var PP, PPP, Q, Q2, RR, X3, Y3 fp.Element + + PP.Square(&P) + PPP.Mul(&P, &PP) + Q.Mul(&p.X, &PP) + RR.Square(&R) + X3.Sub(&RR, &PPP) + Q2.Double(&Q) + p.X.Sub(&X3, &Q2) + Y3.Sub(&Q, &p.X).Mul(&Y3, &R) + R.Mul(&p.Y, &PPP) + p.Y.Sub(&Y3, &R) + p.ZZ.Mul(&p.ZZ, &PP) + p.ZZZ.Mul(&p.ZZZ, &PPP) + + return p + +} + +// addMixed +// http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#addition-madd-2008-s +func (p *g1JacExtended) addMixed(a *G1Affine) *g1JacExtended { + + //if a is infinity return p + if a.IsInfinity() { + return p + } + // p is infinity, return a + if p.ZZ.IsZero() { + p.X = a.X + p.Y = a.Y + p.ZZ.SetOne() + p.ZZZ.SetOne() + return p + } + + var P, R fp.Element + + // p2: a, p1: p + P.Mul(&a.X, &p.ZZ) + P.Sub(&P, &p.X) + + R.Mul(&a.Y, &p.ZZZ) + R.Sub(&R, &p.Y) + + if P.IsZero() { + if R.IsZero() { + return p.doubleMixed(a) + + } + p.ZZ = fp.Element{} + p.ZZZ = fp.Element{} + return p + } + + var PP, PPP, Q, Q2, RR, X3, Y3 fp.Element + + PP.Square(&P) + PPP.Mul(&P, &PP) + Q.Mul(&p.X, &PP) + RR.Square(&R) + X3.Sub(&RR, &PPP) + Q2.Double(&Q) + p.X.Sub(&X3, &Q2) + Y3.Sub(&Q, &p.X).Mul(&Y3, &R) + R.Mul(&p.Y, &PPP) + p.Y.Sub(&Y3, &R) + p.ZZ.Mul(&p.ZZ, &PP) + p.ZZZ.Mul(&p.ZZZ, &PPP) + + return p + +} + +// doubleNegMixed same as double, but will negate q.Y +func (p *g1JacExtended) doubleNegMixed(q *G1Affine) *g1JacExtended { + + var U, V, W, S, XX, M, S2, L fp.Element + + U.Double(&q.Y) + U.Neg(&U) + V.Square(&U) + W.Mul(&U, &V) + S.Mul(&q.X, &V) + XX.Square(&q.X) + M.Double(&XX). + Add(&M, &XX) // -> + a, but a=0 here + S2.Double(&S) + L.Mul(&W, &q.Y) + + p.X.Square(&M). + Sub(&p.X, &S2) + p.Y.Sub(&S, &p.X). + Mul(&p.Y, &M). + Add(&p.Y, &L) + p.ZZ.Set(&V) + p.ZZZ.Set(&W) + + return p +} + +// doubleMixed point in Jacobian extended coordinates +// http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +func (p *g1JacExtended) doubleMixed(q *G1Affine) *g1JacExtended { + + var U, V, W, S, XX, M, S2, L fp.Element + + U.Double(&q.Y) + V.Square(&U) + W.Mul(&U, &V) + S.Mul(&q.X, &V) + XX.Square(&q.X) + M.Double(&XX). + Add(&M, &XX) // -> + a, but a=0 here + S2.Double(&S) + L.Mul(&W, &q.Y) + + p.X.Square(&M). + Sub(&p.X, &S2) + p.Y.Sub(&S, &p.X). + Mul(&p.Y, &M). + Sub(&p.Y, &L) + p.ZZ.Set(&V) + p.ZZZ.Set(&W) + + return p +} + +// BatchJacobianToAffineG1 converts points in Jacobian coordinates to Affine coordinates +// performing a single field inversion (Montgomery batch inversion trick). +func BatchJacobianToAffineG1(points []G1Jac) []G1Affine { + result := make([]G1Affine, len(points)) + zeroes := make([]bool, len(points)) + accumulator := fp.One() + + // batch invert all points[].Z coordinates with Montgomery batch inversion trick + // (stores points[].Z^-1 in result[i].X to avoid allocating a slice of fr.Elements) + for i := 0; i < len(points); i++ { + if points[i].Z.IsZero() { + zeroes[i] = true + continue + } + result[i].X = accumulator + accumulator.Mul(&accumulator, &points[i].Z) + } + + var accInverse fp.Element + accInverse.Inverse(&accumulator) + + for i := len(points) - 1; i >= 0; i-- { + if zeroes[i] { + // do nothing, (X=0, Y=0) is infinity point in affine + continue + } + result[i].X.Mul(&result[i].X, &accInverse) + accInverse.Mul(&accInverse, &points[i].Z) + } + + // batch convert to affine. + parallel.Execute(len(points), func(start, end int) { + for i := start; i < end; i++ { + if zeroes[i] { + // do nothing, (X=0, Y=0) is infinity point in affine + continue + } + var a, b fp.Element + a = result[i].X + b.Square(&a) + result[i].X.Mul(&points[i].X, &b) + result[i].Y.Mul(&points[i].Y, &b). + Mul(&result[i].Y, &a) + } + }) + + return result +} diff --git a/ecc/secp256k1/g1_test.go b/ecc/secp256k1/g1_test.go new file mode 100644 index 000000000..0de1717a8 --- /dev/null +++ b/ecc/secp256k1/g1_test.go @@ -0,0 +1,605 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package secp256k1 + +import ( + "math/big" + "math/rand" + "testing" + + "github.com/consensys/gnark-crypto/ecc/secp256k1/fp" + + "github.com/consensys/gnark-crypto/ecc/secp256k1/fr" + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" +) + +const ( + nbFuzzShort = 10 + nbFuzz = 100 +) + +func TestG1AffineEndomorphism(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + properties.Property("[SECP256K1] check that phi(P) = lambdaGLV * P", prop.ForAll( + func(a fp.Element) bool { + var p, res1, res2 G1Jac + g := MapToG1(a) + p.FromAffine(&g) + res1.phi(&p) + res2.mulWindowed(&p, &lambdaGLV) + + return p.IsInSubGroup() && res1.Equal(&res2) + }, + GenFp(), + )) + + properties.Property("[SECP256K1] check that phi^2(P) + phi(P) + P = 0", prop.ForAll( + func(a fp.Element) bool { + var p, res, tmp G1Jac + g := MapToG1(a) + p.FromAffine(&g) + tmp.phi(&p) + res.phi(&tmp). + AddAssign(&tmp). + AddAssign(&p) + + return res.Z.IsZero() + }, + GenFp(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestG1AffineIsOnCurve(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + properties.Property("[SECP256K1] g1Gen (affine) should be on the curve", prop.ForAll( + func(a fp.Element) bool { + var op1, op2 G1Affine + op1.FromJacobian(&g1Gen) + op2.Set(&op1) + op2.Y.Mul(&op2.Y, &a) + return op1.IsOnCurve() && !op2.IsOnCurve() + }, + GenFp(), + )) + + properties.Property("[SECP256K1] g1Gen (Jacobian) should be on the curve", prop.ForAll( + func(a fp.Element) bool { + var op1, op2, op3 G1Jac + op1.Set(&g1Gen) + op3.Set(&g1Gen) + + op2 = fuzzG1Jac(&g1Gen, a) + op3.Y.Mul(&op3.Y, &a) + return op1.IsOnCurve() && op2.IsOnCurve() && !op3.IsOnCurve() + }, + GenFp(), + )) + + properties.Property("[SECP256K1] IsInSubGroup and MulBy subgroup order should be the same", prop.ForAll( + func(a fp.Element) bool { + var op1, op2 G1Jac + op1 = fuzzG1Jac(&g1Gen, a) + _r := fr.Modulus() + op2.ScalarMultiplication(&op1, _r) + return op1.IsInSubGroup() && op2.Z.IsZero() + }, + GenFp(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestG1AffineConversions(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + properties.Property("[SECP256K1] Affine representation should be independent of the Jacobian representative", prop.ForAll( + func(a fp.Element) bool { + g := fuzzG1Jac(&g1Gen, a) + var op1 G1Affine + op1.FromJacobian(&g) + return op1.X.Equal(&g1Gen.X) && op1.Y.Equal(&g1Gen.Y) + }, + GenFp(), + )) + + properties.Property("[SECP256K1] Affine representation should be independent of a Extended Jacobian representative", prop.ForAll( + func(a fp.Element) bool { + var g g1JacExtended + g.X.Set(&g1Gen.X) + g.Y.Set(&g1Gen.Y) + g.ZZ.Set(&g1Gen.Z) + g.ZZZ.Set(&g1Gen.Z) + gfuzz := fuzzg1JacExtended(&g, a) + + var op1 G1Affine + op1.fromJacExtended(&gfuzz) + return op1.X.Equal(&g1Gen.X) && op1.Y.Equal(&g1Gen.Y) + }, + GenFp(), + )) + + properties.Property("[SECP256K1] Jacobian representation should be the same as the affine representative", prop.ForAll( + func(a fp.Element) bool { + var g G1Jac + var op1 G1Affine + op1.X.Set(&g1Gen.X) + op1.Y.Set(&g1Gen.Y) + + var one fp.Element + one.SetOne() + + g.FromAffine(&op1) + + return g.X.Equal(&g1Gen.X) && g.Y.Equal(&g1Gen.Y) && g.Z.Equal(&one) + }, + GenFp(), + )) + + properties.Property("[SECP256K1] Converting affine symbol for infinity to Jacobian should output correct infinity in Jacobian", prop.ForAll( + func() bool { + var g G1Affine + g.X.SetZero() + g.Y.SetZero() + var op1 G1Jac + op1.FromAffine(&g) + var one, zero fp.Element + one.SetOne() + return op1.X.Equal(&one) && op1.Y.Equal(&one) && op1.Z.Equal(&zero) + }, + )) + + properties.Property("[SECP256K1] Converting infinity in extended Jacobian to affine should output infinity symbol in Affine", prop.ForAll( + func() bool { + var g G1Affine + var op1 g1JacExtended + var zero fp.Element + op1.X.Set(&g1Gen.X) + op1.Y.Set(&g1Gen.Y) + g.fromJacExtended(&op1) + return g.X.Equal(&zero) && g.Y.Equal(&zero) + }, + )) + + properties.Property("[SECP256K1] Converting infinity in extended Jacobian to Jacobian should output infinity in Jacobian", prop.ForAll( + func() bool { + var g G1Jac + var op1 g1JacExtended + var zero, one fp.Element + one.SetOne() + op1.X.Set(&g1Gen.X) + op1.Y.Set(&g1Gen.Y) + g.fromJacExtended(&op1) + return g.X.Equal(&one) && g.Y.Equal(&one) && g.Z.Equal(&zero) + }, + )) + + properties.Property("[SECP256K1] [Jacobian] Two representatives of the same class should be equal", prop.ForAll( + func(a, b fp.Element) bool { + op1 := fuzzG1Jac(&g1Gen, a) + op2 := fuzzG1Jac(&g1Gen, b) + return op1.Equal(&op2) + }, + GenFp(), + GenFp(), + )) + properties.Property("[SECP256K1] BatchJacobianToAffineG1 and FromJacobian should output the same result", prop.ForAll( + func(a, b fp.Element) bool { + g1 := fuzzG1Jac(&g1Gen, a) + g2 := fuzzG1Jac(&g1Gen, b) + var op1, op2 G1Affine + op1.FromJacobian(&g1) + op2.FromJacobian(&g2) + baseTableAff := BatchJacobianToAffineG1([]G1Jac{g1, g2}) + return op1.Equal(&baseTableAff[0]) && op2.Equal(&baseTableAff[1]) + }, + GenFp(), + GenFp(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestG1AffineOps(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + parameters.MinSuccessfulTests = 10 + + properties := gopter.NewProperties(parameters) + + genScalar := GenFr() + + properties.Property("[SECP256K1] [Jacobian] Add should call double when having adding the same point", prop.ForAll( + func(a, b fp.Element) bool { + fop1 := fuzzG1Jac(&g1Gen, a) + fop2 := fuzzG1Jac(&g1Gen, b) + var op1, op2 G1Jac + op1.Set(&fop1).AddAssign(&fop2) + op2.Double(&fop2) + return op1.Equal(&op2) + }, + GenFp(), + GenFp(), + )) + + properties.Property("[SECP256K1] [Jacobian] Adding the opposite of a point to itself should output inf", prop.ForAll( + func(a, b fp.Element) bool { + fop1 := fuzzG1Jac(&g1Gen, a) + fop2 := fuzzG1Jac(&g1Gen, b) + fop2.Neg(&fop2) + fop1.AddAssign(&fop2) + return fop1.Equal(&g1Infinity) + }, + GenFp(), + GenFp(), + )) + + properties.Property("[SECP256K1] [Jacobian] Adding the inf to a point should not modify the point", prop.ForAll( + func(a fp.Element) bool { + fop1 := fuzzG1Jac(&g1Gen, a) + fop1.AddAssign(&g1Infinity) + var op2 G1Jac + op2.Set(&g1Infinity) + op2.AddAssign(&g1Gen) + return fop1.Equal(&g1Gen) && op2.Equal(&g1Gen) + }, + GenFp(), + )) + + properties.Property("[SECP256K1] [Jacobian Extended] addMixed (-G) should equal subMixed(G)", prop.ForAll( + func(a fp.Element) bool { + fop1 := fuzzG1Jac(&g1Gen, a) + var p1, p1Neg G1Affine + p1.FromJacobian(&fop1) + p1Neg = p1 + p1Neg.Y.Neg(&p1Neg.Y) + var o1, o2 g1JacExtended + o1.addMixed(&p1Neg) + o2.subMixed(&p1) + + return o1.X.Equal(&o2.X) && + o1.Y.Equal(&o2.Y) && + o1.ZZ.Equal(&o2.ZZ) && + o1.ZZZ.Equal(&o2.ZZZ) + }, + GenFp(), + )) + + properties.Property("[SECP256K1] [Jacobian Extended] doubleMixed (-G) should equal doubleNegMixed(G)", prop.ForAll( + func(a fp.Element) bool { + fop1 := fuzzG1Jac(&g1Gen, a) + var p1, p1Neg G1Affine + p1.FromJacobian(&fop1) + p1Neg = p1 + p1Neg.Y.Neg(&p1Neg.Y) + var o1, o2 g1JacExtended + o1.doubleMixed(&p1Neg) + o2.doubleNegMixed(&p1) + + return o1.X.Equal(&o2.X) && + o1.Y.Equal(&o2.Y) && + o1.ZZ.Equal(&o2.ZZ) && + o1.ZZZ.Equal(&o2.ZZZ) + }, + GenFp(), + )) + + properties.Property("[SECP256K1] [Jacobian] Addmix the negation to itself should output 0", prop.ForAll( + func(a fp.Element) bool { + fop1 := fuzzG1Jac(&g1Gen, a) + fop1.Neg(&fop1) + var op2 G1Affine + op2.FromJacobian(&g1Gen) + fop1.AddMixed(&op2) + return fop1.Equal(&g1Infinity) + }, + GenFp(), + )) + + properties.Property("[SECP256K1] scalar multiplication (double and add) should depend only on the scalar mod r", prop.ForAll( + func(s fr.Element) bool { + + r := fr.Modulus() + var g G1Jac + g.mulGLV(&g1Gen, r) + + var scalar, blindedScalar, rminusone big.Int + var op1, op2, op3, gneg G1Jac + rminusone.SetUint64(1).Sub(r, &rminusone) + op3.mulWindowed(&g1Gen, &rminusone) + gneg.Neg(&g1Gen) + s.ToBigIntRegular(&scalar) + blindedScalar.Mul(&scalar, r).Add(&blindedScalar, &scalar) + op1.mulWindowed(&g1Gen, &scalar) + op2.mulWindowed(&g1Gen, &blindedScalar) + + return op1.Equal(&op2) && g.Equal(&g1Infinity) && !op1.Equal(&g1Infinity) && gneg.Equal(&op3) + + }, + genScalar, + )) + + properties.Property("[SECP256K1] scalar multiplication (GLV) should depend only on the scalar mod r", prop.ForAll( + func(s fr.Element) bool { + + r := fr.Modulus() + var g G1Jac + g.mulGLV(&g1Gen, r) + + var scalar, blindedScalar, rminusone big.Int + var op1, op2, op3, gneg G1Jac + rminusone.SetUint64(1).Sub(r, &rminusone) + op3.ScalarMultiplication(&g1Gen, &rminusone) + gneg.Neg(&g1Gen) + s.ToBigIntRegular(&scalar) + blindedScalar.Mul(&scalar, r).Add(&blindedScalar, &scalar) + op1.ScalarMultiplication(&g1Gen, &scalar) + op2.ScalarMultiplication(&g1Gen, &blindedScalar) + + return op1.Equal(&op2) && g.Equal(&g1Infinity) && !op1.Equal(&g1Infinity) && gneg.Equal(&op3) + + }, + genScalar, + )) + + properties.Property("[SECP256K1] GLV and Double and Add should output the same result", prop.ForAll( + func(s fr.Element) bool { + + var r big.Int + var op1, op2 G1Jac + s.ToBigIntRegular(&r) + op1.mulWindowed(&g1Gen, &r) + op2.mulGLV(&g1Gen, &r) + return op1.Equal(&op2) && !op1.Equal(&g1Infinity) + + }, + genScalar, + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +// ------------------------------------------------------------ +// benches + +func BenchmarkG1JacIsInSubGroup(b *testing.B) { + var a G1Jac + a.Set(&g1Gen) + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.IsInSubGroup() + } + +} + +func BenchmarkG1JacScalarMultiplication(b *testing.B) { + + var scalar big.Int + r := fr.Modulus() + scalar.SetString("5243587517512619047944770508185965837690552500527637822603658699938581184513", 10) + scalar.Add(&scalar, r) + + var doubleAndAdd G1Jac + + b.Run("double and add", func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + doubleAndAdd.mulWindowed(&g1Gen, &scalar) + } + }) + + var glv G1Jac + b.Run("GLV", func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + glv.mulGLV(&g1Gen, &scalar) + } + }) + +} + +func BenchmarkG1JacAdd(b *testing.B) { + var a G1Jac + a.Double(&g1Gen) + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.AddAssign(&g1Gen) + } +} + +func BenchmarkG1JacAddMixed(b *testing.B) { + var a G1Jac + a.Double(&g1Gen) + + var c G1Affine + c.FromJacobian(&g1Gen) + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.AddMixed(&c) + } + +} + +func BenchmarkG1JacDouble(b *testing.B) { + var a G1Jac + a.Set(&g1Gen) + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.DoubleAssign() + } + +} + +func BenchmarkG1JacExtAddMixed(b *testing.B) { + var a g1JacExtended + a.doubleMixed(&g1GenAff) + + var c G1Affine + c.FromJacobian(&g1Gen) + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.addMixed(&c) + } +} + +func BenchmarkG1JacExtSubMixed(b *testing.B) { + var a g1JacExtended + a.doubleMixed(&g1GenAff) + + var c G1Affine + c.FromJacobian(&g1Gen) + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.subMixed(&c) + } +} + +func BenchmarkG1JacExtDoubleMixed(b *testing.B) { + var a g1JacExtended + a.doubleMixed(&g1GenAff) + + var c G1Affine + c.FromJacobian(&g1Gen) + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.doubleMixed(&c) + } +} + +func BenchmarkG1JacExtDoubleNegMixed(b *testing.B) { + var a g1JacExtended + a.doubleMixed(&g1GenAff) + + var c G1Affine + c.FromJacobian(&g1Gen) + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.doubleNegMixed(&c) + } +} + +func BenchmarkG1JacExtAdd(b *testing.B) { + var a, c g1JacExtended + a.doubleMixed(&g1GenAff) + c.double(&a) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.add(&c) + } +} + +func BenchmarkG1JacExtDouble(b *testing.B) { + var a g1JacExtended + a.doubleMixed(&g1GenAff) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + a.double(&a) + } +} + +func fuzzG1Jac(p *G1Jac, f fp.Element) G1Jac { + var res G1Jac + res.X.Mul(&p.X, &f).Mul(&res.X, &f) + res.Y.Mul(&p.Y, &f).Mul(&res.Y, &f).Mul(&res.Y, &f) + res.Z.Mul(&p.Z, &f) + return res +} + +func fuzzg1JacExtended(p *g1JacExtended, f fp.Element) g1JacExtended { + var res g1JacExtended + var ff, fff fp.Element + ff.Square(&f) + fff.Mul(&ff, &f) + res.X.Mul(&p.X, &ff) + res.Y.Mul(&p.Y, &fff) + res.ZZ.Mul(&p.ZZ, &ff) + res.ZZZ.Mul(&p.ZZZ, &fff) + return res +} + +// define Gopters generators + +// GenFr generates an Fr element +func GenFr() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var elmt fr.Element + + if _, err := elmt.SetRandom(); err != nil { + panic(err) + } + + return gopter.NewGenResult(elmt, gopter.NoShrinker) + } +} + +// GenFp generates an Fp element +func GenFp() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var elmt fp.Element + + if _, err := elmt.SetRandom(); err != nil { + panic(err) + } + + return gopter.NewGenResult(elmt, gopter.NoShrinker) + } +} + +// GenBigInt generates a big.Int +func GenBigInt() gopter.Gen { + return func(genParams *gopter.GenParameters) *gopter.GenResult { + var s big.Int + var b [fp.Bytes]byte + _, err := rand.Read(b[:]) + if err != nil { + panic(err) + } + s.SetBytes(b[:]) + genResult := gopter.NewGenResult(s, gopter.NoShrinker) + return genResult + } +} diff --git a/ecc/secp256k1/hash_to_g1.go b/ecc/secp256k1/hash_to_g1.go new file mode 100644 index 000000000..8f87a6634 --- /dev/null +++ b/ecc/secp256k1/hash_to_g1.go @@ -0,0 +1,161 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package secp256k1 + +import ( + "github.com/consensys/gnark-crypto/ecc/secp256k1/fp" +) + +// mapToCurve1 implements the Shallue and van de Woestijne method, applicable to any elliptic curve in Weierstrass form +// No cofactor clearing or isogeny +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#straightline-svdw +func mapToCurve1(u *fp.Element) G1Affine { + var tv1, tv2, tv3, tv4 fp.Element + var x1, x2, x3, gx1, gx2, gx, x, y fp.Element + var one fp.Element + var gx1NotSquare, gx1SquareOrGx2Not int + + //constants + //c1 = g(Z) + //c2 = -Z / 2 + //c3 = sqrt(-g(Z) * (3 * Z² + 4 * A)) # sgn0(c3) MUST equal 0 + //c4 = -4 * g(Z) / (3 * Z² + 4 * A) + + Z := fp.Element{4294968273} + c1 := fp.Element{34359746184} + c2 := fp.Element{18446744069414583343, 18446744073709551615, 18446744073709551615, 9223372036854775807} + c3 := fp.Element{17776356894683298668, 8246953344079101458, 5263337950127723085, 5166576015932556431} + c4 := fp.Element{18446744023601588431, 18446744073709551615, 18446744073709551615, 18446744073709551615} + + one.SetOne() + + tv1.Square(u) // 1. tv1 = u² + tv1.Mul(&tv1, &c1) // 2. tv1 = tv1 * c1 + tv2.Add(&one, &tv1) // 3. tv2 = 1 + tv1 + tv1.Sub(&one, &tv1) // 4. tv1 = 1 - tv1 + tv3.Mul(&tv1, &tv2) // 5. tv3 = tv1 * tv2 + + tv3.Inverse(&tv3) // 6. tv3 = inv0(tv3) + tv4.Mul(u, &tv1) // 7. tv4 = u * tv1 + tv4.Mul(&tv4, &tv3) // 8. tv4 = tv4 * tv3 + tv4.Mul(&tv4, &c3) // 9. tv4 = tv4 * c3 + x1.Sub(&c2, &tv4) // 10. x1 = c2 - tv4 + + gx1.Square(&x1) // 11. gx1 = x1² + //12. gx1 = gx1 + A All curves in gnark-crypto have A=0 (j-invariant=0). It is crucial to include this step if the curve has nonzero A coefficient. + gx1.Mul(&gx1, &x1) // 13. gx1 = gx1 * x1 + gx1.Add(&gx1, &bCurveCoeff) // 14. gx1 = gx1 + B + gx1NotSquare = gx1.Legendre() >> 1 // 15. e1 = is_square(gx1) + // gx1NotSquare = 0 if gx1 is a square, -1 otherwise + + x2.Add(&c2, &tv4) // 16. x2 = c2 + tv4 + gx2.Square(&x2) // 17. gx2 = x2² + // 18. gx2 = gx2 + A See line 12 + gx2.Mul(&gx2, &x2) // 19. gx2 = gx2 * x2 + gx2.Add(&gx2, &bCurveCoeff) // 20. gx2 = gx2 + B + + { + gx2NotSquare := gx2.Legendre() >> 1 // gx2Square = 0 if gx2 is a square, -1 otherwise + gx1SquareOrGx2Not = gx2NotSquare | ^gx1NotSquare // 21. e2 = is_square(gx2) AND NOT e1 # Avoid short-circuit logic ops + } + + x3.Square(&tv2) // 22. x3 = tv2² + x3.Mul(&x3, &tv3) // 23. x3 = x3 * tv3 + x3.Square(&x3) // 24. x3 = x3² + x3.Mul(&x3, &c4) // 25. x3 = x3 * c4 + + x3.Add(&x3, &Z) // 26. x3 = x3 + Z + x.Select(gx1NotSquare, &x1, &x3) // 27. x = CMOV(x3, x1, e1) # x = x1 if gx1 is square, else x = x3 + // Select x1 iff gx1 is square iff gx1NotSquare = 0 + x.Select(gx1SquareOrGx2Not, &x2, &x) // 28. x = CMOV(x, x2, e2) # x = x2 if gx2 is square and gx1 is not + // Select x2 iff gx2 is square and gx1 is not, iff gx1SquareOrGx2Not = 0 + gx.Square(&x) // 29. gx = x² + // 30. gx = gx + A + + gx.Mul(&gx, &x) // 31. gx = gx * x + gx.Add(&gx, &bCurveCoeff) // 32. gx = gx + B + + y.Sqrt(&gx) // 33. y = sqrt(gx) + signsNotEqual := g1Sgn0(u) ^ g1Sgn0(&y) // 34. e3 = sgn0(u) == sgn0(y) + + tv1.Neg(&y) + y.Select(int(signsNotEqual), &y, &tv1) // 35. y = CMOV(-y, y, e3) # Select correct sign of y + return G1Affine{x, y} +} + +// g1Sgn0 is an algebraic substitute for the notion of sign in ordered fields +// Namely, every non-zero quadratic residue in a finite field of characteristic =/= 2 has exactly two square roots, one of each sign +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#name-the-sgn0-function +// The sign of an element is not obviously related to that of its Montgomery form +func g1Sgn0(z *fp.Element) uint64 { + + nonMont := z.Bits() + + // m == 1 + return nonMont[0] % 2 + +} + +// MapToG1 invokes the SVDW map, and guarantees that the result is in g1 +func MapToG1(u fp.Element) G1Affine { + res := mapToCurve1(&u) + return res +} + +// EncodeToG1 hashes a message to a point on the G1 curve using the SVDW map. +// It is faster than HashToG1, but the result is not uniformly distributed. Unsuitable as a random oracle. +// dst stands for "domain separation tag", a string unique to the construction using the hash function +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +func EncodeToG1(msg, dst []byte) (G1Affine, error) { + + var res G1Affine + u, err := fp.Hash(msg, dst, 1) + if err != nil { + return res, err + } + + res = mapToCurve1(&u[0]) + + return res, nil +} + +// HashToG1 hashes a message to a point on the G1 curve using the SVDW map. +// Slower than EncodeToG1, but usable as a random oracle. +// dst stands for "domain separation tag", a string unique to the construction using the hash function +// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap +func HashToG1(msg, dst []byte) (G1Affine, error) { + u, err := fp.Hash(msg, dst, 2*1) + if err != nil { + return G1Affine{}, err + } + + Q0 := mapToCurve1(&u[0]) + Q1 := mapToCurve1(&u[1]) + + var _Q0, _Q1 G1Jac + _Q0.FromAffine(&Q0) + _Q1.FromAffine(&Q1).AddAssign(&_Q0) + + Q1.FromJacobian(&_Q1) + return Q1, nil +} + +func g1NotZero(x *fp.Element) uint64 { + + return x[0] | x[1] | x[2] | x[3] + +} diff --git a/ecc/secp256k1/hash_to_g1_test.go b/ecc/secp256k1/hash_to_g1_test.go new file mode 100644 index 000000000..825264e00 --- /dev/null +++ b/ecc/secp256k1/hash_to_g1_test.go @@ -0,0 +1,234 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package secp256k1 + +import ( + "github.com/consensys/gnark-crypto/ecc/secp256k1/fp" + "github.com/leanovate/gopter" + "github.com/leanovate/gopter/prop" + "math/rand" + "testing" +) + +func TestHashToFpG1(t *testing.T) { + for _, c := range encodeToG1Vector.cases { + elems, err := fp.Hash([]byte(c.msg), encodeToG1Vector.dst, 1) + if err != nil { + t.Error(err) + } + g1TestMatchCoord(t, "u", c.msg, c.u, g1CoordAt(elems, 0)) + } + + for _, c := range hashToG1Vector.cases { + elems, err := fp.Hash([]byte(c.msg), hashToG1Vector.dst, 2*1) + if err != nil { + t.Error(err) + } + g1TestMatchCoord(t, "u0", c.msg, c.u0, g1CoordAt(elems, 0)) + g1TestMatchCoord(t, "u1", c.msg, c.u1, g1CoordAt(elems, 1)) + } +} + +func TestMapToCurve1(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + properties.Property("[G1] mapping output must be on curve", prop.ForAll( + func(a fp.Element) bool { + + g := mapToCurve1(&a) + + if !g.IsOnCurve() { + t.Log("SVDW output not on curve") + return false + } + + return true + }, + GenFp(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) + + for _, c := range encodeToG1Vector.cases { + var u fp.Element + g1CoordSetString(&u, c.u) + q := mapToCurve1(&u) + g1TestMatchPoint(t, "Q", c.msg, c.Q, &q) + } + + for _, c := range hashToG1Vector.cases { + var u fp.Element + g1CoordSetString(&u, c.u0) + q := mapToCurve1(&u) + g1TestMatchPoint(t, "Q0", c.msg, c.Q0, &q) + + g1CoordSetString(&u, c.u1) + q = mapToCurve1(&u) + g1TestMatchPoint(t, "Q1", c.msg, c.Q1, &q) + } +} + +func TestMapToG1(t *testing.T) { + t.Parallel() + parameters := gopter.DefaultTestParameters() + if testing.Short() { + parameters.MinSuccessfulTests = nbFuzzShort + } else { + parameters.MinSuccessfulTests = nbFuzz + } + + properties := gopter.NewProperties(parameters) + + properties.Property("[G1] mapping to curve should output point on the curve", prop.ForAll( + func(a fp.Element) bool { + g := MapToG1(a) + return g.IsInSubGroup() + }, + GenFp(), + )) + + properties.Property("[G1] mapping to curve should be deterministic", prop.ForAll( + func(a fp.Element) bool { + g1 := MapToG1(a) + g2 := MapToG1(a) + return g1.Equal(&g2) + }, + GenFp(), + )) + + properties.TestingRun(t, gopter.ConsoleReporter(false)) +} + +func TestEncodeToG1(t *testing.T) { + t.Parallel() + for _, c := range encodeToG1Vector.cases { + p, err := EncodeToG1([]byte(c.msg), encodeToG1Vector.dst) + if err != nil { + t.Fatal(err) + } + g1TestMatchPoint(t, "P", c.msg, c.P, &p) + } +} + +func TestHashToG1(t *testing.T) { + t.Parallel() + for _, c := range hashToG1Vector.cases { + p, err := HashToG1([]byte(c.msg), hashToG1Vector.dst) + if err != nil { + t.Fatal(err) + } + g1TestMatchPoint(t, "P", c.msg, c.P, &p) + } +} + +func BenchmarkEncodeToG1(b *testing.B) { + const size = 54 + bytes := make([]byte, size) + dst := encodeToG1Vector.dst + b.ResetTimer() + + for i := 0; i < b.N; i++ { + + bytes[rand.Int()%size] = byte(rand.Int()) + + if _, err := EncodeToG1(bytes, dst); err != nil { + b.Fail() + } + } +} + +func BenchmarkHashToG1(b *testing.B) { + const size = 54 + bytes := make([]byte, size) + dst := hashToG1Vector.dst + b.ResetTimer() + + for i := 0; i < b.N; i++ { + + bytes[rand.Int()%size] = byte(rand.Int()) + + if _, err := HashToG1(bytes, dst); err != nil { + b.Fail() + } + } +} + +// Only works on simple extensions (two-story towers) +func g1CoordSetString(z *fp.Element, s string) { + z.SetString(s) +} + +func g1CoordAt(slice []fp.Element, i int) fp.Element { + return slice[i] +} + +func g1TestMatchCoord(t *testing.T, coordName string, msg string, expectedStr string, seen fp.Element) { + var expected fp.Element + + g1CoordSetString(&expected, expectedStr) + + if !expected.Equal(&seen) { + t.Errorf("mismatch on \"%s\", %s:\n\texpected %s\n\tsaw %s", msg, coordName, expected.String(), &seen) + } +} + +func g1TestMatchPoint(t *testing.T, pointName string, msg string, expected point, seen *G1Affine) { + g1TestMatchCoord(t, pointName+".x", msg, expected.x, seen.X) + g1TestMatchCoord(t, pointName+".y", msg, expected.y, seen.Y) +} + +type hashTestVector struct { + dst []byte + cases []hashTestCase +} + +type encodeTestVector struct { + dst []byte + cases []encodeTestCase +} + +type point struct { + x string + y string +} + +type encodeTestCase struct { + msg string + P point //pY a coordinate of P, the final output + u string //u hashed onto the field + Q point //Q map to curve output +} + +type hashTestCase struct { + msg string + P point //pY a coordinate of P, the final output + u0 string //u0 hashed onto the field + u1 string //u1 extra hashed onto the field + Q0 point //Q0 map to curve output + Q1 point //Q1 extra map to curve output +} + +var encodeToG1Vector encodeTestVector +var hashToG1Vector hashTestVector diff --git a/ecc/secp256k1/secp256k1.go b/ecc/secp256k1/secp256k1.go new file mode 100644 index 000000000..feed74366 --- /dev/null +++ b/ecc/secp256k1/secp256k1.go @@ -0,0 +1,91 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package secp256k1 efficient elliptic curve implementation for secp256k1. This curve is defined in Standards for Efficient Cryptography (SEC) (Certicom Research, http://www.secg.org/sec2-v2.pdf) and appears in the Bitcoin and Ethereum ECDSA signatures. +// +// secp256k1: A j=0 curve with +// +// 𝔽r: r=115792089237316195423570985008687907852837564279074904382605163141518161494337 +// 𝔽p: p=115792089237316195423570985008687907853269984665640564039457584007908834671663 (2^256 - 2^32 - 977) +// (E/𝔽p): Y²=X³+7 +// +// Security: estimated 128-bit level using Pollard's \rho attack +// (r is 256 bits) +// +// # Warning +// +// This code has been partially audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. +package secp256k1 + +import ( + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/secp256k1/fp" + "github.com/consensys/gnark-crypto/ecc/secp256k1/fr" +) + +// ID secp256k1 ID +const ID = ecc.SECP256K1 + +// bCurveCoeff b coeff of the curve Y²=X³+b +var bCurveCoeff fp.Element + +// generator of the r-torsion group +var g1Gen G1Jac + +var g1GenAff G1Affine + +// point at infinity +var g1Infinity G1Jac + +// Parameters useful for the GLV scalar multiplication. The third roots define the +// endomorphisms ϕ₁ for . lambda is such that lies above +// in the ring Z[ϕ]. More concretely it's the associated eigenvalue +// of ϕ₁ restricted to +// see https://www.cosic.esat.kuleuven.be/nessie/reports/phase2/GLV.pdf +var thirdRootOneG1 fp.Element +var lambdaGLV big.Int + +// glvBasis stores R-linearly independent vectors (a,b), (c,d) +// in ker((u,v) → u+vλ[r]), and their determinant +var glvBasis ecc.Lattice + +func init() { + + bCurveCoeff.SetUint64(7) + + g1Gen.X.SetString("55066263022277343669578718895168534326250603453777594175500187360389116729240") + g1Gen.Y.SetString("32670510020758816978083085130507043184471273380659243275938904335757337482424") + g1Gen.Z.SetOne() + + g1GenAff.FromJacobian(&g1Gen) + + // (X,Y,Z) = (1,1,0) + g1Infinity.X.SetOne() + g1Infinity.Y.SetOne() + + thirdRootOneG1.SetString("55594575648329892869085402983802832744385952214688224221778511981742606582254") // 2^((p-1)/3) + lambdaGLV.SetString("37718080363155996902926221483475020450927657555482586988616620542887997980018", 10) // 3^((r-1)/3) + _r := fr.Modulus() + ecc.PrecomputeLattice(_r, &lambdaGLV, &glvBasis) + +} + +// Generators return the generators of the r-torsion group, resp. in ker(pi-id), ker(Tr) +func Generators() (g1Jac G1Jac, g1Aff G1Affine) { + g1Aff = g1GenAff + g1Jac = g1Gen + return +} diff --git a/ecc/utils.go b/ecc/utils.go index f02b3c3d7..78126fc1c 100644 --- a/ecc/utils.go +++ b/ecc/utils.go @@ -1,8 +1,6 @@ package ecc import ( - "crypto/sha256" - "errors" "math/big" "math/bits" ) @@ -180,94 +178,6 @@ func getVector(l *Lattice, a, b *big.Int) [2]big.Int { return res } -func min(a, b int) int { - if a < b { - return a - } - return b -} - -// ExpandMsgXmd expands msg to a slice of lenInBytes bytes. -// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5 -// https://tools.ietf.org/html/rfc8017#section-4.1 (I2OSP/O2ISP) -func ExpandMsgXmd(msg, dst []byte, lenInBytes int) ([]byte, error) { - - h := sha256.New() - ell := (lenInBytes + h.Size() - 1) / h.Size() // ceil(len_in_bytes / b_in_bytes) - if ell > 255 { - return nil, errors.New("invalid lenInBytes") - } - if len(dst) > 255 { - return nil, errors.New("invalid domain size (>255 bytes)") - } - sizeDomain := uint8(len(dst)) - - // Z_pad = I2OSP(0, r_in_bytes) - // l_i_b_str = I2OSP(len_in_bytes, 2) - // DST_prime = I2OSP(len(DST), 1) ∥ DST - // b₀ = H(Z_pad ∥ msg ∥ l_i_b_str ∥ I2OSP(0, 1) ∥ DST_prime) - h.Reset() - if _, err := h.Write(make([]byte, h.BlockSize())); err != nil { - return nil, err - } - if _, err := h.Write(msg); err != nil { - return nil, err - } - if _, err := h.Write([]byte{uint8(lenInBytes >> 8), uint8(lenInBytes), uint8(0)}); err != nil { - return nil, err - } - if _, err := h.Write(dst); err != nil { - return nil, err - } - if _, err := h.Write([]byte{sizeDomain}); err != nil { - return nil, err - } - b0 := h.Sum(nil) - - // b₁ = H(b₀ ∥ I2OSP(1, 1) ∥ DST_prime) - h.Reset() - if _, err := h.Write(b0); err != nil { - return nil, err - } - if _, err := h.Write([]byte{uint8(1)}); err != nil { - return nil, err - } - if _, err := h.Write(dst); err != nil { - return nil, err - } - if _, err := h.Write([]byte{sizeDomain}); err != nil { - return nil, err - } - b1 := h.Sum(nil) - - res := make([]byte, lenInBytes) - copy(res[:h.Size()], b1) - - for i := 2; i <= ell; i++ { - // b_i = H(strxor(b₀, b_(i - 1)) ∥ I2OSP(i, 1) ∥ DST_prime) - h.Reset() - strxor := make([]byte, h.Size()) - for j := 0; j < h.Size(); j++ { - strxor[j] = b0[j] ^ b1[j] - } - if _, err := h.Write(strxor); err != nil { - return nil, err - } - if _, err := h.Write([]byte{uint8(i)}); err != nil { - return nil, err - } - if _, err := h.Write(dst); err != nil { - return nil, err - } - if _, err := h.Write([]byte{sizeDomain}); err != nil { - return nil, err - } - b1 = h.Sum(nil) - copy(res[h.Size()*(i-1):min(h.Size()*i, len(res))], b1) - } - return res, nil -} - // NextPowerOfTwo returns the next power of 2 of n func NextPowerOfTwo(n uint64) uint64 { c := bits.OnesCount64(n) diff --git a/ecc/utils_test.go b/ecc/utils_test.go index e4a98b16d..3bd9e483a 100644 --- a/ecc/utils_test.go +++ b/ecc/utils_test.go @@ -1,8 +1,6 @@ package ecc import ( - "bytes" - "encoding/hex" "math/big" "testing" ) @@ -61,128 +59,3 @@ func BenchmarkSplitting256(b *testing.B) { } } - -type expandMsgXmdTestCase struct { - msg string - lenInBytes int - uniformBytesHex string -} - -//Test vectors from https://datatracker.ietf.org/doc/draft-irtf-cfrg-hash-to-curve/14/ Page 148 Section K.1. -func TestExpandMsgXmd(t *testing.T) { - //name := "expand_message_xmd" - dst := "QUUX-V01-CS02-with-expander-SHA256-128" - //hash := "SHA256" - //k := 128 - - testCases := []expandMsgXmdTestCase{ - { - "", - 0x20, - "68a985b87eb6b46952128911f2a4412bbc302a9d759667f87f7a21d803f07235", - }, - - { - "abc", - 0x20, - "d8ccab23b5985ccea865c6c97b6e5b8350e794e603b4b97902f53a8a0d605615", - }, - - { - "abcdef0123456789", - 0x20, - "eff31487c770a893cfb36f912fbfcbff40d5661771ca4b2cb4eafe524333f5c1", - }, - - { - "q128_qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq", - 0x20, - "b23a1d2b4d97b2ef7785562a7e8bac7eed54ed6e97e29aa51bfe3f12ddad1ff9", - }, - - { - "a512_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", - 0x20, - "4623227bcc01293b8c130bf771da8c298dede7383243dc0993d2d94823958c4c", - }, - { - "", - 0x80, - "af84c27ccfd45d41914fdff5df25293e221afc53d8ad2ac06d5e3e29485dadbee0d121587713a3e0dd4d5e69e93eb7cd4f5df4cd103e188cf60cb02edc3edf18eda8576c412b18ffb658e3dd6ec849469b979d444cf7b26911a08e63cf31f9dcc541708d3491184472c2c29bb749d4286b004ceb5ee6b9a7fa5b646c993f0ced", - }, - { - "", - 0x20, - "68a985b87eb6b46952128911f2a4412bbc302a9d759667f87f7a21d803f07235", - }, - { - "abc", - 0x80, - "abba86a6129e366fc877aab32fc4ffc70120d8996c88aee2fe4b32d6c7b6437a647e6c3163d40b76a73cf6a5674ef1d890f95b664ee0afa5359a5c4e07985635bbecbac65d747d3d2da7ec2b8221b17b0ca9dc8a1ac1c07ea6a1e60583e2cb00058e77b7b72a298425cd1b941ad4ec65e8afc50303a22c0f99b0509b4c895f40", - }, - { - "abcdef0123456789", - 0x80, - "ef904a29bffc4cf9ee82832451c946ac3c8f8058ae97d8d629831a74c6572bd9ebd0df635cd1f208e2038e760c4994984ce73f0d55ea9f22af83ba4734569d4bc95e18350f740c07eef653cbb9f87910d833751825f0ebefa1abe5420bb52be14cf489b37fe1a72f7de2d10be453b2c9d9eb20c7e3f6edc5a60629178d9478df", - }, - { - "q128_qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq", - 0x80, - "80be107d0884f0d881bb460322f0443d38bd222db8bd0b0a5312a6fedb49c1bbd88fd75d8b9a09486c60123dfa1d73c1cc3169761b17476d3c6b7cbbd727acd0e2c942f4dd96ae3da5de368d26b32286e32de7e5a8cb2949f866a0b80c58116b29fa7fabb3ea7d520ee603e0c25bcaf0b9a5e92ec6a1fe4e0391d1cdbce8c68a", - }, - { - "a512_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", - 0x80, - "546aff5444b5b79aa6148bd81728704c32decb73a3ba76e9e75885cad9def1d06d6792f8a7d12794e90efed817d96920d728896a4510864370c207f99bd4a608ea121700ef01ed879745ee3e4ceef777eda6d9e5e38b90c86ea6fb0b36504ba4a45d22e86f6db5dd43d98a294bebb9125d5b794e9d2a81181066eb954966a487", - }, - //test cases not in the standard - { - "", - 0x30, - "3808e9bb0ade2df3aa6f1b459eb5058a78142f439213ddac0c97dcab92ae5a8408d86b32bbcc87de686182cbdf65901f", - }, - { - "abc", - 0x30, - "2b877f5f0dfd881405426c6b87b39205ef53a548b0e4d567fc007cb37c6fa1f3b19f42871efefca518ac950c27ac4e28", - }, - { - "abcdef0123456789", - 0x30, - "226da1780b06e59723714f80da9a63648aebcfc1f08e0db87b5b4d16b108da118214c1450b0e86f9cefeb44903fd3aba", - }, - { - "q128_qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq", - 0x30, - "12b23ae2e888f442fd6d0d85d90a0d7ed5337d38113e89cdc7c22db91bd0abaec1023e9a8f0ef583a111104e2f8a0637", - }, - { - "a512_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", - 0x30, - "1aaee90016547a85ab4dc55e4f78a364c2e239c0e58b05753453c63e6e818334005e90d9ce8f047bddab9fbb315f8722", - }, - } - - for _, testCase := range testCases { - uniformBytes, err := ExpandMsgXmd([]byte(testCase.msg), []byte(dst), testCase.lenInBytes) - if err != nil { - t.Fatal(err) - } - - var testCaseUniformBytes []byte - testCaseUniformBytes, err = hex.DecodeString(testCase.uniformBytesHex) - if err != nil { - t.Fatal(err) - } - - if len(uniformBytes) != testCase.lenInBytes { - t.Error("wrong length: expected", testCase.lenInBytes, "got", len(uniformBytes)) - } - - if !bytes.Equal(uniformBytes, testCaseUniformBytes) { - uniformBytesHex := make([]byte, len(uniformBytes)*2) - hex.Encode(uniformBytesHex, uniformBytes) - t.Errorf("expected \"%s\" got \"%s\"", testCase.uniformBytesHex, uniformBytesHex) - } - } -} diff --git a/fiat-shamir/settings.go b/fiat-shamir/settings.go new file mode 100644 index 000000000..3d4af160f --- /dev/null +++ b/fiat-shamir/settings.go @@ -0,0 +1,25 @@ +package fiatshamir + +import "hash" + +type Settings struct { + Transcript *Transcript + Prefix string + BaseChallenges [][]byte + Hash hash.Hash +} + +func WithTranscript(transcript *Transcript, prefix string, baseChallenges ...[]byte) Settings { + return Settings{ + Transcript: transcript, + Prefix: prefix, + BaseChallenges: baseChallenges, + } +} + +func WithHash(hash hash.Hash, baseChallenges ...[]byte) Settings { + return Settings{ + BaseChallenges: baseChallenges, + Hash: hash, + } +} diff --git a/fiat-shamir/transcript.go b/fiat-shamir/transcript.go index 070922ba7..d371556d0 100644 --- a/fiat-shamir/transcript.go +++ b/fiat-shamir/transcript.go @@ -21,7 +21,7 @@ import ( // errChallengeNotFound is returned when a wrong challenge name is provided. var ( - errChallengeNotFound = errors.New("challenge not recorded in the Transcript") + errChallengeNotFound = errors.New("challenge not recorded in the transcript") errChallengeAlreadyComputed = errors.New("challenge already computed, cannot be binded to other values") errPreviousChallengeNotComputed = errors.New("the previous challenge is needed and has not been computed") ) @@ -36,7 +36,7 @@ type Transcript struct { } type challenge struct { - position int // position of the challenge in the transcript. order matters. + position int // position of the challenge in the Transcript. order matters. bindings []byte // bindings stores the variables a challenge is binded to. value []byte // value stores the computed challenge isComputed bool diff --git a/internal/field/asm/amd64/asm_macros.go b/field/generator/asm/amd64/asm_macros.go similarity index 100% rename from internal/field/asm/amd64/asm_macros.go rename to field/generator/asm/amd64/asm_macros.go diff --git a/internal/field/asm/amd64/build.go b/field/generator/asm/amd64/build.go similarity index 93% rename from internal/field/asm/amd64/build.go rename to field/generator/asm/amd64/build.go index a52196c38..ed022ee18 100644 --- a/internal/field/asm/amd64/build.go +++ b/field/generator/asm/amd64/build.go @@ -23,17 +23,17 @@ import ( "github.com/consensys/bavard" "github.com/consensys/bavard/amd64" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field/generator/config" ) const SmallModulus = 6 -func NewFFAmd64(w io.Writer, F *field.FieldConfig) *FFAmd64 { +func NewFFAmd64(w io.Writer, F *config.FieldConfig) *FFAmd64 { return &FFAmd64{F, amd64.NewAmd64(w), 0, 0} } type FFAmd64 struct { - *field.FieldConfig + *config.FieldConfig *amd64.Amd64 nbElementsOnStack int maxOnStack int @@ -137,7 +137,7 @@ func (f *FFAmd64) qInv0() string { // Generate generates assembly code for the base field provided to goff // see internal/templates/ops* -func Generate(w io.Writer, F *field.FieldConfig) error { +func Generate(w io.Writer, F *config.FieldConfig) error { f := NewFFAmd64(w, F) f.WriteLn(bavard.Apache2Header("ConsenSys Software Inc.", 2020)) @@ -161,7 +161,7 @@ func Generate(w io.Writer, F *field.FieldConfig) error { return nil } -func GenerateMul(w io.Writer, F *field.FieldConfig) error { +func GenerateMul(w io.Writer, F *config.FieldConfig) error { f := NewFFAmd64(w, F) f.WriteLn(bavard.Apache2Header("ConsenSys Software Inc.", 2020)) @@ -179,7 +179,7 @@ func GenerateMul(w io.Writer, F *field.FieldConfig) error { return nil } -func GenerateMulADX(w io.Writer, F *field.FieldConfig) error { +func GenerateMulADX(w io.Writer, F *config.FieldConfig) error { f := NewFFAmd64(w, F) f.WriteLn(bavard.Apache2Header("ConsenSys Software Inc.", 2020)) diff --git a/internal/field/asm/amd64/element_butterfly.go b/field/generator/asm/amd64/element_butterfly.go similarity index 96% rename from internal/field/asm/amd64/element_butterfly.go rename to field/generator/asm/amd64/element_butterfly.go index c907d614a..87a2c5468 100644 --- a/internal/field/asm/amd64/element_butterfly.go +++ b/field/generator/asm/amd64/element_butterfly.go @@ -17,11 +17,12 @@ package amd64 // Butterfly sets // a = a + b // b = a - b -// func Butterfly(a, b *{{.ElementName}}) { -// t := *a -// a.Add(a, b) -// b.Sub(&t, b) -// } +// +// func Butterfly(a, b *{{.ElementName}}) { +// t := *a +// a.Add(a, b) +// b.Sub(&t, b) +// } func (f *FFAmd64) generateButterfly() { f.Comment("Butterfly(a, b *Element) sets a = a + b; b = a - b") diff --git a/internal/field/asm/amd64/element_frommont.go b/field/generator/asm/amd64/element_frommont.go similarity index 100% rename from internal/field/asm/amd64/element_frommont.go rename to field/generator/asm/amd64/element_frommont.go diff --git a/internal/field/asm/amd64/element_mul.go b/field/generator/asm/amd64/element_mul.go similarity index 100% rename from internal/field/asm/amd64/element_mul.go rename to field/generator/asm/amd64/element_mul.go diff --git a/internal/field/asm/amd64/element_mul_constants.go b/field/generator/asm/amd64/element_mul_constants.go similarity index 100% rename from internal/field/asm/amd64/element_mul_constants.go rename to field/generator/asm/amd64/element_mul_constants.go diff --git a/internal/field/asm/amd64/element_reduce.go b/field/generator/asm/amd64/element_reduce.go similarity index 100% rename from internal/field/asm/amd64/element_reduce.go rename to field/generator/asm/amd64/element_reduce.go diff --git a/internal/field/extension.go b/field/generator/config/extension.go similarity index 97% rename from internal/field/extension.go rename to field/generator/config/extension.go index b2adc1ff4..53221f89c 100644 --- a/internal/field/extension.go +++ b/field/generator/config/extension.go @@ -1,10 +1,10 @@ -package field +package config import "math/big" type Element []big.Int -//Extension is a simple radical extension, obtained by adjoining ⁿ√α to Fp +// Extension is a simple radical extension, obtained by adjoining ⁿ√α to Fp type Extension struct { Base *FieldConfig //Fp Size big.Int //q diff --git a/internal/field/field.go b/field/generator/config/field_config.go similarity index 92% rename from internal/field/field.go rename to field/generator/config/field_config.go index eafca7335..457a89d7d 100644 --- a/internal/field/field.go +++ b/field/generator/config/field_config.go @@ -12,19 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package field provides Golang code generation for efficient field arithmetic operations. -package field +// Package config provides Golang code generation for efficient field arithmetic operations. +package config import ( "errors" "fmt" - "github.com/consensys/bavard" - "github.com/consensys/gnark-crypto/internal/field/internal/addchain" "math" "math/big" - "math/bits" "strconv" "strings" + + "github.com/consensys/bavard" + "github.com/consensys/gnark-crypto/field/generator/internal/addchain" ) var ( @@ -40,6 +40,7 @@ type FieldConfig struct { ModulusHex string NbWords int NbBits int + NbBytes int NbWordsLastIndex int NbWordsIndexesNoZero []int NbWordsIndexesFull []int @@ -95,6 +96,7 @@ func NewFieldConfig(packageName, elementName, modulus string, useAddChain bool) // pre compute field constants F.NbBits = bModulus.BitLen() F.NbWords = len(bModulus.Bits()) + F.NbBytes = F.NbWords * 8 // (F.NbBits + 7) / 8 F.NbWordsLastIndex = F.NbWords - 1 @@ -309,8 +311,8 @@ func extendedEuclideanAlgo(r, q, rInv, qInv *big.Int) { qInv.Neg(qInv) } -//StringToMont takes an element written in string form, and returns it in Montgomery form -//Useful for hard-coding in implementation field elements from standards documents +// StringToMont takes an element written in string form, and returns it in Montgomery form +// Useful for hard-coding in implementation field elements from standards documents func (f *FieldConfig) StringToMont(str string) big.Int { var i big.Int @@ -372,31 +374,6 @@ func (f *FieldConfig) halve(res *big.Int, x *big.Int) { res.Rsh(&z, 1) } -func BigIntMatchUint64Slice(aInt *big.Int, a []uint64) error { - - words := aInt.Bits() - - const steps = 64 / bits.UintSize - const filter uint64 = 0xFFFFFFFFFFFFFFFF >> (64 - bits.UintSize) - for i := 0; i < len(a)*steps; i++ { - - var wI big.Word - - if i < len(words) { - wI = words[i] - } - - aI := a[i/steps] >> ((i * bits.UintSize) % 64) - aI &= filter - - if uint64(wI) != aI { - return fmt.Errorf("bignum mismatch: disagreement on word %d: %x ≠ %x; %d ≠ %d", i, uint64(wI), aI, uint64(wI), aI) - } - } - - return nil -} - func (f *FieldConfig) Mul(z *big.Int, x *big.Int, y *big.Int) *FieldConfig { z.Mul(x, y).Mod(z, f.ModulusBig) return f @@ -415,7 +392,7 @@ func (f *FieldConfig) ToMontSlice(x []big.Int) []big.Int { return z } -//TODO: Spaghetti Alert: Okay to have codegen functions here? +// TODO: Spaghetti Alert: Okay to have codegen functions here? func CoordNameForExtensionDegree(degree uint8) string { switch degree { case 1: diff --git a/internal/field/field_test.go b/field/generator/config/field_test.go similarity index 97% rename from internal/field/field_test.go rename to field/generator/config/field_test.go index 72d486cd7..a2eba6d18 100644 --- a/internal/field/field_test.go +++ b/field/generator/config/field_test.go @@ -1,4 +1,4 @@ -package field +package config import ( "crypto/rand" @@ -7,6 +7,7 @@ import ( mrand "math/rand" "testing" + "github.com/consensys/gnark-crypto/field" "github.com/leanovate/gopter/gen" "github.com/leanovate/gopter" @@ -45,7 +46,7 @@ func TestIntToMont(t *testing.T) { i.Lsh(i, 64*uint(f.NbWords)) *i = f.ToMont(*i) - err := BigIntMatchUint64Slice(i, f.RSquare) + err := field.BigIntMatchUint64Slice(i, f.RSquare) return err == nil, err }, genF), ) @@ -73,7 +74,7 @@ func TestBigIntMatchUint64Slice(t *testing.T) { ints[j/8] ^= uint64(bytes[len(bytes)-1-j]) << (8 * (j % 8)) } - err := BigIntMatchUint64Slice(&i, ints) + err := field.BigIntMatchUint64Slice(&i, ints) return err == nil, err }, genF, genUint8SliceSlice(1))) diff --git a/internal/field/generator/generator.go b/field/generator/generator.go similarity index 79% rename from internal/field/generator/generator.go rename to field/generator/generator.go index 0a7ff5079..2af3c3fcd 100644 --- a/internal/field/generator/generator.go +++ b/field/generator/generator.go @@ -10,28 +10,26 @@ import ( "text/template" "github.com/consensys/bavard" - "github.com/consensys/gnark-crypto/internal/field" - "github.com/consensys/gnark-crypto/internal/field/asm/amd64" - "github.com/consensys/gnark-crypto/internal/field/internal/addchain" - "github.com/consensys/gnark-crypto/internal/field/internal/templates/element" + "github.com/consensys/gnark-crypto/field/generator/asm/amd64" + "github.com/consensys/gnark-crypto/field/generator/config" + "github.com/consensys/gnark-crypto/field/generator/internal/addchain" + "github.com/consensys/gnark-crypto/field/generator/internal/templates/element" ) -// TODO @gbotrel → pattern for code generation is different than gnark-crypto/internal because a binary like goff can generate -// base field. in Go 1.16, can embed the template in the binary, and use same pattern than gnark-crypto/internal - // GenerateFF will generate go (and .s) files in outputDir for modulus (in base 10) // // Example usage // -// fp, _ = field.NewField("fp", "Element", fpModulus") -// generator.GenerateFF(fp, filepath.Join(baseDir, "fp")) -func GenerateFF(F *field.FieldConfig, outputDir string) error { +// fp, _ = config.NewField("fp", "Element", fpModulus") +// generator.GenerateFF(fp, filepath.Join(baseDir, "fp")) +func GenerateFF(F *config.FieldConfig, outputDir string) error { // source file templates sourceFiles := []string{ element.Base, element.Reduce, element.Exp, element.Conv, + element.MulDoc, element.MulCIOS, element.MulNoCarry, element.Sqrt, @@ -59,7 +57,10 @@ func GenerateFF(F *field.FieldConfig, outputDir string) error { oldFiles := []string{"_mul.go", "_mul_amd64.go", "_square.go", "_square_amd64.go", "_ops_decl.go", "_square_amd64.s", "_mul_amd64.s", + "_mul_arm64.s", + "_mul_arm64.go", "_ops_amd64.s", + "_ops_noasm.go", "_mul_adx_amd64.s", "_ops_amd64.go", "_fuzz.go", @@ -119,6 +120,8 @@ func GenerateFF(F *field.FieldConfig, outputDir string) error { return err } + _, _ = io.WriteString(f, "// +build !purego\n") + if err := amd64.Generate(f, F); err != nil { _ = f.Close() return err @@ -143,7 +146,7 @@ func GenerateFF(F *field.FieldConfig, outputDir string) error { return err } - _, _ = io.WriteString(f, "// +build !amd64_adx\n") + _, _ = io.WriteString(f, "// +build !purego\n") if err := amd64.GenerateMul(f, F); err != nil { _ = f.Close() @@ -161,41 +164,21 @@ func GenerateFF(F *field.FieldConfig, outputDir string) error { } } - { - pathSrc := filepath.Join(outputDir, eName+"_mul_adx_amd64.s") - fmt.Println("generating", pathSrc) - f, err := os.Create(pathSrc) - if err != nil { - return err - } - - _, _ = io.WriteString(f, "// +build amd64_adx\n") - - if err := amd64.GenerateMulADX(f, F); err != nil { - _ = f.Close() - return err - } - _ = f.Close() - - // run asmfmt - // run go fmt on whole directory - cmd := exec.Command("asmfmt", "-w", pathSrc) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return err - } - } - } if F.ASM { // generate ops_amd64.go src := []string{ + element.MulDoc, element.OpsAMD64, } pathSrc := filepath.Join(outputDir, eName+"_ops_amd64.go") - if err := bavard.GenerateFromString(pathSrc, src, F, bavardOpts...); err != nil { + bavardOptsCpy := make([]func(*bavard.Bavard) error, len(bavardOpts)) + copy(bavardOptsCpy, bavardOpts) + if F.ASM { + bavardOptsCpy = append(bavardOptsCpy, bavard.BuildTag("!purego")) + } + if err := bavard.GenerateFromString(pathSrc, src, F, bavardOptsCpy...); err != nil { return err } } @@ -207,12 +190,13 @@ func GenerateFF(F *field.FieldConfig, outputDir string) error { element.MulCIOS, element.MulNoCarry, element.Reduce, + element.MulDoc, } - pathSrc := filepath.Join(outputDir, eName+"_ops_noasm.go") + pathSrc := filepath.Join(outputDir, eName+"_ops_purego.go") bavardOptsCpy := make([]func(*bavard.Bavard) error, len(bavardOpts)) copy(bavardOptsCpy, bavardOpts) if F.ASM { - bavardOptsCpy = append(bavardOptsCpy, bavard.BuildTag("!amd64")) + bavardOptsCpy = append(bavardOptsCpy, bavard.BuildTag("!amd64 purego")) } if err := bavard.GenerateFromString(pathSrc, src, F, bavardOptsCpy...); err != nil { return err diff --git a/internal/field/generator/generator_test.go b/field/generator/generator_test.go similarity index 97% rename from internal/field/generator/generator_test.go rename to field/generator/generator_test.go index a818e0216..8a73997b3 100644 --- a/internal/field/generator/generator_test.go +++ b/field/generator/generator_test.go @@ -23,7 +23,7 @@ import ( "path/filepath" "testing" - "github.com/consensys/gnark-crypto/internal/field" + field "github.com/consensys/gnark-crypto/field/generator/config" ) // integration test will create modulus for various field sizes and run tests diff --git a/internal/field/internal/addchain/addchain.go b/field/generator/internal/addchain/addchain.go similarity index 100% rename from internal/field/internal/addchain/addchain.go rename to field/generator/internal/addchain/addchain.go diff --git a/internal/field/internal/templates/element/arith.go b/field/generator/internal/templates/element/arith.go similarity index 97% rename from internal/field/internal/templates/element/arith.go rename to field/generator/internal/templates/element/arith.go index 52da213c0..7bf808648 100644 --- a/internal/field/internal/templates/element/arith.go +++ b/field/generator/internal/templates/element/arith.go @@ -14,7 +14,6 @@ func madd0(a, b, c uint64) (hi uint64) { return } -{{- if ne .NbWords 1}} // madd1 hi, lo = a*b + c func madd1(a, b, c uint64) (hi uint64, lo uint64) { var carry uint64 @@ -60,7 +59,6 @@ func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { } return b } -{{- end}} {{- end}} diff --git a/internal/field/internal/templates/element/asm.go b/field/generator/internal/templates/element/asm.go similarity index 100% rename from internal/field/internal/templates/element/asm.go rename to field/generator/internal/templates/element/asm.go diff --git a/internal/field/internal/templates/element/base.go b/field/generator/internal/templates/element/base.go similarity index 82% rename from internal/field/internal/templates/element/base.go rename to field/generator/internal/templates/element/base.go index f0388acf2..ca033ca5c 100644 --- a/internal/field/internal/templates/element/base.go +++ b/field/generator/internal/templates/element/base.go @@ -3,12 +3,12 @@ package element const Base = ` import ( + "github.com/consensys/gnark-crypto/field" "math/big" "math/bits" "io" "crypto/rand" "encoding/binary" - "sync" "strconv" "errors" "reflect" @@ -16,9 +16,9 @@ import ( ) // {{.ElementName}} represents a field element stored on {{.NbWords}} words (uint64) -// +// // {{.ElementName}} are assumed to be in Montgomery form in all methods. -// +// // Modulus q = // // q[base10] = {{.Modulus}} @@ -32,7 +32,7 @@ type {{.ElementName}} [{{.NbWords}}]uint64 const ( Limbs = {{.NbWords}} // number of 64 bits words needed to represent a {{.ElementName}} Bits = {{.NbBits}} // number of bits needed to represent a {{.ElementName}} - Bytes = Limbs * 8 // number of bytes needed to represent a {{.ElementName}} + Bytes = {{.NbBytes}} // number of bytes needed to represent a {{.ElementName}} ) @@ -65,12 +65,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = {{index .QInverse 0}} -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("{{.ModulusHex}}", 16) } @@ -90,7 +84,7 @@ func New{{.ElementName}}(v uint64) {{.ElementName}} { func (z *{{.ElementName}}) SetUint64(v uint64) *{{.ElementName}} { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = {{.ElementName}}{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -242,16 +236,14 @@ func (z *{{.ElementName}}) IsUint64() bool { return true {{- else}} zz := *z - zz.FromMont() + zz.fromMont() return zz.FitsOnOneWord() {{- end}} } // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *{{.ElementName}}) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -272,10 +264,8 @@ func (z *{{.ElementName}}) FitsOnOneWord() bool { // +1 if z > x // func (z *{{.ElementName}}) Cmp(x *{{.ElementName}}) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() {{- range $i := reverse $.NbWordsIndexesFull}} if _z[{{$i}}] > _x[{{$i}}] { return 1 @@ -293,8 +283,7 @@ func (z *{{.ElementName}}) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], {{index .QMinusOneHalvedP 0}}, 0) @@ -307,7 +296,7 @@ func (z *{{.ElementName}}) LexicographicallyLargest() bool { // SetRandom sets z to a uniform random value in [0, q). // -// This might error only if reading from crypto/rand.Reader errors, +// This might error only if reading from crypto/rand.Reader errors, // in which case, value of z is undefined. func (z *{{.ElementName}}) SetRandom() (*{{.ElementName}}, error) { // this code is generated for all modulus @@ -384,7 +373,7 @@ func (z *{{.ElementName}}) Halve() { {{- if not .NoCarry}} if carry != 0 { - // when we added q, the result was larger than our avaible limbs + // when we added q, the result was larger than our available limbs // when we shift right, we need to set the highest bit z[{{.NbWordsLastIndex}}] |= (1 << 63) } @@ -392,7 +381,7 @@ func (z *{{.ElementName}}) Halve() { } {{ define "add_q" }} - // {{$.V1}} = {{$.V1}} + q + // {{$.V1}} = {{$.V1}} + q {{- range $i := $.all.NbWordsIndexesFull }} {{- $carryIn := ne $i 0}} {{- $carryOut := or (ne $i $.all.NbWordsLastIndex) (and (eq $i $.all.NbWordsLastIndex) (not $.all.NoCarry))}} @@ -401,82 +390,10 @@ func (z *{{.ElementName}}) Halve() { {{ end }} -// Mul z = x * y (mod q) -{{- if $.NoCarry}} -// -// x and y must be strictly inferior to q -{{- end }} -func (z *{{.ElementName}}) Mul(x, y *{{.ElementName}}) *{{.ElementName}} { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - {{- if $.NoCarry}} - // - // As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: - // - // for i=0 to N-1 - // (A,t[0]) := t[0] + x[0]*y[i] - // m := t[0]*q'[0] mod W - // C,_ := t[0] + m*q[0] - // for j=1 to N-1 - // (A,t[j]) := t[j] + x[j]*y[i] + A - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // t[N-1] = C + A - // - // This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit - // of the modulus is zero (and not all of the remaining bits are set). - {{- end}} - - {{- if eq $.NbWords 1}} - {{ template "mul_cios_one_limb" dict "all" . "V1" "x" "V2" "y" }} - {{- else }} - mul(z, x, y) - {{- end }} - return z -} -// Square z = x * x (mod q) -{{- if $.NoCarry}} -// -// x must be strictly inferior to q -{{- end }} -func (z *{{.ElementName}}) Square(x *{{.ElementName}}) *{{.ElementName}} { - // see Mul for algorithm documentation - {{- if eq $.NbWords 1}} - {{ template "mul_cios_one_limb" dict "all" . "V1" "x" "V2" "x" }} - {{- else }} - mul(z, x, x) - {{- end }} - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *{{.ElementName}}) FromMont() *{{.ElementName}} { +func (z *{{.ElementName}}) fromMont() *{{.ElementName}} { fromMont(z) return z } @@ -521,7 +438,7 @@ func (z *{{.ElementName}}) Double( x *{{.ElementName}}) *{{.ElementName}} { {{- if eq .NbWords 1}} if x[0] & (1 << 63) == (1 << 63) { // if highest bit is set, then we have a carry to x + x, we shift and subtract q - z[0] = (x[0] << 1) - q + z[0] = (x[0] << 1) - q } else { // highest bit is not set, but x + x can still be >= q z[0] = (x[0] << 1) @@ -615,18 +532,13 @@ func (z *{{.ElementName}}) Select(c int, x0 *{{.ElementName}}, x1 *{{.ElementNam return z } - +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z,x,y *{{.ElementName}}) { - // see Mul for algorithm documentation - {{ if eq $.NbWords 1}} - {{ template "mul_cios_one_limb" dict "all" . "V1" "x" "V2" "y" }} - {{ else if .NoCarry}} - {{ template "mul_nocarry" dict "all" . "V1" "x" "V2" "y"}} - {{ template "reduce" . }} - {{ else }} - {{ template "mul_cios" dict "all" . "V1" "x" "V2" "y" }} - {{ template "reduce" . }} - {{ end }} + {{ mul_doc false }} + {{ template "mul_cios" dict "all" . "V1" "x" "V2" "y"}} + {{ template "reduce" . }} } @@ -703,6 +615,34 @@ func (z *{{.ElementName}}) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]{{.ElementName}}, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]{{.ElementName}}, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} {{ define "rsh V nbWords" }} diff --git a/internal/field/internal/templates/element/bignum.go b/field/generator/internal/templates/element/bignum.go similarity index 100% rename from internal/field/internal/templates/element/bignum.go rename to field/generator/internal/templates/element/bignum.go diff --git a/internal/field/internal/templates/element/conv.go b/field/generator/internal/templates/element/conv.go similarity index 58% rename from internal/field/internal/templates/element/conv.go rename to field/generator/internal/templates/element/conv.go index 2633f448a..b8fc1f849 100644 --- a/internal/field/internal/templates/element/conv.go +++ b/field/generator/internal/templates/element/conv.go @@ -10,23 +10,32 @@ var rSquare = {{.ElementName}}{ {{$i}},{{end}} } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *{{.ElementName}}) ToMont() *{{.ElementName}} { +func (z *{{.ElementName}}) toMont() *{{.ElementName}} { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z {{.ElementName}}) ToRegular() {{.ElementName}} { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *{{.ElementName}}) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *{{.ElementName}}) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + {{- range $i := reverse .NbWordsIndexesFull}} + {{- $j := mul $i 8}} + {{- $k := sub $.NbWords 1}} + {{- $k := sub $k $i}} + {{- $jj := add $j 8}} + binary.BigEndian.PutUint64(b[{{$j}}:{{$jj}}], z[{{$k}}]) + {{- end}} + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -46,66 +55,61 @@ func (z *{{.ElementName}}) Text(base int) string { if base == 10 { var zzNeg {{.ElementName}} zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } - zz := *z - zz.FromMont() + zz := z.Bits() return strconv.FormatUint(zz[0], base) {{- else }} if base == 10 { var zzNeg {{.ElementName}} zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg.FitsOnOneWord() && zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } zz := *z - zz.FromMont() + zz.fromMont() if zz.FitsOnOneWord() { return strconv.FormatUint(zz[0], base) } - vv := bigIntPool.Get().(*big.Int) - r := zz.ToBigInt(vv).Text(base) - bigIntPool.Put(vv) + vv := field.BigIntPool.Get() + r := zz.toBigInt(vv).Text(base) + field.BigIntPool.Put(vv) return r {{- end}} } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *{{.ElementName}}) ToBigInt(res *big.Int) *big.Int { - var b [Limbs*8]byte - {{- range $i := reverse .NbWordsIndexesFull}} - {{- $j := mul $i 8}} - {{- $k := sub $.NbWords 1}} - {{- $k := sub $k $i}} - {{- $jj := add $j 8}} - binary.BigEndian.PutUint64(b[{{$j}}:{{$jj}}], z[{{$k}}]) - {{- end}} - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *{{.ElementName}}) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z {{.ElementName}}) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *{{.ElementName}}) Bytes() (res [Limbs*8]byte) { - _z := z.ToRegular() - {{- range $i := reverse .NbWordsIndexesFull}} - {{- $j := mul $i 8}} - {{- $k := sub $.NbWords 1}} - {{- $k := sub $k $i}} - {{- $jj := add $j 8}} - binary.BigEndian.PutUint64(res[{{$j}}:{{$jj}}], _z[{{$k}}]) - {{- end}} +// Bits provides access to z by returning its value as a little-endian [{{.NbWords}}]uint64 array. +// Bits is intended to support implementation of missing low-level {{.ElementName}} +// functionality outside this package; it should be avoided otherwise. +func (z *{{.ElementName}}) Bits() [{{.NbWords}}]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *{{.ElementName}}) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -118,24 +122,42 @@ func (z *{{.ElementName}}) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *{{.ElementName}}) SetBytes(e []byte) *{{.ElementName}} { - {{- if eq .NbWords 1}} - if len(e) == 8 { + if len(e) == Bytes { // fast path - z[0] = binary.BigEndian.Uint64(e) - return z.ToMont() + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } } - {{- end}} + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) - return z + return z +} + +// SetBytesCanonical interprets e as the bytes of a big-endian {{.NbBytes}}-byte integer. +// If e is not a {{.NbBytes}}-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *{{.ElementName}}) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid {{.PackageName}}.{{.ElementName}} encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil } @@ -156,17 +178,16 @@ func (z *{{.ElementName}}) SetBigInt(v *big.Int) *{{.ElementName}} { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -188,7 +209,7 @@ func (z *{{.ElementName}}) setBigInt(v *big.Int) *{{.ElementName}} { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z @@ -210,7 +231,7 @@ func (z *{{.ElementName}}) setBigInt(v *big.Int) *{{.ElementName}} { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *{{.ElementName}}) SetString(number string) (*{{.ElementName}}, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("{{.ElementName}}.SetString failed -> can't parse number into a big.Int " + number) @@ -219,7 +240,7 @@ func (z *{{.ElementName}}) SetString(number string) (*{{.ElementName}}, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -260,7 +281,7 @@ func (z *{{.ElementName}}) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -269,9 +290,94 @@ func (z *{{.ElementName}}) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a {{.ElementName}} +type ByteOrder interface { + Element(*[Bytes]byte) ({{.ElementName}}, error) + PutElement(*[Bytes]byte, {{.ElementName}}) + String() string +} + + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian {{.NbBytes}}-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) ({{.ElementName}}, error) { + var z {{.ElementName}} + {{- range $i := reverse .NbWordsIndexesFull}} + {{- $j := mul $i 8}} + {{- $k := sub $.NbWords 1}} + {{- $k := sub $k $i}} + {{- $jj := add $j 8}} + z[{{$k}}] = binary.BigEndian.Uint64((*b)[{{$j}}:{{$jj}}]) + {{- end}} + + if !z.smallerThanModulus() { + return {{.ElementName}}{}, errors.New("invalid {{.PackageName}}.{{.ElementName}} encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e {{.ElementName}}) { + e.fromMont() + + {{- range $i := reverse .NbWordsIndexesFull}} + {{- $j := mul $i 8}} + {{- $k := sub $.NbWords 1}} + {{- $k := sub $k $i}} + {{- $jj := add $j 8}} + binary.BigEndian.PutUint64((*b)[{{$j}}:{{$jj}}], e[{{$k}}]) + {{- end}} +} + +func (bigEndian) String() string { return "BigEndian" } + + + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) ({{.ElementName}}, error) { + var z {{.ElementName}} + {{- range $i := .NbWordsIndexesFull}} + {{- $j := mul $i 8}} + {{- $jj := add $j 8}} + z[{{$i}}] = binary.LittleEndian.Uint64((*b)[{{$j}}:{{$jj}}]) + {{- end}} + + if !z.smallerThanModulus() { + return {{.ElementName}}{}, errors.New("invalid {{.PackageName}}.{{.ElementName}} encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e {{.ElementName}}) { + e.fromMont() + + {{- range $i := .NbWordsIndexesFull}} + {{- $j := mul $i 8}} + {{- $jj := add $j 8}} + binary.LittleEndian.PutUint64((*b)[{{$j}}:{{$jj}}], e[{{$i}}]) + {{- end}} +} + +func (littleEndian) String() string { return "LittleEndian" } + + + + ` diff --git a/internal/field/internal/templates/element/doc.go b/field/generator/internal/templates/element/doc.go similarity index 100% rename from internal/field/internal/templates/element/doc.go rename to field/generator/internal/templates/element/doc.go diff --git a/internal/field/internal/templates/element/exp.go b/field/generator/internal/templates/element/exp.go similarity index 89% rename from internal/field/internal/templates/element/exp.go rename to field/generator/internal/templates/element/exp.go index f18286f88..1dc63f1e0 100644 --- a/internal/field/internal/templates/element/exp.go +++ b/field/generator/internal/templates/element/exp.go @@ -15,8 +15,8 @@ func (z *{{.ElementName}}) Exp(x {{.ElementName}}, k *big.Int) *{{.ElementName}} // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } diff --git a/internal/field/internal/templates/element/fixed_exp.go b/field/generator/internal/templates/element/fixed_exp.go similarity index 100% rename from internal/field/internal/templates/element/fixed_exp.go rename to field/generator/internal/templates/element/fixed_exp.go diff --git a/internal/field/internal/templates/element/inverse.go b/field/generator/internal/templates/element/inverse.go similarity index 96% rename from internal/field/internal/templates/element/inverse.go rename to field/generator/internal/templates/element/inverse.go index 9d4dc648c..6adaedc55 100644 --- a/internal/field/internal/templates/element/inverse.go +++ b/field/generator/internal/templates/element/inverse.go @@ -34,7 +34,7 @@ func (z *{{.ElementName}}) Inverse( x *{{.ElementName}}) *{{.ElementName}} { var r,s,u,v uint64 u = q - s = {{index .RSquare 0}} // s = r^2 + s = {{index .RSquare 0}} // s = r² r = 0 v = x[0] @@ -94,7 +94,7 @@ func (z *{{.ElementName}}) Inverse( x *{{.ElementName}}) *{{.ElementName}} { // note: allocates a big.Int (math/big) func (z *{{.ElementName}}) Inverse( x *{{.ElementName}}) *{{.ElementName}} { var _xNonMont big.Int - x.ToBigIntRegular(&_xNonMont) + x.BigInt(&_xNonMont) _xNonMont.ModInverse(&_xNonMont, Modulus()) z.SetBigInt(&_xNonMont) return z @@ -258,7 +258,7 @@ func (z *{{.ElementName}}) Inverse(x *{{.ElementName}}) *{{.ElementName}} { // we would multiply by pSq up to 13times; // on x86, the assembly routine outperforms generic code for mul by word // on arm64, we may loose up to ~5% for 6 limbs - mul(&v, &v, &a) + v.Mul(&v, &a) } u.Set(x) // for correctness check @@ -272,17 +272,28 @@ func (z *{{.ElementName}}) Inverse(x *{{.ElementName}}) *{{.ElementName}} { // correctness check v.Mul(&u, z) if !v.IsOne() && !u.IsZero() { - return z.inverseExp(&u) + return z.inverseExp(u) } return z } // inverseExp computes z = x⁻¹ (mod q) = x**(q-2) (mod q) -func (z *{{.ElementName}}) inverseExp(x *{{.ElementName}}) *{{.ElementName}} { - qMinusTwo := Modulus() - qMinusTwo.Sub(qMinusTwo, big.NewInt(2)) - return z.Exp(*x, qMinusTwo) +func (z *{{.ElementName}}) inverseExp(x {{.ElementName}}) *{{.ElementName}} { + // e == q-2 + e := Modulus() + e.Sub(e, big.NewInt(2)) + + z.Set(&x) + + for i := e.BitLen() - 2; i >= 0; i-- { + z.Square(z) + if e.Bit(i) == 1 { + z.Mul(z, &x) + } + } + + return z } // approximate a big number x into a single 64 bit word using its uppermost and lowermost bits diff --git a/internal/field/internal/templates/element/inverse_tests.go b/field/generator/internal/templates/element/inverse_tests.go similarity index 98% rename from internal/field/internal/templates/element/inverse_tests.go rename to field/generator/internal/templates/element/inverse_tests.go index 7effe660c..d85d1ead8 100644 --- a/internal/field/internal/templates/element/inverse_tests.go +++ b/field/generator/internal/templates/element/inverse_tests.go @@ -37,7 +37,7 @@ func Test{{.ElementName}}InversionCorrectionFactorFormula(t *testing.T) { inversionCorrectionFactorWord{{$i}}, {{- end}} } - inversionCorrectionFactor.ToBigInt(&refFactorInt) + inversionCorrectionFactor.toBigInt(&refFactorInt) if refFactorInt.Cmp(factorInt) != 0 { t.Error("mismatch") @@ -79,7 +79,7 @@ func Test{{.ElementName}}InversionCorrectionFactor(t *testing.T) { if !oneInv.Equal(&one) { var i big.Int - oneInv.ToBigIntRegular(&i) // no montgomery + oneInv.BigInt(&i) // no montgomery i.ModInverse(&i, Modulus()) var fac {{.ElementName}} fac.setBigInt(&i) // back to montgomery @@ -439,11 +439,11 @@ func randomizeUpdateFactors() (int64, int64) { func testLinearComb(t *testing.T, x *{{.ElementName}}, xC int64, y *{{.ElementName}}, yC int64) { var p1 big.Int - x.ToBigInt(&p1) + x.toBigInt(&p1) p1.Mul(&p1, big.NewInt(xC)) var p2 big.Int - y.ToBigInt(&p2) + y.toBigInt(&p2) p2.Mul(&p2, big.NewInt(yC)) p1.Add(&p1, &p2) @@ -479,7 +479,7 @@ func montReduce(res *big.Int, x *big.Int) { } func (z *{{.ElementName}}) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { - z.ToBigInt(i) + z.toBigInt(i) var upperWord big.Int upperWord.SetUint64(xHi) upperWord.Lsh(&upperWord, Limbs*64) @@ -497,7 +497,7 @@ func (z *{{.ElementName}}) toVeryBigIntSigned(i *big.Int, xHi uint64) { func assertMulProduct(t *testing.T, x *{{.ElementName}}, c int64, result *{{.ElementName}}, resultHi uint64) big.Int { var xInt big.Int - x.ToBigInt(&xInt) + x.toBigInt(&xInt) xInt.Mul(&xInt, big.NewInt(c)) @@ -508,7 +508,7 @@ func assertMulProduct(t *testing.T, x *{{.ElementName}}, c int64, result *{{.Ele func approximateRef(x *{{.ElementName}}) uint64 { var asInt big.Int - x.ToBigInt(&asInt) + x.toBigInt(&asInt) n := x.BitLen() if n <= 64 { diff --git a/field/generator/internal/templates/element/mul_cios.go b/field/generator/internal/templates/element/mul_cios.go new file mode 100644 index 000000000..b7bccb432 --- /dev/null +++ b/field/generator/internal/templates/element/mul_cios.go @@ -0,0 +1,174 @@ +package element + +// MulCIOS text book CIOS works for all modulus. +// +// There are couple of variations to the multiplication (and squaring) algorithms. +// +// All versions are derived from the Montgomery CIOS algorithm: see +// section 2.3.2 of Tolga Acar's thesis +// https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf +// +// For 1-word modulus, the generator will call mul_cios_one_limb (standard REDC) +// +// For 13-word+ modulus, the generator will output a unoptimized textbook CIOS code, in plain Go. +// +// For all other moduli, we look at the available bits in the last limb. +// If they are none (like secp256k1) we generate a unoptimized textbook CIOS code, in plain Go, for all architectures. +// If there is at least one we can ommit a carry propagation in the CIOS algorithm. +// If there is at least two we can use the same technique for the CIOS Squaring. +// See appendix in https://eprint.iacr.org/2022/1400.pdf for the exact condition. +// +// In practice, we have 3 differents targets in mind: x86(amd64), arm64 and wasm. +// +// For amd64, we can leverage (when available) the BMI2 and ADX instructions to have 2-carry-chains in parallel. +// This make the use of assembly worth it as it results in a significant perf improvment; most CPUs since 2016 support these +// instructions, and we assume it to be the "default path"; in case the CPU has no support, we fall back to a slow, unoptimized version. +// +// On amd64, the Squaring algorithm always call the Multiplication (assembly) implementation. +// +// For arm64, we unroll the loops in the CIOS (+nocarry optimization) algorithm, such that the instructions generated +// by the Go compiler closely match what we would hand-write. Hence, there is no assembly needed for arm64 target. +// +// Additionally, if 2-bits+ are available on the last limb, we have a template to generate a dedicated Squaring algorithm +// This is not activated by default, to minimize the codebase size. +// On M1, AWS Graviton3 it results in a 5-10% speedup. On some mobile devices, speed up observed was more important (~20%). +// +// The same (arm64) unrolled Go code produce satisfying perfomrance for WASM (compiled using TinyGo). +const MulCIOS = ` +{{ define "mul_cios" }} + var t [{{add .all.NbWords 1}}]uint64 + var D uint64 + var m, C uint64 + + {{- range $j := .all.NbWordsIndexesFull}} + // ----------------------------------- + // First loop + {{ if eq $j 0}} + C, t[0] = bits.Mul64({{$.V2}}[{{$j}}], {{$.V1}}[0]) + {{- range $i := $.all.NbWordsIndexesNoZero}} + C, t[{{$i}}] = madd1({{$.V2}}[{{$j}}], {{$.V1}}[{{$i}}], C) + {{- end}} + {{ else }} + C, t[0] = madd1({{$.V2}}[{{$j}}], {{$.V1}}[0], t[0]) + {{- range $i := $.all.NbWordsIndexesNoZero}} + C, t[{{$i}}] = madd2({{$.V2}}[{{$j}}], {{$.V1}}[{{$i}}], t[{{$i}}], C) + {{- end}} + {{ end }} + t[{{$.all.NbWords}}], D = bits.Add64(t[{{$.all.NbWords}}], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + {{- range $i := $.all.NbWordsIndexesNoZero}} + C, t[{{sub $i 1}}] = madd2(m, q{{$i}}, t[{{$i}}], C) + {{- end}} + + t[{{sub $.all.NbWords 1}}], C = bits.Add64(t[{{$.all.NbWords}}], C, 0) + t[{{$.all.NbWords}}], _ = bits.Add64(0, D, C) + {{- end}} + + + if t[{{$.all.NbWords}}] != 0 { + // we need to reduce, we have a result on {{add 1 $.all.NbWords}} words + {{- if gt $.all.NbWords 1}} + var b uint64 + {{- end}} + z[0], {{- if gt $.all.NbWords 1}}b{{- else}}_{{- end}} = bits.Sub64(t[0], q0, 0) + {{- range $i := .all.NbWordsIndexesNoZero}} + {{- if eq $i $.all.NbWordsLastIndex}} + z[{{$i}}], _ = bits.Sub64(t[{{$i}}], q{{$i}}, b) + {{- else }} + z[{{$i}}], b = bits.Sub64(t[{{$i}}], q{{$i}}, b) + {{- end}} + {{- end}} + return {{if $.ReturnZ }} z{{- end}} + } + + // copy t into z + {{- range $i := $.all.NbWordsIndexesFull}} + z[{{$i}}] = t[{{$i}}] + {{- end}} + +{{ end }} + +{{ define "mul_cios_one_limb" }} + // In fact, since the modulus R fits on one register, the CIOS algorithm gets reduced to standard REDC (textbook Montgomery reduction): + // hi, lo := x * y + // m := (lo * qInvNeg) mod R + // (*) r := (hi * R + lo + m * q) / R + // reduce r if necessary + + // On the emphasized line, we get r = hi + (lo + m * q) / R + // If we write hi2, lo2 = m * q then R | m * q - lo2 ⇒ R | (lo * qInvNeg) q - lo2 = -lo - lo2 + // This shows lo + lo2 = 0 mod R. i.e. lo + lo2 = 0 if lo = 0 and R otherwise. + // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) + // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. + + var r uint64 + hi, lo := bits.Mul64({{$.V1}}[0], {{$.V2}}[0]) + if lo != 0 { + hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow + } + m := lo * qInvNeg + hi2, _ := bits.Mul64(m, q) + r, carry := bits.Add64(hi2, hi, 0) + + if carry != 0 || r >= q { + // we need to reduce + r -= q + } + z[0] = r +{{ end }} +` + +const MulDoc = ` +{{define "mul_doc noCarry"}} +// Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis +// https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf +// +// The algorithm: +// +// for i=0 to N-1 +// C := 0 +// for j=0 to N-1 +// (C,t[j]) := t[j] + x[j]*y[i] + C +// (t[N+1],t[N]) := t[N] + C +// +// C := 0 +// m := t[0]*q'[0] mod D +// (C,_) := t[0] + m*q[0] +// for j=1 to N-1 +// (C,t[j-1]) := t[j] + m*q[j] + C +// +// (C,t[N-1]) := t[N] + C +// t[N] := t[N+1] + C +// +// → N is the number of machine words needed to store the modulus q +// → D is the word size. For example, on a 64-bit architecture D is 2 64 +// → x[i], y[i], q[i] is the ith word of the numbers x,y,q +// → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. +// → t is a temporary array of size N+2 +// → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number +{{- if .noCarry}} +// +// As described here https://hackmd.io/@gnark/modular_multiplication we can get rid of one carry chain and simplify: +// (also described in https://eprint.iacr.org/2022/1400.pdf annex) +// +// for i=0 to N-1 +// (A,t[0]) := t[0] + x[0]*y[i] +// m := t[0]*q'[0] mod W +// C,_ := t[0] + m*q[0] +// for j=1 to N-1 +// (A,t[j]) := t[j] + x[j]*y[i] + A +// (C,t[j-1]) := t[j] + m*q[j] + C +// +// t[N-1] = C + A +// +// This optimization saves 5N + 2 additions in the algorithm, and can be used whenever the highest bit +// of the modulus is zero (and not all of the remaining bits are set). +{{- end}} +{{ end }} +` diff --git a/field/generator/internal/templates/element/mul_nocarry.go b/field/generator/internal/templates/element/mul_nocarry.go new file mode 100644 index 000000000..0ec89f7a8 --- /dev/null +++ b/field/generator/internal/templates/element/mul_nocarry.go @@ -0,0 +1,224 @@ +package element + +// MulNoCarry see https://eprint.iacr.org/2022/1400.pdf annex for more info on the algorithm +// Note that these templates are optimized for arm64 target, since x86 benefits from assembly impl. +const MulNoCarry = ` +{{ define "mul_nocarry" }} +var {{range $i := .all.NbWordsIndexesFull}}t{{$i}}{{- if ne $i $.all.NbWordsLastIndex}},{{- end}}{{- end}} uint64 +var {{range $i := .all.NbWordsIndexesFull}}u{{$i}}{{- if ne $i $.all.NbWordsLastIndex}},{{- end}}{{- end}} uint64 +{{- range $i := .all.NbWordsIndexesFull}} +{ + var c0, c1, c2 uint64 + v := {{$.V1}}[{{$i}}] + {{- if eq $i 0}} + {{- range $j := $.all.NbWordsIndexesFull}} + u{{$j}}, t{{$j}} = bits.Mul64(v, {{$.V2}}[{{$j}}]) + {{- end}} + {{- else}} + {{- range $j := $.all.NbWordsIndexesFull}} + u{{$j}}, c1 = bits.Mul64(v, {{$.V2}}[{{$j}}]) + {{- if eq $j 0}} + t{{$j}}, c0 = bits.Add64(c1, t{{$j}}, 0) + {{- else }} + t{{$j}}, c0 = bits.Add64(c1, t{{$j}}, c0) + {{- end}} + {{- if eq $j $.all.NbWordsLastIndex}} + {{/* yes, we're tempted to write c2 = c0, but that slow the whole MUL by 20% */}} + c2, _ = bits.Add64(0, 0, c0) + {{- end}} + {{- end}} + {{- end}} + + {{- range $j := $.all.NbWordsIndexesFull}} + {{- if eq $j 0}} + t{{add $j 1}}, c0 = bits.Add64(u{{$j}}, t{{add $j 1}}, 0) + {{- else if eq $j $.all.NbWordsLastIndex}} + {{- if eq $i 0}} + c2, _ = bits.Add64(u{{$j}}, 0, c0) + {{- else}} + c2, _ = bits.Add64(u{{$j}},c2, c0) + {{- end}} + {{- else }} + t{{add $j 1}}, c0 = bits.Add64(u{{$j}}, t{{add $j 1}}, c0) + {{- end}} + {{- end}} + + {{- $k := $.all.NbWordsLastIndex}} + + m := qInvNeg * t0 + + u0, c1 = bits.Mul64(m, q0) + {{- range $j := $.all.NbWordsIndexesFull}} + {{- if ne $j 0}} + {{- if eq $j 1}} + _, c0 = bits.Add64(t0, c1, 0) + {{- else}} + t{{sub $j 2}}, c0 = bits.Add64(t{{sub $j 1}}, c1, c0) + {{- end}} + u{{$j}}, c1 = bits.Mul64(m, q{{$j}}) + {{- end}} + {{- end}} + {{/* TODO @gbotrel it seems this can create a carry (c0) -- study the bounds */}} + t{{sub $.all.NbWordsLastIndex 1}}, c0 = bits.Add64(0, c1, c0) + u{{$k}}, _ = bits.Add64(u{{$k}}, 0, c0) + + {{- range $j := $.all.NbWordsIndexesFull}} + {{- if eq $j 0}} + t{{$j}}, c0 = bits.Add64(u{{$j}}, t{{$j}}, 0) + {{- else if eq $j $.all.NbWordsLastIndex}} + c2, _ = bits.Add64(c2, 0, c0) + {{- else}} + t{{$j}}, c0 = bits.Add64(u{{$j}}, t{{$j}}, c0) + {{- end}} + {{- end}} + + {{- $l := sub $.all.NbWordsLastIndex 1}} + t{{$l}}, c0 = bits.Add64(t{{$k}}, t{{$l}}, 0) + t{{$k}}, _ = bits.Add64(u{{$k}}, c2, c0) + +} +{{- end}} + + +{{- range $i := $.all.NbWordsIndexesFull}} +z[{{$i}}] = t{{$i}} +{{- end}} + +{{ end }} + + + +{{ define "square_nocarry" }} +var {{range $i := .all.NbWordsIndexesFull}}t{{$i}}{{- if ne $i $.all.NbWordsLastIndex}},{{- end}}{{- end}} uint64 +var {{range $i := $.all.NbWordsIndexesFull}}u{{$i}}{{- if ne $i $.all.NbWordsLastIndex}},{{- end}}{{- end}} uint64 +var {{range $i := interval 0 (add $.all.NbWordsLastIndex 1)}}lo{{$i}}{{- if ne $i $.all.NbWordsLastIndex}},{{- end}}{{- end}} uint64 + +// note that if hi, _ = bits.Mul64() didn't generate +// UMULH and MUL, (but just UMULH) we could use same pattern +// as in mulRaw and reduce the stack space of this function (no need for lo..) + +{{- range $i := .all.NbWordsIndexesFull}} +{ + + {{$jStart := add $i 1}} + {{$jEnd := add $.all.NbWordsLastIndex 1}} + + var c0, c2 uint64 + + + // for j=i+1 to N-1 + // p,C,t[j] = 2*a[j]*a[i] + t[j] + (p,C) + // A = C + + {{- if eq $i 0}} + u{{$i}}, lo1 = bits.Mul64(x[{{$i}}], x[{{$i}}]) + {{- range $j := interval $jStart $jEnd}} + u{{$j}}, t{{$j}} = bits.Mul64(x[{{$j}}], x[{{$i}}]) + {{- end}} + + // propagate lo, from t[j] to end, twice. + {{- range $j := interval $jStart $jEnd}} + {{- if eq $j $jStart}} + t{{$j}}, c0 = bits.Add64(t{{$j}}, t{{$j}}, 0) + {{- else }} + t{{$j}}, c0 = bits.Add64(t{{$j}}, t{{$j}}, c0) + {{- end}} + {{- if eq $j $.all.NbWordsLastIndex}} + c2, _ = bits.Add64(c2, 0, c0) + {{- end}} + {{- end}} + + t{{$i}}, c0 = bits.Add64( lo1,t{{$i}}, 0) + {{- else}} + {{- range $j := interval (sub $jStart 1) $jEnd}} + u{{$j}}, lo{{$j}} = bits.Mul64(x[{{$j}}], x[{{$i}}]) + {{- end}} + + // propagate lo, from t[j] to end, twice. + {{- range $j := interval $jStart $jEnd}} + {{- if eq $j $jStart}} + lo{{$j}}, c0 = bits.Add64(lo{{$j}}, lo{{$j}}, 0) + {{- else }} + lo{{$j}}, c0 = bits.Add64(lo{{$j}}, lo{{$j}}, c0) + {{- end}} + {{- if eq $j $.all.NbWordsLastIndex}} + c2, _ = bits.Add64(c2, 0, c0) + {{- end}} + {{- end}} + {{- range $j := interval $jStart $jEnd}} + {{- if eq $j $jStart}} + t{{$j}}, c0 = bits.Add64(lo{{$j}}, t{{$j}}, 0) + {{- else }} + t{{$j}}, c0 = bits.Add64(lo{{$j}}, t{{$j}}, c0) + {{- end}} + {{- if eq $j $.all.NbWordsLastIndex}} + c2, _ = bits.Add64(c2, 0, c0) + {{- end}} + {{- end}} + + t{{$i}}, c0 = bits.Add64( lo{{$i}},t{{$i}}, 0) + {{- end}} + + + // propagate u{{$i}} + hi + {{- range $j := interval $jStart $jEnd}} + t{{$j}}, c0 = bits.Add64(u{{sub $j 1}}, t{{$j}}, c0) + {{- end}} + c2, _ = bits.Add64(u{{$.all.NbWordsLastIndex}}, c2, c0) + + // hi again + {{- range $j := interval $jStart $jEnd}} + {{- if eq $j $.all.NbWordsLastIndex}} + c2, _ = bits.Add64(c2, u{{$j}}, {{- if eq $j $jStart}} 0 {{- else}}c0{{- end}}) + {{- else if eq $j $jStart}} + t{{add $j 1}}, c0 = bits.Add64(u{{$j}}, t{{add $j 1}}, 0) + {{- else }} + t{{add $j 1}}, c0 = bits.Add64(u{{$j}}, t{{add $j 1}}, c0) + {{- end}} + {{- end}} + + {{- $k := $.all.NbWordsLastIndex}} + + // this part is unchanged. + m := qInvNeg * t0 + {{- range $j := $.all.NbWordsIndexesFull}} + u{{$j}}, lo{{$j}} = bits.Mul64(m, q{{$j}}) + {{- end}} + {{- range $j := $.all.NbWordsIndexesFull}} + {{- if ne $j 0}} + {{- if eq $j 1}} + _, c0 = bits.Add64(t0, lo{{sub $j 1}}, 0) + {{- else}} + t{{sub $j 2}}, c0 = bits.Add64(t{{sub $j 1}}, lo{{sub $j 1}}, c0) + {{- end}} + {{- end}} + {{- end}} + t{{sub $.all.NbWordsLastIndex 1}}, c0 = bits.Add64(0, lo{{$.all.NbWordsLastIndex}}, c0) + u{{$k}}, _ = bits.Add64(u{{$k}}, 0, c0) + + {{- range $j := $.all.NbWordsIndexesFull}} + {{- if eq $j 0}} + t{{$j}}, c0 = bits.Add64(u{{$j}}, t{{$j}}, 0) + {{- else if eq $j $.all.NbWordsLastIndex}} + c2, _ = bits.Add64(c2, 0, c0) + {{- else}} + t{{$j}}, c0 = bits.Add64(u{{$j}}, t{{$j}}, c0) + {{- end}} + {{- end}} + + {{- $k := sub $.all.NbWordsLastIndex 0}} + {{- $l := sub $.all.NbWordsLastIndex 1}} + t{{$l}}, c0 = bits.Add64(t{{$k}}, t{{$l}}, 0) + t{{$k}}, _ = bits.Add64(u{{$k}}, c2, c0) +} +{{- end}} + + +{{- range $i := $.all.NbWordsIndexesFull}} +z[{{$i}}] = t{{$i}} +{{- end}} + +{{ end }} + + +` diff --git a/internal/field/internal/templates/element/ops.go b/field/generator/internal/templates/element/ops_asm.go similarity index 60% rename from internal/field/internal/templates/element/ops.go rename to field/generator/internal/templates/element/ops_asm.go index 1eb1db15d..8f5c01510 100644 --- a/internal/field/internal/templates/element/ops.go +++ b/field/generator/internal/templates/element/ops_asm.go @@ -29,6 +29,26 @@ func reduce(res *{{.ElementName}}) //go:noescape func Butterfly(a, b *{{.ElementName}}) + + +// Mul z = x * y (mod q) +// +// x and y must be less than q +func (z *{{.ElementName}}) Mul(x, y *{{.ElementName}}) *{{.ElementName}} { + {{ mul_doc $.NoCarry }} + mul(z, x, y) + return z +} + +// Square z = x * x (mod q) +// +// x must be less than q +func (z *{{.ElementName}}) Square(x *{{.ElementName}}) *{{.ElementName}} { + // see Mul for doc. + mul(z, x, x) + return z +} + {{end}} diff --git a/field/generator/internal/templates/element/ops_purego.go b/field/generator/internal/templates/element/ops_purego.go new file mode 100644 index 000000000..9447465c1 --- /dev/null +++ b/field/generator/internal/templates/element/ops_purego.go @@ -0,0 +1,95 @@ +package element + +const OpsNoAsm = ` + +import "math/bits" + +{{ $mulConsts := list 3 5 13 }} +{{- range $i := $mulConsts }} + +// MulBy{{$i}} x *= {{$i}} (mod q) +func MulBy{{$i}}(x *{{$.ElementName}}) { + {{- if eq 1 $.NbWords}} + var y {{$.ElementName}} + y.SetUint64({{$i}}) + x.Mul(x, &y) + {{- else}} + {{- if eq $i 3}} + _x := *x + x.Double(x).Add(x, &_x) + {{- else if eq $i 5}} + _x := *x + x.Double(x).Double(x).Add(x, &_x) + {{- else if eq $i 13}} + var y = {{$.ElementName}}{ + {{- range $i := $.Thirteen}} + {{$i}},{{end}} + } + x.Mul(x, &y) + {{- else }} + NOT IMPLEMENTED + {{- end}} + {{- end}} +} + +{{- end}} + +// Butterfly sets +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *{{.ElementName}}) { + _butterflyGeneric(a, b) +} + + +func fromMont(z *{{.ElementName}} ) { + _fromMontGeneric(z) +} + +func reduce(z *{{.ElementName}}) { + _reduceGeneric(z) +} + + + +// Mul z = x * y (mod q) +{{- if $.NoCarry}} +// +// x and y must be less than q +{{- end }} +func (z *{{.ElementName}}) Mul(x, y *{{.ElementName}}) *{{.ElementName}} { + {{- if eq $.NbWords 1}} + {{ template "mul_cios_one_limb" dict "all" . "V1" "x" "V2" "y" }} + {{- else }} + {{ mul_doc $.NoCarry }} + {{- if $.NoCarry}} + {{ template "mul_nocarry" dict "all" . "V1" "x" "V2" "y"}} + {{- else}} + {{ template "mul_cios" dict "all" . "V1" "x" "V2" "y" "ReturnZ" true}} + {{- end}} + {{ template "reduce" . }} + {{- end }} + return z +} + +// Square z = x * x (mod q) +{{- if $.NoCarry}} +// +// x must be less than q +{{- end }} +func (z *{{.ElementName}}) Square(x *{{.ElementName}}) *{{.ElementName}} { + // see Mul for algorithm documentation + {{- if eq $.NbWords 1}} + {{ template "mul_cios_one_limb" dict "all" . "V1" "x" "V2" "x" }} + {{- else }} + {{- if $.NoCarry}} + {{ template "mul_nocarry" dict "all" . "V1" "x" "V2" "x"}} + {{- else}} + {{ template "mul_cios" dict "all" . "V1" "x" "V2" "x" "ReturnZ" true}} + {{- end}} + {{ template "reduce" . }} + {{- end }} + return z +} + +` diff --git a/internal/field/internal/templates/element/reduce.go b/field/generator/internal/templates/element/reduce.go similarity index 94% rename from internal/field/internal/templates/element/reduce.go rename to field/generator/internal/templates/element/reduce.go index 3faba6276..ae5d069d2 100644 --- a/internal/field/internal/templates/element/reduce.go +++ b/field/generator/internal/templates/element/reduce.go @@ -2,7 +2,7 @@ package element const Reduce = ` {{ define "reduce" }} -// if z >= q → z -= q +// if z ⩾ q → z -= q if !z.smallerThanModulus() { {{- if eq $.NbWords 1}} z[0] -= q diff --git a/internal/field/internal/templates/element/sqrt.go b/field/generator/internal/templates/element/sqrt.go similarity index 97% rename from internal/field/internal/templates/element/sqrt.go rename to field/generator/internal/templates/element/sqrt.go index 2a58ae92f..2745ca521 100644 --- a/internal/field/internal/templates/element/sqrt.go +++ b/field/generator/internal/templates/element/sqrt.go @@ -102,7 +102,7 @@ func (z *{{.ElementName}}) Sqrt(x *{{.ElementName}}) *{{.ElementName}} { // y = x^((s+1)/2)) = w * x y.Mul(x, &w) - // b = x^s = w * w * x = y * x + // b = xˢ = w * w * x = y * x b.Mul(&w, &y) // g = nonResidue ^ s @@ -113,7 +113,7 @@ func (z *{{.ElementName}}) Sqrt(x *{{.ElementName}}) *{{.ElementName}} { r := uint64({{.SqrtE}}) // compute legendre symbol - // t = x^((q-1)/2) = r-1 squaring of x^s + // t = x^((q-1)/2) = r-1 squaring of xˢ t = b for i:=uint64(0); i < r-1; i++ { t.Square(&t) diff --git a/internal/field/internal/templates/element/tests.go b/field/generator/internal/templates/element/tests.go similarity index 97% rename from internal/field/internal/templates/element/tests.go rename to field/generator/internal/templates/element/tests.go index 8eba7760d..5f06ed497 100644 --- a/internal/field/internal/templates/element/tests.go +++ b/field/generator/internal/templates/element/tests.go @@ -10,7 +10,7 @@ import ( "math/bits" "fmt" {{if .UsingP20Inverse}} - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field" mrand "math/rand" {{end}} "testing" @@ -174,17 +174,10 @@ func Benchmark{{toTitle .ElementName}}FromMont(b *testing.B) { benchRes{{.ElementName}}.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchRes{{.ElementName}}.FromMont() + benchRes{{.ElementName}}.fromMont() } } -func Benchmark{{toTitle .ElementName}}ToMont(b *testing.B) { - benchRes{{.ElementName}}.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchRes{{.ElementName}}.ToMont() - } -} func Benchmark{{toTitle .ElementName}}Square(b *testing.B) { benchRes{{.ElementName}}.SetRandom() b.ResetTimer() @@ -654,7 +647,7 @@ func Test{{toTitle .ElementName}}BitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPair{{.ElementName}}) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -804,7 +797,7 @@ func Test{{toTitle .all.ElementName}}{{.Op}}(t *testing.T) { {{- end }} - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -817,7 +810,7 @@ func Test{{toTitle .all.ElementName}}{{.Op}}(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c {{.all.ElementName}} {{- if eq .Op "Div"}} @@ -842,7 +835,7 @@ func Test{{toTitle .all.ElementName}}{{.Op}}(t *testing.T) { } {{end}} - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -888,11 +881,11 @@ func Test{{toTitle .all.ElementName}}{{.Op}}(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c {{.all.ElementName}} @@ -920,7 +913,7 @@ func Test{{toTitle .all.ElementName}}{{.Op}}(t *testing.T) { {{end}} - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("{{.Op}} failed special test values") } } @@ -984,7 +977,7 @@ func Test{{toTitle .all.ElementName}}{{.Op}}(t *testing.T) { {{- end }} - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1019,7 +1012,7 @@ func Test{{toTitle .all.ElementName}}{{.Op}}(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c {{.all.ElementName}} c.{{.Op}}(&a) @@ -1046,7 +1039,7 @@ func Test{{toTitle .all.ElementName}}{{.Op}}(t *testing.T) { {{end}} - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("{{.Op}} failed special test values") } } @@ -1355,7 +1348,7 @@ func Test{{toTitle .ElementName}}NegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -1492,17 +1485,17 @@ func Test{{toTitle .ElementName}}FromMont(t *testing.T) { func(a testPair{{.ElementName}}) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPair{{.ElementName}}) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -1599,7 +1592,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } diff --git a/field/goff/cmd/root.go b/field/goff/cmd/root.go index 2d5aedfbe..15ed94fc0 100644 --- a/field/goff/cmd/root.go +++ b/field/goff/cmd/root.go @@ -22,8 +22,8 @@ import ( "path/filepath" "strings" - "github.com/consensys/gnark-crypto/internal/field" - "github.com/consensys/gnark-crypto/internal/field/generator" + "github.com/consensys/gnark-crypto/field/generator" + field "github.com/consensys/gnark-crypto/field/generator/config" "github.com/spf13/cobra" ) diff --git a/field/goff/main.go b/field/goff/main.go index 54ff21a9c..92cec47e5 100644 --- a/field/goff/main.go +++ b/field/goff/main.go @@ -17,9 +17,10 @@ // Generated code is optimized for x86 (amd64) targets, and most methods do not allocate memory on the heap. // // Example usage: -// goff -m 0xffffffff00000001 -o ./goldilocks/ -p goldilocks -e Element // -// Warning +// goff -m 0xffffffff00000001 -o ./goldilocks/ -p goldilocks -e Element +// +// # Warning // // The generated code has not been audited for all moduli (only bn254 and bls12-381) and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package main diff --git a/field/goldilocks/arith.go b/field/goldilocks/arith.go index ec9b1faf5..99941ffdf 100644 --- a/field/goldilocks/arith.go +++ b/field/goldilocks/arith.go @@ -28,3 +28,33 @@ func madd0(a, b, c uint64) (hi uint64) { hi, _ = bits.Add64(hi, 0, carry) return } + +// madd1 hi, lo = a*b + c +func madd1(a, b, c uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +// madd2 hi, lo = a*b + c + d +func madd2(a, b, c, d uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, 0, carry) + return +} + +func madd3(a, b, c, d, e uint64) (hi uint64, lo uint64) { + var carry uint64 + hi, lo = bits.Mul64(a, b) + c, carry = bits.Add64(c, d, 0) + hi, _ = bits.Add64(hi, 0, carry) + lo, carry = bits.Add64(lo, c, 0) + hi, _ = bits.Add64(hi, e, carry) + return +} diff --git a/field/goldilocks/doc.go b/field/goldilocks/doc.go index 48548d3fe..076637815 100644 --- a/field/goldilocks/doc.go +++ b/field/goldilocks/doc.go @@ -21,30 +21,33 @@ // The modulus is hardcoded in all the operations. // // Field elements are represented as an array, and assumed to be in Montgomery form in all methods: -// type Element [1]uint64 // -// Usage +// type Element [1]uint64 +// +// # Usage // // Example API signature: -// // Mul z = x * y (mod q) -// func (z *Element) Mul(x, y *Element) *Element +// +// // Mul z = x * y (mod q) +// func (z *Element) Mul(x, y *Element) *Element // // and can be used like so: -// var a, b Element -// a.SetUint64(2) -// b.SetString("984896738") -// a.Mul(a, b) -// a.Sub(a, a) -// .Add(a, b) -// .Inv(a) -// b.Exp(b, new(big.Int).SetUint64(42)) +// +// var a, b Element +// a.SetUint64(2) +// b.SetString("984896738") +// a.Mul(a, b) +// a.Sub(a, a) +// .Add(a, b) +// .Inv(a) +// b.Exp(b, new(big.Int).SetUint64(42)) // // Modulus q = // -// q[base10] = 18446744069414584321 -// q[base16] = 0xffffffff00000001 +// q[base10] = 18446744069414584321 +// q[base16] = 0xffffffff00000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. package goldilocks diff --git a/field/goldilocks/element.go b/field/goldilocks/element.go index 5f8df411a..d6c244250 100644 --- a/field/goldilocks/element.go +++ b/field/goldilocks/element.go @@ -20,13 +20,13 @@ import ( "crypto/rand" "encoding/binary" "errors" + "github.com/consensys/gnark-crypto/field" "io" "math/big" "math/bits" "reflect" "strconv" "strings" - "sync" ) // Element represents a field element stored on 1 words (uint64) @@ -35,18 +35,18 @@ import ( // // Modulus q = // -// q[base10] = 18446744069414584321 -// q[base16] = 0xffffffff00000001 +// q[base10] = 18446744069414584321 +// q[base16] = 0xffffffff00000001 // -// Warning +// # Warning // // This code has not been audited and is provided as-is. In particular, there is no security guarantees such as constant time implementation or side-channel attack resistance. type Element [1]uint64 const ( - Limbs = 1 // number of 64 bits words needed to represent a Element - Bits = 64 // number of bits needed to represent a Element - Bytes = Limbs * 8 // number of bytes needed to represent a Element + Limbs = 1 // number of 64 bits words needed to represent a Element + Bits = 64 // number of bits needed to represent a Element + Bytes = 8 // number of bytes needed to represent a Element ) // Field modulus q @@ -63,8 +63,8 @@ var _modulus big.Int // q stored as big.Int // Modulus returns q as a big.Int // -// q[base10] = 18446744069414584321 -// q[base16] = 0xffffffff00000001 +// q[base10] = 18446744069414584321 +// q[base16] = 0xffffffff00000001 func Modulus() *big.Int { return new(big.Int).Set(&_modulus) } @@ -73,12 +73,6 @@ func Modulus() *big.Int { // used for Montgomery reduction const qInvNeg uint64 = 18446744069414584319 -var bigIntPool = sync.Pool{ - New: func() interface{} { - return new(big.Int) - }, -} - func init() { _modulus.SetString("ffffffff00000001", 16) } @@ -86,8 +80,9 @@ func init() { // NewElement returns a new Element from a uint64 value // // it is equivalent to -// var v Element -// v.SetUint64(...) +// +// var v Element +// v.SetUint64(...) func NewElement(v uint64) Element { z := Element{v} z.Mul(&z, &rSquare) @@ -98,7 +93,7 @@ func NewElement(v uint64) Element { func (z *Element) SetUint64(v uint64) *Element { // sets z LSB to v (non-Montgomery form) and convert z to Montgomery form *z = Element{v} - return z.Mul(z, &rSquare) // z.ToMont() + return z.Mul(z, &rSquare) // z.toMont() } // SetInt64 sets z to v and returns z @@ -125,14 +120,15 @@ func (z *Element) Set(x *Element) *Element { // SetInterface converts provided interface into Element // returns an error if provided type is not supported // supported types: -// Element -// *Element -// uint64 -// int -// string (see SetString for valid formats) -// *big.Int -// big.Int -// []byte +// +// Element +// *Element +// uint64 +// int +// string (see SetString for valid formats) +// *big.Int +// big.Int +// []byte func (z *Element) SetInterface(i1 interface{}) (*Element, error) { if i1 == nil { return nil, errors.New("can't set goldilocks.Element with ") @@ -240,9 +236,7 @@ func (z *Element) IsUint64() bool { // Uint64 returns the uint64 representation of x. If x cannot be represented in a uint64, the result is undefined. func (z *Element) Uint64() uint64 { - zz := *z - zz.FromMont() - return zz[0] + return z.Bits()[0] } // FitsOnOneWord reports whether z words (except the least significant word) are 0 @@ -254,15 +248,12 @@ func (z *Element) FitsOnOneWord() bool { // Cmp compares (lexicographic order) z and x and returns: // -// -1 if z < x -// 0 if z == x -// +1 if z > x -// +// -1 if z < x +// 0 if z == x +// +1 if z > x func (z *Element) Cmp(x *Element) int { - _z := *z - _x := *x - _z.FromMont() - _x.FromMont() + _z := z.Bits() + _x := x.Bits() if _z[0] > _x[0] { return 1 } else if _z[0] < _x[0] { @@ -278,8 +269,7 @@ func (z *Element) LexicographicallyLargest() bool { // we check if the element is larger than (q-1) / 2 // if z - (((q -1) / 2) + 1) have no underflow, then z > (q-1) / 2 - _z := *z - _z.FromMont() + _z := z.Bits() var b uint64 _, b = bits.Sub64(_z[0], 9223372034707292161, 0) @@ -357,81 +347,16 @@ func (z *Element) Halve() { z[0] >>= 1 if carry != 0 { - // when we added q, the result was larger than our avaible limbs + // when we added q, the result was larger than our available limbs // when we shift right, we need to set the highest bit z[0] |= (1 << 63) } } -// Mul z = x * y (mod q) -func (z *Element) Mul(x, y *Element) *Element { - // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis - // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf - // - // The algorithm: - // - // for i=0 to N-1 - // C := 0 - // for j=0 to N-1 - // (C,t[j]) := t[j] + x[j]*y[i] + C - // (t[N+1],t[N]) := t[N] + C - // - // C := 0 - // m := t[0]*q'[0] mod D - // (C,_) := t[0] + m*q[0] - // for j=1 to N-1 - // (C,t[j-1]) := t[j] + m*q[j] + C - // - // (C,t[N-1]) := t[N] + C - // t[N] := t[N+1] + C - // - // → N is the number of machine words needed to store the modulus q - // → D is the word size. For example, on a 64-bit architecture D is 2 64 - // → x[i], y[i], q[i] is the ith word of the numbers x,y,q - // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. - // → t is a temporary array of size N+2 - // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number - - var r uint64 - hi, lo := bits.Mul64(x[0], y[0]) - m := lo * qInvNeg - hi2, lo2 := bits.Mul64(m, q) - _, carry := bits.Add64(lo2, lo, 0) - r, carry = bits.Add64(hi2, hi, carry) - - if carry != 0 || r >= q { - // we need to reduce - r -= q - } - z[0] = r - - return z -} - -// Square z = x * x (mod q) -func (z *Element) Square(x *Element) *Element { - // see Mul for algorithm documentation - - var r uint64 - hi, lo := bits.Mul64(x[0], x[0]) - m := lo * qInvNeg - hi2, lo2 := bits.Mul64(m, q) - _, carry := bits.Add64(lo2, lo, 0) - r, carry = bits.Add64(hi2, hi, carry) - - if carry != 0 || r >= q { - // we need to reduce - r -= q - } - z[0] = r - - return z -} - -// FromMont converts z in place (i.e. mutates) from Montgomery to regular representation +// fromMont converts z in place (i.e. mutates) from Montgomery to regular representation // sets and returns z = z * 1 -func (z *Element) FromMont() *Element { +func (z *Element) fromMont() *Element { fromMont(z) return z } @@ -490,22 +415,71 @@ func (z *Element) Select(c int, x0 *Element, x1 *Element) *Element { return z } +// _mulGeneric is unoptimized textbook CIOS +// it is a fallback solution on x86 when ADX instruction set is not available +// and is used for testing purposes. func _mulGeneric(z, x, y *Element) { - // see Mul for algorithm documentation - var r uint64 - hi, lo := bits.Mul64(x[0], y[0]) - m := lo * qInvNeg - hi2, lo2 := bits.Mul64(m, q) - _, carry := bits.Add64(lo2, lo, 0) - r, carry = bits.Add64(hi2, hi, carry) + // Implements CIOS multiplication -- section 2.3.2 of Tolga Acar's thesis + // https://www.microsoft.com/en-us/research/wp-content/uploads/1998/06/97Acar.pdf + // + // The algorithm: + // + // for i=0 to N-1 + // C := 0 + // for j=0 to N-1 + // (C,t[j]) := t[j] + x[j]*y[i] + C + // (t[N+1],t[N]) := t[N] + C + // + // C := 0 + // m := t[0]*q'[0] mod D + // (C,_) := t[0] + m*q[0] + // for j=1 to N-1 + // (C,t[j-1]) := t[j] + m*q[j] + C + // + // (C,t[N-1]) := t[N] + C + // t[N] := t[N+1] + C + // + // → N is the number of machine words needed to store the modulus q + // → D is the word size. For example, on a 64-bit architecture D is 2 64 + // → x[i], y[i], q[i] is the ith word of the numbers x,y,q + // → q'[0] is the lowest word of the number -q⁻¹ mod r. This quantity is pre-computed, as it does not depend on the inputs. + // → t is a temporary array of size N+2 + // → C, S are machine words. A pair (C,S) refers to (hi-bits, lo-bits) of a two-word number + + var t [2]uint64 + var D uint64 + var m, C uint64 + // ----------------------------------- + // First loop + + C, t[0] = bits.Mul64(y[0], x[0]) + + t[1], D = bits.Add64(t[1], C, 0) + + // m = t[0]n'[0] mod W + m = t[0] * qInvNeg + + // ----------------------------------- + // Second loop + C = madd0(m, q0, t[0]) + + t[0], C = bits.Add64(t[1], C, 0) + t[1], _ = bits.Add64(0, D, C) - if carry != 0 || r >= q { - // we need to reduce - r -= q + if t[1] != 0 { + // we need to reduce, we have a result on 2 words + z[0], _ = bits.Sub64(t[0], q0, 0) + return } - z[0] = r + // copy t into z + z[0] = t[0] + + // if z ⩾ q → z -= q + if !z.smallerThanModulus() { + z[0] -= q + } } func _fromMontGeneric(z *Element) { @@ -519,7 +493,7 @@ func _fromMontGeneric(z *Element) { z[0] = C } - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { z[0] -= q } @@ -527,7 +501,7 @@ func _fromMontGeneric(z *Element) { func _reduceGeneric(z *Element) { - // if z >= q → z -= q + // if z ⩾ q → z -= q if !z.smallerThanModulus() { z[0] -= q } @@ -578,6 +552,35 @@ func (z *Element) BitLen() int { return bits.Len64(z[0]) } +// Hash msg to count prime field elements. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 +func Hash(msg, dst []byte, count int) ([]Element, error) { + // 128 bits of security + // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 + const Bytes = 1 + (Bits-1)/8 + const L = 16 + Bytes + + lenInBytes := count * L + pseudoRandomBytes, err := field.ExpandMsgXmd(msg, dst, lenInBytes) + if err != nil { + return nil, err + } + + // get temporary big int from the pool + vv := field.BigIntPool.Get() + + res := make([]Element, count) + for i := 0; i < count; i++ { + vv.SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) + res[i].SetBigInt(vv) + } + + // release object into pool + field.BigIntPool.Put(vv) + + return res, nil +} + // Exp z = xᵏ (mod q) func (z *Element) Exp(x Element, k *big.Int) *Element { if k.IsUint64() && k.Uint64() == 0 { @@ -592,8 +595,8 @@ func (z *Element) Exp(x Element, k *big.Int) *Element { // we negate k in a temp big.Int since // Int.Bit(_) of k and -k is different - e = bigIntPool.Get().(*big.Int) - defer bigIntPool.Put(e) + e = field.BigIntPool.Get() + defer field.BigIntPool.Put(e) e.Neg(k) } @@ -616,23 +619,26 @@ var rSquare = Element{ 18446744065119617025, } -// ToMont converts z to Montgomery form +// toMont converts z to Montgomery form // sets and returns z = z * r² -func (z *Element) ToMont() *Element { +func (z *Element) toMont() *Element { return z.Mul(z, &rSquare) } -// ToRegular returns z in regular form (doesn't mutate z) -func (z Element) ToRegular() Element { - return *z.FromMont() -} - // String returns the decimal representation of z as generated by // z.Text(10). func (z *Element) String() string { return z.Text(10) } +// toBigInt returns z as a big.Int in Montgomery form +func (z *Element) toBigInt(res *big.Int) *big.Int { + var b [Bytes]byte + binary.BigEndian.PutUint64(b[0:8], z[0]) + + return res.SetBytes(b[:]) +} + // Text returns the string representation of z in the given base. // Base must be between 2 and 36, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35. @@ -651,35 +657,42 @@ func (z *Element) Text(base int) string { if base == 10 { var zzNeg Element zzNeg.Neg(z) - zzNeg.FromMont() + zzNeg.fromMont() if zzNeg[0] <= maxUint16 && zzNeg[0] != 0 { return "-" + strconv.FormatUint(zzNeg[0], base) } } - zz := *z - zz.FromMont() + zz := z.Bits() return strconv.FormatUint(zz[0], base) } -// ToBigInt returns z as a big.Int in Montgomery form -func (z *Element) ToBigInt(res *big.Int) *big.Int { - var b [Limbs * 8]byte - binary.BigEndian.PutUint64(b[0:8], z[0]) - - return res.SetBytes(b[:]) +// BigInt sets and return z as a *big.Int +func (z *Element) BigInt(res *big.Int) *big.Int { + _z := *z + _z.fromMont() + return _z.toBigInt(res) } // ToBigIntRegular returns z as a big.Int in regular form +// +// Deprecated: use BigInt(*big.Int) instead func (z Element) ToBigIntRegular(res *big.Int) *big.Int { - z.FromMont() - return z.ToBigInt(res) + z.fromMont() + return z.toBigInt(res) } -// Bytes returns the value of z as a big-endian byte array -func (z *Element) Bytes() (res [Limbs * 8]byte) { - _z := z.ToRegular() - binary.BigEndian.PutUint64(res[0:8], _z[0]) +// Bits provides access to z by returning its value as a little-endian [1]uint64 array. +// Bits is intended to support implementation of missing low-level Element +// functionality outside this package; it should be avoided otherwise. +func (z *Element) Bits() [1]uint64 { + _z := *z + fromMont(&_z) + return _z +} +// Bytes returns the value of z as a big-endian byte array +func (z *Element) Bytes() (res [Bytes]byte) { + BigEndian.PutElement(&res, *z) return } @@ -692,24 +705,44 @@ func (z *Element) Marshal() []byte { // SetBytes interprets e as the bytes of a big-endian unsigned integer, // sets z to that value, and returns z. func (z *Element) SetBytes(e []byte) *Element { - if len(e) == 8 { + if len(e) == Bytes { // fast path - z[0] = binary.BigEndian.Uint64(e) - return z.ToMont() + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err == nil { + *z = v + return z + } } + + // slow path. // get a big int from our pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() vv.SetBytes(e) // set big int z.SetBigInt(vv) // put temporary object back in pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } +// SetBytesCanonical interprets e as the bytes of a big-endian 8-byte integer. +// If e is not a 8-byte slice or encodes a value higher than q, +// SetBytesCanonical returns an error. +func (z *Element) SetBytesCanonical(e []byte) error { + if len(e) != Bytes { + return errors.New("invalid goldilocks.Element encoding") + } + v, err := BigEndian.Element((*[Bytes]byte)(e)) + if err != nil { + return err + } + *z = v + return nil +} + // SetBigInt sets z to v and returns z func (z *Element) SetBigInt(v *big.Int) *Element { z.SetZero() @@ -727,17 +760,16 @@ func (z *Element) SetBigInt(v *big.Int) *Element { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() // copy input + modular reduction - vv.Set(v) vv.Mod(v, &_modulus) // set big int byte value z.setBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z } @@ -759,20 +791,20 @@ func (z *Element) setBigInt(v *big.Int) *Element { } } - return z.ToMont() + return z.toMont() } // SetString creates a big.Int with number and calls SetBigInt on z // // The number prefix determines the actual base: A prefix of -// ''0b'' or ''0B'' selects base 2, ''0'', ''0o'' or ''0O'' selects base 8, -// and ''0x'' or ''0X'' selects base 16. Otherwise, the selected base is 10 +// ”0b” or ”0B” selects base 2, ”0”, ”0o” or ”0O” selects base 8, +// and ”0x” or ”0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For base 16, lower and upper case letters are considered the same: // The letters 'a' to 'f' and 'A' to 'F' represent digit values 10 to 15. // -// An underscore character ''_'' may appear between a base +// An underscore character ”_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as a panic if there @@ -781,7 +813,7 @@ func (z *Element) setBigInt(v *big.Int) *Element { // If the number is invalid this method leaves z unchanged and returns nil, error. func (z *Element) SetString(number string) (*Element, error) { // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(number, 0); !ok { return nil, errors.New("Element.SetString failed -> can't parse number into a big.Int " + number) @@ -790,7 +822,7 @@ func (z *Element) SetString(number string) (*Element, error) { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return z, nil } @@ -830,7 +862,7 @@ func (z *Element) UnmarshalJSON(data []byte) error { } // get temporary big int from the pool - vv := bigIntPool.Get().(*big.Int) + vv := field.BigIntPool.Get() if _, ok := vv.SetString(s, 0); !ok { return errors.New("can't parse into a big.Int: " + s) @@ -839,10 +871,67 @@ func (z *Element) UnmarshalJSON(data []byte) error { z.SetBigInt(vv) // release object into pool - bigIntPool.Put(vv) + field.BigIntPool.Put(vv) return nil } +// A ByteOrder specifies how to convert byte slices into a Element +type ByteOrder interface { + Element(*[Bytes]byte) (Element, error) + PutElement(*[Bytes]byte, Element) + String() string +} + +// BigEndian is the big-endian implementation of ByteOrder and AppendByteOrder. +var BigEndian bigEndian + +type bigEndian struct{} + +// Element interpret b is a big-endian 8-byte slice. +// If b encodes a value higher than q, Element returns error. +func (bigEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.BigEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid goldilocks.Element encoding") + } + + z.toMont() + return z, nil +} + +func (bigEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.BigEndian.PutUint64((*b)[0:8], e[0]) +} + +func (bigEndian) String() string { return "BigEndian" } + +// LittleEndian is the little-endian implementation of ByteOrder and AppendByteOrder. +var LittleEndian littleEndian + +type littleEndian struct{} + +func (littleEndian) Element(b *[Bytes]byte) (Element, error) { + var z Element + z[0] = binary.LittleEndian.Uint64((*b)[0:8]) + + if !z.smallerThanModulus() { + return Element{}, errors.New("invalid goldilocks.Element encoding") + } + + z.toMont() + return z, nil +} + +func (littleEndian) PutElement(b *[Bytes]byte, e Element) { + e.fromMont() + binary.LittleEndian.PutUint64((*b)[0:8], e[0]) +} + +func (littleEndian) String() string { return "LittleEndian" } + // Legendre returns the Legendre symbol of z (either +1, -1, or 0.) func (z *Element) Legendre() int { var l Element @@ -875,7 +964,7 @@ func (z *Element) Sqrt(x *Element) *Element { // y = x^((s+1)/2)) = w * x y.Mul(x, &w) - // b = x^s = w * w * x = y * x + // b = xˢ = w * w * x = y * x b.Mul(&w, &y) // g = nonResidue ^ s @@ -885,7 +974,7 @@ func (z *Element) Sqrt(x *Element) *Element { r := uint64(32) // compute legendre symbol - // t = x^((q-1)/2) = r-1 squaring of x^s + // t = x^((q-1)/2) = r-1 squaring of xˢ t = b for i := uint64(0); i < r-1; i++ { t.Square(&t) @@ -938,7 +1027,7 @@ func (z *Element) Inverse(x *Element) *Element { var r, s, u, v uint64 u = q - s = 18446744065119617025 // s = r^2 + s = 18446744065119617025 // s = r² r = 0 v = x[0] diff --git a/field/goldilocks/element_ops_purego.go b/field/goldilocks/element_ops_purego.go new file mode 100644 index 000000000..5dc51ea24 --- /dev/null +++ b/field/goldilocks/element_ops_purego.go @@ -0,0 +1,123 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package goldilocks + +import "math/bits" + +// MulBy3 x *= 3 (mod q) +func MulBy3(x *Element) { + var y Element + y.SetUint64(3) + x.Mul(x, &y) +} + +// MulBy5 x *= 5 (mod q) +func MulBy5(x *Element) { + var y Element + y.SetUint64(5) + x.Mul(x, &y) +} + +// MulBy13 x *= 13 (mod q) +func MulBy13(x *Element) { + var y Element + y.SetUint64(13) + x.Mul(x, &y) +} + +// Butterfly sets +// +// a = a + b (mod q) +// b = a - b (mod q) +func Butterfly(a, b *Element) { + _butterflyGeneric(a, b) +} + +func fromMont(z *Element) { + _fromMontGeneric(z) +} + +func reduce(z *Element) { + _reduceGeneric(z) +} + +// Mul z = x * y (mod q) +func (z *Element) Mul(x, y *Element) *Element { + + // In fact, since the modulus R fits on one register, the CIOS algorithm gets reduced to standard REDC (textbook Montgomery reduction): + // hi, lo := x * y + // m := (lo * qInvNeg) mod R + // (*) r := (hi * R + lo + m * q) / R + // reduce r if necessary + + // On the emphasized line, we get r = hi + (lo + m * q) / R + // If we write hi2, lo2 = m * q then R | m * q - lo2 ⇒ R | (lo * qInvNeg) q - lo2 = -lo - lo2 + // This shows lo + lo2 = 0 mod R. i.e. lo + lo2 = 0 if lo = 0 and R otherwise. + // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) + // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. + + var r uint64 + hi, lo := bits.Mul64(x[0], y[0]) + if lo != 0 { + hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow + } + m := lo * qInvNeg + hi2, _ := bits.Mul64(m, q) + r, carry := bits.Add64(hi2, hi, 0) + + if carry != 0 || r >= q { + // we need to reduce + r -= q + } + z[0] = r + + return z +} + +// Square z = x * x (mod q) +func (z *Element) Square(x *Element) *Element { + // see Mul for algorithm documentation + + // In fact, since the modulus R fits on one register, the CIOS algorithm gets reduced to standard REDC (textbook Montgomery reduction): + // hi, lo := x * y + // m := (lo * qInvNeg) mod R + // (*) r := (hi * R + lo + m * q) / R + // reduce r if necessary + + // On the emphasized line, we get r = hi + (lo + m * q) / R + // If we write hi2, lo2 = m * q then R | m * q - lo2 ⇒ R | (lo * qInvNeg) q - lo2 = -lo - lo2 + // This shows lo + lo2 = 0 mod R. i.e. lo + lo2 = 0 if lo = 0 and R otherwise. + // Which finally gives (lo + m * q) / R = (lo + lo2 + R hi2) / R = hi2 + (lo+lo2) / R = hi2 + (lo != 0) + // This "optimization" lets us do away with one MUL instruction on ARM architectures and is available for all q < R. + + var r uint64 + hi, lo := bits.Mul64(x[0], x[0]) + if lo != 0 { + hi++ // x[0] * y[0] ≤ 2¹²⁸ - 2⁶⁵ + 1, meaning hi ≤ 2⁶⁴ - 2 so no need to worry about overflow + } + m := lo * qInvNeg + hi2, _ := bits.Mul64(m, q) + r, carry := bits.Add64(hi2, hi, 0) + + if carry != 0 || r >= q { + // we need to reduce + r -= q + } + z[0] = r + + return z +} diff --git a/field/goldilocks/element_test.go b/field/goldilocks/element_test.go index 790676788..e3cacf21b 100644 --- a/field/goldilocks/element_test.go +++ b/field/goldilocks/element_test.go @@ -179,17 +179,10 @@ func BenchmarkElementFromMont(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() for i := 0; i < b.N; i++ { - benchResElement.FromMont() + benchResElement.fromMont() } } -func BenchmarkElementToMont(b *testing.B) { - benchResElement.SetRandom() - b.ResetTimer() - for i := 0; i < b.N; i++ { - benchResElement.ToMont() - } -} func BenchmarkElementSquare(b *testing.B) { benchResElement.SetRandom() b.ResetTimer() @@ -582,7 +575,7 @@ func TestElementBitLen(t *testing.T) { properties.Property("BitLen should output same result than big.Int.BitLen", prop.ForAll( func(a testPairElement) bool { - return a.element.FromMont().BitLen() == a.bigint.BitLen() + return a.element.fromMont().BitLen() == a.bigint.BitLen() }, genA, )) @@ -697,7 +690,7 @@ func TestElementAdd(t *testing.T) { var d, e big.Int d.Add(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -710,13 +703,13 @@ func TestElementAdd(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Add(&a.element, &r) d.Add(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -745,17 +738,17 @@ func TestElementAdd(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Add(&a, &b) d.Add(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Add failed special test values") } } @@ -806,7 +799,7 @@ func TestElementSub(t *testing.T) { var d, e big.Int d.Sub(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -819,13 +812,13 @@ func TestElementSub(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Sub(&a.element, &r) d.Sub(&a.bigint, &rb).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -854,17 +847,17 @@ func TestElementSub(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Sub(&a, &b) d.Sub(&aBig, &bBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sub failed special test values") } } @@ -915,7 +908,7 @@ func TestElementMul(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &b.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -928,7 +921,7 @@ func TestElementMul(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Mul(&a.element, &r) @@ -942,7 +935,7 @@ func TestElementMul(t *testing.T) { return false } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -982,11 +975,11 @@ func TestElementMul(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Mul(&a, &b) @@ -999,7 +992,7 @@ func TestElementMul(t *testing.T) { t.Fatal("Mul failed special test values: asm and generic impl don't match") } - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Mul failed special test values") } } @@ -1051,7 +1044,7 @@ func TestElementDiv(t *testing.T) { d.ModInverse(&b.bigint, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1064,14 +1057,14 @@ func TestElementDiv(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Div(&a.element, &r) d.ModInverse(&rb, Modulus()) d.Mul(&d, &a.bigint).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1100,18 +1093,18 @@ func TestElementDiv(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Div(&a, &b) d.ModInverse(&bBig, Modulus()) d.Mul(&d, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Div failed special test values") } } @@ -1162,7 +1155,7 @@ func TestElementExp(t *testing.T) { var d, e big.Int d.Exp(&a.bigint, &b.bigint, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1175,13 +1168,13 @@ func TestElementExp(t *testing.T) { for _, r := range testValues { var d, e, rb big.Int - r.ToBigIntRegular(&rb) + r.BigInt(&rb) var c Element c.Exp(a.element, &rb) d.Exp(&a.bigint, &rb, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { return false } } @@ -1210,17 +1203,17 @@ func TestElementExp(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) for _, b := range testValues { var bBig, d, e big.Int - b.ToBigIntRegular(&bBig) + b.BigInt(&bBig) var c Element c.Exp(a, &bBig) d.Exp(&aBig, &bBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Exp failed special test values") } } @@ -1265,7 +1258,7 @@ func TestElementSquare(t *testing.T) { var d, e big.Int d.Mul(&a.bigint, &a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1286,14 +1279,14 @@ func TestElementSquare(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Square(&a) var d, e big.Int d.Mul(&aBig, &aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Square failed special test values") } } @@ -1337,7 +1330,7 @@ func TestElementInverse(t *testing.T) { var d, e big.Int d.ModInverse(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1358,14 +1351,14 @@ func TestElementInverse(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Inverse(&a) var d, e big.Int d.ModInverse(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Inverse failed special test values") } } @@ -1409,7 +1402,7 @@ func TestElementSqrt(t *testing.T) { var d, e big.Int d.ModSqrt(&a.bigint, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1430,14 +1423,14 @@ func TestElementSqrt(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Sqrt(&a) var d, e big.Int d.ModSqrt(&aBig, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Sqrt failed special test values") } } @@ -1481,7 +1474,7 @@ func TestElementDouble(t *testing.T) { var d, e big.Int d.Lsh(&a.bigint, 1).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1502,14 +1495,14 @@ func TestElementDouble(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Double(&a) var d, e big.Int d.Lsh(&aBig, 1).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Double failed special test values") } } @@ -1553,7 +1546,7 @@ func TestElementNeg(t *testing.T) { var d, e big.Int d.Neg(&a.bigint).Mod(&d, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, )) @@ -1574,14 +1567,14 @@ func TestElementNeg(t *testing.T) { for _, a := range testValues { var aBig big.Int - a.ToBigIntRegular(&aBig) + a.BigInt(&aBig) var c Element c.Neg(&a) var d, e big.Int d.Neg(&aBig).Mod(&d, Modulus()) - if c.FromMont().ToBigInt(&e).Cmp(&d) != 0 { + if c.BigInt(&e).Cmp(&d) != 0 { t.Fatal("Neg failed special test values") } } @@ -1969,7 +1962,7 @@ func TestElementNegativeExp(t *testing.T) { d.Exp(&a.bigint, &nb, Modulus()) - return c.FromMont().ToBigInt(&e).Cmp(&d) == 0 + return c.BigInt(&e).Cmp(&d) == 0 }, genA, genA, )) @@ -2102,17 +2095,17 @@ func TestElementFromMont(t *testing.T) { func(a testPairElement) bool { c := a.element d := a.element - c.FromMont() + c.fromMont() _fromMontGeneric(&d) return c.Equal(&d) }, genA, )) - properties.Property("x.FromMont().ToMont() == x", prop.ForAll( + properties.Property("x.fromMont().toMont() == x", prop.ForAll( func(a testPairElement) bool { c := a.element - c.FromMont().ToMont() + c.fromMont().toMont() return c.Equal(&a.element) }, genA, @@ -2198,7 +2191,7 @@ func gen() gopter.Gen { } } - g.element.ToBigIntRegular(&g.bigint) + g.element.BigInt(&g.bigint) genResult := gopter.NewGenResult(g, gopter.NoShrinker) return genResult } diff --git a/field/goldilocks/internal/main.go b/field/goldilocks/internal/main.go index a6ae38c4d..4f5bacd3f 100644 --- a/field/goldilocks/internal/main.go +++ b/field/goldilocks/internal/main.go @@ -3,14 +3,14 @@ package main import ( "fmt" - "github.com/consensys/gnark-crypto/internal/field" - "github.com/consensys/gnark-crypto/internal/field/generator" + "github.com/consensys/gnark-crypto/field/generator" + "github.com/consensys/gnark-crypto/field/generator/config" ) //go:generate go run main.go func main() { const modulus = "0xFFFFFFFF00000001" - goldilocks, err := field.NewFieldConfig("goldilocks", "Element", modulus, true) + goldilocks, err := config.NewFieldConfig("goldilocks", "Element", modulus, true) if err != nil { panic(err) } diff --git a/field/hashutils.go b/field/hashutils.go new file mode 100644 index 000000000..4b25f861c --- /dev/null +++ b/field/hashutils.go @@ -0,0 +1,94 @@ +package field + +import ( + "crypto/sha256" + "errors" +) + +// ExpandMsgXmd expands msg to a slice of lenInBytes bytes. +// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5 +// https://tools.ietf.org/html/rfc8017#section-4.1 (I2OSP/O2ISP) +func ExpandMsgXmd(msg, dst []byte, lenInBytes int) ([]byte, error) { + + h := sha256.New() + ell := (lenInBytes + h.Size() - 1) / h.Size() // ceil(len_in_bytes / b_in_bytes) + if ell > 255 { + return nil, errors.New("invalid lenInBytes") + } + if len(dst) > 255 { + return nil, errors.New("invalid domain size (>255 bytes)") + } + sizeDomain := uint8(len(dst)) + + // Z_pad = I2OSP(0, r_in_bytes) + // l_i_b_str = I2OSP(len_in_bytes, 2) + // DST_prime = I2OSP(len(DST), 1) ∥ DST + // b₀ = H(Z_pad ∥ msg ∥ l_i_b_str ∥ I2OSP(0, 1) ∥ DST_prime) + h.Reset() + if _, err := h.Write(make([]byte, h.BlockSize())); err != nil { + return nil, err + } + if _, err := h.Write(msg); err != nil { + return nil, err + } + if _, err := h.Write([]byte{uint8(lenInBytes >> 8), uint8(lenInBytes), uint8(0)}); err != nil { + return nil, err + } + if _, err := h.Write(dst); err != nil { + return nil, err + } + if _, err := h.Write([]byte{sizeDomain}); err != nil { + return nil, err + } + b0 := h.Sum(nil) + + // b₁ = H(b₀ ∥ I2OSP(1, 1) ∥ DST_prime) + h.Reset() + if _, err := h.Write(b0); err != nil { + return nil, err + } + if _, err := h.Write([]byte{uint8(1)}); err != nil { + return nil, err + } + if _, err := h.Write(dst); err != nil { + return nil, err + } + if _, err := h.Write([]byte{sizeDomain}); err != nil { + return nil, err + } + b1 := h.Sum(nil) + + res := make([]byte, lenInBytes) + copy(res[:h.Size()], b1) + + for i := 2; i <= ell; i++ { + // b_i = H(strxor(b₀, b_(i - 1)) ∥ I2OSP(i, 1) ∥ DST_prime) + h.Reset() + strxor := make([]byte, h.Size()) + for j := 0; j < h.Size(); j++ { + strxor[j] = b0[j] ^ b1[j] + } + if _, err := h.Write(strxor); err != nil { + return nil, err + } + if _, err := h.Write([]byte{uint8(i)}); err != nil { + return nil, err + } + if _, err := h.Write(dst); err != nil { + return nil, err + } + if _, err := h.Write([]byte{sizeDomain}); err != nil { + return nil, err + } + b1 = h.Sum(nil) + copy(res[h.Size()*(i-1):min(h.Size()*i, len(res))], b1) + } + return res, nil +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/field/hashutils_test.go b/field/hashutils_test.go new file mode 100644 index 000000000..3ff1d3da8 --- /dev/null +++ b/field/hashutils_test.go @@ -0,0 +1,132 @@ +package field + +import ( + "bytes" + "encoding/hex" + "testing" +) + +type expandMsgXmdTestCase struct { + msg string + lenInBytes int + uniformBytesHex string +} + +// Test vectors from https://datatracker.ietf.org/doc/draft-irtf-cfrg-hash-to-curve/14/ Page 148 Section K.1. +func TestExpandMsgXmd(t *testing.T) { + //name := "expand_message_xmd" + dst := "QUUX-V01-CS02-with-expander-SHA256-128" + //hash := "SHA256" + //k := 128 + + testCases := []expandMsgXmdTestCase{ + { + "", + 0x20, + "68a985b87eb6b46952128911f2a4412bbc302a9d759667f87f7a21d803f07235", + }, + + { + "abc", + 0x20, + "d8ccab23b5985ccea865c6c97b6e5b8350e794e603b4b97902f53a8a0d605615", + }, + + { + "abcdef0123456789", + 0x20, + "eff31487c770a893cfb36f912fbfcbff40d5661771ca4b2cb4eafe524333f5c1", + }, + + { + "q128_qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq", + 0x20, + "b23a1d2b4d97b2ef7785562a7e8bac7eed54ed6e97e29aa51bfe3f12ddad1ff9", + }, + + { + "a512_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + 0x20, + "4623227bcc01293b8c130bf771da8c298dede7383243dc0993d2d94823958c4c", + }, + { + "", + 0x80, + "af84c27ccfd45d41914fdff5df25293e221afc53d8ad2ac06d5e3e29485dadbee0d121587713a3e0dd4d5e69e93eb7cd4f5df4cd103e188cf60cb02edc3edf18eda8576c412b18ffb658e3dd6ec849469b979d444cf7b26911a08e63cf31f9dcc541708d3491184472c2c29bb749d4286b004ceb5ee6b9a7fa5b646c993f0ced", + }, + { + "", + 0x20, + "68a985b87eb6b46952128911f2a4412bbc302a9d759667f87f7a21d803f07235", + }, + { + "abc", + 0x80, + "abba86a6129e366fc877aab32fc4ffc70120d8996c88aee2fe4b32d6c7b6437a647e6c3163d40b76a73cf6a5674ef1d890f95b664ee0afa5359a5c4e07985635bbecbac65d747d3d2da7ec2b8221b17b0ca9dc8a1ac1c07ea6a1e60583e2cb00058e77b7b72a298425cd1b941ad4ec65e8afc50303a22c0f99b0509b4c895f40", + }, + { + "abcdef0123456789", + 0x80, + "ef904a29bffc4cf9ee82832451c946ac3c8f8058ae97d8d629831a74c6572bd9ebd0df635cd1f208e2038e760c4994984ce73f0d55ea9f22af83ba4734569d4bc95e18350f740c07eef653cbb9f87910d833751825f0ebefa1abe5420bb52be14cf489b37fe1a72f7de2d10be453b2c9d9eb20c7e3f6edc5a60629178d9478df", + }, + { + "q128_qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq", + 0x80, + "80be107d0884f0d881bb460322f0443d38bd222db8bd0b0a5312a6fedb49c1bbd88fd75d8b9a09486c60123dfa1d73c1cc3169761b17476d3c6b7cbbd727acd0e2c942f4dd96ae3da5de368d26b32286e32de7e5a8cb2949f866a0b80c58116b29fa7fabb3ea7d520ee603e0c25bcaf0b9a5e92ec6a1fe4e0391d1cdbce8c68a", + }, + { + "a512_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + 0x80, + "546aff5444b5b79aa6148bd81728704c32decb73a3ba76e9e75885cad9def1d06d6792f8a7d12794e90efed817d96920d728896a4510864370c207f99bd4a608ea121700ef01ed879745ee3e4ceef777eda6d9e5e38b90c86ea6fb0b36504ba4a45d22e86f6db5dd43d98a294bebb9125d5b794e9d2a81181066eb954966a487", + }, + //test cases not in the standard + { + "", + 0x30, + "3808e9bb0ade2df3aa6f1b459eb5058a78142f439213ddac0c97dcab92ae5a8408d86b32bbcc87de686182cbdf65901f", + }, + { + "abc", + 0x30, + "2b877f5f0dfd881405426c6b87b39205ef53a548b0e4d567fc007cb37c6fa1f3b19f42871efefca518ac950c27ac4e28", + }, + { + "abcdef0123456789", + 0x30, + "226da1780b06e59723714f80da9a63648aebcfc1f08e0db87b5b4d16b108da118214c1450b0e86f9cefeb44903fd3aba", + }, + { + "q128_qqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqqq", + 0x30, + "12b23ae2e888f442fd6d0d85d90a0d7ed5337d38113e89cdc7c22db91bd0abaec1023e9a8f0ef583a111104e2f8a0637", + }, + { + "a512_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + 0x30, + "1aaee90016547a85ab4dc55e4f78a364c2e239c0e58b05753453c63e6e818334005e90d9ce8f047bddab9fbb315f8722", + }, + } + + for _, testCase := range testCases { + uniformBytes, err := ExpandMsgXmd([]byte(testCase.msg), []byte(dst), testCase.lenInBytes) + if err != nil { + t.Fatal(err) + } + + var testCaseUniformBytes []byte + testCaseUniformBytes, err = hex.DecodeString(testCase.uniformBytesHex) + if err != nil { + t.Fatal(err) + } + + if len(uniformBytes) != testCase.lenInBytes { + t.Error("wrong length: expected", testCase.lenInBytes, "got", len(uniformBytes)) + } + + if !bytes.Equal(uniformBytes, testCaseUniformBytes) { + uniformBytesHex := make([]byte, len(uniformBytes)*2) + hex.Encode(uniformBytesHex, uniformBytes) + t.Errorf("expected \"%s\" got \"%s\"", testCase.uniformBytesHex, uniformBytesHex) + } + } +} diff --git a/field/utils.go b/field/utils.go new file mode 100644 index 000000000..27db0c7c6 --- /dev/null +++ b/field/utils.go @@ -0,0 +1,53 @@ +package field + +import ( + "fmt" + "math/big" + "math/bits" + "sync" +) + +// BigIntPool is a shared *big.Int memory pool +var BigIntPool bigIntPool + +var _bigIntPool = sync.Pool{ + New: func() interface{} { + return new(big.Int) + }, +} + +type bigIntPool struct{} + +func (bigIntPool) Get() *big.Int { + return _bigIntPool.Get().(*big.Int) +} + +func (bigIntPool) Put(v *big.Int) { + _bigIntPool.Put(v) +} + +// BigIntMatchUint64Slice is a test helper to match big.Int words againt a uint64 slice +func BigIntMatchUint64Slice(aInt *big.Int, a []uint64) error { + + words := aInt.Bits() + + const steps = 64 / bits.UintSize + const filter uint64 = 0xFFFFFFFFFFFFFFFF >> (64 - bits.UintSize) + for i := 0; i < len(a)*steps; i++ { + + var wI big.Word + + if i < len(words) { + wI = words[i] + } + + aI := a[i/steps] >> ((i * bits.UintSize) % 64) + aI &= filter + + if uint64(wI) != aI { + return fmt.Errorf("bignum mismatch: disagreement on word %d: %x ≠ %x; %d ≠ %d", i, uint64(wI), aI, uint64(wI), aI) + } + } + + return nil +} diff --git a/go.mod b/go.mod index 583ea3142..fe822ccf5 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/consensys/gnark-crypto -go 1.17 +go 1.18 require ( github.com/consensys/bavard v0.1.13 @@ -9,7 +9,7 @@ require ( github.com/spf13/cobra v1.5.0 github.com/stretchr/testify v1.8.0 golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa - golang.org/x/sys v0.0.0-20220727055044-e65921a090b8 + golang.org/x/sys v0.2.0 ) require ( diff --git a/go.sum b/go.sum index 24019a221..b3ae5f84f 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/consensys/bavard v0.1.12 h1:rApQlUvBg5FeW/fnigtVnAs0sBrgDN2pEuHNdWElSUE= -github.com/consensys/bavard v0.1.12/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= github.com/consensys/bavard v0.1.13 h1:oLhMLOFGTLdlda/kma4VOJazblc7IM5y5QPd2A/YjhQ= github.com/consensys/bavard v0.1.13/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= @@ -28,15 +26,8 @@ github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PK github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220727055044-e65921a090b8 h1:dyU22nBWzrmTQxtNrr4dzVOvaw35nUYE279vF9UmsI8= -golang.org/x/sys v0.0.0-20220727055044-e65921a090b8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/internal/bench/main.go b/internal/bench/main.go index b0f881276..3611d10ec 100644 --- a/internal/bench/main.go +++ b/internal/bench/main.go @@ -45,7 +45,7 @@ func main() { for _, e := range entries { buf.Reset() count := strconv.Itoa(benchCount) - cmd := exec.Command("go", "test", "-timeout", "10m", "-run", "^$", "-bench", regexp, "-count", count, "-tags", "amd64_adx") + cmd := exec.Command("go", "test", "-timeout", "10m", "-run", "^$", "-bench", regexp, "-count", count) args := strings.Join(cmd.Args, " ") log.Println("running benchmark", "dir", e.path, "cmd", args) cmd.Dir = e.path diff --git a/internal/field/field.md b/internal/field/field.md deleted file mode 100644 index 4e9648ceb..000000000 --- a/internal/field/field.md +++ /dev/null @@ -1,48 +0,0 @@ - -# Usage - -At the root of your repo: -```bash -go get github.com/consensys/gnark-crypto/field -``` - -then in a `main.go` (that can be called using a `go:generate` workflow): - -``` -generator.GenerateFF(packageName, structName, modulus, destinationPath, false) -``` - -The generated type has an API that's similar with `big.Int` - -Example API signature -```go -// Mul z = x * y mod q -func (z *Element) Mul(x, y *Element) *Element -``` - -and can be used like so: - -```go -var a, b Element -a.SetUint64(2) -b.SetString("984896738") - -a.Mul(a, b) - -a.Sub(a, a) - .Add(a, b) - .Inv(a) - -b.Exp(b, 42) -b.Neg(b) -``` - -### Build tags - -Generates optimized assembly for `amd64` target. - -For the `Mul` operation, using `ADX` instructions and `ADOX/ADCX` result in a significant performance gain. - -The "default" target `amd64` checks if the running architecture supports these instruction, and reverts to generic path if not. This check adds a branch and forces the function to reserve some bytes on the frame to store the argument to call `_mulGeneric` . - -This package outputs code that can be compiled with `amd64_adx` flag which omits this check. Will crash if the platform running the binary doesn't support the `ADX` instructions (roughly, before 2016). \ No newline at end of file diff --git a/internal/field/internal/templates/element/mul_cios.go b/internal/field/internal/templates/element/mul_cios.go deleted file mode 100644 index ffd7f55e5..000000000 --- a/internal/field/internal/templates/element/mul_cios.go +++ /dev/null @@ -1,77 +0,0 @@ -package element - -const MulCIOS = ` -{{ define "mul_cios" }} - var t [{{add .all.NbWords 1}}]uint64 - var D uint64 - var m, C uint64 - - {{- range $j := .all.NbWordsIndexesFull}} - // ----------------------------------- - // First loop - {{ if eq $j 0}} - C, t[0] = bits.Mul64({{$.V2}}[{{$j}}], {{$.V1}}[0]) - {{- range $i := $.all.NbWordsIndexesNoZero}} - C, t[{{$i}}] = madd1({{$.V2}}[{{$j}}], {{$.V1}}[{{$i}}], C) - {{- end}} - {{ else }} - C, t[0] = madd1({{$.V2}}[{{$j}}], {{$.V1}}[0], t[0]) - {{- range $i := $.all.NbWordsIndexesNoZero}} - C, t[{{$i}}] = madd2({{$.V2}}[{{$j}}], {{$.V1}}[{{$i}}], t[{{$i}}], C) - {{- end}} - {{ end }} - t[{{$.all.NbWords}}], D = bits.Add64(t[{{$.all.NbWords}}], C, 0) - - // m = t[0]n'[0] mod W - m = t[0] * qInvNeg - - // ----------------------------------- - // Second loop - C = madd0(m, q0, t[0]) - {{- range $i := $.all.NbWordsIndexesNoZero}} - C, t[{{sub $i 1}}] = madd2(m, q{{$i}}, t[{{$i}}], C) - {{- end}} - - t[{{sub $.all.NbWords 1}}], C = bits.Add64(t[{{$.all.NbWords}}], C, 0) - t[{{$.all.NbWords}}], _ = bits.Add64(0, D, C) - {{- end}} - - - if t[{{$.all.NbWords}}] != 0 { - // we need to reduce, we have a result on {{add 1 $.all.NbWords}} words - {{- if gt $.all.NbWords 1}} - var b uint64 - {{- end}} - z[0], {{- if gt $.all.NbWords 1}}b{{- else}}_{{- end}} = bits.Sub64(t[0], q0, 0) - {{- range $i := .all.NbWordsIndexesNoZero}} - {{- if eq $i $.all.NbWordsLastIndex}} - z[{{$i}}], _ = bits.Sub64(t[{{$i}}], q{{$i}}, b) - {{- else }} - z[{{$i}}], b = bits.Sub64(t[{{$i}}], q{{$i}}, b) - {{- end}} - {{- end}} - return - } - - // copy t into z - {{- range $i := $.all.NbWordsIndexesFull}} - z[{{$i}}] = t[{{$i}}] - {{- end}} - -{{ end }} - -{{ define "mul_cios_one_limb" }} - var r uint64 - hi, lo := bits.Mul64({{$.V1}}[0], {{$.V2}}[0]) - m := lo * qInvNeg - hi2, lo2 := bits.Mul64(m, q) - _, carry := bits.Add64(lo2, lo, 0) - r, carry = bits.Add64(hi2, hi, carry) - - if carry != 0 || r >= q { - // we need to reduce - r -= q - } - z[0] = r -{{ end }} -` diff --git a/internal/field/internal/templates/element/mul_nocarry.go b/internal/field/internal/templates/element/mul_nocarry.go deleted file mode 100644 index c283c3683..000000000 --- a/internal/field/internal/templates/element/mul_nocarry.go +++ /dev/null @@ -1,57 +0,0 @@ -package element - -// MulNoCarry see https://hackmd.io/@gnark/modular_multiplication for more info on the algorithm -const MulNoCarry = ` -{{ define "mul_nocarry" }} -var t [{{.all.NbWords}}]uint64 -var c [3]uint64 -{{- range $j := .all.NbWordsIndexesFull}} -{ - // round {{$j}} - v := {{$.V1}}[{{$j}}] - {{- if eq $j $.all.NbWordsLastIndex}} - c[1], c[0] = madd1(v, {{$.V2}}[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - {{- if eq $.all.NbWords 1}} - z[0], _ = madd3(m, q0, c[0], c[2], c[1]) - {{- else}} - {{- range $i := $.all.NbWordsIndexesNoZero}} - c[1], c[0] = madd2(v, {{$.V2}}[{{$i}}], c[1], t[{{$i}}]) - {{- if eq $i $.all.NbWordsLastIndex}} - z[{{sub $.all.NbWords 1}}], z[{{sub $i 1}}] = madd3(m, q{{$i}}, c[0], c[2], c[1]) - {{- else}} - c[2], z[{{sub $i 1}}] = madd2(m, q{{$i}}, c[2], c[0]) - {{- end}} - {{- end}} - {{- end}} - {{- else if eq $j 0}} - c[1], c[0] = bits.Mul64(v, {{$.V2}}[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - {{- range $i := $.all.NbWordsIndexesNoZero}} - c[1], c[0] = madd1(v, {{$.V2}}[{{$i}}], c[1]) - {{- if eq $i $.all.NbWordsLastIndex}} - t[{{sub $.all.NbWords 1}}], t[{{sub $i 1}}] = madd3(m, q{{$i}}, c[0], c[2], c[1]) - {{- else}} - c[2], t[{{sub $i 1}}] = madd2(m, q{{$i}}, c[2], c[0]) - {{- end}} - {{- end}} - {{- else}} - c[1], c[0] = madd1(v, {{$.V2}}[0], t[0]) - m := c[0] * qInvNeg - c[2] = madd0(m, q0, c[0]) - {{- range $i := $.all.NbWordsIndexesNoZero}} - c[1], c[0] = madd2(v, {{$.V2}}[{{$i}}], c[1], t[{{$i}}]) - {{- if eq $i $.all.NbWordsLastIndex}} - t[{{sub $.all.NbWords 1}}], t[{{sub $i 1}}] = madd3(m, q{{$i}}, c[0], c[2], c[1]) - {{- else}} - c[2], t[{{sub $i 1}}] = madd2(m, q{{$i}}, c[2], c[0]) - {{- end}} - {{- end}} - {{- end }} -} -{{- end}} -{{ end }} - -` diff --git a/internal/field/internal/templates/element/ops_generic.go b/internal/field/internal/templates/element/ops_generic.go deleted file mode 100644 index 8e89d2b4a..000000000 --- a/internal/field/internal/templates/element/ops_generic.go +++ /dev/null @@ -1,55 +0,0 @@ -package element - -const OpsNoAsm = ` - -{{ $mulConsts := list 3 5 13 }} -{{- range $i := $mulConsts }} - -// MulBy{{$i}} x *= {{$i}} (mod q) -func MulBy{{$i}}(x *{{$.ElementName}}) { - {{- if eq 1 $.NbWords}} - var y {{$.ElementName}} - y.SetUint64({{$i}}) - x.Mul(x, &y) - {{- else}} - {{- if eq $i 3}} - _x := *x - x.Double(x).Add(x, &_x) - {{- else if eq $i 5}} - _x := *x - x.Double(x).Double(x).Add(x, &_x) - {{- else if eq $i 13}} - var y = {{$.ElementName}}{ - {{- range $i := $.Thirteen}} - {{$i}},{{end}} - } - x.Mul(x, &y) - {{- else }} - NOT IMPLEMENTED - {{- end}} - {{- end}} -} - -{{- end}} - -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *{{.ElementName}}) { - _butterflyGeneric(a, b) -} - -{{- if ne .NbWords 1}} -func mul(z, x, y *{{.ElementName}}) { - _mulGeneric(z, x, y) -} -{{- end}} - -func fromMont(z *{{.ElementName}} ) { - _fromMontGeneric(z) -} - -func reduce(z *{{.ElementName}}) { - _reduceGeneric(z) -} -` diff --git a/internal/generator/addchain/1fffffffffffffffffffffffffffffffd755db9cd5e9140777fa4bd19a06c82 b/internal/generator/addchain/1fffffffffffffffffffffffffffffffd755db9cd5e9140777fa4bd19a06c82 new file mode 100644 index 000000000..9b6ee5ece Binary files /dev/null and b/internal/generator/addchain/1fffffffffffffffffffffffffffffffd755db9cd5e9140777fa4bd19a06c82 differ diff --git a/internal/generator/addchain/3fffffffffffffffffffffffffffffffffffffffffffffffffffffffbfffff0c b/internal/generator/addchain/3fffffffffffffffffffffffffffffffffffffffffffffffffffffffbfffff0c new file mode 100644 index 000000000..7bfbed2ec Binary files /dev/null and b/internal/generator/addchain/3fffffffffffffffffffffffffffffffffffffffffffffffffffffffbfffff0c differ diff --git a/internal/generator/addchain/7fffffffffffffffffffffffffffffff5d576e7357a4501ddfe92f46681b20a0 b/internal/generator/addchain/7fffffffffffffffffffffffffffffff5d576e7357a4501ddfe92f46681b20a0 new file mode 100644 index 000000000..ed212529f Binary files /dev/null and b/internal/generator/addchain/7fffffffffffffffffffffffffffffff5d576e7357a4501ddfe92f46681b20a0 differ diff --git a/internal/generator/addchain/7fffffffffffffffffffffffffffffffffffffffffffffffffffffff7ffffe17 b/internal/generator/addchain/7fffffffffffffffffffffffffffffffffffffffffffffffffffffff7ffffe17 new file mode 100644 index 000000000..4e6da0c92 Binary files /dev/null and b/internal/generator/addchain/7fffffffffffffffffffffffffffffffffffffffffffffffffffffff7ffffe17 differ diff --git a/internal/generator/config/curve.go b/internal/generator/config/curve.go index 0d387a7cf..0c55a6a97 100644 --- a/internal/generator/config/curve.go +++ b/internal/generator/config/curve.go @@ -3,7 +3,7 @@ package config import ( "math/big" - "github.com/consensys/gnark-crypto/internal/field" + "github.com/consensys/gnark-crypto/field/generator/config" ) // Curve describes parameters of the curve useful for the template @@ -15,8 +15,8 @@ type Curve struct { FpModulus string FrModulus string - Fp *field.FieldConfig - Fr *field.FieldConfig + Fp *config.FieldConfig + Fr *config.FieldConfig FpUnusedBits int FpInfo, FrInfo Field @@ -68,7 +68,7 @@ var TwistedEdwardsCurves []TwistedEdwardsCurve func defaultCRange() []int { // default range for C values in the multiExp - return []int{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 21} + return []int{4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} } func addCurve(c *Curve) { @@ -90,7 +90,13 @@ func newFieldInfo(modulus string) Field { } F.Bits = bModulus.BitLen() - F.Bytes = len(bModulus.Bits()) * 8 + F.Bytes = (F.Bits + 7) / 8 F.Modulus = func() *big.Int { return new(big.Int).Set(&bModulus) } return F } + +type FieldDependency struct { + FieldPackagePath string + ElementType string + FieldPackageName string +} diff --git a/internal/generator/config/hash_to_curve.go b/internal/generator/config/hash_to_curve.go index 801f335d9..cb7c5b914 100644 --- a/internal/generator/config/hash_to_curve.go +++ b/internal/generator/config/hash_to_curve.go @@ -1,8 +1,9 @@ package config import ( - "github.com/consensys/gnark-crypto/internal/field" "math/big" + + field "github.com/consensys/gnark-crypto/field/generator/config" ) type FieldElementToCurvePoint string @@ -178,7 +179,7 @@ type IsogenyInfo struct { type RationalPolynomialInfo struct { Num []field.Element - Den []field.Element //Denominator is monic. The leading coefficient (1) is omitted. + Den []field.Element //denominator is monic. The leading coefficient (1) is omitted. } type HashSuiteInfo struct { diff --git a/internal/generator/config/secp256k1.go b/internal/generator/config/secp256k1.go new file mode 100644 index 000000000..29a4f669c --- /dev/null +++ b/internal/generator/config/secp256k1.go @@ -0,0 +1,28 @@ +package config + +var SECP256K1 = Curve{ + Name: "secp256k1", + CurvePackage: "secp256k1", + EnumID: "SECP256k1", + FrModulus: "115792089237316195423570985008687907852837564279074904382605163141518161494337", + FpModulus: "115792089237316195423570985008687907853269984665640564039457584007908834671663", + G1: Point{ + CoordType: "fp.Element", + CoordExtDegree: 1, + PointName: "g1", + GLV: true, + CofactorCleaning: false, + CRange: defaultCRange(), + }, + HashE1: &HashSuiteSvdw{ + z: []string{"1"}, + c1: []string{"8"}, + c2: []string{"57896044618658097711785492504343953926634992332820282019728792003954417335831"}, + c3: []string{"10388779673325959979325452626823788324994718367665745800388075445979975427086"}, + c4: []string{"77194726158210796949047323339125271902179989777093709359638389338605889781098"}, + }, +} + +func init() { + addCurve(&SECP256K1) +} diff --git a/internal/generator/crypto/hash/mimc/generate.go b/internal/generator/crypto/hash/mimc/generate.go index c26a08452..94caf7f56 100644 --- a/internal/generator/crypto/hash/mimc/generate.go +++ b/internal/generator/crypto/hash/mimc/generate.go @@ -1,6 +1,7 @@ package mimc import ( + "os" "path/filepath" "github.com/consensys/bavard" @@ -8,11 +9,19 @@ import ( ) func Generate(conf config.Curve, baseDir string, bgen *bavard.BatchGenerator) error { + if conf.Equal(config.SECP256K1) { + return nil + } + conf.Package = "mimc" entries := []bavard.Entry{ {File: filepath.Join(baseDir, "doc.go"), Templates: []string{"doc.go.tmpl"}}, {File: filepath.Join(baseDir, "mimc.go"), Templates: []string{"mimc.go.tmpl"}}, + {File: filepath.Join(baseDir, "decompose.go"), Templates: []string{"decompose.go.tmpl"}}, + {File: filepath.Join(baseDir, "decompose_test.go"), Templates: []string{"tests/decompose.go.tmpl"}}, } + os.Remove(filepath.Join(baseDir, "utils.go")) + os.Remove(filepath.Join(baseDir, "utils_test.go")) return bgen.Generate(conf, conf.Package, "./crypto/hash/mimc/template", entries...) } diff --git a/internal/generator/crypto/hash/mimc/template/decompose.go.tmpl b/internal/generator/crypto/hash/mimc/template/decompose.go.tmpl new file mode 100644 index 000000000..1658a9e1d --- /dev/null +++ b/internal/generator/crypto/hash/mimc/template/decompose.go.tmpl @@ -0,0 +1,28 @@ + +import ( "math/big" + + "github.com/consensys/gnark-crypto/ecc/{{ .Name }}/fr" +) + +// Decompose interpret rawBytes as a bigInt x in big endian, +// and returns the digits of x (from LSB to MSB) when x is written +// in basis modulo. +func Decompose(rawBytes []byte) []fr.Element { + + rawBigInt := big.NewInt(0).SetBytes(rawBytes) + modulo := fr.Modulus() + + // maximum number of chunks that a function + maxNbChunks := len(rawBytes) / fr.Bytes + + res := make([]fr.Element, 0, maxNbChunks) + var tmp fr.Element + t := new(big.Int) + for rawBigInt.Sign() != 0 { + rawBigInt.DivMod(rawBigInt, modulo, t) + tmp.SetBigInt(t) + res = append(res, tmp) + } + + return res +} diff --git a/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl b/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl index b14de2c01..463a6767c 100644 --- a/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl +++ b/internal/generator/crypto/hash/mimc/template/mimc.go.tmpl @@ -1,5 +1,6 @@ import ( "hash" + "errors" "math/big" "github.com/consensys/gnark-crypto/ecc/{{ .Name }}/fr" @@ -77,44 +78,44 @@ func (d *digest) BlockSize() int { } // Write (via the embedded io.Writer interface) adds more data to the running hash. -// It never returns an error. +// +// Each []byte block of size BlockSize represents a big endian fr.Element. +// +// If len(p) is not a multiple of BlockSize and any of the []byte in p represent an integer +// larger than fr.Modulus, this function returns an error. +// +// To hash arbitrary data ([]byte not representing canonical field elements) use Decompose +// function in this package. func (d *digest) Write(p []byte) (n int, err error) { n = len(p) + if n % BlockSize != 0 { + return 0, errors.New("invalid input length: must represent a list of field elements, expects a []byte of len m*BlockSize") + } + + // ensure each block represents a field element in canonical reduced form + for i := 0; i < n; i += BlockSize { + if _, err = fr.BigEndian.Element((*[BlockSize]byte)(p[i:i+BlockSize])); err != nil { + return 0, err + } + } + d.data = append(d.data, p...) return } -// Hash hash using Miyaguchi–Preneel: +// Hash hash using Miyaguchi-Preneel: // https://en.wikipedia.org/wiki/One-way_compression_function // The XOR operation is replaced by field addition, data is in Montgomery form func (d *digest) checksum() fr.Element { - - var buffer [BlockSize]byte - var x fr.Element - - // if data size is not multiple of BlockSizes we padd: - // .. || 0xaf8 -> .. || 0x0000...0af8 - if len(d.data)%BlockSize != 0 { - q := len(d.data) / BlockSize - r := len(d.data) % BlockSize - sliceq := make([]byte, q*BlockSize) - copy(sliceq, d.data) - slicer := make([]byte, r) - copy(slicer, d.data[q*BlockSize:]) - sliceremainder := make([]byte, BlockSize-r) - d.data = append(sliceq, sliceremainder...) - d.data = append(d.data, slicer...) - } - - if len(d.data) == 0 { - d.data = make([]byte, 32) + // Write guarantees len(data) % BlockSize == 0 + + // TODO @ThomasPiellard shouldn't Sum() returns an error if there is no data? + if len(d.data) == 0 { + d.data = make([]byte, BlockSize) } - nbChunks := len(d.data) / BlockSize - - for i := 0; i < nbChunks; i++ { - copy(buffer[:], d.data[i*BlockSize:(i+1)*BlockSize]) - x.SetBytes(buffer[:]) + for i := 0; i < len(d.data); i+=BlockSize { + x, _ := fr.BigEndian.Element((*[BlockSize]byte)(d.data[i:i+BlockSize])) r := d.encrypt(x) d.h.Add(&r, &d.h).Add(&d.h, &x) } diff --git a/internal/generator/crypto/hash/mimc/template/tests/decompose.go.tmpl b/internal/generator/crypto/hash/mimc/template/tests/decompose.go.tmpl new file mode 100644 index 000000000..26dc6661f --- /dev/null +++ b/internal/generator/crypto/hash/mimc/template/tests/decompose.go.tmpl @@ -0,0 +1,35 @@ +import ( + "math/big" + "testing" + + "github.com/consensys/gnark-crypto/ecc/{{ .Name }}/fr" +) + +func TestDecompose(t *testing.T) { + + // create 10 random digits in basis r + nbDigits := 10 + a := make([]fr.Element, nbDigits) + for i := 0; i < nbDigits; i++ { + a[i].SetRandom() + } + + // create a big int whose digits in basis r are a + m := fr.Modulus() + var b, tmp big.Int + for i := nbDigits - 1; i >= 0; i-- { + b.Mul(&b, m) + a[i].ToBigIntRegular(&tmp) + b.Add(&b, &tmp) + } + + // query the decomposition and compare to a + bb := b.Bytes() + d := Decompose(bb) + for i := 0; i < nbDigits; i++ { + if !d[i].Equal(&a[i]) { + t.Fatal("error decomposition") + } + } + +} diff --git a/internal/generator/ecc/generate.go b/internal/generator/ecc/generate.go index 594e1343c..49f6e7b2b 100644 --- a/internal/generator/ecc/generate.go +++ b/internal/generator/ecc/generate.go @@ -3,6 +3,8 @@ package ecc import ( "fmt" "path/filepath" + "reflect" + "sort" "strings" "text/template" @@ -11,16 +13,120 @@ import ( ) func Generate(conf config.Curve, baseDir string, bgen *bavard.BatchGenerator) error { + if conf.Equal(config.SECP256K1) { + return nil + } + packageName := strings.ReplaceAll(conf.Name, "-", "") entries := []bavard.Entry{ {File: filepath.Join(baseDir, "multiexp.go"), Templates: []string{"multiexp.go.tmpl"}}, + {File: filepath.Join(baseDir, "multiexp_affine.go"), Templates: []string{"multiexp_affine.go.tmpl"}}, + {File: filepath.Join(baseDir, "multiexp_jacobian.go"), Templates: []string{"multiexp_jacobian.go.tmpl"}}, {File: filepath.Join(baseDir, "multiexp_test.go"), Templates: []string{"tests/multiexp.go.tmpl"}}, {File: filepath.Join(baseDir, "marshal.go"), Templates: []string{"marshal.go.tmpl"}}, {File: filepath.Join(baseDir, "marshal_test.go"), Templates: []string{"tests/marshal.go.tmpl"}}, } conf.Package = packageName - if err := bgen.Generate(conf, packageName, "./ecc/template", entries...); err != nil { + funcs := make(template.FuncMap) + funcs["last"] = func(x int, a interface{}) bool { + return x == reflect.ValueOf(a).Len()-1 + } + + // return the last window size for a scalar; + // this last window should accomodate a carry (from the NAF decomposition) + // it can be == c if we have 1 available bit + // it can be > c if we have 0 available bit + // it can be < c if we have 2+ available bits + lastC := func(c int) int { + nbChunks := (conf.Fr.NbBits + c - 1) / c + nbAvailableBits := (nbChunks * c) - conf.Fr.NbBits + lc := c + 1 - nbAvailableBits + if lc > 16 { + panic("we have a problem since we are using uint16 to store digits") + } + return lc + } + batchSize := func(c int) int { + // nbBuckets := (1 << (c - 1)) + // if c <= 12 { + // return nbBuckets/10 + 3*c + // } + // if c <= 14 { + // return nbBuckets/15 + // } + // return nbBuckets / 20 + // TODO @gbotrel / @yelhousni this need a better heuristic + // in theory, larger batch size == less inversions + // but if nbBuckets is small, then a large batch size will produce lots of collisions + // and queue ops. + // there is probably a cache-friendlyness factor at play here too. + switch c { + case 10: + return 80 + case 11: + return 150 + case 12: + return 200 + case 13: + return 350 + case 14: + return 400 + case 15: + return 500 + default: + return 640 + } + } + funcs["lastC"] = lastC + funcs["batchSize"] = batchSize + + funcs["nbBuckets"] = func(c int) int { + return 1 << (c - 1) + } + + funcs["contains"] = func(v int, s []int) bool { + for _, sv := range s { + if v == sv { + return true + } + } + return false + } + lastCG1 := make([]int, 0) + for { + for i := 0; i < len(conf.G1.CRange); i++ { + lc := lastC(conf.G1.CRange[i]) + if !contains(conf.G1.CRange, lc) && !contains(lastCG1, lc) { + lastCG1 = append(lastCG1, lc) + } + } + if len(lastCG1) == 0 { + break + } + conf.G1.CRange = append(conf.G1.CRange, lastCG1...) + sort.Ints(conf.G1.CRange) + lastCG1 = lastCG1[:0] + } + + lastCG2 := make([]int, 0) + for { + for i := 0; i < len(conf.G2.CRange); i++ { + lc := lastC(conf.G2.CRange[i]) + if !contains(conf.G2.CRange, lc) && !contains(lastCG2, lc) { + lastCG2 = append(lastCG2, lc) + } + } + if len(lastCG2) == 0 { + break + } + conf.G2.CRange = append(conf.G2.CRange, lastCG2...) + sort.Ints(conf.G2.CRange) + lastCG2 = lastCG2[:0] + } + + bavardOpts := []func(*bavard.Bavard) error{bavard.Funcs(funcs)} + if err := bgen.GenerateWithOptions(conf, packageName, "./ecc/template", bavardOpts, entries...); err != nil { return err } @@ -74,3 +180,12 @@ type pconf struct { config.Curve config.Point } + +func contains(slice []int, v int) bool { + for i := 0; i < len(slice); i++ { + if slice[i] == v { + return true + } + } + return false +} diff --git a/internal/generator/ecc/template/hash_to_curve.go.tmpl b/internal/generator/ecc/template/hash_to_curve.go.tmpl index d4e720bf5..ade6de5fd 100644 --- a/internal/generator/ecc/template/hash_to_curve.go.tmpl +++ b/internal/generator/ecc/template/hash_to_curve.go.tmpl @@ -14,9 +14,6 @@ import( {{- if not (eq $TowerDegree 1) }} "github.com/consensys/gnark-crypto/ecc/{{.Name}}/internal/fptower" {{- end}} - {{- if $IsG1}} - "github.com/consensys/gnark-crypto/ecc" - {{- end}} {{if eq $.MappingAlgorithm "SSWU"}} {{template "sswu" .}} @@ -25,37 +22,14 @@ import( {{template "svdw" .}} {{end}} -{{if $IsG1}} -// hashToFp hashes msg to count prime field elements. -// https://tools.ietf.org/html/draft-irtf-cfrg-hash-to-curve-06#section-5.2 -func hashToFp(msg, dst []byte, count int) ([]fp.Element, error) { - // 128 bits of security - // L = ceil((ceil(log2(p)) + k) / 8), where k is the security parameter = 128 - const Bytes = 1 + (fp.Bits - 1 ) / 8 - const L = 16 + Bytes - - lenInBytes := count * L - pseudoRandomBytes, err := ecc.ExpandMsgXmd(msg, dst, lenInBytes) - if err != nil { - return nil, err - } - - res := make([]fp.Element, count) - for i := 0; i < count; i++ { - res[i].SetBytes(pseudoRandomBytes[i*L : (i+1)*L]) - } - return res, nil -} -{{end}} - // {{$CurveName}}Sgn0 is an algebraic substitute for the notion of sign in ordered fields // Namely, every non-zero quadratic residue in a finite field of characteristic =/= 2 has exactly two square roots, one of each sign // https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#name-the-sgn0-function // The sign of an element is not obviously related to that of its Montgomery form func {{$CurveName}}Sgn0(z *{{$CoordType}}) uint64 { - nonMont := *z - nonMont.FromMont() + nonMont := z.Bits() + {{if eq $TowerDegree 1}} // m == 1 return nonMont[0]%2 {{else}} @@ -105,7 +79,7 @@ func MapTo{{$CurveTitle}}(u {{$CoordType}}) {{$AffineType}} { func EncodeTo{{$CurveTitle}}(msg, dst []byte) ({{$AffineType}}, error) { var res {{$AffineType}} - u, err := hashToFp(msg, dst, {{$TowerDegree}}) + u, err := fp.Hash(msg, dst, {{$TowerDegree}}) if err != nil { return res, err } @@ -133,7 +107,7 @@ func EncodeTo{{$CurveTitle}}(msg, dst []byte) ({{$AffineType}}, error) { // dst stands for "domain separation tag", a string unique to the construction using the hash function //https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-16.html#roadmap func HashTo{{$CurveTitle}}(msg, dst []byte) ({{$AffineType}}, error) { - u, err := hashToFp(msg, dst, 2 * {{$TowerDegree}}) + u, err := fp.Hash(msg, dst, 2 * {{$TowerDegree}}) if err != nil { return {{$AffineType}}{}, err } diff --git a/internal/generator/ecc/template/marshal.go.tmpl b/internal/generator/ecc/template/marshal.go.tmpl index f4a3558ce..f3aa89e0d 100644 --- a/internal/generator/ecc/template/marshal.go.tmpl +++ b/internal/generator/ecc/template/marshal.go.tmpl @@ -106,7 +106,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fr.Bytes]) + err = t.SetBytesCanonical(buf[:fr.Bytes]) return case *fp.Element: read, err = io.ReadFull(dec.r, buf[:fp.Bytes]) @@ -114,7 +114,7 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - t.SetBytes(buf[:fp.Bytes]) + err = t.SetBytesCanonical(buf[:fp.Bytes]) return case *[]fr.Element: var sliceLen uint32 @@ -132,7 +132,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fr.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fr.Bytes]); err != nil { + return + } } return case *[]fp.Element: @@ -151,7 +153,9 @@ func (dec *Decoder) Decode(v interface{}) (err error) { if err != nil { return } - (*t)[i].SetBytes(buf[:fp.Bytes]) + if err = (*t)[i].SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return + } } return case *G1Affine: @@ -227,7 +231,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -282,7 +290,11 @@ func (dec *Decoder) Decode(v interface{}) (err error) { return } } else { - compressed[i] = !((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])) + var r bool + if r, err = ((*t)[i].unsafeSetCompressedBytes(buf[:nbBytes])); err != nil { + return + } + compressed[i] = !r } } var nbErrs uint64 @@ -580,8 +592,6 @@ func (p *{{ $.TAffine }}) Bytes() (res [SizeOf{{ $.TAffine }}Compressed]byte) { return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element msbMask := mCompressedSmallest // compressed, we need to know if Y is lexicographically bigger than -Y @@ -629,8 +639,6 @@ func (p *{{ $.TAffine }}) RawBytes() (res [SizeOf{{ $.TAffine }}Uncompressed]byt return } - // tmp is used to convert from montgomery representation to regular - var tmp fp.Element // not compressed // we store the Y coordinate @@ -731,25 +739,53 @@ func (p *{{ $.TAffine }}) setBytes(buf []byte, subGroupCheck bool) (int, error) // read X and Y coordinates {{- if eq $.CoordType "fptower.E2"}} // p.X.A1 | p.X.A0 - p.X.A1.SetBytes(buf[:fp.Bytes]) - p.X.A0.SetBytes(buf[fp.Bytes:fp.Bytes*2]) + if err := p.X.A1.SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.X.A0.SetBytesCanonical(buf[fp.Bytes:fp.Bytes*2]); err != nil { + return 0, err + } // p.Y.A1 | p.Y.A0 - p.Y.A1.SetBytes(buf[fp.Bytes*2:fp.Bytes*3]) - p.Y.A0.SetBytes(buf[fp.Bytes*3:fp.Bytes*4]) + if err := p.Y.A1.SetBytesCanonical(buf[fp.Bytes*2:fp.Bytes*3]); err != nil { + return 0, err + } + if err := p.Y.A0.SetBytesCanonical(buf[fp.Bytes*3:fp.Bytes*4]); err != nil { + return 0, err + } {{- else if eq $.CoordType "fptower.E4"}} // p.X.B1.A1 | p.X.B1.A0 | p.X.B0.A1 | p.X.B0.A0 - p.X.B1.A1.SetBytes(buf[fp.Bytes*0:fp.Bytes*1]) - p.X.B1.A0.SetBytes(buf[fp.Bytes*1:fp.Bytes*2]) - p.X.B0.A1.SetBytes(buf[fp.Bytes*2:fp.Bytes*3]) - p.X.B0.A0.SetBytes(buf[fp.Bytes*3:fp.Bytes*4]) + if err := p.X.B1.A1.SetBytesCanonical(buf[fp.Bytes*0:fp.Bytes*1]); err != nil { + return 0, err + } + if err := p.X.B1.A0.SetBytesCanonical(buf[fp.Bytes*1:fp.Bytes*2]); err != nil { + return 0, err + } + if err := p.X.B0.A1.SetBytesCanonical(buf[fp.Bytes*2:fp.Bytes*3]); err != nil { + return 0, err + } + if err := p.X.B0.A0.SetBytesCanonical(buf[fp.Bytes*3:fp.Bytes*4]); err != nil { + return 0, err + } // p.Y.B1.A1 | p.Y.B1.A0 | p.Y.B0.A1 | p.Y.B0.A0 - p.Y.B1.A1.SetBytes(buf[fp.Bytes*4:fp.Bytes*5]) - p.Y.B1.A0.SetBytes(buf[fp.Bytes*5:fp.Bytes*6]) - p.Y.B0.A1.SetBytes(buf[fp.Bytes*6:fp.Bytes*7]) - p.Y.B0.A0.SetBytes(buf[fp.Bytes*7:fp.Bytes*8]) + if err := p.Y.B1.A1.SetBytesCanonical(buf[fp.Bytes*4:fp.Bytes*5]); err != nil { + return 0, err + } + if err := p.Y.B1.A0.SetBytesCanonical(buf[fp.Bytes*5:fp.Bytes*6]); err != nil { + return 0, err + } + if err := p.Y.B0.A1.SetBytesCanonical(buf[fp.Bytes*6:fp.Bytes*7]); err != nil { + return 0, err + } + if err := p.Y.B0.A0.SetBytesCanonical(buf[fp.Bytes*7:fp.Bytes*8]); err != nil { + return 0, err + } {{- else}} - p.X.SetBytes(buf[:fp.Bytes]) - p.Y.SetBytes(buf[fp.Bytes:fp.Bytes*2]) + if err := p.X.SetBytesCanonical(buf[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.Y.SetBytesCanonical(buf[fp.Bytes:fp.Bytes*2]); err != nil { + return 0, err + } {{- end}} // subgroup check @@ -772,16 +808,30 @@ func (p *{{ $.TAffine }}) setBytes(buf []byte, subGroupCheck bool) (int, error) // read X coordinate {{- if eq $.CoordType "fptower.E2"}} // p.X.A1 | p.X.A0 - p.X.A1.SetBytes(bufX[:fp.Bytes]) - p.X.A0.SetBytes(buf[fp.Bytes:fp.Bytes*2]) + if err := p.X.A1.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return 0, err + } + if err := p.X.A0.SetBytesCanonical(buf[fp.Bytes:fp.Bytes*2]); err != nil { + return 0, err + } {{- else if eq $.CoordType "fptower.E4"}} // p.X.B1.A1 | p.X.B1.A0 | p.X.B0.A1 | p.X.B0.A0 - p.X.B1.A1.SetBytes(bufX[fp.Bytes*0:fp.Bytes*1]) - p.X.B1.A0.SetBytes(buf[fp.Bytes*1:fp.Bytes*2]) - p.X.B0.A1.SetBytes(buf[fp.Bytes*2:fp.Bytes*3]) - p.X.B0.A0.SetBytes(buf[fp.Bytes*3:fp.Bytes*4]) + if err := p.X.B1.A1.SetBytesCanonical(bufX[fp.Bytes*0:fp.Bytes*1]); err != nil { + return 0, err + } + if err := p.X.B1.A0.SetBytesCanonical(buf[fp.Bytes*1:fp.Bytes*2]); err != nil { + return 0, err + } + if err := p.X.B0.A1.SetBytesCanonical(buf[fp.Bytes*2:fp.Bytes*3]); err != nil { + return 0, err + } + if err := p.X.B0.A0.SetBytesCanonical(buf[fp.Bytes*3:fp.Bytes*4]); err != nil { + return 0, err + } {{- else}} - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return 0, err + } {{- end}} @@ -883,7 +933,7 @@ func (p *{{ $.TAffine }}) unsafeComputeY(subGroupCheck bool) error { // assumes buf[:8] mask is set to compressed // returns true if point is infinity and need no further processing // it sets X coordinate and uses Y for scratch space to store decompression metadata -func (p *{{ $.TAffine }}) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) { +func (p *{{ $.TAffine }}) unsafeSetCompressedBytes(buf []byte) (isInfinity bool, err error) { // read the most significant byte mData := buf[0] & mMask @@ -892,7 +942,7 @@ func (p *{{ $.TAffine }}) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) p.X.SetZero() p.Y.SetZero() isInfinity = true - return + return isInfinity, nil } // we need to copy the input buffer (to keep this method thread safe) @@ -903,28 +953,42 @@ func (p *{{ $.TAffine }}) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) // read X coordinate {{- if eq $.CoordType "fptower.E2"}} // p.X.A1 | p.X.A0 - p.X.A1.SetBytes(bufX[:fp.Bytes]) - p.X.A0.SetBytes(buf[fp.Bytes:fp.Bytes*2]) + if err := p.X.A1.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return false, err + } + if err := p.X.A0.SetBytesCanonical(buf[fp.Bytes:fp.Bytes*2]); err != nil { + return false, err + } // store mData in p.Y.A0[0] p.Y.A0[0] = uint64(mData) {{- else if eq $.CoordType "fptower.E4"}} // p.X.B1.A1 | p.X.B1.A0 | p.X.B0.A1 | p.X.B0.A0 - p.X.B1.A1.SetBytes(bufX[fp.Bytes*0:fp.Bytes*1]) - p.X.B1.A0.SetBytes(buf[fp.Bytes*1:fp.Bytes*2]) - p.X.B0.A1.SetBytes(buf[fp.Bytes*2:fp.Bytes*3]) - p.X.B0.A0.SetBytes(buf[fp.Bytes*3:fp.Bytes*4]) + if err := p.X.B1.A1.SetBytesCanonical(bufX[fp.Bytes*0:fp.Bytes*1]); err != nil { + return false, err + } + if err := p.X.B1.A0.SetBytesCanonical(buf[fp.Bytes*1:fp.Bytes*2]); err != nil { + return false, err + } + if err := p.X.B0.A1.SetBytesCanonical(buf[fp.Bytes*2:fp.Bytes*3]); err != nil { + return false, err + } + if err := p.X.B0.A0.SetBytesCanonical(buf[fp.Bytes*3:fp.Bytes*4]); err != nil { + return false, err + } // store mData in p.Y.B0.A0[0] p.Y.B0.A0[0] = uint64(mData) {{- else}} - p.X.SetBytes(bufX[:fp.Bytes]) + if err := p.X.SetBytesCanonical(bufX[:fp.Bytes]); err != nil { + return false, err + } // store mData in p.Y[0] p.Y[0] = uint64(mData) {{- end}} // recomputing Y will be done asynchronously - return + return isInfinity, nil } @@ -934,15 +998,6 @@ func (p *{{ $.TAffine }}) unsafeSetCompressedBytes(buf []byte) (isInfinity bool) -{{define "putFp"}} - tmp = {{$.From}} - tmp.FromMont() - {{- range $i := reverse .all.Fp.NbWordsIndexesFull}} - {{- $j := mul $i 8}} - {{- $j := add $j $.OffSet}} - {{- $k := sub $.all.Fp.NbWords 1}} - {{- $k := sub $k $i}} - {{- $jj := add $j 8}} - binary.BigEndian.PutUint64(res[{{$j}}:{{$jj}}], tmp[{{$k}}]) - {{- end}} -{{end}} +{{- define "putFp"}} + fp.BigEndian.PutElement((*[fp.Bytes]byte)( res[{{$.OffSet}}:{{$.OffSet}} + fp.Bytes]), {{$.From}}) +{{- end}} diff --git a/internal/generator/ecc/template/multiexp.go.tmpl b/internal/generator/ecc/template/multiexp.go.tmpl index e476f889b..b77285cfc 100644 --- a/internal/generator/ecc/template/multiexp.go.tmpl +++ b/internal/generator/ecc/template/multiexp.go.tmpl @@ -16,6 +16,10 @@ import ( "runtime" ) +{{ template "multiexp" dict "PointName" .G1.PointName "UPointName" (toUpper .G1.PointName) "TAffine" $G1TAffine "TJacobian" $G1TJacobian "TJacobianExtended" $G1TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G1.CRange}} +{{ template "multiexp" dict "PointName" .G2.PointName "UPointName" (toUpper .G2.PointName) "TAffine" $G2TAffine "TJacobian" $G2TJacobian "TJacobianExtended" $G2TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G2.CRange}} + + // selector stores the index, mask and shifts needed to select bits from a scalar // it is used during the multiExp algorithm or the batch scalar multiplication type selector struct { @@ -28,26 +32,46 @@ type selector struct { shiftHigh uint64 // same than shift, for index+1 } +// return number of chunks for a given window size c +// the last chunk may be bigger to accomodate a potential carry from the NAF decomposition +func computeNbChunks(c uint64) uint64 { + return (fr.Bits+c-1) / c +} + +// return the last window size for a scalar; +// this last window should accomodate a carry (from the NAF decomposition) +// it can be == c if we have 1 available bit +// it can be > c if we have 0 available bit +// it can be < c if we have 2+ available bits +func lastC(c uint64) uint64 { + nbAvailableBits := (computeNbChunks(c)*c) - fr.Bits + return c+1-nbAvailableBits +} + +type chunkStat struct { + // relative weight of work compared to other chunks. 100.0 -> nominal weight. + weight float32 + + // percentage of bucket filled in the window; + ppBucketFilled float32 + nbBucketFilled int +} + + + // partitionScalars compute, for each scalars over c-bit wide windows, nbChunk digits // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract // 2^{c} to the current digit, making it negative. // negative digits can be processed in a later step as adding -G into the bucket instead of G // (computing -G is cheap, and this saves us half of the buckets in the MultiExp or BatchScalarMultiplication) -// scalarsMont indicates wheter the provided scalars are in montgomery form -// returns smallValues, which represent the number of scalars which meets the following condition -// 0 < scalar < 2^c (in other words, scalars where only the c-least significant bits are non zero) -func partitionScalars(scalars []fr.Element, c uint64, scalarsMont bool, nbTasks int) ([]fr.Element, int) { - toReturn := make([]fr.Element, len(scalars)) - +func partitionScalars(scalars []fr.Element, c uint64, nbTasks int) ([]uint16, []chunkStat) { // number of c-bit radixes in a scalar - nbChunks := fr.Limbs * 64 / c - if (fr.Limbs * 64)%c != 0 { - nbChunks++ - } + nbChunks := computeNbChunks(c) + + digits := make([]uint16, len(scalars)*int(nbChunks)) mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c -1)) // msb of the c-bit window - max := int(1 << (c -1)) // max value we want for our digits + max := int(1 << (c -1)) - 1 // max value (inclusive) we want for our digits cDivides64 := (64 %c ) == 0 // if c doesn't divide 64, we may need to select over multiple words @@ -68,38 +92,19 @@ func partitionScalars(scalars []fr.Element, c uint64, scalarsMont bool, nbTasks selectors[chunk] = d } - // for each chunk, we could track the number of non-zeros points we will need to process - // this way, if a chunk has more work to do than others, we can spawn off more go routines - // (at the cost of more buckets allocated) - // a simplified approach is to track the small values where only the first word is set - // if this number represent a significant number of points, then we will split first chunk - // processing in the msm in 2, to ensure all go routines finish at ~same time - // /!\ nbTasks is enough as parallel.Execute is not going to spawn more than nbTasks go routine - // if it does, though, this will deadlocK. - chSmallValues := make(chan int, nbTasks) parallel.Execute(len(scalars), func(start, end int) { - smallValues := 0 for i:=start; i < end; i++ { - var carry int - - scalar := scalars[i] - if scalarsMont { - scalar.FromMont() - } - if scalar.FitsOnOneWord() { + if scalars[i].IsZero() { // everything is 0, no need to process this scalar - if scalar[0] == 0 { - continue - } - // low c-bits are 1 in mask - if scalar[0]&mask == scalar[0] { - smallValues++ - } + continue } + scalar := scalars[i].Bits() + + var carry int // for each chunk in the scalar, compute the current digit, and an eventual carry - for chunk := uint64(0); chunk < nbChunks; chunk++ { + for chunk := uint64(0); chunk < nbChunks - 1; chunk++ { s := selectors[chunk] // init with carry if any @@ -114,70 +119,117 @@ func partitionScalars(scalars []fr.Element, c uint64, scalarsMont bool, nbTasks digit += int(scalar[s.index+1] & s.maskHigh) << s.shiftHigh } - // if digit is zero, no impact on result - if digit == 0 { - continue - } - // if the digit is larger than 2^{c-1}, then, we borrow 2^c from the next window and substract // 2^{c} to the current digit, making it negative. - if digit >= max { + if digit > max { digit -= (1 << c) carry = 1 } - var bits uint64 - if digit >= 0 { - bits = uint64(digit) - } else { - bits = uint64(-digit-1) | msbWindow + // if digit is zero, no impact on result + if digit == 0 { + continue } - toReturn[i][s.index] |= (bits << s.shift) - if s.multiWordSelect { - toReturn[i][s.index+1] |= (bits >> s.shiftHigh) + var bits uint16 + if digit > 0 { + bits = uint16(digit) << 1 + } else { + bits = (uint16(-digit-1) << 1) + 1 } + digits[int(chunk)*len(scalars)+i] = bits + } + // for the last chunk, we don't want to borrow from a next window + // (but may have a larger max value) + chunk := nbChunks - 1 + s := selectors[chunk] + // init with carry if any + digit := carry + // digit = value of the c-bit window + digit += int((scalar[s.index] & s.mask) >> s.shift) + if s.multiWordSelect { + // we are selecting bits over 2 words + digit += int(scalar[s.index+1] & s.maskHigh) << s.shiftHigh } + digits[int(chunk)*len(scalars)+i] = uint16(digit) << 1 } - chSmallValues <- smallValues + }, nbTasks) + + // aggregate chunk stats + chunkStats := make([]chunkStat, nbChunks) + if c <= 9 { + // no need to compute stats for small window sizes + return digits, chunkStats + } + parallel.Execute(len(chunkStats), func(start, end int) { + // for each chunk compute the statistics + for chunkID := start; chunkID < end; chunkID++ { + // indicates if a bucket is hit. + var b bitSetC16 + + // digits for the chunk + chunkDigits := digits[chunkID*len(scalars):(chunkID+1)*len(scalars)] + + totalOps := 0 + nz := 0 // non zero buckets count + for _, digit := range chunkDigits { + if digit == 0 { + continue + } + totalOps++ + bucketID := digit >> 1 + if digit &1 == 0 { + bucketID-=1 + } + if !b[bucketID] { + nz++ + b[bucketID] = true + } + } + chunkStats[chunkID].weight = float32(totalOps) // count number of ops for now, we will compute the weight after + chunkStats[chunkID].ppBucketFilled = (float32(nz) * 100.0) / float32(int(1 << (c-1))) + chunkStats[chunkID].nbBucketFilled = nz + } }, nbTasks) - - - // aggregate small values - close(chSmallValues) - smallValues := 0 - for o := range chSmallValues { - smallValues+=o + + totalOps := float32(0.0) + for _, stat := range chunkStats { + totalOps+=stat.weight } - return toReturn, smallValues -} + target := totalOps / float32(nbChunks) + if target != 0.0 { + // if target == 0, it means all the scalars are 0 everywhere, there is no work to be done. + for i := 0; i < len(chunkStats); i++ { + chunkStats[i].weight = (chunkStats[i].weight * 100.0) / target + } + } -{{ template "multiexp" dict "PointName" .G1.PointName "TAffine" $G1TAffine "TJacobian" $G1TJacobian "TJacobianExtended" $G1TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G1.CRange}} -{{ template "multiexp" dict "PointName" .G2.PointName "TAffine" $G2TAffine "TJacobian" $G2TJacobian "TJacobianExtended" $G2TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G2.CRange}} + return digits, chunkStats +} {{define "multiexp" }} // MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf -// +// // This call return an error if len(scalars) != len(points) or if provided config is invalid. func (p *{{ $.TAffine }}) MultiExp(points []{{ $.TAffine }}, scalars []fr.Element, config ecc.MultiExpConfig) (*{{ $.TAffine }}, error) { var _p {{$.TJacobian}} if _, err := _p.MultiExp(points, scalars, config); err != nil { - return nil, err + return nil, err } p.FromJacobian(&_p) return p, nil } // MultiExp implements section 4 of https://eprint.iacr.org/2012/549.pdf -// +// // This call return an error if len(scalars) != len(points) or if provided config is invalid. func (p *{{ $.TJacobian }}) MultiExp(points []{{ $.TAffine }}, scalars []fr.Element, config ecc.MultiExpConfig) (*{{ $.TJacobian }}, error) { // note: @@ -224,7 +276,7 @@ func (p *{{ $.TJacobian }}) MultiExp(points []{{ $.TAffine }}, scalars []fr.Elem bestC := func(nbPoints int) uint64 { // implemented msmC methods (the c we use must be in this slice) implementedCs := []uint64{ - {{- range $c := $.CRange}} {{- if and (eq $.PointName "g1") (gt $c 21)}}{{- else}} {{$c}},{{- end}}{{- end}} + {{- range $c := $.CRange}}{{- if ge $c 4}}{{$c}},{{- end}}{{- end}} } var C uint64 // approximate cost (in group operations) @@ -233,79 +285,123 @@ func (p *{{ $.TJacobian }}) MultiExp(points []{{ $.TAffine }}, scalars []fr.Elem // for example, on a MBP 2016, for G2 MultiExp > 8M points, hand picking c gives better results min := math.MaxFloat64 for _, c := range implementedCs { - cc := fr.Limbs * 64 * (nbPoints + (1 << (c))) + cc := (fr.Bits+1) * (nbPoints + (1 << c)) cost := float64(cc) / float64(c) if cost < min { min = cost C = c } } - // empirical, needs to be tuned. - // if C > 16 && nbPoints < 1 << 23 { - // C = 16 - // } return C } - var C uint64 - nbSplits := 1 - nbChunks := 0 - for nbChunks < config.NbTasks { - C = bestC(nbPoints) - nbChunks = int(fr.Limbs * 64 / C) // number of c-bit radixes in a scalar - if (fr.Limbs * 64) % C != 0 { - nbChunks ++ - } - nbChunks *= nbSplits - if nbChunks < config.NbTasks { - nbSplits <<= 1 - nbPoints >>= 1 + C := bestC(nbPoints) + nbChunks := int(computeNbChunks(C)) + + // if we don't utilise all the tasks (CPU in the default case) that we could, let's see if it's worth it to split + if config.NbTasks > 1 && nbChunks < config.NbTasks { + // before spliting, let's see if we endup with more tasks than thread; + cSplit := bestC(nbPoints/2) + nbChunksPostSplit := int(computeNbChunks(cSplit)) + nbTasksPostSplit := nbChunksPostSplit*2 + if (nbTasksPostSplit <= config.NbTasks /2 ) || ( nbTasksPostSplit - config.NbTasks/2 ) <= ( config.NbTasks - nbChunks) { + // if postSplit we still have less tasks than available CPU + // or if we have more tasks BUT the difference of CPU usage is in our favor, we split. + config.NbTasks /= 2 + var _p {{ $.TJacobian }} + chDone := make(chan struct{}, 1) + go func() { + _p.MultiExp(points[:nbPoints/2], scalars[:nbPoints/2], config) + close(chDone) + }() + p.MultiExp(points[nbPoints/2:], scalars[nbPoints/2:], config) + <-chDone + p.AddAssign(&_p) + return p, nil } } + _innerMsm{{ $.UPointName }}(p, C, points, scalars, config) + + return p, nil +} + +func _innerMsm{{ $.UPointName }}(p *{{ $.TJacobian }}, c uint64, points []{{ $.TAffine }}, scalars []fr.Element, config ecc.MultiExpConfig) *{{ $.TJacobian }} { // partition the scalars - // note: we do that before the actual chunk processing, as for each c-bit window (starting from LSW) - // if it's larger than 2^{c-1}, we have a carry we need to propagate up to the higher window - var smallValues int - scalars, smallValues = partitionScalars(scalars, C, config.ScalarsMont, config.NbTasks) - - // if we have more than 10% of small values, we split the processing of the first chunk in 2 - // we may want to do that in msmInner{{ $.TJacobian }} , but that would incur a cost of looping through all scalars one more time - splitFirstChunk := (float64(smallValues) / float64(len(scalars))) >= 0.1 - - // we have nbSplits intermediate results that we must sum together. - _p := make([]{{ $.TJacobian }}, nbSplits - 1) - chDone := make(chan int, nbSplits - 1) - for i:=0; i < nbSplits-1; i++ { - start := i * nbPoints - end := start + nbPoints - go func(start, end, i int) { - msmInner{{ $.TJacobian }}(&_p[i], int(C), points[start:end], scalars[start:end], splitFirstChunk) - chDone <- i - }(start, end, i) + digits, chunkStats := partitionScalars(scalars, c, config.NbTasks) + + nbChunks := computeNbChunks(c) + + // for each chunk, spawn one go routine that'll loop through all the scalars in the + // corresponding bit-window + // note that buckets is an array allocated on the stack and this is critical for performance + + // each go routine sends its result in chChunks[i] channel + chChunks := make([]chan {{ $.TJacobianExtended }}, nbChunks) + for i:=0; i < len(chChunks);i++ { + chChunks[i] = make(chan {{ $.TJacobianExtended }}, 1) } - - msmInner{{ $.TJacobian }}(p, int(C), points[(nbSplits - 1) * nbPoints:], scalars[(nbSplits - 1) * nbPoints:], splitFirstChunk) - for i:=0; i < nbSplits-1; i++ { - done := <-chDone - p.AddAssign(&_p[done]) + + // the last chunk may be processed with a different method than the rest, as it could be smaller. + n := len(points) + for j := int(nbChunks - 1); j >= 0; j-- { + processChunk := getChunkProcessor{{ $.UPointName }}(c, chunkStats[j]) + if j == int(nbChunks - 1) { + processChunk = getChunkProcessor{{ $.UPointName }}(lastC(c), chunkStats[j]) + } + if chunkStats[j].weight >= 115 { + // we split this in more go routines since this chunk has more work to do than the others. + // else what would happen is this go routine would finish much later than the others. + chSplit := make(chan {{ $.TJacobianExtended }}, 2) + split := n / 2 + go processChunk(uint64(j),chSplit, c, points[:split], digits[j*n:(j*n)+split]) + go processChunk(uint64(j),chSplit, c, points[split:], digits[(j*n)+split:(j+1)*n]) + go func(chunkID int) { + s1 := <-chSplit + s2 := <-chSplit + close(chSplit) + s1.add(&s2) + chChunks[chunkID] <- s1 + }(j) + continue + } + go processChunk(uint64(j), chChunks[j], c, points, digits[j*n:(j+1)*n]) } - close(chDone) - return p, nil + + return msmReduceChunk{{ $.TAffine }}(p, int(c), chChunks[:]) } -func msmInner{{ $.TJacobian }}(p *{{ $.TJacobian }}, c int, points []{{ $.TAffine }}, scalars []fr.Element, splitFirstChunk bool) { +// getChunkProcessor{{ $.UPointName }} decides, depending on c window size and statistics for the chunk +// to return the best algorithm to process the chunk. +func getChunkProcessor{{ $.UPointName }}(c uint64, stat chunkStat) func(chunkID uint64, chRes chan<- {{ $.TJacobianExtended }}, c uint64, points []{{ $.TAffine }}, digits []uint16) { switch c { - {{range $c := $.CRange}} - case {{$c}}: - p.msmC{{$c}}(points, scalars, splitFirstChunk) - {{end}} - default: - panic("not implemented") + {{- range $c := $.LastCRange}} + case {{$c}}: + return processChunk{{ $.UPointName }}Jacobian[bucket{{ $.TJacobianExtended }}C{{$c}}] + {{- end }} + {{range $c := $.CRange}} + case {{$c}}: + {{- if le $c 9}} + return processChunk{{ $.UPointName }}Jacobian[bucket{{ $.TJacobianExtended }}C{{$c}}] + {{- else}} + const batchSize = {{batchSize $c}} + // here we could check some chunk statistic (deviation, ...) to determine if calling + // the batch affine version is worth it. + if stat.nbBucketFilled < batchSize { + // clear indicator that batch affine method is not appropriate here. + return processChunk{{ $.UPointName }}Jacobian[bucket{{ $.TJacobianExtended }}C{{$c}}] + } + return processChunk{{ $.UPointName }}BatchAffine[bucket{{ $.TJacobianExtended }}C{{$c}}, bucket{{ $.TAffine }}C{{$c}}, bitSetC{{$c}}, p{{$.TAffine}}C{{$c}}, pp{{$.TAffine}}C{{$c}}, q{{$.TAffine}}C{{$c}}, c{{$.TAffine}}C{{$c}}] + {{- end}} + {{- end}} + default: + // panic("will not happen c != previous values is not generated by templates") + return processChunk{{ $.UPointName }}Jacobian[bucket{{ $.TJacobianExtended }}C16] } } + // msmReduceChunk{{ $.TAffine }} reduces the weighted sum of the buckets into the result of the multiExp func msmReduceChunk{{ $.TAffine }}(p *{{ $.TJacobian }}, c int, chChunks []chan {{ $.TJacobianExtended }}) *{{ $.TJacobian }} { var _p {{ $.TJacobianExtended }} @@ -323,138 +419,7 @@ func msmReduceChunk{{ $.TAffine }}(p *{{ $.TJacobian }}, c int, chChunks []chan } -func msmProcessChunk{{ $.TAffine }}(chunk uint64, - chRes chan<- {{ $.TJacobianExtended }}, - buckets []{{ $.TJacobianExtended }}, - c uint64, - points []{{ $.TAffine }}, - scalars []fr.Element) { - - - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c -1)) - - for i := 0 ; i < len(buckets); i++ { - buckets[i].setInfinity() - } - - jc := uint64(chunk * c) - s := selector{} - s.index = jc / 64 - s.shift = jc - (s.index * 64) - s.mask = mask << s.shift - s.multiWordSelect = (64 %c)!=0 && s.shift > (64-c) && s.index < (fr.Limbs - 1 ) - if s.multiWordSelect { - nbBitsHigh := s.shift - uint64(64-c) - s.maskHigh = (1 << nbBitsHigh) - 1 - s.shiftHigh = (c - nbBitsHigh) - } - - - // for each scalars, get the digit corresponding to the chunk we're processing. - for i := 0; i < len(scalars); i++ { - bits := (scalars[i][s.index] & s.mask) >> s.shift - if s.multiWordSelect { - bits += (scalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { - continue - } - - // if msbWindow bit is set, we need to substract - if bits & msbWindow == 0 { - // add - buckets[bits-1].addMixed(&points[i]) - } else { - // sub - buckets[bits & ^msbWindow].subMixed(&points[i]) - } - } - - - // reduce buckets into total - // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] - - var runningSum, total {{ $.TJacobianExtended }} - runningSum.setInfinity() - total.setInfinity() - for k := len(buckets) - 1; k >= 0; k-- { - if !buckets[k].ZZ.IsZero() { - runningSum.add(&buckets[k]) - } - total.add(&runningSum) - } - - chRes <- total - {{/* close(chRes) */}} -} - - -{{range $c := $.CRange}} - -{{- $frBits := mul $.FrNbWords 64}} -{{- $cDividesBits := divides $c $frBits}} -{{- $nbChunks := div $frBits $c}} - -func (p *{{ $.TJacobian }}) msmC{{$c}}(points []{{ $.TAffine }}, scalars []fr.Element, splitFirstChunk bool) *{{ $.TJacobian }} { - const ( - c = {{$c}} // scalars partitioned into c-bit radixes - nbChunks = (fr.Limbs * 64 / c) // number of c-bit radixes in a scalar - ) - - // for each chunk, spawn one go routine that'll loop through all the scalars in the - // corresponding bit-window - // note that buckets is an array allocated on the stack (for most sizes of c) and this is - // critical for performance - - // each go routine sends its result in chChunks[i] channel - var chChunks [nbChunks{{if not $cDividesBits }} + 1 {{end}} ]chan {{ $.TJacobianExtended }} - for i:=0; i < len(chChunks);i++ { - chChunks[i] = make(chan {{ $.TJacobianExtended }}, 1) - } - - - {{ if not $cDividesBits }} - - // c doesn't divide {{$frBits}}, last window is smaller we can allocate less buckets - const lastC = (fr.Limbs * 64) - (c * (fr.Limbs * 64 / c)) - go func(j uint64, points []{{ $.TAffine }}, scalars []fr.Element) { - var buckets [1<<(lastC-1)]{{ $.TJacobianExtended }} - msmProcessChunk{{ $.TAffine }}(j, chChunks[j], buckets[:], c, points, scalars) - }(uint64(nbChunks), points, scalars) - - {{- end}} - - processChunk := func(j int, points []{{ $.TAffine }}, scalars []fr.Element, chChunk chan {{ $.TJacobianExtended }}) { - var buckets [1<<(c-1)]{{ $.TJacobianExtended }} - msmProcessChunk{{ $.TAffine }}(uint64(j), chChunk, buckets[:], c, points, scalars) - } - - for j := int(nbChunks - 1); j >0; j-- { - go processChunk(j, points, scalars, chChunks[j]) - } - if !splitFirstChunk { - go processChunk(0, points, scalars, chChunks[0]) - } else { - chSplit := make(chan {{ $.TJacobianExtended }}, 2) - split := len(points) / 2 - go processChunk(0, points[:split], scalars[:split], chSplit) - go processChunk(0, points[split:], scalars[split:], chSplit) - go func() { - s1 := <-chSplit - s2 := <-chSplit - close(chSplit) - s1.add(&s2) - chChunks[0] <- s1 - }() - } - - - return msmReduceChunk{{ $.TAffine }}(p, c, chChunks[:]) -} -{{end}} {{end }} diff --git a/internal/generator/ecc/template/multiexp_affine.go.tmpl b/internal/generator/ecc/template/multiexp_affine.go.tmpl new file mode 100644 index 000000000..979d05c00 --- /dev/null +++ b/internal/generator/ecc/template/multiexp_affine.go.tmpl @@ -0,0 +1,311 @@ +{{ $G1TAffine := print (toUpper .G1.PointName) "Affine" }} +{{ $G1TJacobian := print (toUpper .G1.PointName) "Jac" }} +{{ $G1TJacobianExtended := print (toLower .G1.PointName) "JacExtended" }} + +{{ $G2TAffine := print (toUpper .G2.PointName) "Affine" }} +{{ $G2TJacobian := print (toUpper .G2.PointName) "Jac" }} +{{ $G2TJacobianExtended := print (toLower .G2.PointName) "JacExtended" }} + + +import ( + "github.com/consensys/gnark-crypto/ecc/{{.Name}}/fp" + {{- if ne .G1.CoordType .G2.CoordType}} + "github.com/consensys/gnark-crypto/ecc/{{.Name}}/internal/fptower" + {{- end}} +) + +{{ template "multiexp" dict "CoordType" .G1.CoordType "PointName" .G1.PointName "UPointName" (toUpper .G1.PointName) "TAffine" $G1TAffine "TJacobian" $G1TJacobian "TJacobianExtended" $G1TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G1.CRange}} +{{ template "multiexp" dict "CoordType" .G2.CoordType "PointName" .G2.PointName "UPointName" (toUpper .G2.PointName) "TAffine" $G2TAffine "TJacobian" $G2TJacobian "TJacobianExtended" $G2TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G2.CRange}} + + +{{define "multiexp" }} + +type batchOp{{ $.TAffine }} struct { + bucketID uint16 + point {{ $.TAffine }} +} + +// processChunk{{ $.UPointName }}BatchAffine process a chunk of the scalars during the msm +// using affine coordinates for the buckets. To amortize the cost of the inverse in the affine addition +// we use a batch affine addition. +// +// this is derived from a PR by 0x0ece : https://github.com/ConsenSys/gnark-crypto/pull/249 +// See Section 5.3: ia.cr/2022/1396 +func processChunk{{ $.UPointName }}BatchAffine[BJE ib{{ $.TJacobianExtended }},B ib{{ $.TAffine }}, BS bitSet, TP p{{ $.TAffine }}, TPP pp{{ $.TAffine }}, TQ qOps{{ $.TAffine }}, TC c{{ $.TAffine}}]( + chunk uint64, + chRes chan<- {{ $.TJacobianExtended }}, + c uint64, + points []{{ $.TAffine }}, + digits []uint16) { + + // the batch affine addition needs independent points; in other words, for a window of batchSize + // we want to hit independent bucketIDs when processing the digit. if there is a conflict (we're trying + // to add 2 different points to the same bucket), then we push the conflicted point to a queue. + // each time the batch is full, we execute it, and tentatively put the points (if not conflict) + // from the top of the queue into the next batch. + // if the queue is full, we "flush it"; we sequentially add the points to the buckets in + // {{ $.TJacobianExtended }} coordinates. + // The reasoning behind this is the following; batchSize is chosen such as, for a uniformly random + // input, the number of conflicts is going to be low, and the element added to the queue should be immediatly + // processed in the next batch. If it's not the case, then our inputs are not random; and we fallback to + // non-batch-affine version. + + // note that we have 2 sets of buckets + // 1 in {{ $.TAffine }} used with the batch affine additions + // 1 in {{ $.TJacobianExtended }} used in case the queue of conflicting points + var buckets B + var bucketsJE BJE + for i := 0; i < len(buckets); i++ { + buckets[i].setInfinity() + bucketsJE[i].setInfinity() + } + + // setup for the batch affine; + var ( + bucketIds BS // bitSet to signify presence of a bucket in current batch + cptAdd int // count the number of bucket + point added to current batch + R TPP // bucket references + P TP // points to be added to R (buckets); it is beneficial to store them on the stack (ie copy) + queue TQ // queue of points that conflict the current batch + qID int // current position in queue + ) + + batchSize := len(P) + + isFull := func() bool { return cptAdd == batchSize} + + executeAndReset := func () { + batchAdd{{ $.TAffine }}[TP, TPP, TC](&R, &P, cptAdd) + var tmp BS + bucketIds = tmp + cptAdd = 0 + } + + addFromQueue := func(op batchOp{{ $.TAffine }}) { + // @precondition: must ensures bucket is not "used" in current batch + // note that there is a bit of duplicate logic between add and addFromQueue + // the reason is that as of Go 1.19.3, if we pass a pointer to the queue item (see add signature) + // the compiler will put the queue on the heap. + BK := &buckets[op.bucketID] + + // handle special cases with inf or -P / P + if BK.IsInfinity() { + BK.Set(&op.point) + return + } + if BK.X.Equal(&op.point.X) { + if BK.Y.Equal(&op.point.Y) { + // P + P: doubling, which should be quite rare -- + // we use the other set of buckets + bucketsJE[op.bucketID].addMixed(&op.point) + return + } + BK.setInfinity() + return + } + + bucketIds[op.bucketID] = true + R[cptAdd] = BK + P[cptAdd] = op.point + cptAdd++ + } + + add := func(bucketID uint16, PP *{{$.TAffine}}, isAdd bool) { + // @precondition: ensures bucket is not "used" in current batch + BK := &buckets[bucketID] + // handle special cases with inf or -P / P + if BK.IsInfinity() { + if isAdd { + BK.Set(PP) + } else { + BK.Neg(PP) + } + return + } + if BK.X.Equal(&PP.X) { + if BK.Y.Equal(&PP.Y) { + // P + P: doubling, which should be quite rare -- + if isAdd { + bucketsJE[bucketID].addMixed(PP) + } else { + BK.setInfinity() + } + return + } + if isAdd { + BK.setInfinity() + } else { + bucketsJE[bucketID].subMixed(PP) + } + return + } + + bucketIds[bucketID] = true + R[cptAdd] = BK + if isAdd { + P[cptAdd].Set(PP) + } else { + P[cptAdd].Neg(PP) + } + cptAdd++ + } + + flushQueue := func () { + for i:=0; i < qID; i++ { + bucketsJE[queue[i].bucketID].addMixed(&queue[i].point) + } + qID = 0 + } + + processTopQueue := func () { + for i := qID - 1; i >= 0; i-- { + if bucketIds[queue[i].bucketID] { + return + } + addFromQueue(queue[i]) + // len(queue) < batchSize so no need to check for full batch. + qID-- + } + } + + + for i, digit := range digits { + + if digit == 0 || points[i].IsInfinity() { + continue + } + + bucketID := uint16((digit>>1)) + isAdd := digit&1 == 0 + if isAdd { + // add + bucketID-=1 + } + + if bucketIds[bucketID] { + // put it in queue + queue[qID].bucketID = bucketID + if isAdd { + queue[qID].point.Set(&points[i]) + } else { + queue[qID].point.Neg(&points[i]) + } + qID++ + + // queue is full, flush it. + if qID == len(queue) - 1 { + flushQueue() + } + continue + } + + // we add the point to the batch. + add(bucketID, &points[i], isAdd) + if isFull() { + executeAndReset() + processTopQueue() + } + } + + + // flush items in batch. + executeAndReset() + + // empty the queue + flushQueue() + + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + var runningSum, total {{ $.TJacobianExtended }} + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + runningSum.addMixed(&buckets[k]) + if !bucketsJE[k].ZZ.IsZero() { + runningSum.add(&bucketsJE[k]) + } + total.add(&runningSum) + } + + chRes <- total + +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +{{- range $c := $.CRange}} +{{- if gt $c 9}} +type bucket{{ $.TAffine }}C{{$c}} [{{nbBuckets $c}}]{{ $.TAffine }} +{{- end}} +{{- end}} + + +// buckets: array of {{ $.TAffine }} points of size 1 << (c-1) +type ib{{ $.TAffine }} interface { + {{- range $i, $c := $.CRange}} + {{- if gt $c 9}} + bucket{{ $.TAffine }}C{{$c}} {{- if not (last $i $.CRange)}} | {{- end}} + {{- end}} + {{- end}} +} + +// array of coordinates {{ $.CoordType }} +type c{{ $.TAffine }} interface { + {{- range $i, $c := $.CRange}} + {{- if gt $c 9}} + c{{ $.TAffine }}C{{$c}} {{- if not (last $i $.CRange)}} | {{- end}} + {{- end}} + {{- end}} +} + +// buckets: array of {{ $.TAffine }} points (for the batch addition) +type p{{ $.TAffine }} interface { + {{- range $i, $c := $.CRange}} + {{- if gt $c 9}} + p{{ $.TAffine }}C{{$c}} {{- if not (last $i $.CRange)}} | {{- end}} + {{- end}} + {{- end}} +} + +// buckets: array of *{{ $.TAffine }} points (for the batch addition) +type pp{{ $.TAffine }} interface { + {{- range $i, $c := $.CRange}} + {{- if gt $c 9}} + pp{{ $.TAffine }}C{{$c}} {{- if not (last $i $.CRange)}} | {{- end}} + {{- end}} + {{- end}} +} + +// buckets: array of {{ $.TAffine }} queue operations (for the batch addition) +type qOps{{ $.TAffine }} interface { + {{- range $i, $c := $.CRange}} + {{- if gt $c 9}} + q{{ $.TAffine }}C{{$c}} {{- if not (last $i $.CRange)}} | {{- end}} + {{- end}} + {{- end}} +} + + +{{- range $c := $.CRange}} +{{if gt $c 9}} +// batch size {{batchSize $c}} when c = {{$c}} +type c{{ $.TAffine }}C{{$c}} [{{batchSize $c}}]{{ $.CoordType }} +type p{{ $.TAffine }}C{{$c}} [{{batchSize $c}}]{{ $.TAffine }} +type pp{{ $.TAffine }}C{{$c}} [{{batchSize $c}}]*{{ $.TAffine }} +type q{{ $.TAffine }}C{{$c}} [{{batchSize $c}}]batchOp{{ $.TAffine }} +{{- end}} +{{- end}} + + +{{end }} + +{{- range $c := $.G1.CRange}} +type bitSetC{{$c}} [{{nbBuckets $c}}]bool +{{- end}} + +type bitSet interface { + {{- range $i, $c := $.G1.CRange}} + bitSetC{{$c}} {{- if not (last $i $.G1.CRange)}} | {{- end}} + {{- end}} +} diff --git a/internal/generator/ecc/template/multiexp_jacobian.go.tmpl b/internal/generator/ecc/template/multiexp_jacobian.go.tmpl new file mode 100644 index 000000000..166d185fa --- /dev/null +++ b/internal/generator/ecc/template/multiexp_jacobian.go.tmpl @@ -0,0 +1,76 @@ +{{ $G1TAffine := print (toUpper .G1.PointName) "Affine" }} +{{ $G1TJacobian := print (toUpper .G1.PointName) "Jac" }} +{{ $G1TJacobianExtended := print (toLower .G1.PointName) "JacExtended" }} + +{{ $G2TAffine := print (toUpper .G2.PointName) "Affine" }} +{{ $G2TJacobian := print (toUpper .G2.PointName) "Jac" }} +{{ $G2TJacobianExtended := print (toLower .G2.PointName) "JacExtended" }} + + + +{{ template "multiexp" dict "PointName" .G1.PointName "UPointName" (toUpper .G1.PointName) "TAffine" $G1TAffine "TJacobian" $G1TJacobian "TJacobianExtended" $G1TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G1.CRange }} +{{ template "multiexp" dict "PointName" .G2.PointName "UPointName" (toUpper .G2.PointName) "TAffine" $G2TAffine "TJacobian" $G2TJacobian "TJacobianExtended" $G2TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G2.CRange }} + + + +{{define "multiexp" }} + +func processChunk{{ $.UPointName }}Jacobian[B ib{{ $.TJacobianExtended }}](chunk uint64, + chRes chan<- {{ $.TJacobianExtended }}, + c uint64, + points []{{ $.TAffine }}, + digits []uint16) { + + + + var buckets B + for i := 0 ; i < len(buckets); i++ { + buckets[i].setInfinity() + } + + // for each scalars, get the digit corresponding to the chunk we're processing. + for i, digit := range digits { + if digit == 0 { + continue + } + + // if msbWindow bit is set, we need to substract + if digit & 1 == 0 { + // add + buckets[(digit>>1)-1].addMixed(&points[i]) + } else { + // sub + buckets[(digit>>1)].subMixed(&points[i]) + } + } + + + // reduce buckets into total + // total = bucket[0] + 2*bucket[1] + 3*bucket[2] ... + n*bucket[n-1] + + var runningSum, total {{ $.TJacobianExtended }} + runningSum.setInfinity() + total.setInfinity() + for k := len(buckets) - 1; k >= 0; k-- { + if !buckets[k].ZZ.IsZero() { + runningSum.add(&buckets[k]) + } + total.add(&runningSum) + } + + chRes <- total +} + +// we declare the buckets as fixed-size array types +// this allow us to allocate the buckets on the stack +{{- range $c := $.CRange}} +type bucket{{ $.TJacobianExtended }}C{{$c}} [{{nbBuckets $c}}]{{ $.TJacobianExtended }} +{{- end}} + +type ib{{ $.TJacobianExtended }} interface { + {{- range $i, $c := $.CRange}} + bucket{{ $.TJacobianExtended }}C{{$c}} {{- if not (last $i $.CRange)}} | {{- end}} + {{- end}} +} + +{{end }} diff --git a/internal/generator/ecc/template/point.go.tmpl b/internal/generator/ecc/template/point.go.tmpl index e54abcd56..5efbfeeef 100644 --- a/internal/generator/ecc/template/point.go.tmpl +++ b/internal/generator/ecc/template/point.go.tmpl @@ -9,7 +9,9 @@ import ( "math/big" "runtime" + {{- if .GLV}} "github.com/consensys/gnark-crypto/ecc" + {{- end}} "github.com/consensys/gnark-crypto/internal/parallel" "github.com/consensys/gnark-crypto/ecc/{{.Name}}/fr" {{- if or (eq .CoordType "fptower.E2") (eq .CoordType "fptower.E4") }} @@ -52,11 +54,22 @@ func (p *{{ $TAffine }}) Set(a *{{ $TAffine }}) *{{ $TAffine }} { return p } +// setInfinity sets p to O +func (p *{{ $TAffine }}) setInfinity() *{{ $TAffine }} { + p.X.SetZero() + p.Y.SetZero() + return p +} + // ScalarMultiplication computes and returns p = a ⋅ s func (p *{{ $TAffine }}) ScalarMultiplication(a *{{ $TAffine }}, s *big.Int) *{{ $TAffine }} { var _p {{ $TJacobian }} _p.FromAffine(a) - _p.mulGLV(&_p, s) + {{- if .GLV}} + _p.mulGLV(&_p, s) + {{- else }} + _p.mulWindowed(&_p, s) + {{- end }} p.FromJacobian(&_p) return p } @@ -66,7 +79,11 @@ func (p *{{ $TAffine }}) ScalarMultiplication(a *{{ $TAffine }}, s *big.Int) *{{ // Takes an affine point and returns a Jacobian point (useful for KZG) func (p *{{ $TJacobian }}) ScalarMultiplicationAffine(a *{{ $TAffine }}, s *big.Int) *{{ $TJacobian }} { p.FromAffine(a) - p.mulGLV(p, s) + {{- if .GLV}} + p.mulGLV(p, s) + {{- else }} + p.mulWindowed(p, s) + {{- end }} return p } {{- end}} @@ -409,19 +426,27 @@ func (p *{{ $TJacobian }}) IsOnCurve() bool { } {{else if eq .PointName "g2"}} // IsInSubGroup returns true if p is on the r-torsion, false otherwise. - // [r]P == 0 <==> Frob(P) == [6x²]P + // https://eprint.iacr.org/2022/348.pdf, sec. 3 and 5.1 + // [r]P == 0 <==> [x₀+1]P + ψ([x₀]P) + ψ²([x₀]P) = ψ³([2x₀]P) func (p *{{ $TJacobian }}) IsInSubGroup() bool { - var a, res G2Jac - a.psi(p) - res.ScalarMultiplication(p, &fixedCoeff). - SubAssign(&a) - - return res.IsOnCurve() && res.Z.IsZero() + var a, b, c, res G2Jac + a.ScalarMultiplication(p, &xGen) + b.psi(&a) + a.AddAssign(p) + res.psi(&b) + c.Set(&res). + AddAssign(&b). + AddAssign(&a) + res.psi(&res). + Double(&res). + SubAssign(&c) + return res.IsOnCurve() && res.Z.IsZero() } {{- end}} {{else if or (eq .Name "bw6-761") (eq .Name "bw6-756")}} // IsInSubGroup returns true if p is on the r-torsion, false otherwise. + {{ if .GLV}} // Z[r,0]+Z[-lambda{{ $TAffine }}, 1] is the kernel // of (u,v)->u+lambda{{ $TAffine }}v mod r. Expressing r, lambda{{ $TAffine }} as // polynomials in x, a short vector of this Zmodule is @@ -440,10 +465,18 @@ func (p *{{ $TJacobian }}) IsOnCurve() bool { phip.ScalarMultiplication(p, &xGen).AddAssign(p).AddAssign(&res) return phip.IsOnCurve() && phip.Z.IsZero() + {{ else}} + func (p *{{ $TJacobian }}) IsInSubGroup() bool { + + var res {{ $TJacobian }} + res.ScalarMultiplication(p, fr.Modulus()) + return res.IsOnCurve() && res.Z.IsZero() + {{ end}} } {{else if eq .Name "bw6-633"}} // IsInSubGroup returns true if p is on the r-torsion, false otherwise. + {{ if .GLV}} // 3r P = (x+1)ϕ(P) + (-x^5 + x⁴ + x)P func (p *{{ $TJacobian }}) IsInSubGroup() bool { @@ -458,11 +491,18 @@ func (p *{{ $TJacobian }}) IsOnCurve() bool { AddAssign(&u4P). AddAssign(&u5P) + {{ else}} + func (p *{{ $TJacobian }}) IsInSubGroup() bool { + + var r {{ $TJacobian }} + r.ScalarMultiplication(p, fr.Modulus()) + {{ end}} return r.IsOnCurve() && r.Z.IsZero() } {{else if or (eq .Name "bls24-315") (eq .Name "bls24-317")}} {{- if eq .PointName "g1"}} // IsInSubGroup returns true if p is on the r-torsion, false otherwise. + {{ if .GLV}} // Z[r,0]+Z[-lambda{{ $TAffine }}, 1] is the kernel // of (u,v)->u+lambda{{ $TAffine }}v mod r. Expressing r, lambda{{ $TAffine }} as // polynomials in x, a short vector of this Zmodule is @@ -477,6 +517,13 @@ func (p *{{ $TJacobian }}) IsOnCurve() bool { ScalarMultiplication(&res, &xGen). ScalarMultiplication(&res, &xGen). AddAssign(p) + {{ else}} + func (p *{{ $TJacobian }}) IsInSubGroup() bool { + + var res {{ $TJacobian }} + res.ScalarMultiplication(p, fr.Modulus()) + return res.IsOnCurve() && res.Z.IsZero() + {{ end}} return res.IsOnCurve() && res.Z.IsZero() @@ -484,7 +531,8 @@ func (p *{{ $TJacobian }}) IsOnCurve() bool { {{else if eq .PointName "g2"}} // IsInSubGroup returns true if p is on the r-torsion, false otherwise. // https://eprint.iacr.org/2021/1130.pdf, sec.4 - // ψ(p) = x₀ P + // and https://eprint.iacr.org/2022/352.pdf, sec. 4.2 + // ψ(p) = [x₀]P func (p *{{ $TJacobian }}) IsInSubGroup() bool { var res, tmp {{ $TJacobian }} tmp.psi(p) @@ -510,10 +558,14 @@ func (p *{{ $TJacobian }}) IsOnCurve() bool { func (p *{{ $TJacobian }}) IsInSubGroup() bool { var res {{ $TJacobian }} + {{ if .GLV}} res.phi(p). ScalarMultiplication(&res, &xGen). ScalarMultiplication(&res, &xGen). AddAssign(p) + {{ else}} + res.ScalarMultiplication(p, fr.Modulus()) + {{ end}} return res.IsOnCurve() && res.Z.IsZero() @@ -522,7 +574,8 @@ func (p *{{ $TJacobian }}) IsOnCurve() bool { {{if eq .Name "bls12-381"}} // IsInSubGroup returns true if p is on the r-torsion, false otherwise. // https://eprint.iacr.org/2021/1130.pdf, sec.4 - // ψ(p) = x₀ P + // and https://eprint.iacr.org/2022/352.pdf, sec. 4.2 + // ψ(p) = [x₀]P func (p *{{ $TJacobian }}) IsInSubGroup() bool { var res, tmp {{ $TJacobian }} tmp.psi(p) @@ -533,7 +586,8 @@ func (p *{{ $TJacobian }}) IsOnCurve() bool { } {{else if or (eq .Name "bls12-377") (eq .Name "bls12-378")}} // https://eprint.iacr.org/2021/1130.pdf, sec.4 - // ψ(p) = x₀ P + // and https://eprint.iacr.org/2022/352.pdf, sec. 4.2 + // ψ(p) = [x₀]P func (p *{{ $TJacobian }}) IsInSubGroup() bool { var res, tmp {{ $TJacobian }} tmp.psi(p) @@ -655,8 +709,8 @@ func (p *{{ $TJacobian }}) mulGLV(a *{{ $TJacobian }}, s *big.Int) *{{ $TJacobia // bounds on the lattice base vectors guarantee that k1, k2 are len(r)/2 or len(r)/2+1 bits long max // this is because we use a probabilistic scalar decomposition that replaces a division by a right-shift - k1.SetBigInt(&k[0]).FromMont() - k2.SetBigInt(&k[1]).FromMont() + k1 = k1.SetBigInt(&k[0]).Bits() + k2 = k2.SetBigInt(&k[1]).Bits() // we don't target constant-timeness so we check first if we increase the bounds or not maxBit := k1.BitLen() @@ -715,6 +769,7 @@ func (p *{{$TJacobian}}) ClearCofactor(a *{{$TJacobian}}) *{{$TJacobian}} { p.Set(&res) return p {{else if eq .Name "bw6-761"}} +{{ if .GLV}} // https://eprint.iacr.org/2020/351.pdf var points [4]{{$TJacobian}} points[0].Set(a) @@ -749,9 +804,15 @@ func (p *{{$TJacobian}}) ClearCofactor(a *{{$TJacobian}}) *{{$TJacobian}} { p2.phi(&p2) p.Set(&p1).AddAssign(&p2) +{{ else}} + var c1 big.Int + c1.SetString("26642435879335816683987677701488073867751118270052650655942102502312977592501693353047140953112195348280268661194876", 10) + p.ScalarMultiplication(a, &c1) +{{ end}} return p {{else if eq .Name "bw6-633"}} +{{ if .GLV}} var uP, vP, wP, L0, L1, tmp {{$TJacobian}} var v, one, uPlusOne, uMinusOne, d1, d2, ht big.Int one.SetInt64(1) @@ -779,9 +840,15 @@ func (p *{{$TJacobian}}) ClearCofactor(a *{{$TJacobian}}) *{{$TJacobian}} { L1.AddAssign(&tmp) p.phi(&L1).AddAssign(&L0) +{{ else}} + var c1 big.Int + c1.SetString("516166855112631370346774477030598579858367278343565509012644853411927535599366632765988905418773", 10) + p.ScalarMultiplication(a, &c1) +{{ end}} return p {{else if eq .Name "bw6-756"}} +{{ if .GLV}} var L0, L1, uP, u2P, u3P, tmp G1Jac uP.ScalarMultiplication(a, &xGen) @@ -806,6 +873,11 @@ func (p *{{$TJacobian}}) ClearCofactor(a *{{$TJacobian}}) *{{$TJacobian}} { p.phi(&L1). AddAssign(&L0) +{{ else}} + var c1 big.Int + c1.SetString("605248206075306171568857128027361794400937215108643640003009340657451546212610770151705515081537938829431808196608", 10) + p.ScalarMultiplication(a, &c1) +{{ end}} return p {{- end}} @@ -933,6 +1005,7 @@ func (p *{{$TJacobian}}) ClearCofactor(a *{{$TJacobian}}) *{{$TJacobian}} { return p {{else if eq .Name "bw6-761"}} +{{- if .GLV}} var points [4]{{$TJacobian}} points[0].Set(a) points[1].ScalarMultiplication(a, &xGen) @@ -966,9 +1039,15 @@ func (p *{{$TJacobian}}) ClearCofactor(a *{{$TJacobian}}) *{{$TJacobian}} { p2.phi(&p2).phi(&p2) p.Set(&p1).AddAssign(&p2) +{{else}} + var c2 big.Int + c2.SetString("26642435879335816683987677701488073867751118270052650655942102502312977592501693353047140953112195348280268661194869", 10) + p.ScalarMultiplication(a, &c2) +{{end}} return p {{else if eq .Name "bw6-633"}} +{{- if .GLV}} var uP, u2P, u3P, u4P, u5P, xP, vP, wP, L0, L1, tmp {{$TJacobian}} var ht, d1, d3 big.Int ht.SetInt64(7) // negative @@ -999,10 +1078,16 @@ func (p *{{$TJacobian}}) ClearCofactor(a *{{$TJacobian}}) *{{$TJacobian}} { L1.AddAssign(&tmp) p.phi(&L1).AddAssign(&L0) +{{else}} + var c2 big.Int + c2.SetString("516166855112631370346774477030598579858367278343565509012644853411927535599366632765988905418768", 10) + p.ScalarMultiplication(a, &c2) +{{end}} return p {{else if eq .Name "bw6-756"}} +{{- if .GLV}} var L0, L1, uP, u2P, u3P, tmp G2Jac uP.ScalarMultiplication(a, &xGen) @@ -1023,6 +1108,11 @@ func (p *{{$TJacobian}}) ClearCofactor(a *{{$TJacobian}}) *{{$TJacobian}} { p.phi(&L0). AddAssign(&L1) +{{else}} + var c2 big.Int + c2.SetString("605248206075306171568857128027361794400937215108643640003009340657451546212610770151705515081537938829431808196609", 10) + p.ScalarMultiplication(a, &c2) +{{end}} return p @@ -1099,17 +1189,17 @@ func (p *{{ $TJacobianExtended }}) add(q *{{ $TJacobianExtended }}) *{{ $TJacobi return p } - var A, B, X1ZZ2, X2ZZ1, Y1ZZZ2, Y2ZZZ1 {{.CoordType}} + var A, B, U1, U2, S1, S2 {{.CoordType}} // p2: q, p1: p - X2ZZ1.Mul(&q.X, &p.ZZ) - X1ZZ2.Mul(&p.X, &q.ZZ) - A.Sub(&X2ZZ1, &X1ZZ2) - Y2ZZZ1.Mul(&q.Y, &p.ZZZ) - Y1ZZZ2.Mul(&p.Y, &q.ZZZ) - B.Sub(&Y2ZZZ1, &Y1ZZZ2) - - if A.IsZero() { + U2.Mul(&q.X, &p.ZZ) + U1.Mul(&p.X, &q.ZZ) + A.Sub(&U2, &U1) + S2.Mul(&q.Y, &p.ZZZ) + S1.Mul(&p.Y, &q.ZZZ) + B.Sub(&S2, &S1) + + if A.IsZero() { if B.IsZero() { return p.double(q) @@ -1119,11 +1209,8 @@ func (p *{{ $TJacobianExtended }}) add(q *{{ $TJacobianExtended }}) *{{ $TJacobi return p } - var U1, U2, S1, S2, P, R, PP, PPP, Q, V {{.CoordType}} - U1.Mul(&p.X, &q.ZZ) - U2.Mul(&q.X, &p.ZZ) - S1.Mul(&p.Y, &q.ZZZ) - S2.Mul(&q.Y, &p.ZZZ) + + var P, R, PP, PPP, Q, V {{.CoordType}} P.Sub(&U2, &U1) R.Sub(&S2, &S1) PP.Square(&P) @@ -1148,6 +1235,8 @@ func (p *{{ $TJacobianExtended }}) add(q *{{ $TJacobianExtended }}) *{{ $TJacobi // double point in Jacobian extended coordinates // http://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html#doubling-dbl-2008-s-1 +// since we consider any point on Z=0 as the point at infinity +// this doubling formula works for infinity points as well func (p *{{ $TJacobianExtended }}) double(q *{{ $TJacobianExtended }}) *{{ $TJacobianExtended }} { var U, V, W, S, XX, M {{.CoordType}} @@ -1428,63 +1517,41 @@ func BatchJacobianToAffine{{ toUpper .PointName }}(points []{{ $TJacobian }}) [] // and return resulting points in affine coordinates // uses a simple windowed-NAF like exponentiation algorithm func BatchScalarMultiplication{{ toUpper .PointName }}(base *{{ $TAffine }}, scalars []fr.Element) []{{ $TAffine }} { - // approximate cost in group ops is // cost = 2^{c-1} + n(scalar.nbBits+nbChunks) nbPoints := uint64(len(scalars)) min := ^uint64(0) bestC := 0 - for c := 2; c < 18; c++ { - cost := uint64(1 << (c-1)) - nbChunks := uint64(fr.Limbs * 64 / c) - if (fr.Limbs*64) %c != 0 { - nbChunks++ - } - cost += nbPoints*((fr.Limbs*64) + nbChunks) + for c := 2; c <= 16; c++ { + cost := uint64(1 << (c-1)) // pre compute the table + nbChunks := computeNbChunks(uint64(c)) + cost += nbPoints * (uint64(c) + 1) * nbChunks // doublings + point add if cost < min { min = cost bestC = c } } - c := uint64(bestC) // window size - nbChunks := int(fr.Limbs * 64 / c) - if (fr.Limbs*64) %c != 0 { - nbChunks++ + c := uint64(bestC) // window size + nbChunks := int(computeNbChunks(c)) + + // last window may be slightly larger than c; in which case we need to compute one + // extra element in the baseTable + maxC := lastC(c) + if c > maxC { + maxC = c } - mask := uint64((1 << c) - 1) // low c bits are 1 - msbWindow := uint64(1 << (c -1)) // precompute all powers of base for our window // note here that if performance is critical, we can implement as in the msmX methods // this allocation to be on the stack - baseTable := make([]{{ $TJacobian }}, (1<<(c-1))) - baseTable[0].Set(&{{ toLower .PointName}}Infinity) - baseTable[0].AddMixed(base) + baseTable := make([]{{ $TJacobian }}, (1<<(maxC-1))) + baseTable[0].FromAffine(base) for i:=1;i (64-c) && d.index < (fr.Limbs - 1 ) - if d.multiWordSelect { - nbBitsHigh := d.shift - uint64(64-c) - d.maskHigh = (1 << nbBitsHigh) - 1 - d.shiftHigh = (c - nbBitsHigh) - } - selectors[chunk] = d - } - {{- if eq .PointName "g1"}} // convert our base exp table into affine to use AddMixed baseTableAff := BatchJacobianToAffine{{ toUpper .PointName}}(baseTable) @@ -1493,44 +1560,43 @@ func BatchScalarMultiplication{{ toUpper .PointName }}(base *{{ $TAffine }}, sca toReturn := make([]{{ $TAffine }}, len(scalars)) {{- end}} + // partition the scalars into digits + digits, _ := partitionScalars(scalars, c, runtime.NumCPU()) + // for each digit, take value in the base table, double it c time, voilà. - parallel.Execute( len(pScalars), func(start, end int) { + parallel.Execute( len(scalars), func(start, end int) { var p {{ $TJacobian }} for i:=start; i < end; i++ { p.Set(&{{ toLower .PointName}}Infinity) for chunk := nbChunks - 1; chunk >=0; chunk-- { - s := selectors[chunk] if chunk != nbChunks -1 { for j:=uint64(0); j> s.shift - if s.multiWordSelect { - bits += (pScalars[i][s.index+1] & s.maskHigh) << s.shiftHigh - } - - if bits == 0 { + if digit == 0 { continue } // if msbWindow bit is set, we need to substract - if bits & msbWindow == 0 { + if digit & 1 == 0 { // add {{- if eq .PointName "g1"}} - p.AddMixed(&baseTableAff[bits-1]) + p.AddMixed(&baseTableAff[(digit >> 1)-1]) {{- else}} - p.AddAssign(&baseTable[bits-1]) + p.AddAssign(&baseTable[(digit >> 1)-1]) {{- end}} } else { // sub {{- if eq .PointName "g1"}} - t := baseTableAff[bits & ^msbWindow] + t := baseTableAff[digit >> 1] t.Neg(&t) p.AddMixed(&t) {{- else}} - t := baseTable[bits & ^msbWindow] + t := baseTable[digit >> 1] t.Neg(&t) p.AddAssign(&t) {{- end}} @@ -1554,3 +1620,57 @@ func BatchScalarMultiplication{{ toUpper .PointName }}(base *{{ $TAffine }}, sca return toReturn {{- end}} } + + + +// batch add affine coordinates +// using batch inversion +// special cases (doubling, infinity) must be filtered out before this call +func batchAdd{{ $TAffine }}[TP p{{ $TAffine }}, TPP pp{{ $TAffine }}, TC c{{ $TAffine }}](R *TPP,P *TP, batchSize int) { + var lambda, lambdain TC + + + // add part + for j := 0; j < batchSize; j++ { + lambdain[j].Sub(&(*P)[j].X, &(*R)[j].X) + } + + // invert denominator using montgomery batch invert technique + { + var accumulator {{.CoordType}} + lambda[0].SetOne() + accumulator.Set(&lambdain[0]) + + for i := 1; i < batchSize; i++ { + lambda[i] = accumulator + accumulator.Mul(&accumulator, &lambdain[i]) + } + + accumulator.Inverse(&accumulator) + + for i := batchSize - 1; i > 0; i-- { + lambda[i].Mul(&lambda[i], &accumulator) + accumulator.Mul(&accumulator, &lambdain[i]) + } + lambda[0].Set(&accumulator) + } + + var d {{.CoordType}} + var rr {{ $TAffine }} + + // add part + for j := 0; j < batchSize; j++ { + // computa lambda + d.Sub(&(*P)[j].Y, &(*R)[j].Y) + lambda[j].Mul(&lambda[j], &d) + + // compute X, Y + rr.X.Square(&lambda[j]) + rr.X.Sub(&rr.X, &(*R)[j].X) + rr.X.Sub(&rr.X, &(*P)[j].X) + d.Sub(&(*R)[j].X, &rr.X) + rr.Y.Mul(&lambda[j], &d) + rr.Y.Sub(&rr.Y, &(*R)[j].Y) + (*R)[j].Set(&rr) + } +} diff --git a/internal/generator/ecc/template/tests/hash_to_curve.go.tmpl b/internal/generator/ecc/template/tests/hash_to_curve.go.tmpl index 7739ea75f..053eaf753 100644 --- a/internal/generator/ecc/template/tests/hash_to_curve.go.tmpl +++ b/internal/generator/ecc/template/tests/hash_to_curve.go.tmpl @@ -68,7 +68,7 @@ func Test{{$CurveTitle}}SqrtRatio(t *testing.T) { func TestHashToFp{{$CurveTitle}}(t *testing.T) { for _, c := range encodeTo{{$CurveTitle}}Vector.cases { - elems, err := hashToFp([]byte(c.msg), encodeTo{{$CurveTitle}}Vector.dst, {{$TowerDegree}}) + elems, err := fp.Hash([]byte(c.msg), encodeTo{{$CurveTitle}}Vector.dst, {{$TowerDegree}}) if err != nil { t.Error(err) } @@ -76,7 +76,7 @@ func TestHashToFp{{$CurveTitle}}(t *testing.T) { } for _, c := range hashTo{{$CurveTitle}}Vector.cases { - elems, err := hashToFp([]byte(c.msg), hashTo{{$CurveTitle}}Vector.dst, 2 * {{$TowerDegree}}) + elems, err := fp.Hash([]byte(c.msg), hashTo{{$CurveTitle}}Vector.dst, 2 * {{$TowerDegree}}) if err != nil { t.Error(err) } diff --git a/internal/generator/ecc/template/tests/multiexp.go.tmpl b/internal/generator/ecc/template/tests/multiexp.go.tmpl index dd8751733..3be1ec4b4 100644 --- a/internal/generator/ecc/template/tests/multiexp.go.tmpl +++ b/internal/generator/ecc/template/tests/multiexp.go.tmpl @@ -9,12 +9,14 @@ import ( "fmt" + "time" + "runtime" + "math/rand" "math/big" "testing" - "runtime" "math/bits" "sync" - + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/{{.Name}}/fr" "github.com/leanovate/gopter" @@ -22,26 +24,26 @@ import ( ) -{{template "multiexp" dict "PointName" .G1.PointName "TAffine" $G1TAffine "TJacobian" $G1TJacobian "TJacobianExtended" $G1TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G1.CRange}} -{{template "multiexp" dict "PointName" .G2.PointName "TAffine" $G2TAffine "TJacobian" $G2TJacobian "TJacobianExtended" $G2TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G2.CRange}} +{{template "multiexp" dict "PointName" .G1.PointName "UPointName" (toUpper .G1.PointName) "TAffine" $G1TAffine "TJacobian" $G1TJacobian "TJacobianExtended" $G1TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G1.CRange}} +{{template "multiexp" dict "PointName" .G2.PointName "UPointName" (toUpper .G2.PointName) "TAffine" $G2TAffine "TJacobian" $G2TJacobian "TJacobianExtended" $G2TJacobianExtended "FrNbWords" .Fr.NbWords "CRange" .G2.CRange}} {{define "multiexp" }} -func TestMultiExp{{toUpper $.PointName}}(t *testing.T) { +func TestMultiExp{{$.UPointName}}(t *testing.T) { parameters := gopter.DefaultTestParameters() if testing.Short() { - parameters.MinSuccessfulTests = 2 + parameters.MinSuccessfulTests = 3 } else { - parameters.MinSuccessfulTests = nbFuzzShort + parameters.MinSuccessfulTests = nbFuzzShort * 2 } properties := gopter.NewProperties(parameters) genScalar := GenFr() - - // size of the multiExps + + // size of the multiExps const nbSamples = 73 // multi exp points @@ -53,6 +55,13 @@ func TestMultiExp{{toUpper $.PointName}}(t *testing.T) { g.AddAssign(&{{ toLower $.PointName }}Gen) } + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + // final scalar to use in double and add method (without mixer factor) // n(n+1)(2n+1)/6 (sum of the squares from 1 to n) var scalar big.Int @@ -63,7 +72,7 @@ func TestMultiExp{{toUpper $.PointName}}(t *testing.T) { // ensure a multiexp that's splitted has the same result as a non-splitted one.. - properties.Property("[{{ toUpper $.PointName }}] Multi exponentation (c=16) should be consistent with splitted multiexp", prop.ForAll( + properties.Property("[{{ $.UPointName }}] Multi exponentation (c=16) should be consistent with splitted multiexp", prop.ForAll( func(mixer fr.Element) bool { var samplePointsLarge [nbSamples*13]{{ $.TAffine }} for i:=0; i<13; i++ { @@ -71,19 +80,16 @@ func TestMultiExp{{toUpper $.PointName}}(t *testing.T) { } var r16, splitted1, splitted2 {{ $.TJacobian }} - + // mixer ensures that all the words of a fpElement are set var sampleScalars [nbSamples*13]fr.Element - + for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - scalars16, _ := partitionScalars(sampleScalars[:], 16, false, runtime.NumCPU()) - r16.msmC16(samplePoints[:], scalars16, true) - + r16.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{}) splitted1.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 128}) splitted2.MultiExp(samplePointsLarge[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: 51}) return r16.Equal(&splitted1) && r16.Equal(&splitted2) @@ -94,7 +100,7 @@ func TestMultiExp{{toUpper $.PointName}}(t *testing.T) { // cRange is generated from template and contains the available parameters for the multiexp window size {{- if eq $.PointName "g1" }} cRange := []uint64{ - {{- range $c := $.CRange}} {{- if and (eq $.PointName "g1") (gt $c 21)}}{{- else}} {{$c}},{{- end}}{{- end}} + {{- range $c := $.CRange}}{{- if gt $c 1}}{{$c}},{{- end}}{{- end}} } if testing.Short() { // test only "odd" and "even" (ie windows size divide word size vs not) @@ -106,11 +112,11 @@ func TestMultiExp{{toUpper $.PointName}}(t *testing.T) { cRange := []uint64{5, 16} {{- end}} - properties.Property(fmt.Sprintf("[{{ toUpper $.PointName }}] Multi exponentation (c in %v) should be consistent with sum of square", cRange), prop.ForAll( + properties.Property(fmt.Sprintf("[{{ $.UPointName }}] Multi exponentation (c in %v) should be consistent with sum of square", cRange), prop.ForAll( func(mixer fr.Element) bool { - + var expected {{ $.TJacobian }} - + // compute expected result with double and add var finalScalar,mixerBigInt big.Int finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) @@ -118,29 +124,81 @@ func TestMultiExp{{toUpper $.PointName}}(t *testing.T) { // mixer ensures that all the words of a fpElement are set var sampleScalars [nbSamples]fr.Element - + for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } - - results := make([]{{ $.TJacobian }}, len(cRange) + 1) + + results := make([]{{ $.TJacobian }}, len(cRange)) for i, c := range cRange { - scalars, _ := partitionScalars(sampleScalars[:], c, false, runtime.NumCPU()) - msmInner{{ $.TJacobian }}(&results[i], int(c), samplePoints[:], scalars, false) - if c == 16 { - // split the first chunk - msmInner{{ $.TJacobian }}(&results[len(results)-1], 16, samplePoints[:], scalars, true) - } + _innerMsm{{ $.UPointName }}(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) } for i:=1; i < len(results);i++ { if !results[i].Equal(&results[i-1]) { + t.Logf("result for c=%d != c=%d", cRange[i-1],cRange[i]) + return false + } + } + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[{{ $.UPointName }}] Multi exponentation (c in %v) of points at infinity should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + + var samplePointsZero [nbSamples]{{ $.TAffine }} + + var expected {{ $.TJacobian }} + + // compute expected result with double and add + var finalScalar, mixerBigInt big.Int + finalScalar.Mul(&scalar, mixer.ToBigIntRegular(&mixerBigInt)) + expected.ScalarMultiplication(&{{ toLower $.PointName }}Gen, &finalScalar) + + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + for i := 1; i <= nbSamples; i++ { + sampleScalars[i-1].SetUint64(uint64(i)). + Mul(&sampleScalars[i-1], &mixer) + samplePointsZero[i-1].setInfinity() + } + + results := make([]{{ $.TJacobian }}, len(cRange)) + for i, c := range cRange { + _innerMsm{{ $.UPointName }}(&results[i], c, samplePointsZero[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) return false } } - return true + return true + }, + genScalar, + )) + + properties.Property(fmt.Sprintf("[{{ $.UPointName }}] Multi exponentation (c in %v) with a vector of 0s as input should output a point at infinity", cRange), prop.ForAll( + func(mixer fr.Element) bool { + // mixer ensures that all the words of a fpElement are set + var sampleScalars [nbSamples]fr.Element + + + results := make([]{{ $.TJacobian }}, len(cRange)) + for i, c := range cRange { + _innerMsm{{ $.UPointName }}(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + for i := 0; i < len(results); i++ { + if !results[i].Z.IsZero() { + t.Logf("result for c=%d is not infinity", cRange[i]) + return false + } + } + return true }, genScalar, )) @@ -148,7 +206,7 @@ func TestMultiExp{{toUpper $.PointName}}(t *testing.T) { // note : this test is here as we expect to have a different multiExp than the above bucket method // for small number of points - properties.Property("[{{ toUpper $.PointName }}] Multi exponentation (<50points) should be consistent with sum of square", prop.ForAll( + properties.Property("[{{ $.UPointName }}] Multi exponentation (<50points) should be consistent with sum of square", prop.ForAll( func(mixer fr.Element) bool { var g {{ $.TJacobian }} @@ -160,8 +218,7 @@ func TestMultiExp{{toUpper $.PointName}}(t *testing.T) { for i := 1; i <= 30; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) samplePoints[i-1].FromJacobian(&g) g.AddAssign(&{{ toLower .PointName}}Gen) } @@ -186,7 +243,100 @@ func TestMultiExp{{toUpper $.PointName}}(t *testing.T) { } -func BenchmarkMultiExp{{ toUpper $.PointName }}(b *testing.B) { +func TestCrossMultiExp{{ $.UPointName }}(t *testing.T) { + const nbSamples = 1 << 14 + // multi exp points + var samplePoints [nbSamples]{{ $.TAffine }} + var g {{ $.TJacobian }} + g.Set(&{{ toLower $.PointName }}Gen) + for i := 1; i <= nbSamples; i++ { + samplePoints[i-1].FromJacobian(&g) + g.AddAssign(&{{ toLower $.PointName }}Gen) + } + + // sprinkle some points at infinity + rand.Seed(time.Now().UnixNano()) + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + samplePoints[rand.Intn(nbSamples)].setInfinity() + + + var sampleScalars [nbSamples]fr.Element + fillBenchScalars(sampleScalars[:]) + + // sprinkle some doublings + for i:=10; i < 100; i++ { + samplePoints[i] = samplePoints[0] + sampleScalars[i] = sampleScalars[0] + } + + // cRange is generated from template and contains the available parameters for the multiexp window size + {{- if eq $.PointName "g1" }} + cRange := []uint64{ + {{- range $c := $.CRange}}{{- if gt $c 1}}{{$c}},{{- end}}{{- end}} + } + if testing.Short() { + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange = []uint64{5, 16} + } + {{- else }} + // for g2, CI suffers with large c size since it needs to allocate a lot of memory for the buckets. + // test only "odd" and "even" (ie windows size divide word size vs not) + cRange := []uint64{5, 16} + {{- end}} + + results := make([]{{ $.TJacobian }}, len(cRange)) + for i, c := range cRange { + _innerMsm{{ $.UPointName }}(&results[i], c, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + } + + var r {{ $.TJacobian }} + _innerMsm{{ $.UPointName }}Reference(&r, samplePoints[:], sampleScalars[:], ecc.MultiExpConfig{NbTasks: runtime.NumCPU()}) + + var expected, got {{ $.TAffine}} + expected.FromJacobian(&r) + + for i:=0; i= 0; j-- { + processChunk := processChunk{{ $.UPointName }}Jacobian[bucket{{ $.TJacobianExtended }}C16] + go processChunk(uint64(j), chChunks[j], 16, points, digits[j*n:(j+1)*n]) + } + + return msmReduceChunk{{ $.TAffine }}(p, int(16), chChunks[:]) +} + + +func BenchmarkMultiExp{{ $.UPointName }}(b *testing.B) { const ( pow = (bits.UintSize / 2) - (bits.UintSize / 8) // 24 on 64 bits arch, 12 on 32 bits @@ -196,10 +346,33 @@ func BenchmarkMultiExp{{ toUpper $.PointName }}(b *testing.B) { var ( samplePoints [nbSamples]{{ $.TAffine }} sampleScalars [nbSamples]fr.Element + sampleScalarsSmallValues [nbSamples]fr.Element + sampleScalarsRedundant [nbSamples]fr.Element ) fillBenchScalars(sampleScalars[:]) - fillBenchBases{{ toUpper $.PointName }}(samplePoints[:]) + copy(sampleScalarsSmallValues[:],sampleScalars[:]) + copy(sampleScalarsRedundant[:],sampleScalars[:]) + + // this means first chunk is going to have more work to do and should be split into several go routines + for i:=0; i < len(sampleScalarsSmallValues);i++ { + if i % 5 == 0 { + sampleScalarsSmallValues[i].SetZero() + sampleScalarsSmallValues[i][0] = 1 + } + } + + // bad case for batch affine because scalar distribution might look uniform + // but over batchSize windows, we may hit a lot of conflicts and force the msm-affine + // to process small batches of additions to flush its queue of conflicted points. + for i:=0; i < len(sampleScalarsRedundant);i+=100 { + for j:=i+1; j < i+100 && j < len(sampleScalarsRedundant);j++ { + sampleScalarsRedundant[j] = sampleScalarsRedundant[i] + } + } + + fillBenchBases{{ $.UPointName }}(samplePoints[:]) + var testPoint {{ $.TAffine }} @@ -212,11 +385,25 @@ func BenchmarkMultiExp{{ toUpper $.PointName }}(b *testing.B) { testPoint.MultiExp(samplePoints[:using], sampleScalars[:using],ecc.MultiExpConfig{}) } }) + + b.Run(fmt.Sprintf("%d points-smallvalues", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsSmallValues[:using],ecc.MultiExpConfig{}) + } + }) + + b.Run(fmt.Sprintf("%d points-redundancy", using), func(b *testing.B) { + b.ResetTimer() + for j := 0; j < b.N; j++ { + testPoint.MultiExp(samplePoints[:using], sampleScalarsRedundant[:using],ecc.MultiExpConfig{}) + } + }) } } -func BenchmarkMultiExp{{ toUpper $.PointName }}Reference(b *testing.B) { +func BenchmarkMultiExp{{ $.UPointName }}Reference(b *testing.B) { const nbSamples = 1 << 20 var ( @@ -225,7 +412,7 @@ func BenchmarkMultiExp{{ toUpper $.PointName }}Reference(b *testing.B) { ) fillBenchScalars(sampleScalars[:]) - fillBenchBases{{ toUpper $.PointName }}(samplePoints[:]) + fillBenchBases{{ $.UPointName }}(samplePoints[:]) var testPoint {{ $.TAffine }} @@ -236,7 +423,7 @@ func BenchmarkMultiExp{{ toUpper $.PointName }}Reference(b *testing.B) { } -func BenchmarkManyMultiExp{{ toUpper $.PointName }}Reference(b *testing.B) { +func BenchmarkManyMultiExp{{ $.UPointName }}Reference(b *testing.B) { const nbSamples = 1 << 20 var ( @@ -245,7 +432,7 @@ func BenchmarkManyMultiExp{{ toUpper $.PointName }}Reference(b *testing.B) { ) fillBenchScalars(sampleScalars[:]) - fillBenchBases{{ toUpper $.PointName }}(samplePoints[:]) + fillBenchBases{{ $.UPointName }}(samplePoints[:]) var t1, t2, t3 {{ $.TAffine }} @@ -274,30 +461,28 @@ func BenchmarkManyMultiExp{{ toUpper $.PointName }}Reference(b *testing.B) { // // Rationale for generating points that are not on the curve is that for large benchmarks, generating // a vector of different points can take minutes. Using the same point or subset will bias the benchmark result -// since bucket additions in extended jacobian coordinates will hit doubling algorithm instead of add. -func fillBenchBases{{ toUpper $.PointName }}(samplePoints []{{ $.TAffine }}) { +// since bucket additions in extended jacobian coordinates will hit doubling algorithm instead of add. +func fillBenchBases{{ $.UPointName }}(samplePoints []{{ $.TAffine }}) { var r big.Int r.SetString("340444420969191673093399857471996460938405", 10) samplePoints[0].ScalarMultiplication(&samplePoints[0], &r) one := samplePoints[0].X one.SetOne() - + for i := 1; i < len(samplePoints); i++ { samplePoints[i].X.Add(&samplePoints[i-1].X, &one) - samplePoints[i].Y.Sub(&samplePoints[i-1].Y, &one) + samplePoints[i].Y.Sub(&samplePoints[i-1].Y, &one) } } + {{end }} + func fillBenchScalars(sampleScalars []fr.Element) { // ensure every words of the scalars are filled - var mixer fr.Element - mixer.SetString("7716837800905789770901243404444209691916730933998574719964609384059111546487") - for i := 1; i <= len(sampleScalars); i++ { - sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + for i := 0; i < len(sampleScalars); i++ { + sampleScalars[i].SetRandom() } -} \ No newline at end of file +} diff --git a/internal/generator/ecc/template/tests/point.go.tmpl b/internal/generator/ecc/template/tests/point.go.tmpl index 556d9befc..161177d0f 100644 --- a/internal/generator/ecc/template/tests/point.go.tmpl +++ b/internal/generator/ecc/template/tests/point.go.tmpl @@ -16,6 +16,7 @@ import ( "fmt" "math/big" "testing" + "math/rand" {{if or (eq .CoordType "fptower.E2") (eq .CoordType "fptower.E4")}} "github.com/consensys/gnark-crypto/ecc/{{.Name}}/internal/fptower" @@ -370,7 +371,7 @@ func Test{{ $TAffine }}Ops(t *testing.T) { r := fr.Modulus() var g {{ $TJacobian }} - g.mulGLV(&{{.PointName}}Gen, r) + g.ScalarMultiplication(&{{.PointName}}Gen, r) var scalar, blindedScalar, rminusone big.Int var op1, op2, op3, gneg {{ $TJacobian }} @@ -518,8 +519,7 @@ func Test{{ $TAffine }}BatchScalarMultiplication(t *testing.T) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } result := BatchScalarMultiplication{{ toUpper .PointName }}(&{{.PointName}}GenAff, sampleScalars[:]) @@ -532,7 +532,7 @@ func Test{{ $TAffine }}BatchScalarMultiplication(t *testing.T) { var expectedJac {{ $TJacobian }} var expected {{ $TAffine }} var b big.Int - expectedJac.mulGLV(&{{.PointName}}Gen, sampleScalars[i].ToBigInt(&b)) + expectedJac.ScalarMultiplication(&{{.PointName}}Gen, sampleScalars[i].ToBigIntRegular(&b)) expected.FromJacobian(&expectedJac) if !result[i].Equal(&expected) { return false @@ -559,6 +559,33 @@ func Benchmark{{ $TJacobian }}IsInSubGroup(b *testing.B) { } +func BenchmarkBatchAdd{{ $TAffine }}(b *testing.B) { + {{$c := 16}} + var P, R p{{$TAffine}}C{{$c}} + var RR pp{{$TAffine}}C{{$c}} + ridx := make([]int, len(P)) + + // TODO P == R may produce skewed benches + fillBenchBases{{ toUpper $.PointName }}(P[:]) + fillBenchBases{{ toUpper $.PointName }}(R[:]) + + for i:=0; i < len(ridx);i++ { + ridx[i] = i + } + + // random permute + rand.Shuffle(len(ridx), func(i, j int) { ridx[i], ridx[j] = ridx[j], ridx[i] }) + + for i, ri := range ridx { + RR[i] = &R[ri] + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + batchAdd{{ $TAffine }}[p{{$TAffine}}C{{$c}}, pp{{$TAffine}}C{{$c}}, c{{$TAffine}}C{{$c}}](&RR, &P, len(P)) + } +} + func Benchmark{{ $TAffine }}BatchScalarMultiplication(b *testing.B) { // ensure every words of the scalars are filled var mixer fr.Element @@ -571,8 +598,7 @@ func Benchmark{{ $TAffine }}BatchScalarMultiplication(b *testing.B) { for i := 1; i <= nbSamples; i++ { sampleScalars[i-1].SetUint64(uint64(i)). - Mul(&sampleScalars[i-1], &mixer). - FromMont() + Mul(&sampleScalars[i-1], &mixer) } for i := 5; i <= pow; i++ { diff --git a/internal/generator/edwards/eddsa/template/eddsa.test.go.tmpl b/internal/generator/edwards/eddsa/template/eddsa.test.go.tmpl index 09a2b32e7..0f0f04eff 100644 --- a/internal/generator/edwards/eddsa/template/eddsa.test.go.tmpl +++ b/internal/generator/edwards/eddsa/template/eddsa.test.go.tmpl @@ -20,8 +20,10 @@ func Example() { privateKey, _ := GenerateKey(crand.Reader) publicKey := privateKey.PublicKey - // note that the message is on 4 bytes - msg := []byte{0xde, 0xad, 0xf0, 0x0d} + // generate a message (the size must be a multiple of the size of Fr) + var _msg fr.Element + _msg.SetRandom() + msg := _msg.Marshal() // sign the message signature, _ := privateKey.Sign(msg, hFunc) diff --git a/internal/generator/edwards/template/point.go.tmpl b/internal/generator/edwards/template/point.go.tmpl index 9bb9a7244..7c2249d53 100644 --- a/internal/generator/edwards/template/point.go.tmpl +++ b/internal/generator/edwards/template/point.go.tmpl @@ -33,7 +33,7 @@ const ( mUnmask = 0x7f // size in byte of a compressed point (point.Y --> fr.Element) - sizePointCompressed = fr.Limbs * 8 + sizePointCompressed = fr.Bytes ) // Bytes returns the compressed point as a byte array diff --git a/internal/generator/fft/generate.go b/internal/generator/fft/generate.go index 7fc01f599..0d83ff601 100644 --- a/internal/generator/fft/generate.go +++ b/internal/generator/fft/generate.go @@ -8,6 +8,10 @@ import ( ) func Generate(conf config.Curve, baseDir string, bgen *bavard.BatchGenerator) error { + if conf.Equal(config.SECP256K1) { + return nil + } + conf.Package = "fft" entries := []bavard.Entry{ {File: filepath.Join(baseDir, "doc.go"), Templates: []string{"doc.go.tmpl"}}, diff --git a/internal/generator/fft/template/domain.go.tmpl b/internal/generator/fft/template/domain.go.tmpl index bdabf4a3e..82201dd45 100644 --- a/internal/generator/fft/template/domain.go.tmpl +++ b/internal/generator/fft/template/domain.go.tmpl @@ -88,6 +88,10 @@ func NewDomain(m uint64) *Domain { {{else if eq .Name "bls24-317"}} rootOfUnity.SetString("16532287748948254263922689505213135976137839535221842169193829039521719560631") const maxOrderRoot uint64 = 60 + domain.FrMultiplicativeGen.SetUint64(7) + {{else if eq .Name "secp256k1"}} + rootOfUnity.SetString("78074008874160198520644763525212887401909906723592317393988542598630163514319") + const maxOrderRoot uint64 = 6 domain.FrMultiplicativeGen.SetUint64(7) {{end}} diff --git a/internal/generator/fft/template/imports.go.tmpl b/internal/generator/fft/template/imports.go.tmpl index ab88ba81a..4f6e4ee1e 100644 --- a/internal/generator/fft/template/imports.go.tmpl +++ b/internal/generator/fft/template/imports.go.tmpl @@ -18,6 +18,8 @@ "github.com/consensys/gnark-crypto/ecc/bls24-315/fr" {{ else if eq .Name "bls24-317"}} "github.com/consensys/gnark-crypto/ecc/bls24-317/fr" +{{ else if eq .Name "secp256k1"}} + "github.com/consensys/gnark-crypto/ecc/secp256k1/fr" {{end}} {{end}} @@ -41,6 +43,8 @@ curve "github.com/consensys/gnark-crypto/ecc/bls24-315" {{ else if eq .Name "bls24-317"}} curve "github.com/consensys/gnark-crypto/ecc/bls24-317" +{{ else if eq .Name "secp256k1"}} + curve "github.com/consensys/gnark-crypto/ecc/secp256k1" {{end}} diff --git a/internal/generator/fri/template/generate.go b/internal/generator/fri/template/generate.go index 3ac33c51a..c1fb45427 100644 --- a/internal/generator/fri/template/generate.go +++ b/internal/generator/fri/template/generate.go @@ -8,6 +8,9 @@ import ( ) func Generate(conf config.Curve, baseDir string, bgen *bavard.BatchGenerator) error { + if conf.Equal(config.SECP256K1) { + return nil + } // fri commitment scheme conf.Package = "fri" diff --git a/internal/generator/gkr/generate.go b/internal/generator/gkr/generate.go new file mode 100644 index 000000000..02fc45568 --- /dev/null +++ b/internal/generator/gkr/generate.go @@ -0,0 +1,29 @@ +package gkr + +import ( + "path/filepath" + + "github.com/consensys/bavard" + "github.com/consensys/gnark-crypto/internal/generator/config" +) + +type Config struct { + config.FieldDependency + GenerateTests bool + RetainTestCaseRawInfo bool + OutsideGkrPackage bool + TestVectorsRelativePath string +} + +func Generate(conf Config, baseDir string, bgen *bavard.BatchGenerator) error { + entries := []bavard.Entry{ + {File: filepath.Join(baseDir, "gkr.go"), Templates: []string{"gkr.go.tmpl"}}, + } + + if conf.GenerateTests { + entries = append(entries, + bavard.Entry{File: filepath.Join(baseDir, "gkr_test.go"), Templates: []string{"gkr.test.go.tmpl", "gkr.test.vectors.go.tmpl"}}) + } + + return bgen.Generate(conf, "gkr", "./gkr/template/", entries...) +} diff --git a/internal/generator/gkr/template/gkr.go.tmpl b/internal/generator/gkr/template/gkr.go.tmpl new file mode 100644 index 000000000..e69a0c647 --- /dev/null +++ b/internal/generator/gkr/template/gkr.go.tmpl @@ -0,0 +1,758 @@ +import ( + "fmt" + "{{.FieldPackagePath}}" + "{{.FieldPackagePath}}/polynomial" + "{{.FieldPackagePath}}/sumcheck" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" + "sync" +) + +{{$topologicalSort := select (eq .ElementType "fr.Element") "TopologicalSort" "topologicalSort"}} + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// Gate must be a low-degree polynomial +type Gate interface { + Evaluate(...{{.ElementType}}) {{.ElementType}} + Degree() int +} + +type Wire struct { + Gate Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]{{.ElementType}} + claimedEvaluations []{{.ElementType}} + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a {{.ElementType}}) {{.ElementType}} { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]{{.ElementType}}) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation {{.ElementType}} + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]{{.ElementType}}, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]{{.ElementType}} // x in the paper + claimedEvaluations []{{.ElementType}} // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff {{.ElementType}}) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + newEq.Eq(c.evaluationPoints[k]) + + eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// computeValAndStep returns val : i ↦ m(1, i...) and step : i ↦ m(1, i...) - m(0, i...) +func computeValAndStep(m polynomial.MultiLin, p *polynomial.Pool) (val polynomial.MultiLin, step polynomial.MultiLin) { + val = p.Clone(m[len(m)/2:]) + step = p.Clone(m[:len(m)/2]) + + valAsPoly, stepAsPoly := polynomial.Polynomial(val), polynomial.Polynomial(step) + + stepAsPoly.Sub(valAsPoly, stepAsPoly) + return +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = g_{j-1}(r_{j-1}). By convention, g_0 is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() (gJ polynomial.Polynomial) { + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + EVal, EStep := computeValAndStep(c.eq, c.manager.memPool) + + puVal := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO: Make a two-dimensional array struct, and index it i-first rather than inputI first: would result in scanning memory access in the "d" loop and obviate the gateInput variable + puStep := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO, ctd: the greater degGJ, the more this would matter + + for i, puI := range c.inputPreprocessors { + puVal[i], puStep[i] = computeValAndStep(puI, c.manager.memPool) + } + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + gJ = make([]{{.ElementType}}, degGJ) + + parallel := len(EVal) >= 1024 //TODO: Experiment with threshold + + var gateInput [][]{{.ElementType}} + + if parallel { + gateInput = [][]{{.ElementType}}{c.manager.memPool.Make(len(c.inputPreprocessors)), + c.manager.memPool.Make(len(c.inputPreprocessors))} + } else { + gateInput = [][]{{.ElementType}}{c.manager.memPool.Make(len(c.inputPreprocessors))} + } + + var wg sync.WaitGroup + + for d := 0; d < degGJ; d++ { + + notLastIteration := d+1 < degGJ + + sumOverI := func(res *{{.ElementType}}, gateInput []{{.ElementType}}, start, end int) { + for i := start; i < end; i++ { + + for inputI := range puVal { + gateInput[inputI].Set(&puVal[inputI][i]) + if notLastIteration { + puVal[inputI][i].Add(&puVal[inputI][i], &puStep[inputI][i]) + } + } + + // gJAtDI = gJ(d, i...) + gJAtDI := c.wire.Gate.Evaluate(gateInput...) + gJAtDI.Mul(&gJAtDI, &EVal[i]) + + res.Add(res, &gJAtDI) + + if notLastIteration { + EVal[i].Add(&EVal[i], &EStep[i]) + } + } + wg.Done() + } + + if parallel { + var firstHalf, secondHalf {{.ElementType}} + wg.Add(2) + go sumOverI(&secondHalf, gateInput[1], len(EVal)/2, len(EVal)) + go sumOverI(&firstHalf, gateInput[0], 0, len(EVal)/2) + wg.Wait() + gJ[d].Add(&firstHalf, &secondHalf) + } else { + wg.Add(1) // formalities + sumOverI(&gJ[d], gateInput[0], 0, len(EVal)) + } + } + + c.manager.memPool.Dump(gateInput...) + c.manager.memPool.Dump(EVal, EStep) + + for inputI := range puVal { + c.manager.memPool.Dump(puVal[inputI], puStep[inputI]) + } + + return +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element {{.ElementType}}) polynomial.Polynomial { + c.eq.Fold(element) + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []{{.ElementType}}) interface{} { + + //defer the proof, return list of claims + evaluations := make([]{{.ElementType}}, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, pool *polynomial.Pool) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = pool + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]{{.ElementType}}, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []{{.ElementType}}, evaluation {{.ElementType}}) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func (options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< b { + return a + } + return b +} + +func ChallengeNames(sorted []*Wire, logNbInstances int, prefix string) []string { + + // Pre-compute the size TODO: Consider not doing this and just grow the list by appending + size := logNbInstances // first challenge + + for _, w := range sorted { + if w.noProof() { // no proof, no challenge + continue + } + if w.NbClaims() > 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]{{.ElementType}}, error) { + res := make([]{{.ElementType}}, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + + claims := newClaimsManager(c, assignment, o.pool) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []{{.ElementType}} + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []{{.ElementType}}{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]{{.ElementType}}) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + + claims := newClaimsManager(c, assignment, o.pool) + + var firstChallenge []{{.ElementType}} + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]{{.ElementType}}) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return fmt.Errorf("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return fmt.Errorf("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +type IdentityGate struct{} + +func (IdentityGate) Evaluate(input ...{{.ElementType}}) {{.ElementType}} { + return input[0] +} + +func (IdentityGate) Degree() int { + return 1 +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = IdentityGate{} + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// {{$topologicalSort}} sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func {{$topologicalSort}}(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := {{$topologicalSort}}(c) + + numEvaluations := 0 + + for _, w := range sortedWires { + if !w.IsInput() { + if numEvaluations == 0 { + numEvaluations = len(a[w.Inputs[0]]) + } + evals := make([]{{.ElementType}}, numEvaluations) + ins := make([]{{.ElementType}}, len(w.Inputs)) + for k := 0; k < numEvaluations; k++ { + for inI, in := range w.Inputs { + ins[inI] = a[in][k] + } + evals[k] = w.Gate.Evaluate(ins...) + } + a[w] = evals + } + } + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} \ No newline at end of file diff --git a/internal/generator/gkr/template/gkr.test.go.tmpl b/internal/generator/gkr/template/gkr.test.go.tmpl new file mode 100644 index 000000000..0759a69a0 --- /dev/null +++ b/internal/generator/gkr/template/gkr.test.go.tmpl @@ -0,0 +1,466 @@ + +import ( + "{{.FieldPackagePath}}" + "{{.FieldPackagePath}}/mimc" + "{{.FieldPackagePath}}/polynomial" + "{{.FieldPackagePath}}/sumcheck" + "{{.FieldPackagePath}}/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "fmt" + "os" + "strconv" + "testing" + "path/filepath" + "encoding/json" + "reflect" +) + +{{$GenerateLargeTests := .GenerateTests}} {{/* this is redundant. soon to be removed if a use case for it doesn't come back */}} +{{$topologicalSort := select (eq .ElementType "fr.Element") "TopologicalSort" "topologicalSort"}} + +func TestNoGateTwoInstances(t *testing.T) { + // Testing a single instance is not possible because the sumcheck implementation doesn't cover the trivial 0-variate case + testNoGate(t, []{{.ElementType}}{four, three}) +} + +func TestNoGate(t *testing.T) { + testManyInstances(t, 1, testNoGate) +} + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []{{.ElementType}}{four, three}, []{{.ElementType}}{two, three}) +} + +func TestSingleMulGate(t *testing.T) { + testManyInstances(t, 2, testSingleMulGate) +} + +func TestSingleInputTwoIdentityGatesTwoInstances(t *testing.T) { + + testSingleInputTwoIdentityGates(t, []{{.ElementType}}{two, three}) +} + +func TestSingleInputTwoIdentityGates(t *testing.T) { + + testManyInstances(t, 2, testSingleInputTwoIdentityGates) +} + +func TestSingleInputTwoIdentityGatesComposedTwoInstances(t *testing.T) { + testSingleInputTwoIdentityGatesComposed(t, []{{.ElementType}}{two, one}) +} + +func TestSingleInputTwoIdentityGatesComposed(t *testing.T) { + testManyInstances(t, 1, testSingleInputTwoIdentityGatesComposed) +} + +func TestSingleMimcCipherGateTwoInstances(t *testing.T) { + testSingleMimcCipherGate(t, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) +} + +func TestSingleMimcCipherGate(t *testing.T) { + testManyInstances(t, 2, testSingleMimcCipherGate) +} + +func TestATimesBSquaredTwoInstances(t *testing.T) { + testATimesBSquared(t, 2, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) +} + +func TestShallowMimcTwoInstances(t *testing.T) { + testMimc(t, 2, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) +} + +{{- if $GenerateLargeTests}} +func TestMimcTwoInstances(t *testing.T) { + testMimc(t, 93, []{{.ElementType}}{one, one}, []{{.ElementType}}{one, two}) +} + +func TestMimc(t *testing.T) { + testManyInstances(t, 2, generateTestMimc(93)) +} + +func generateTestMimc(numRounds int) func(*testing.T, ...[]{{.ElementType}}) { + return func(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + testMimc(t, numRounds, inputAssignments...) + } +} + +{{- end}} + +func TestSumcheckFromSingleInputTwoIdentityGatesGateTwoInstances(t *testing.T) { + circuit := Circuit{ Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{}, + nbUniqueOutputs: 2, + } } + + wire := &circuit[0] + + assignment := WireAssignment{&circuit[0]: []{{.ElementType}}{two, three}} + pool := polynomial.NewPool(256, 1<<11) + + claimsManagerGen := func() *claimsManager { + manager := newClaimsManager(circuit, assignment, &pool) + manager.add(wire, []{{.ElementType}}{three}, five) + manager.add(wire, []{{.ElementType}}{four}, six) + return &manager + } + + transcriptGen := test_vector_utils.NewMessageCounterGenerator(4, 1) + + proof, err := sumcheck.Prove(claimsManagerGen().getClaim(wire), fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) + err = sumcheck.Verify(claimsManagerGen().getLazyClaim(wire), proof, fiatshamir.WithHash(transcriptGen(), nil)) + assert.NoError(t, err) +} + +var one, two, three, four, five, six {{.ElementType}} + +func init() { + one.SetOne() + two.Double(&one) + three.Add(&two, &one) + four.Double(&two) + five.Add(&three, &two) + six.Double(&three) +} + +var testManyInstancesLogMaxInstances = -1 + +func getLogMaxInstances(t *testing.T) int { + if testManyInstancesLogMaxInstances == -1 { + + s := os.Getenv("GKR_LOG_INSTANCES") + if s == "" { + testManyInstancesLogMaxInstances = 5 + } else { + var err error + testManyInstancesLogMaxInstances, err = strconv.Atoi(s) + if err != nil { + t.Error(err) + } + } + + } + return testManyInstancesLogMaxInstances +} + +func testManyInstances(t *testing.T, numInput int, test func(*testing.T, ...[]{{.ElementType}})) { + fullAssignments := make([][]{{.ElementType}}, numInput) + maxSize := 1 << getLogMaxInstances(t) + + t.Log("Entered test orchestrator, assigning and randomizing inputs") + + for i := range fullAssignments { + fullAssignments[i] = make([]fr.Element, maxSize) + setRandom(fullAssignments[i]) + } + + inputAssignments := make([][]{{.ElementType}}, numInput) + for numEvals := maxSize; numEvals <= maxSize; numEvals *= 2 { + for i, fullAssignment := range fullAssignments { + inputAssignments[i] = fullAssignment[:numEvals] + } + + t.Log("Selected inputs for test") + test(t, inputAssignments...) + } +} + +func testNoGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + c := Circuit{ + { + Inputs: []*Wire{}, + Gate: nil, + }, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]} + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + // Even though a hash is called here, the proof is empty + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleInputTwoIdentityGates(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func testSingleMimcCipherGate(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + c := make(Circuit, 3) + + c[2] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + t.Log("Proof complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification complete") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification complete") +} + +func testSingleInputTwoIdentityGatesComposed(t *testing.T, inputAssignments ...[]{{.ElementType}}) { + c := make(Circuit, 3) + + c[1] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[0]}, + } + c[2] = Wire{ + Gate: IdentityGate{}, + Inputs: []*Wire{&c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func mimcCircuit(numRounds int) Circuit { + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mimcCipherGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + return c +} + +func testMimc(t *testing.T, numRounds int, inputAssignments ...[]{{.ElementType}}) { + //TODO: Implement mimc correctly. Currently, the computation is mimc(a,b) = cipher( cipher( ... cipher(a, b), b) ..., b) + // @AlexandreBelling: Please explain the extra layers in https://github.com/ConsenSys/gkr-mimc/blob/81eada039ab4ed403b7726b535adb63026e8011f/examples/mimc.go#L10 + + c := mimcCircuit(numRounds) + + t.Log("Evaluating all circuit wires") + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + t.Log("Circuit evaluation complete") + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + t.Log("Proof finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + t.Log("Successful verification finished") + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") + t.Log("Unsuccessful verification finished") +} + +func testATimesBSquared(t *testing.T, numRounds int, inputAssignments ...[]{{.ElementType}}) { + // This imitates the MiMC circuit + + c := make(Circuit, numRounds+2) + + for i := 2; i < len(c); i++ { + c[i] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[i-1], &c[0]}, + } + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +func setRandom(slice []{{.ElementType}}) { + for i := range slice { + slice[i].SetRandom() + } +} + +func generateTestProver(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + proof, err := Prove(testCase.Circuit, testCase.FullAssignment, testCase.transcriptSetting()) + assert.NoError(t, err) + assert.NoError(t, proofEquals(testCase.Proof, proof)) + } +} + +func generateTestVerifier(path string) func(t *testing.T) { + return func(t *testing.T) { + testCase, err := newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, testCase.transcriptSetting()) + assert.NoError(t, err, "proof rejected") + testCase, err = newTestCase(path) + assert.NoError(t, err) + err = Verify(testCase.Circuit, testCase.InOutAssignment, testCase.Proof, fiatshamir.WithHash(&test_vector_utils.MapHash{Map: testCase.Hash}, []byte{1})) + assert.NotNil(t, err, "bad proof accepted") + } +} + +func TestGkrVectors(t *testing.T) { + + testDirPath := "{{.TestVectorsRelativePath}}" + dirEntries, err := os.ReadDir(testDirPath) + assert.NoError(t, err) + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + path := filepath.Join(testDirPath, dirEntry.Name()) + noExt := dirEntry.Name()[:len(dirEntry.Name())-len(".json")] + + t.Run(noExt+"_prover", generateTestProver(path)) + t.Run(noExt+"_verifier", generateTestVerifier(path)) + + } + } + } +} + +func proofEquals(expected Proof, seen Proof) error { + if len(expected) != len(seen) { + return fmt.Errorf("length mismatch %d ≠ %d", len(expected), len(seen)) + } + for i, x := range expected { + xSeen := seen[i] + + if xSeen.FinalEvalProof == nil { + if seenFinalEval := x.FinalEvalProof.([]fr.Element); len(seenFinalEval) != 0 { + return fmt.Errorf("length mismatch %d ≠ %d", 0, len(seenFinalEval)) + } + } else { + if err := test_vector_utils.SliceEquals(x.FinalEvalProof.([]fr.Element), xSeen.FinalEvalProof.([]fr.Element)); err != nil { + return fmt.Errorf("final evaluation proof mismatch") + } + } + if err := test_vector_utils.PolynomialSliceEquals(x.PartialSumPolys, xSeen.PartialSumPolys); err != nil { + return err + } + } + return nil +} + +func BenchmarkGkrMimc(b *testing.B) { + const N = 1 << 19 + fmt.Println("creating circuit structure") + c := mimcCircuit(91) + + in0 := make([]fr.Element, N) + in1 := make([]fr.Element, N) + setRandom(in0) + setRandom(in1) + + fmt.Println("evaluating circuit") + assignment := WireAssignment{&c[0]: in0, &c[1]: in1}.Complete(c) + + //b.ResetTimer() + fmt.Println("constructing proof") + Prove(c, assignment, fiatshamir.WithHash(mimc.NewMiMC())) +} + +func TestTopSortTrivial(t *testing.T) { + c := make(Circuit, 2) + c[0].Inputs = []*Wire{&c[1]} + sorted := {{$topologicalSort}}(c) + assert.Equal(t, []*Wire{&c[1], &c[0]}, sorted) +} + +func TestTopSortDeep(t *testing.T) { + c := make(Circuit, 4) + c[0].Inputs = []*Wire{&c[2]} + c[1].Inputs = []*Wire{&c[3]} + c[2].Inputs = []*Wire{} + c[3].Inputs = []*Wire{&c[0]} + sorted := {{$topologicalSort}}(c) + assert.Equal(t, []*Wire{&c[2], &c[0], &c[3], &c[1]}, sorted) +} + +func TestTopSortWide(t *testing.T) { + c := make(Circuit, 10) + c[0].Inputs = []*Wire{&c[3], &c[8]} + c[1].Inputs = []*Wire{&c[6]} + c[2].Inputs = []*Wire{&c[4]} + c[3].Inputs = []*Wire{} + c[4].Inputs = []*Wire{} + c[5].Inputs = []*Wire{&c[9]} + c[6].Inputs = []*Wire{&c[9]} + c[7].Inputs = []*Wire{&c[9], &c[5], &c[2]} + c[8].Inputs = []*Wire{&c[4], &c[3]} + c[9].Inputs = []*Wire{} + + sorted := {{$topologicalSort}}(c) + sortedExpected := []*Wire{&c[3], &c[4], &c[2], &c[8], &c[0], &c[9], &c[5], &c[6], &c[1], &c[7]} + + assert.Equal(t, sortedExpected, sorted) +} + +{{template "gkrTestVectors" .}} \ No newline at end of file diff --git a/internal/generator/gkr/template/gkr.test.vectors.gen.go.tmpl b/internal/generator/gkr/template/gkr.test.vectors.gen.go.tmpl new file mode 100644 index 000000000..4f8fe1524 --- /dev/null +++ b/internal/generator/gkr/template/gkr.test.vectors.gen.go.tmpl @@ -0,0 +1,119 @@ +import ( + "encoding/json" + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/gkr" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils" + "os" + "path/filepath" + "reflect" +) + +func main() { + if err := func() error { + if err := GenerateVectors(); err != nil { + return err + } + return test_vector_utils.SaveUsedHashEntries() + }(); err != nil { + fmt.Println(err.Error()) + os.Exit(-1) + } +} + +func GenerateVectors() error { + testDirPath, err := filepath.Abs("gkr/test_vectors") + if err != nil { + return err + } + + fmt.Printf("generating GKR test cases: scanning directory %s for test specs\n", testDirPath) + + dirEntries, err := os.ReadDir(testDirPath) + if err != nil { + return err + } + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + fmt.Println("\tprocessing", dirEntry.Name()) + path := filepath.Join(testDirPath, dirEntry.Name()) + if err = run(path); err != nil { + return err + } + } + } + } + + return nil +} + +func run(absPath string) error { + testCase, err := newTestCase(absPath) + if err != nil { + return err + } + + var proof gkr.Proof + proof, err = gkr.Prove(testCase.Circuit, testCase.FullAssignment, testCase.transcriptSetting()) + if err != nil { + return err + } + + if testCase.Info.Proof, err = toPrintableProof(proof); err != nil { + return err + } + var outBytes []byte + if outBytes, err = json.MarshalIndent(testCase.Info, "", "\t"); err == nil { + if err = os.WriteFile(absPath, outBytes, 0); err != nil { + return err + } + } else { + return err + } + + testCase, err = newTestCase(absPath) + if err != nil { + return err + } + + err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, testCase.transcriptSetting()) + if err != nil { + return err + } + + testCase, err = newTestCase(absPath) + if err != nil { + return err + } + + err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, testCase.transcriptSetting([]byte{0, 1})) + if err == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func toPrintableProof(proof gkr.Proof) (PrintableProof, error) { + res := make(PrintableProof, len(proof)) + + for i := range proof { + + partialSumPolys := make([][]interface{}, len(proof[i].PartialSumPolys)) + for k, partialK := range proof[i].PartialSumPolys { + partialSumPolys[k] = test_vector_utils.ElementSliceToInterfaceSlice(partialK) + } + + res[i] = PrintableSumcheckProof{ + FinalEvalProof: test_vector_utils.ElementSliceToInterfaceSlice(proof[i].FinalEvalProof), + PartialSumPolys: partialSumPolys, + } + } + return res, nil +} + +{{template "gkrTestVectors" .}} \ No newline at end of file diff --git a/internal/generator/gkr/template/gkr.test.vectors.go.tmpl b/internal/generator/gkr/template/gkr.test.vectors.go.tmpl new file mode 100644 index 000000000..253b90910 --- /dev/null +++ b/internal/generator/gkr/template/gkr.test.vectors.go.tmpl @@ -0,0 +1,277 @@ +{{define "gkrTestVectors"}} + +{{$GkrPackagePrefix := select .OutsideGkrPackage "" "gkr."}} +{{$CheckOutputCorrectness := not .OutsideGkrPackage}} + +{{$Circuit := print $GkrPackagePrefix "Circuit"}} +{{$Gate := print $GkrPackagePrefix "Gate"}} +{{$Proof := print $GkrPackagePrefix "Proof"}} +{{$WireAssignment := print $GkrPackagePrefix "WireAssignment"}} +{{$Wire := print $GkrPackagePrefix "Wire"}} +{{$CircuitLayer := print $GkrPackagePrefix "CircuitLayer"}} + +type WireInfo struct { + Gate string `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]{{$Circuit}}) + +func getCircuit(path string) ({{$Circuit}}, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit {{$Circuit}}) { + circuit = make({{$Circuit}}, len(c)) + for i := range c { + circuit[i].Gate = gates[c[i].Gate] + circuit[i].Inputs = make([]*{{$Wire}}, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +var gates map[string]{{$Gate}} + +func init() { + gates = make(map[string]{{$Gate}}) + gates["identity"] = {{$GkrPackagePrefix}}IdentityGate{} + gates["mul"] = mulGate{} + gates["mimc"] = mimcCipherGate{} //TODO: Add ark + gates["select-input-3"] = _select(2) +} + +type mimcCipherGate struct { + ark {{.ElementType}} +} + +func (m mimcCipherGate) Evaluate(input ...{{.ElementType}}) (res {{.ElementType}}) { + var sum {{.ElementType}} + + sum. + Add(&input[0], &input[1]). + Add(&sum, &m.ark) + + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +func (m mimcCipherGate) Degree() int { + return 7 +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) ({{$Proof}}, error) { + proof := make({{$Proof}}, len(printable)) + for i := range printable { + finalEvalProof := []{{.ElementType}}(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]{{.ElementType}}, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := {{ setElement "finalEvalProof[k]" "finalEvalSlice.Index(k).Interface()" .ElementType}}; err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit {{$Circuit}} + Hash *test_vector_utils.ElementMap + Proof {{$Proof}} + FullAssignment {{$WireAssignment}} + InOutAssignment {{$WireAssignment}} + {{if .RetainTestCaseRawInfo}}Info TestCaseInfo{{end}} +} + +type TestCaseInfo struct { + Hash string `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit {{$Circuit}} + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash *test_vector_utils.ElementMap + if _hash, err = test_vector_utils.ElementMapFromFile(filepath.Join(dir, info.Hash)); err != nil { + return nil, err + } + var proof {{$Proof}} + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make({{$WireAssignment}}) + inOutAssignment := make({{$WireAssignment}}) + + sorted := {{select .OutsideGkrPackage "t" "gkr.T"}}opologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []{{.ElementType}} + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + {{if not $CheckOutputCorrectness}} + info.Output = make([][]interface{}, 0, outI) + {{end}} + + for _, w := range sorted { + if w.IsOutput() { + {{if $CheckOutputCorrectness}} + if err = test_vector_utils.SliceEquals(inOutAssignment[w], fullAssignment[w]); err != nil { + return nil, fmt.Errorf("assignment mismatch: %v", err) + } + {{else}} + info.Output = append(info.Output, test_vector_utils.ElementSliceToInterfaceSlice(inOutAssignment[w])) + {{end}} + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + {{if .RetainTestCaseRawInfo }}Info: info,{{end}} + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func (c *TestCase) transcriptSetting(initialChallenge ...[]byte) fiatshamir.Settings { + return fiatshamir.WithHash(&test_vector_utils.MapHash{Map: c.Hash}, initialChallenge...) +} + +type mulGate struct{} + +func (g mulGate) Evaluate(element ...{{.ElementType}}) (result {{.ElementType}}) { + result.Mul(&element[0], &element[1]) + return +} + +func (g mulGate) Degree() int { + return 2 +} + +type _select int + +func (g _select) Evaluate(in ...{{.ElementType}}) {{.ElementType}} { + return in[g] +} + +func (g _select) Degree() int { + return 1 +} + +{{end}} + +{{- define "setElement element value elementType"}} +{{- if eq .elementType "fr.Element"}} test_vector_utils.SetElement(&{{.element}}, {{.value}}) +{{- else if eq .elementType "small_rational.SmallRational"}} {{.element}}.SetInterface({{.value}}) +{{- else}} +{{print "\"UNEXPECTED TYPE" .elementType "\""}} +{{- end}} +{{- end}} \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/main.go b/internal/generator/gkr/test_vectors/main.go new file mode 100644 index 000000000..1515bef8d --- /dev/null +++ b/internal/generator/gkr/test_vectors/main.go @@ -0,0 +1,385 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package main + +import ( + "encoding/json" + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/gkr" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils" + "os" + "path/filepath" + "reflect" +) + +func main() { + if err := func() error { + if err := GenerateVectors(); err != nil { + return err + } + return test_vector_utils.SaveUsedHashEntries() + }(); err != nil { + fmt.Println(err.Error()) + os.Exit(-1) + } +} + +func GenerateVectors() error { + testDirPath, err := filepath.Abs("gkr/test_vectors") + if err != nil { + return err + } + + fmt.Printf("generating GKR test cases: scanning directory %s for test specs\n", testDirPath) + + dirEntries, err := os.ReadDir(testDirPath) + if err != nil { + return err + } + for _, dirEntry := range dirEntries { + if !dirEntry.IsDir() { + + if filepath.Ext(dirEntry.Name()) == ".json" { + fmt.Println("\tprocessing", dirEntry.Name()) + path := filepath.Join(testDirPath, dirEntry.Name()) + if err = run(path); err != nil { + return err + } + } + } + } + + return nil +} + +func run(absPath string) error { + testCase, err := newTestCase(absPath) + if err != nil { + return err + } + + var proof gkr.Proof + proof, err = gkr.Prove(testCase.Circuit, testCase.FullAssignment, testCase.transcriptSetting()) + if err != nil { + return err + } + + if testCase.Info.Proof, err = toPrintableProof(proof); err != nil { + return err + } + var outBytes []byte + if outBytes, err = json.MarshalIndent(testCase.Info, "", "\t"); err == nil { + if err = os.WriteFile(absPath, outBytes, 0); err != nil { + return err + } + } else { + return err + } + + testCase, err = newTestCase(absPath) + if err != nil { + return err + } + + err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, testCase.transcriptSetting()) + if err != nil { + return err + } + + testCase, err = newTestCase(absPath) + if err != nil { + return err + } + + err = gkr.Verify(testCase.Circuit, testCase.InOutAssignment, proof, testCase.transcriptSetting([]byte{0, 1})) + if err == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func toPrintableProof(proof gkr.Proof) (PrintableProof, error) { + res := make(PrintableProof, len(proof)) + + for i := range proof { + + partialSumPolys := make([][]interface{}, len(proof[i].PartialSumPolys)) + for k, partialK := range proof[i].PartialSumPolys { + partialSumPolys[k] = test_vector_utils.ElementSliceToInterfaceSlice(partialK) + } + + res[i] = PrintableSumcheckProof{ + FinalEvalProof: test_vector_utils.ElementSliceToInterfaceSlice(proof[i].FinalEvalProof), + PartialSumPolys: partialSumPolys, + } + } + return res, nil +} + +type WireInfo struct { + Gate string `json:"gate"` + Inputs []int `json:"inputs"` +} + +type CircuitInfo []WireInfo + +var circuitCache = make(map[string]gkr.Circuit) + +func getCircuit(path string) (gkr.Circuit, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if circuit, ok := circuitCache[path]; ok { + return circuit, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var circuitInfo CircuitInfo + if err = json.Unmarshal(bytes, &circuitInfo); err == nil { + circuit := circuitInfo.toCircuit() + circuitCache[path] = circuit + return circuit, nil + } else { + return nil, err + } + } else { + return nil, err + } +} + +func (c CircuitInfo) toCircuit() (circuit gkr.Circuit) { + circuit = make(gkr.Circuit, len(c)) + for i := range c { + circuit[i].Gate = gates[c[i].Gate] + circuit[i].Inputs = make([]*gkr.Wire, len(c[i].Inputs)) + for k, inputCoord := range c[i].Inputs { + input := &circuit[inputCoord] + circuit[i].Inputs[k] = input + } + } + return +} + +var gates map[string]gkr.Gate + +func init() { + gates = make(map[string]gkr.Gate) + gates["identity"] = gkr.IdentityGate{} + gates["mul"] = mulGate{} + gates["mimc"] = mimcCipherGate{} //TODO: Add ark + gates["select-input-3"] = _select(2) +} + +type mimcCipherGate struct { + ark small_rational.SmallRational +} + +func (m mimcCipherGate) Evaluate(input ...small_rational.SmallRational) (res small_rational.SmallRational) { + var sum small_rational.SmallRational + + sum. + Add(&input[0], &input[1]). + Add(&sum, &m.ark) + + res.Square(&sum) // sum^2 + res.Mul(&res, &sum) // sum^3 + res.Square(&res) //sum^6 + res.Mul(&res, &sum) //sum^7 + + return +} + +func (m mimcCipherGate) Degree() int { + return 7 +} + +type PrintableProof []PrintableSumcheckProof + +type PrintableSumcheckProof struct { + FinalEvalProof interface{} `json:"finalEvalProof"` + PartialSumPolys [][]interface{} `json:"partialSumPolys"` +} + +func unmarshalProof(printable PrintableProof) (gkr.Proof, error) { + proof := make(gkr.Proof, len(printable)) + for i := range printable { + finalEvalProof := []small_rational.SmallRational(nil) + + if printable[i].FinalEvalProof != nil { + finalEvalSlice := reflect.ValueOf(printable[i].FinalEvalProof) + finalEvalProof = make([]small_rational.SmallRational, finalEvalSlice.Len()) + for k := range finalEvalProof { + if _, err := finalEvalProof[k].SetInterface(finalEvalSlice.Index(k).Interface()); err != nil { + return nil, err + } + } + } + + proof[i] = sumcheck.Proof{ + PartialSumPolys: make([]polynomial.Polynomial, len(printable[i].PartialSumPolys)), + FinalEvalProof: finalEvalProof, + } + for k := range printable[i].PartialSumPolys { + var err error + if proof[i].PartialSumPolys[k], err = test_vector_utils.SliceToElementSlice(printable[i].PartialSumPolys[k]); err != nil { + return nil, err + } + } + } + return proof, nil +} + +type TestCase struct { + Circuit gkr.Circuit + Hash *test_vector_utils.ElementMap + Proof gkr.Proof + FullAssignment gkr.WireAssignment + InOutAssignment gkr.WireAssignment + Info TestCaseInfo +} + +type TestCaseInfo struct { + Hash string `json:"hash"` + Circuit string `json:"circuit"` + Input [][]interface{} `json:"input"` + Output [][]interface{} `json:"output"` + Proof PrintableProof `json:"proof"` +} + +var testCases = make(map[string]*TestCase) + +func newTestCase(path string) (*TestCase, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + dir := filepath.Dir(path) + + tCase, ok := testCases[path] + if !ok { + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var info TestCaseInfo + err = json.Unmarshal(bytes, &info) + if err != nil { + return nil, err + } + + var circuit gkr.Circuit + if circuit, err = getCircuit(filepath.Join(dir, info.Circuit)); err != nil { + return nil, err + } + var _hash *test_vector_utils.ElementMap + if _hash, err = test_vector_utils.ElementMapFromFile(filepath.Join(dir, info.Hash)); err != nil { + return nil, err + } + var proof gkr.Proof + if proof, err = unmarshalProof(info.Proof); err != nil { + return nil, err + } + + fullAssignment := make(gkr.WireAssignment) + inOutAssignment := make(gkr.WireAssignment) + + sorted := gkr.TopologicalSort(circuit) + + inI, outI := 0, 0 + for _, w := range sorted { + var assignmentRaw []interface{} + if w.IsInput() { + if inI == len(info.Input) { + return nil, fmt.Errorf("fewer input in vector than in circuit") + } + assignmentRaw = info.Input[inI] + inI++ + } else if w.IsOutput() { + if outI == len(info.Output) { + return nil, fmt.Errorf("fewer output in vector than in circuit") + } + assignmentRaw = info.Output[outI] + outI++ + } + if assignmentRaw != nil { + var wireAssignment []small_rational.SmallRational + if wireAssignment, err = test_vector_utils.SliceToElementSlice(assignmentRaw); err != nil { + return nil, err + } + + fullAssignment[w] = wireAssignment + inOutAssignment[w] = wireAssignment + } + } + + fullAssignment.Complete(circuit) + + info.Output = make([][]interface{}, 0, outI) + + for _, w := range sorted { + if w.IsOutput() { + + info.Output = append(info.Output, test_vector_utils.ElementSliceToInterfaceSlice(inOutAssignment[w])) + + } + } + + tCase = &TestCase{ + FullAssignment: fullAssignment, + InOutAssignment: inOutAssignment, + Proof: proof, + Hash: _hash, + Circuit: circuit, + Info: info, + } + + testCases[path] = tCase + } else { + return nil, err + } + } + + return tCase, nil +} + +func (c *TestCase) transcriptSetting(initialChallenge ...[]byte) fiatshamir.Settings { + return fiatshamir.WithHash(&test_vector_utils.MapHash{Map: c.Hash}, initialChallenge...) +} + +type mulGate struct{} + +func (g mulGate) Evaluate(element ...small_rational.SmallRational) (result small_rational.SmallRational) { + result.Mul(&element[0], &element[1]) + return +} + +func (g mulGate) Degree() int { + return 2 +} + +type _select int + +func (g _select) Evaluate(in ...small_rational.SmallRational) small_rational.SmallRational { + return in[g] +} + +func (g _select) Degree() int { + return 1 +} diff --git a/internal/generator/gkr/test_vectors/mimc_five_levels_two_instances._json b/internal/generator/gkr/test_vectors/mimc_five_levels_two_instances._json new file mode 100644 index 000000000..94f45f4a9 --- /dev/null +++ b/internal/generator/gkr/test_vectors/mimc_five_levels_two_instances._json @@ -0,0 +1,7 @@ +{ + "hash": "resources/hash.json", + "circuit": "resources/mimc_five_levels.json", + "input": [[1, 3], [1, 3], [1, 3], [1, 3], [1, 3], [1, 3]], + "output": [[4, 3]], + "proof": [[{"partialSumPolys":[[3,4]],"finalEvalProof":[3]}],[{"partialSumPolys":null,"finalEvalProof":null}]] +} \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/resources/hash.json b/internal/generator/gkr/test_vectors/resources/hash.json new file mode 100644 index 000000000..9a96e4aaa --- /dev/null +++ b/internal/generator/gkr/test_vectors/resources/hash.json @@ -0,0 +1,100 @@ +{ + "-1145901219840000000,5":1, + "-364381509670404096,1":5, + "-95630089550561280,-5":1, + "-19168501451194368,5":-5, + "-2559554567012352,-1":5, + "-2559554567012352,1":5, + "-172523520000000,-1":1, + "-172523520000000,2":-1, + "-2861958168576,-2":-1, + "-2861958168576,0":2, + "-191032962484,-4":4, + "-191032962484,4":-5, + "-72481259520,4":-4, + "-72481259520,5":4, + "-24034694442,0":4, + "-24034694442,3":5, + "-6664736128,-5":0, + "-6664736128,1":3, + "-1440000120,-5":-5, + "-1440000120,2":1, + "-520000000,0":5, + "-408944640,-5":0, + "-408944640,4":-2, + "-215233605,-4":0, + "-215233605,1":0, + "-213909504,-3":-5, + "-213909504,-1":2, + "-79691776,-4":-4, + "-79691776,-1":1, + "-25529833,-1":-4, + "-25529833,4":-1, + "-16796110,-5":-3, + "-16796110,-1":-1, + "-6718464,-4":-1, + "-6718464,3":4, + "-1328125,1":-4, + "-1328125,2":3, + "-292992,-2":-1, + "-292992,1":-5, + "-163840,1":2, + "-163840,2":1, + "-6561,-3":1, + "-6561,1":2, + "-272,0":0, + "-128,-3":5, + "-90,-3":0, + "-85,0":5, + "-85,1":0, + "-80,-3":1, + "-80,-2":0, + "-40,-4":-5, + "-40,-1":3, + "-40,4":-4, + "-33,-2":-2, + "-33,0":-3, + "-30,-2":-5, + "-30,0":-3, + "-27,-3":-2, + "-27,1":-3, + "-20,-4":1, + "-20,-2":3, + "-12,-2":-3, + "-9,-5":-1, + "-9,-4":-2, + "-9,-3":4, + "-9,-2":-4, + "-9,1":-4, + "-6,-3":-2, + "-6,1":0, + "-5,-4":-2, + "-5,-2":-1, + "-5,4":-5, + "-4,-4":-3, + "-4,-2":-4, + "-4,-1":-3, + "-3,-4":1, + "-3,-3":-3, + "-3,-2":-2, + "-3,-1":-5, + "-3,1":2, + "-2,-2":-2, + "-2,5":0, + "0,-3":-2, + "0,-2":-4, + "0,2":5, + "1,-3":-4, + "1,-2":1, + "1,5":-3, + "3,-3":-2, + "4,4":4, + "5,-2":-2, + "1715678768":-3, + "1715678769":-1, + "33548498023443810":-3, + "8588415549364514352":5, + "8588697024341225008":-2, + "8588978499317935664":-4, + "8588978499317935665":4 +} \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/resources/mimc_five_levels.json b/internal/generator/gkr/test_vectors/resources/mimc_five_levels.json new file mode 100644 index 000000000..3dd74f42b --- /dev/null +++ b/internal/generator/gkr/test_vectors/resources/mimc_five_levels.json @@ -0,0 +1,36 @@ +[ + [ + { + "gate": "mimc", + "inputs": [[1,0], [5,5]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[2,0], [5,4]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[3,0], [5,3]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[4,0], [5,2]] + } + ], + [ + { + "gate": "mimc", + "inputs": [[5,0], [5,1]] + } + ], + [ + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, {"gate": null, "inputs": []} + ] +] \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/resources/single_identity_gate.json b/internal/generator/gkr/test_vectors/resources/single_identity_gate.json new file mode 100644 index 000000000..a44066c7b --- /dev/null +++ b/internal/generator/gkr/test_vectors/resources/single_identity_gate.json @@ -0,0 +1,10 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/resources/single_input_two_identity_gates.json b/internal/generator/gkr/test_vectors/resources/single_input_two_identity_gates.json new file mode 100644 index 000000000..6181784fa --- /dev/null +++ b/internal/generator/gkr/test_vectors/resources/single_input_two_identity_gates.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/resources/single_input_two_outs.json b/internal/generator/gkr/test_vectors/resources/single_input_two_outs.json new file mode 100644 index 000000000..c577c1cac --- /dev/null +++ b/internal/generator/gkr/test_vectors/resources/single_input_two_outs.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "mul", + "inputs": [0, 0] + }, + { + "gate": "identity", + "inputs": [0] + } +] \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/resources/single_mimc_gate.json b/internal/generator/gkr/test_vectors/resources/single_mimc_gate.json new file mode 100644 index 000000000..c89e7d52a --- /dev/null +++ b/internal/generator/gkr/test_vectors/resources/single_mimc_gate.json @@ -0,0 +1,7 @@ +[ + {"gate": null, "inputs": []}, {"gate": null, "inputs": []}, + { + "gate": "mimc", + "inputs": [0, 1] + } +] \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/resources/single_mul_gate.json b/internal/generator/gkr/test_vectors/resources/single_mul_gate.json new file mode 100644 index 000000000..0f65a07ed --- /dev/null +++ b/internal/generator/gkr/test_vectors/resources/single_mul_gate.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": null, + "inputs": [] + }, + { + "gate": "mul", + "inputs": [0, 1] + } +] \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/resources/two_identity_gates_composed_single_input.json b/internal/generator/gkr/test_vectors/resources/two_identity_gates_composed_single_input.json new file mode 100644 index 000000000..26681c2f8 --- /dev/null +++ b/internal/generator/gkr/test_vectors/resources/two_identity_gates_composed_single_input.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": "identity", + "inputs": [0] + }, + { + "gate": "identity", + "inputs": [1] + } +] \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/resources/two_inputs_select-input-3_gate.json b/internal/generator/gkr/test_vectors/resources/two_inputs_select-input-3_gate.json new file mode 100644 index 000000000..cdbdb3b47 --- /dev/null +++ b/internal/generator/gkr/test_vectors/resources/two_inputs_select-input-3_gate.json @@ -0,0 +1,14 @@ +[ + { + "gate": null, + "inputs": [] + }, + { + "gate": null, + "inputs": [] + }, + { + "gate": "select-input-3", + "inputs": [0,0,1] + } +] \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/single_identity_gate_two_instances.json b/internal/generator/gkr/test_vectors/single_identity_gate_two_instances.json new file mode 100644 index 000000000..01cf5b7c4 --- /dev/null +++ b/internal/generator/gkr/test_vectors/single_identity_gate_two_instances.json @@ -0,0 +1,33 @@ +{ + "hash": "resources/hash.json", + "circuit": "resources/single_identity_gate.json", + "input": [ + [ + 4, + 3 + ] + ], + "output": [ + [ + 4, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 3 + ], + "partialSumPolys": [ + [ + -9, + -20 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/single_input_two_identity_gates_two_instances.json b/internal/generator/gkr/test_vectors/single_input_two_identity_gates_two_instances.json new file mode 100644 index 000000000..d926fc2b6 --- /dev/null +++ b/internal/generator/gkr/test_vectors/single_input_two_identity_gates_two_instances.json @@ -0,0 +1,53 @@ +{ + "hash": "resources/hash.json", + "circuit": "resources/single_input_two_identity_gates.json", + "input": [ + [ + 2, + 3 + ] + ], + "output": [ + [ + 2, + 3 + ], + [ + 2, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [ + [ + -33, + -128 + ] + ] + }, + { + "finalEvalProof": [ + 5 + ], + "partialSumPolys": [ + [ + -9, + -40 + ] + ] + }, + { + "finalEvalProof": [ + -3 + ], + "partialSumPolys": [ + [ + -9, + -40 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/single_input_two_outs_two_instances.json b/internal/generator/gkr/test_vectors/single_input_two_outs_two_instances.json new file mode 100644 index 000000000..39ad0f0ac --- /dev/null +++ b/internal/generator/gkr/test_vectors/single_input_two_outs_two_instances.json @@ -0,0 +1,54 @@ +{ + "hash": "resources/hash.json", + "circuit": "resources/single_input_two_outs.json", + "input": [ + [ + 1, + 2 + ] + ], + "output": [ + [ + 1, + 4 + ], + [ + 1, + 2 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [ + [ + -6, + -33 + ] + ] + }, + { + "finalEvalProof": [ + 1 + ], + "partialSumPolys": [ + [ + -12, + -90, + -272 + ] + ] + }, + { + "finalEvalProof": [ + -2 + ], + "partialSumPolys": [ + [ + -6, + -30 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/single_mimc_gate_four_instances.json b/internal/generator/gkr/test_vectors/single_mimc_gate_four_instances.json new file mode 100644 index 000000000..8c50ed09a --- /dev/null +++ b/internal/generator/gkr/test_vectors/single_mimc_gate_four_instances.json @@ -0,0 +1,64 @@ +{ + "hash": "resources/hash.json", + "circuit": "resources/single_mimc_gate.json", + "input": [ + [ + 1, + 1, + 2, + 1 + ], + [ + 1, + 2, + 2, + 1 + ] + ], + "output": [ + [ + 128, + 2187, + 16384, + 128 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 1, + 7 + ], + "partialSumPolys": [ + [ + -292992, + -16796110, + "-213909504", + "-1440000120", + "-6664736128", + "-24034694442", + "-72481259520", + "-191032962484" + ], + [ + "-408944640", + "-2861958168576", + "-172523520000000", + "-2559554567012352", + "-19168501451194368", + "-95630089550561280", + "-364381509670404096", + "-1145901219840000000" + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/single_mimc_gate_two_instances.json b/internal/generator/gkr/test_vectors/single_mimc_gate_two_instances.json new file mode 100644 index 000000000..488c9aa24 --- /dev/null +++ b/internal/generator/gkr/test_vectors/single_mimc_gate_two_instances.json @@ -0,0 +1,48 @@ +{ + "hash": "resources/hash.json", + "circuit": "resources/single_mimc_gate.json", + "input": [ + [ + 1, + 1 + ], + [ + 1, + 2 + ] + ], + "output": [ + [ + 128, + 2187 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 1, + 6 + ], + "partialSumPolys": [ + [ + -6561, + -163840, + -1328125, + -6718464, + -25529833, + -79691776, + "-215233605", + "-520000000" + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/single_mul_gate_two_instances.json b/internal/generator/gkr/test_vectors/single_mul_gate_two_instances.json new file mode 100644 index 000000000..bc01efd87 --- /dev/null +++ b/internal/generator/gkr/test_vectors/single_mul_gate_two_instances.json @@ -0,0 +1,43 @@ +{ + "hash": "resources/hash.json", + "circuit": "resources/single_mul_gate.json", + "input": [ + [ + 4, + 3 + ], + [ + 2, + 3 + ] + ], + "output": [ + [ + 8, + 9 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 4, + 2 + ], + "partialSumPolys": [ + [ + -27, + -80, + -85 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json b/internal/generator/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json new file mode 100644 index 000000000..4daede589 --- /dev/null +++ b/internal/generator/gkr/test_vectors/two_identity_gates_composed_single_input_two_instances.json @@ -0,0 +1,44 @@ +{ + "hash": "resources/hash.json", + "circuit": "resources/two_identity_gates_composed_single_input.json", + "input": [ + [ + 2, + 1 + ] + ], + "output": [ + [ + 2, + 1 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + 6 + ], + "partialSumPolys": [ + [ + 5, + 0 + ] + ] + }, + { + "finalEvalProof": [ + -3 + ], + "partialSumPolys": [ + [ + -3, + 0 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/generator/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json b/internal/generator/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json new file mode 100644 index 000000000..ed4fc4f18 --- /dev/null +++ b/internal/generator/gkr/test_vectors/two_inputs_select-input-3_gate_two_instances.json @@ -0,0 +1,42 @@ +{ + "hash": "resources/hash.json", + "circuit": "resources/two_inputs_select-input-3_gate.json", + "input": [ + [ + 0, + 1 + ], + [ + 2, + 3 + ] + ], + "output": [ + [ + 2, + 3 + ] + ], + "proof": [ + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [], + "partialSumPolys": [] + }, + { + "finalEvalProof": [ + -5, + -3 + ], + "partialSumPolys": [ + [ + -9, + -40 + ] + ] + } + ] +} \ No newline at end of file diff --git a/internal/generator/kzg/generate.go b/internal/generator/kzg/generate.go index 640074a48..12e39cf18 100644 --- a/internal/generator/kzg/generate.go +++ b/internal/generator/kzg/generate.go @@ -8,6 +8,9 @@ import ( ) func Generate(conf config.Curve, baseDir string, bgen *bavard.BatchGenerator) error { + if conf.Equal(config.SECP256K1) { + return nil + } // kzg commitment scheme conf.Package = "kzg" diff --git a/internal/generator/kzg/template/kzg.go.tmpl b/internal/generator/kzg/template/kzg.go.tmpl index b387b4df5..216b05008 100644 --- a/internal/generator/kzg/template/kzg.go.tmpl +++ b/internal/generator/kzg/template/kzg.go.tmpl @@ -66,9 +66,6 @@ func NewSRS(size uint64, bAlpha *big.Int) (*SRS, error) { for i := 1; i < len(alphas); i++ { alphas[i].Mul(&alphas[i-1], &alpha) } - for i := 0; i < len(alphas); i++ { - alphas[i].FromMont() - } g1s := {{ .CurvePackage }}.BatchScalarMultiplicationG1(&gen1Aff, alphas) copy(srs.G1[1:], g1s) @@ -107,7 +104,7 @@ func Commit(p []fr.Element, srs *SRS, nbTasks ...int) (Digest, error) { var res {{ .CurvePackage }}.G1Affine - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} if len(nbTasks) > 0 { config.NbTasks = nbTasks[0] } @@ -377,7 +374,7 @@ func BatchVerifyMultiPoints(digests []Digest, proofs []OpeningProof, points []fr for i := 0; i < len(randomNumbers); i++ { quotients[i].Set(&proofs[i].H) } - config := ecc.MultiExpConfig{ScalarsMont: true} + config := ecc.MultiExpConfig{} _, err := foldedQuotients.MultiExp(quotients, randomNumbers, config) if err != nil { return nil @@ -460,7 +457,7 @@ func fold(di []Digest, fai []fr.Element, ci []fr.Element) (Digest, fr.Element, e // fold the digests ∑ᵢ[cᵢ]([fᵢ(α)]G₁) var foldedDigests Digest - _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{ScalarsMont: true}) + _, err := foldedDigests.MultiExp(di, ci, ecc.MultiExpConfig{}) if err != nil { return foldedDigests, foldedEvaluations, err } diff --git a/internal/generator/main.go b/internal/generator/main.go index 631b5cb8c..63cdc4c32 100644 --- a/internal/generator/main.go +++ b/internal/generator/main.go @@ -8,8 +8,8 @@ import ( "sync" "github.com/consensys/bavard" - "github.com/consensys/gnark-crypto/internal/field" - "github.com/consensys/gnark-crypto/internal/field/generator" + "github.com/consensys/gnark-crypto/field/generator" + field "github.com/consensys/gnark-crypto/field/generator/config" "github.com/consensys/gnark-crypto/internal/generator/config" "github.com/consensys/gnark-crypto/internal/generator/crypto/hash/mimc" "github.com/consensys/gnark-crypto/internal/generator/ecc" @@ -17,11 +17,15 @@ import ( "github.com/consensys/gnark-crypto/internal/generator/edwards/eddsa" "github.com/consensys/gnark-crypto/internal/generator/fft" fri "github.com/consensys/gnark-crypto/internal/generator/fri/template" + "github.com/consensys/gnark-crypto/internal/generator/gkr" "github.com/consensys/gnark-crypto/internal/generator/kzg" "github.com/consensys/gnark-crypto/internal/generator/pairing" + "github.com/consensys/gnark-crypto/internal/generator/pedersen" "github.com/consensys/gnark-crypto/internal/generator/permutation" "github.com/consensys/gnark-crypto/internal/generator/plookup" "github.com/consensys/gnark-crypto/internal/generator/polynomial" + "github.com/consensys/gnark-crypto/internal/generator/sumcheck" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils" "github.com/consensys/gnark-crypto/internal/generator/tower" ) @@ -36,6 +40,7 @@ var bgen = bavard.NewBatchGenerator(copyrightHolder, copyrightYear, "consensys/g //go:generate go run main.go func main() { var wg sync.WaitGroup + for _, conf := range config.Curves { wg.Add(1) // for each curve, generate the needed files @@ -62,12 +67,12 @@ func main() { // generate fft on fr assertNoError(fft.Generate(conf, filepath.Join(curveDir, "fr", "fft"), bgen)) - // generate polynomial on fr - assertNoError(polynomial.Generate(conf, filepath.Join(curveDir, "fr", "polynomial"), bgen)) - // generate kzg on fr assertNoError(kzg.Generate(conf, filepath.Join(curveDir, "fr", "kzg"), bgen)) + // generate pedersen on fr + assertNoError(pedersen.Generate(conf, filepath.Join(curveDir, "fr", "pedersen"), bgen)) + // generate plookup on fr assertNoError(plookup.Generate(conf, filepath.Join(curveDir, "fr", "plookup"), bgen)) @@ -77,6 +82,15 @@ func main() { // generate mimc on fr assertNoError(mimc.Generate(conf, filepath.Join(curveDir, "fr", "mimc"), bgen)) + frInfo := config.FieldDependency{ + FieldPackagePath: "github.com/consensys/gnark-crypto/ecc/" + conf.Name + "/fr", + FieldPackageName: "fr", + ElementType: "fr.Element", + } + + // generate polynomial on fr + assertNoError(polynomial.Generate(frInfo, filepath.Join(curveDir, "fr", "polynomial"), true, bgen)) + // generate eddsa on companion curves assertNoError(fri.Generate(conf, filepath.Join(curveDir, "fr", "fri"), bgen)) @@ -86,12 +100,29 @@ func main() { // generate pairing tests assertNoError(pairing.Generate(conf, curveDir, bgen)) + if conf.Equal(config.SECP256K1) { + return // TODO @yelhousni + } + // generate sumcheck on fr + assertNoError(sumcheck.Generate(frInfo, filepath.Join(curveDir, "fr", "sumcheck"), bgen)) + + // generate gkr on fr + assertNoError(gkr.Generate(gkr.Config{ + FieldDependency: frInfo, + GenerateTests: true, + TestVectorsRelativePath: "../../../../internal/generator/gkr/test_vectors", + }, filepath.Join(curveDir, "fr", "gkr"), bgen)) + + // generate test vector utils on fr + assertNoError(test_vector_utils.Generate(test_vector_utils.Config{ + FieldDependency: frInfo, + RandomizeMissingHashEntries: false, + }, filepath.Join(curveDir, "fr", "test_vector_utils"), bgen)) + }(conf) } - wg.Wait() - for _, conf := range config.TwistedEdwardsCurves { wg.Add(1) @@ -108,8 +139,32 @@ func main() { } + wg.Add(1) + go func() { + defer wg.Done() + assertNoError(test_vector_utils.GenerateRationals(bgen)) + }() wg.Wait() + wg.Add(2) + go func() { + // generate test vectors for sumcheck + cmd := exec.Command("go", "run", "./sumcheck/test_vectors") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + assertNoError(cmd.Run()) + wg.Done() + }() + + go func() { + // generate test vectors for gkr + cmd := exec.Command("go", "run", "./gkr/test_vectors") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + assertNoError(cmd.Run()) + wg.Done() + }() + // format the whole directory cmd := exec.Command("gofmt", "-s", "-w", baseDir) @@ -122,11 +177,12 @@ func main() { cmd.Stderr = os.Stderr assertNoError(cmd.Run()) - //mathfmt doesn't accept directories. TODO: PR? + //mathfmt doesn't accept directories. TODO: PR pending /*cmd = exec.Command("mathfmt", "-w", baseDir) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr assertNoError(cmd.Run())*/ + wg.Wait() } func assertNoError(err error) { diff --git a/internal/generator/pairing/generate.go b/internal/generator/pairing/generate.go index 1a6039b03..ad7b6a7c7 100644 --- a/internal/generator/pairing/generate.go +++ b/internal/generator/pairing/generate.go @@ -9,6 +9,9 @@ import ( ) func Generate(conf config.Curve, baseDir string, bgen *bavard.BatchGenerator) error { + if conf.Equal(config.SECP256K1) { + return nil + } packageName := strings.ReplaceAll(conf.Name, "-", "") return bgen.Generate(conf, packageName, "./pairing/template", bavard.Entry{ File: filepath.Join(baseDir, "pairing_test.go"), Templates: []string{"tests/pairing.go.tmpl"}, diff --git a/internal/generator/pedersen/generate.go b/internal/generator/pedersen/generate.go new file mode 100644 index 000000000..2e2d0cea0 --- /dev/null +++ b/internal/generator/pedersen/generate.go @@ -0,0 +1,22 @@ +package pedersen + +import ( + "github.com/consensys/bavard" + "github.com/consensys/gnark-crypto/internal/generator/config" + "path/filepath" +) + +func Generate(conf config.Curve, baseDir string, bgen *bavard.BatchGenerator) error { + if conf.Equal(config.SECP256K1) { + return nil + } + + // pedersen commitment scheme + conf.Package = "pedersen" + entries := []bavard.Entry{ + {File: filepath.Join(baseDir, "pedersen.go"), Templates: []string{"pedersen.go.tmpl"}}, + {File: filepath.Join(baseDir, "pedersen_test.go"), Templates: []string{"pedersen.test.go.tmpl"}}, + } + return bgen.Generate(conf, conf.Package, "./pedersen/template/", entries...) + +} diff --git a/internal/generator/pedersen/template/pedersen.go.tmpl b/internal/generator/pedersen/template/pedersen.go.tmpl new file mode 100644 index 000000000..19f6c6b9c --- /dev/null +++ b/internal/generator/pedersen/template/pedersen.go.tmpl @@ -0,0 +1,95 @@ +import ( + "crypto/rand" + "fmt" + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark-crypto/ecc/{{.Name}}" + "github.com/consensys/gnark-crypto/ecc/{{.Name}}/fr" + "math/big" +) + +// Key for proof and verification +type Key struct { + g {{.CurvePackage}}.G2Affine // TODO @tabaie: does this really have to be randomized? + gRootSigmaNeg {{.CurvePackage}}.G2Affine //gRootSigmaNeg = g^{-1/σ} + basis []{{.CurvePackage}}.G1Affine + basisExpSigma []{{.CurvePackage}}.G1Affine +} + +func randomOnG2() ({{.CurvePackage}}.G2Affine, error) { // TODO: Add to G2.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return {{.CurvePackage}}.G2Affine{}, err + } + return {{.CurvePackage}}.HashToG2(gBytes, []byte("random on g2")) +} + +func Setup(basis []{{.CurvePackage}}.G1Affine) (Key, error) { + var ( + k Key + err error + ) + + if k.g, err = randomOnG2(); err != nil { + return k, err + } + + var modMinusOne big.Int + modMinusOne.Sub(fr.Modulus(), big.NewInt(1)) + var sigma *big.Int + if sigma, err = rand.Int(rand.Reader, &modMinusOne); err != nil { + return k, err + } + sigma.Add(sigma, big.NewInt(1)) + + var sigmaInvNeg big.Int + sigmaInvNeg.ModInverse(sigma, fr.Modulus()) + sigmaInvNeg.Sub(fr.Modulus(), &sigmaInvNeg) + k.gRootSigmaNeg.ScalarMultiplication(&k.g, &sigmaInvNeg) + + k.basisExpSigma = make([]{{.CurvePackage}}.G1Affine, len(basis)) + for i := range basis { + k.basisExpSigma[i].ScalarMultiplication(&basis[i], sigma) + } + + k.basis = basis + return k, err +} + +func (k *Key) Commit(values []fr.Element) (commitment {{.CurvePackage}}.G1Affine, knowledgeProof {{.CurvePackage}}.G1Affine, err error) { + + if len(values) != len(k.basis) { + err = fmt.Errorf("unexpected number of values") + return + } + + // TODO @gbotrel this will spawn more than one task, see + // https://github.com/ConsenSys/gnark-crypto/issues/269 + config := ecc.MultiExpConfig{ + NbTasks: 1, // TODO Experiment + } + + if _, err = commitment.MultiExp(k.basis, values, config); err != nil { + return + } + + _, err = knowledgeProof.MultiExp(k.basisExpSigma, values, config) + + return +} + +// VerifyKnowledgeProof checks if the proof of knowledge is valid +func (k *Key) VerifyKnowledgeProof(commitment {{.CurvePackage}}.G1Affine, knowledgeProof {{.CurvePackage}}.G1Affine) error { + + if !commitment.IsInSubGroup() || !knowledgeProof.IsInSubGroup() { + return fmt.Errorf("subgroup check failed") + } + + product, err := {{.CurvePackage}}.Pair([]{{.CurvePackage}}.G1Affine{commitment, knowledgeProof}, []{{.CurvePackage}}.G2Affine{k.g, k.gRootSigmaNeg}) + if err != nil { + return err + } + if product.IsOne() { + return nil + } + return fmt.Errorf("proof rejected") +} diff --git a/internal/generator/pedersen/template/pedersen.test.go.tmpl b/internal/generator/pedersen/template/pedersen.test.go.tmpl new file mode 100644 index 000000000..1df98e4f5 --- /dev/null +++ b/internal/generator/pedersen/template/pedersen.test.go.tmpl @@ -0,0 +1,72 @@ +import ( + "github.com/consensys/gnark-crypto/ecc/{{.Name}}" + "github.com/consensys/gnark-crypto/ecc/{{.Name}}/fr" + "github.com/stretchr/testify/assert" + "math/rand" + "testing" +) + +func interfaceSliceToFrSlice(t *testing.T, values ...interface{}) []fr.Element { + res := make([]fr.Element, len(values)) + for i, v := range values { + _, err := res[i].SetInterface(v) + assert.NoError(t, err) + } + return res +} + +func randomFrSlice(t *testing.T, size int) []interface{} { + res := make([]interface{}, size) + var err error + for i := range res { + var v fr.Element + res[i], err = v.SetRandom() + assert.NoError(t, err) + } + return res +} + +func randomOnG1() ({{.CurvePackage}}.G1Affine, error) { // TODO: Add to G1.go? + gBytes := make([]byte, fr.Bytes) + if _, err := rand.Read(gBytes); err != nil { + return {{.CurvePackage}}.G1Affine{}, err + } + return {{.CurvePackage}}.HashToG1(gBytes, []byte("random on g2")) +} + +func testCommit(t *testing.T, values ...interface{}) { + + basis := make([]{{.CurvePackage}}.G1Affine, len(values)) + for i := range basis { + var err error + basis[i], err = randomOnG1() + assert.NoError(t, err) + } + + var ( + key Key + err error + commitment, pok {{.CurvePackage}}.G1Affine + ) + + key, err = Setup(basis) + assert.NoError(t, err) + commitment, pok, err = key.Commit(interfaceSliceToFrSlice(t, values...)) + assert.NoError(t, err) + assert.NoError(t, key.VerifyKnowledgeProof(commitment, pok)) + + pok.Neg(&pok) + assert.NotNil(t, key.VerifyKnowledgeProof(commitment, pok)) +} + +func TestCommitToOne(t *testing.T) { + testCommit(t, 1) +} + +func TestCommitSingle(t *testing.T) { + testCommit(t, randomFrSlice(t, 1)...) +} + +func TestCommitFiveElements(t *testing.T) { + testCommit(t, randomFrSlice(t, 5)...) +} diff --git a/internal/generator/permutation/generator.go b/internal/generator/permutation/generator.go index 978102f93..7a3a0f4e2 100644 --- a/internal/generator/permutation/generator.go +++ b/internal/generator/permutation/generator.go @@ -8,6 +8,9 @@ import ( ) func Generate(conf config.Curve, baseDir string, bgen *bavard.BatchGenerator) error { + if conf.Equal(config.SECP256K1) { + return nil + } // permutation data conf.Package = "permutation" diff --git a/internal/generator/plookup/generate.go b/internal/generator/plookup/generate.go index dd8ad8625..fda73972a 100644 --- a/internal/generator/plookup/generate.go +++ b/internal/generator/plookup/generate.go @@ -8,6 +8,9 @@ import ( ) func Generate(conf config.Curve, baseDir string, bgen *bavard.BatchGenerator) error { + if conf.Equal(config.SECP256K1) { + return nil + } // kzg commitment scheme conf.Package = "plookup" diff --git a/internal/generator/polynomial/generate.go b/internal/generator/polynomial/generate.go index a9a55993c..57e45b43c 100644 --- a/internal/generator/polynomial/generate.go +++ b/internal/generator/polynomial/generate.go @@ -7,16 +7,21 @@ import ( "github.com/consensys/gnark-crypto/internal/generator/config" ) -func Generate(conf config.Curve, baseDir string, bgen *bavard.BatchGenerator) error { +func Generate(conf config.FieldDependency, baseDir string, generateTests bool, bgen *bavard.BatchGenerator) error { - conf.Package = "polynomial" entries := []bavard.Entry{ {File: filepath.Join(baseDir, "doc.go"), Templates: []string{"doc.go.tmpl"}}, {File: filepath.Join(baseDir, "polynomial.go"), Templates: []string{"polynomial.go.tmpl"}}, {File: filepath.Join(baseDir, "multilin.go"), Templates: []string{"multilin.go.tmpl"}}, {File: filepath.Join(baseDir, "pool.go"), Templates: []string{"pool.go.tmpl"}}, - {File: filepath.Join(baseDir, "polynomial_test.go"), Templates: []string{"polynomial.test.go.tmpl"}}, - {File: filepath.Join(baseDir, "multilin_test.go"), Templates: []string{"multilin.test.go.tmpl"}}, } - return bgen.Generate(conf, conf.Package, "./polynomial/template/", entries...) + + if generateTests { + entries = append(entries, + bavard.Entry{File: filepath.Join(baseDir, "polynomial_test.go"), Templates: []string{"polynomial.test.go.tmpl"}}, + bavard.Entry{File: filepath.Join(baseDir, "multilin_test.go"), Templates: []string{"multilin.test.go.tmpl"}}, + ) + } + + return bgen.Generate(conf, "polynomial", "./polynomial/template/", entries...) } diff --git a/internal/generator/polynomial/template/doc.go.tmpl b/internal/generator/polynomial/template/doc.go.tmpl index 9a47cdf74..37c08ac0b 100644 --- a/internal/generator/polynomial/template/doc.go.tmpl +++ b/internal/generator/polynomial/template/doc.go.tmpl @@ -1,2 +1,2 @@ -// Package {{.Package}} provides polynomial methods and commitment schemes. -package {{.Package}} \ No newline at end of file +// Package polynomial provides polynomial methods and commitment schemes. +package polynomial \ No newline at end of file diff --git a/internal/generator/polynomial/template/multilin.go.tmpl b/internal/generator/polynomial/template/multilin.go.tmpl index 4a1880aef..256262677 100644 --- a/internal/generator/polynomial/template/multilin.go.tmpl +++ b/internal/generator/polynomial/template/multilin.go.tmpl @@ -1,5 +1,6 @@ import ( - "github.com/consensys/gnark-crypto/ecc/{{ .Name }}/fr" + "{{.FieldPackagePath}}" + "math/bits" ) @@ -7,10 +8,10 @@ import ( // The variables are X₁ through Xₙ where n = log(len(.)) // .[∑ᵢ 2ⁱ⁻¹ bₙ₋ᵢ] = the polynomial evaluated at (b₁, b₂, ..., bₙ) // It is understood that any hypercube evaluation can be extrapolated to a multilinear polynomial -type MultiLin []fr.Element +type MultiLin []{{.ElementType}} // Fold is partial evaluation function k[X₁, X₂, ..., Xₙ] → k[X₂, ..., Xₙ] by setting X₁=r -func (m *MultiLin) Fold(r fr.Element) { +func (m *MultiLin) Fold(r {{.ElementType}}) { mid := len(*m) / 2 bottom, top := (*m)[:mid], (*m)[mid:] @@ -29,46 +30,63 @@ func (m *MultiLin) Fold(r fr.Element) { *m = (*m)[:mid] } +func (m MultiLin) Sum() {{.ElementType}} { + s := m[0] + for i := 1; i < len(m); i++ { + s.Add(&s, &m[i]) + } + return s +} + +func _clone(m MultiLin, p *Pool) MultiLin { + if p == nil { + return m.Clone() + } else { + return p.Clone(m) + } +} + +func _dump(m MultiLin, p *Pool) { + if p != nil { + p.Dump(m) + } +} // Evaluate extrapolate the value of the multilinear polynomial corresponding to m // on the given coordinates -func (m MultiLin) Evaluate(coordinates []fr.Element) fr.Element { +func (m MultiLin) Evaluate(coordinates []{{.ElementType}}, p *Pool) {{.ElementType}} { // Folding is a mutating operation - bkCopy := m.Clone() + bkCopy := _clone(m, p) // Evaluate step by step through repeated folding (i.e. evaluation at the first remaining variable) for _, r := range coordinates { bkCopy.Fold(r) } - return bkCopy[0] + result := bkCopy[0] + + _dump(bkCopy, p) + return result } -// Clone creates a deep copy of a book-keeping table. +// Clone creates a deep copy of a bookkeeping table. // Both multilinear interpolation and sumcheck require folding an underlying // array, but folding changes the array. To do both one requires a deep copy -// of the book-keeping table. +// of the bookkeeping table. func (m MultiLin) Clone() MultiLin { - tableDeepCopy := Make(len(m)) - copy(tableDeepCopy, m) - return tableDeepCopy + res := make(MultiLin, len(m)) + copy(res, m) + return res } // Add two bookKeepingTables func (m *MultiLin) Add(left, right MultiLin) { size := len(left) // Check that left and right have the same size - if len(right) != size { - panic("Left and right do not have the right size") - } - // Reallocate the table if necessary - if cap(*m) < size { - *m = make([]fr.Element, size) + if len(right) != size || len(*m) != size{ + panic("left, right and destination must have the right size") } - // Resize the destination table - *m = (*m)[:size] - // Add elementwise for i := 0; i < size; i++ { (*m)[i].Add(&left[i], &right[i]) @@ -89,8 +107,8 @@ func (m *MultiLin) Add(left, right MultiLin) { // x // In other words the polynomial evaluated here is the multilinear extrapolation of // one that evaluates to q' == h' for vectors q', h' of binary values -func EvalEq(q, h []fr.Element) fr.Element { - var res, nxt, one, sum fr.Element +func EvalEq(q, h []{{.ElementType}}) {{.ElementType}} { + var res, nxt, one, sum {{.ElementType}} one.SetOne() for i := 0; i < len(q); i++ { nxt.Mul(&q[i], &h[i]) // nxt <- qᵢ * hᵢ @@ -109,14 +127,11 @@ func EvalEq(q, h []fr.Element) fr.Element { } // Eq sets m to the representation of the polynomial Eq(q₁, ..., qₙ, *, ..., *) × m[0] -func (m *MultiLin) Eq(q []fr.Element) { +func (m *MultiLin) Eq(q []{{.ElementType}}) { n := len(q) - if len(*m) != 1<= 0; i-- { @@ -48,28 +46,28 @@ func (p *Polynomial) Set(p1 Polynomial) { } // AddConstantInPlace adds a constant to the polynomial, modifying p -func (p *Polynomial) AddConstantInPlace(c *fr.Element) { +func (p *Polynomial) AddConstantInPlace(c *{{.ElementType}}) { for i := 0; i < len(*p); i++ { (*p)[i].Add(&(*p)[i], c) } } // SubConstantInPlace subs a constant to the polynomial, modifying p -func (p *Polynomial) SubConstantInPlace(c *fr.Element) { +func (p *Polynomial) SubConstantInPlace(c *{{.ElementType}}) { for i := 0; i < len(*p); i++ { (*p)[i].Sub(&(*p)[i], c) } } // ScaleInPlace multiplies p by v, modifying p -func (p *Polynomial) ScaleInPlace(c *fr.Element) { +func (p *Polynomial) ScaleInPlace(c *{{.ElementType}}) { for i := 0; i < len(*p); i++ { (*p)[i].Mul(&(*p)[i], c) } } // Scale multiplies p0 by v, storing the result in p -func (p *Polynomial) Scale(c *fr.Element, p0 Polynomial) { +func (p *Polynomial) Scale(c *{{.ElementType}}, p0 Polynomial) { if len(*p) != len(p0) { *p = make(Polynomial, len(p0)) } @@ -112,6 +110,18 @@ func (p *Polynomial) Add(p1, p2 Polynomial) *Polynomial { return p } +// Sub subtracts p2 from p1 +// TODO make interface more consistent with Add +func (p *Polynomial) Sub(p1, p2 Polynomial) *Polynomial { + if len(p1) != len(p2) || len(p2) != len(*p) { + return nil + } + for i := 0; i < len(*p); i++ { + (*p)[i].Sub(&p1[i], &p2[i]) + } + return p +} + // Equal checks equality between two polynomials func (p *Polynomial) Equal(p1 Polynomial) bool { if (*p == nil) != (p1 == nil) { @@ -131,16 +141,10 @@ func (p *Polynomial) Equal(p1 Polynomial) bool { return true } -func signedBigInt(v *fr.Element) big.Int { - var i big.Int - v.ToBigIntRegular(&i) - var iDouble big.Int - iDouble.Lsh(&i, 1) - if iDouble.Cmp(fr.Modulus()) > 0 { - i.Sub(fr.Modulus(), &i) - i.Neg(&i) +func (p Polynomial) SetZero() { + for i := 0; i < len(p); i++ { + p[i].SetZero() } - return i } func (p Polynomial) Text(base int) string { @@ -153,12 +157,13 @@ func (p Polynomial) Text(base int) string { continue } - i := signedBigInt(&p[d]) + pD := p[d] + pDText := pD.Text(base) initialLen := builder.Len() - if i.Sign() < 1 { - i.Neg(&i) + if pDText[0] == '-' { + pDText = pDText[1:] if first { builder.WriteString("-") } else { @@ -170,13 +175,8 @@ func (p Polynomial) Text(base int) string { first = false - asInt64 := int64(0) - if i.IsInt64() { - asInt64 = i.Int64() - } - - if asInt64 != 1 || d == 0 { - builder.WriteString(i.Text(base)) + if !pD.IsOne() || d == 0 { + builder.WriteString(pDText) } if builder.Len()-initialLen > 10 { diff --git a/internal/generator/polynomial/template/polynomial.test.go.tmpl b/internal/generator/polynomial/template/polynomial.test.go.tmpl index 6a96e77ad..727b2f2e1 100644 --- a/internal/generator/polynomial/template/polynomial.test.go.tmpl +++ b/internal/generator/polynomial/template/polynomial.test.go.tmpl @@ -1,8 +1,8 @@ import ( "math/big" "testing" - - "github.com/consensys/gnark-crypto/ecc/{{ .Name }}/fr" + "github.com/stretchr/testify/assert" + "{{.FieldPackagePath}}" ) func TestPolynomialEval(t *testing.T) { @@ -14,11 +14,11 @@ func TestPolynomialEval(t *testing.T) { } // random value - var point fr.Element + var point {{.ElementType}} point.SetRandom() // compute manually f(val) - var expectedEval, one, den fr.Element + var expectedEval, one, den {{.ElementType}} var expo big.Int one.SetOne() expo.SetUint64(20) @@ -45,14 +45,14 @@ func TestPolynomialAddConstantInPlace(t *testing.T) { } // constant to add - var c fr.Element + var c {{.ElementType}} c.SetRandom() // add constant f.AddConstantInPlace(&c) // check - var expectedCoeffs, one fr.Element + var expectedCoeffs, one {{.ElementType}} one.SetOne() expectedCoeffs.Add(&one, &c) for i := 0; i < 20; i++ { @@ -71,14 +71,14 @@ func TestPolynomialSubConstantInPlace(t *testing.T) { } // constant to sub - var c fr.Element + var c {{.ElementType}} c.SetRandom() // sub constant f.SubConstantInPlace(&c) // check - var expectedCoeffs, one fr.Element + var expectedCoeffs, one {{.ElementType}} one.SetOne() expectedCoeffs.Sub(&one, &c) for i := 0; i < 20; i++ { @@ -97,7 +97,7 @@ func TestPolynomialScaleInPlace(t *testing.T) { } // constant to scale by - var c fr.Element + var c {{.ElementType}} c.SetRandom() // scale by constant @@ -129,7 +129,7 @@ func TestPolynomialAdd(t *testing.T) { } // expected result - var one, two fr.Element + var one, two {{.ElementType}} one.SetOne() two.Double(&one) expectedSum := make(Polynomial, 20) @@ -187,4 +187,14 @@ func TestPolynomialAdd(t *testing.T) { if !_f2.Equal(f2Backup) { t.Fatal("side effect, _f2 should not have been modified") } -} \ No newline at end of file +} + +func TestPolynomialText(t *testing.T) { + var one, negTwo {{.ElementType}} + one.SetOne() + negTwo.SetInt64(-2) + + p := Polynomial{one, negTwo, one} + + assert.Equal(t, "X² - 2X + 1", p.Text(10)) +} diff --git a/internal/generator/polynomial/template/pool.go.tmpl b/internal/generator/polynomial/template/pool.go.tmpl index 9a532ab42..6704fe3e2 100644 --- a/internal/generator/polynomial/template/pool.go.tmpl +++ b/internal/generator/polynomial/template/pool.go.tmpl @@ -1,114 +1,211 @@ - - +{{ $sham := eq .ElementType "small_rational.SmallRational"}} import ( +"{{.FieldPackagePath}}" +{{- if not $sham}} + "encoding/json" "fmt" - "github.com/consensys/gnark-crypto/ecc/{{.Name}}/fr" "reflect" + "runtime" + "sort" "sync" "unsafe" +{{- end}} ) +{{ if $sham}} +// Do as little as possible to instantiate the interface +type Pool struct { +} + +func NewPool(...int) (pool Pool) { + return Pool{} +} + +func (p *Pool) Make(n int) []{{.ElementType}} { + return make([]{{.ElementType}}, n) +} + +func (p *Pool) Dump(...[]{{.ElementType}}) { +} + +func (p *Pool) Clone(slice []{{.ElementType}}) []{{.ElementType}} { + res := p.Make(len(slice)) + copy(res, slice) + return res +} +{{ else}} // Memory management for polynomials -// Copied verbatim from gkr repo +// WARNING: This is not thread safe TODO: Make sure that is not a problem +// TODO: There is a lot of "unsafe" memory management here and needs to be vetted thoroughly + +type sizedPool struct { + maxN int + pool sync.Pool + stats poolStats +} -// Sets a maximum for the array size we keep in pool -const maxNForLargePool int = 1 << 24 -const maxNForSmallPool int = 256 +type inUseData struct { + allocatedFor []uintptr + pool *sizedPool +} -// Aliases because it is annoying to use arrays in all the places -type largeArr = [maxNForLargePool]fr.Element -type smallArr = [maxNForSmallPool]fr.Element +type Pool struct { + //lock sync.Mutex + inUse map[*{{.ElementType}}]inUseData + subPools []sizedPool +} -var rC = sync.Map{} +func (p *sizedPool) get(n int) *{{.ElementType}} { + p.stats.maake(n) + return p.pool.Get().(*{{.ElementType}}) +} -var ( - largePool = sync.Pool{ - New: func() interface{} { - var res largeArr - return &res - }, - } - smallPool = sync.Pool{ - New: func() interface{} { - var res smallArr - return &res - }, +func (p *sizedPool) put(ptr *{{.ElementType}}) { + p.stats.dump() + p.pool.Put(ptr) +} + +func NewPool(maxN ...int) (pool Pool) { + + sort.Ints(maxN) + pool = Pool{ + inUse: make(map[*{{.ElementType}}]inUseData), + subPools: make([]sizedPool, len(maxN)), } -) -// ClearPool Clears the pool completely, shields against memory leaks -// Eg: if we forgot to dump a polynomial at some point, this will ensure the value get dumped eventually -// Returns how many polynomials were cleared that way -func ClearPool() int { - res := 0 - rC.Range(func(k, _ interface{}) bool { - switch ptr := k.(type) { - case *largeArr: - largePool.Put(ptr) - case *smallArr: - smallPool.Put(ptr) - default: - panic(fmt.Sprintf("tried to clear %v", reflect.TypeOf(ptr))) + for i := range pool.subPools { + subPool := &pool.subPools[i] + subPool.maxN = maxN[i] + subPool.pool = sync.Pool{ + New: func() interface{} { + subPool.stats.Allocated++ + return getDataPointer(make([]{{.ElementType}}, 0, subPool.maxN)) + }, } - res++ - return true - }) - return res + } + return } -// CountPool Returns the number of elements in the pool without mutating it -func CountPool() int { - res := 0 - rC.Range(func(_, _ interface{}) bool { - res++ - return true - }) - return res +func (p *Pool) findCorrespondingPool(n int) *sizedPool { + poolI := 0 + for poolI < len(p.subPools) && n > p.subPools[poolI].maxN { + poolI++ + } + return &p.subPools[poolI] // out of bounds error here would mean that n is too large +} + +func (p *Pool) Make(n int) []{{.ElementType}} { + pool := p.findCorrespondingPool(n) + ptr := pool.get(n) + p.addInUse(ptr, pool) + return unsafe.Slice(ptr,n) } -// Make tries to find a reusable polynomial or allocates a new one -func Make(n int) []fr.Element { - if n > maxNForLargePool { - panic(fmt.Sprintf("been provided with size of %v but the maximum is %v", n, maxNForLargePool)) +// Dump dumps a set of polynomials into the pool +func (p *Pool) Dump(slices ...[]{{.ElementType}}) { + for _, slice := range slices { + ptr := getDataPointer(slice) + if metadata, ok := p.inUse[ptr]; ok { + delete(p.inUse, ptr) + metadata.pool.put(ptr) + } else { + panic("attempting to dump a slice not created by the pool") + } } +} + +func (p *Pool) addInUse(ptr *{{.ElementType}}, pool *sizedPool) { + pcs := make([]uintptr, 2) + n := runtime.Callers(3, pcs) - if n <= maxNForSmallPool { - ptr := smallPool.Get().(*smallArr) - rC.Store(ptr, struct{}{}) // registers the pointer being used - return (*ptr)[:n] + if prevPcs, ok := p.inUse[ptr]; ok { // TODO: remove if unnecessary for security + panic(fmt.Errorf("re-allocated non-dumped slice, previously allocated at %v", runtime.CallersFrames(prevPcs.allocatedFor))) } + p.inUse[ptr] = inUseData{ + allocatedFor: pcs[:n], + pool: pool, + } +} - ptr := largePool.Get().(*largeArr) - rC.Store(ptr, struct{}{}) // remember we allocated the pointer is being used - return (*ptr)[:n] +func printFrame(frame runtime.Frame) { + fmt.Printf("\t%s line %d, function %s\n", frame.File, frame.Line, frame.Function) } -// Dump dumps a set of polynomials into the pool -// Returns the number of deallocated polys -func Dump(arrs ...[]fr.Element) int { - cnt := 0 - for _, arr := range arrs { - ptr := ptr(arr) - pool := &smallPool - if len(arr) > maxNForSmallPool { - pool = &largePool - } - // If the rC did not register, then - // either the array was allocated somewhere else which can be ignored - // otherwise a double put which MUST be ignored - if _, ok := rC.Load(ptr); ok { - pool.Put(ptr) - // And deregisters the ptr - rC.Delete(ptr) - cnt++ +func (p *Pool) printInUse() { + fmt.Println("slices never dumped allocated at:") + for _, pcs := range p.inUse { + fmt.Println("-------------------------") + + var frame runtime.Frame + frames := runtime.CallersFrames(pcs.allocatedFor) + more := true + for more { + frame, more = frames.Next() + printFrame(frame) } } - return cnt } -func ptr(m []fr.Element) unsafe.Pointer { - if cap(m) != maxNForSmallPool && cap(m) != maxNForLargePool { - panic(fmt.Sprintf("can't cast to large or small array, the put array's is %v it should have capacity %v or %v", cap(m), maxNForLargePool, maxNForSmallPool)) +type poolStats struct { + Used int + Allocated int + ReuseRate float64 + InUse int + GreatestNUsed int + SmallestNUsed int +} + +type poolsStats struct { + SubPools []poolStats + InUse int +} + +func (s *poolStats) maake(n int) { + s.Used++ + s.InUse++ + if n > s.GreatestNUsed { + s.GreatestNUsed = n + } + if s.SmallestNUsed == 0 || s.SmallestNUsed > n { + s.SmallestNUsed = n + } +} + +func (s *poolStats) dump() { + s.InUse-- +} + +func (s *poolStats) finalize() { + s.ReuseRate = float64(s.Used) / float64(s.Allocated) +} + +func getDataPointer(slice []{{.ElementType}}) *{{.ElementType}} { + header := (*reflect.SliceHeader)(unsafe.Pointer(&slice)) + return (*{{.ElementType}})(unsafe.Pointer(header.Data)) +} + +func (p *Pool) PrintPoolStats() { + InUse := 0 + subStats := make([]poolStats, len(p.subPools)) + for i := range p.subPools { + subPool := &p.subPools[i] + subPool.stats.finalize() + subStats[i] = subPool.stats + InUse += subPool.stats.InUse + } + + poolsStats := poolsStats{ + SubPools: subStats, + InUse: InUse, } - return unsafe.Pointer(&m[0]) + serialized, _ := json.MarshalIndent(poolsStats, "", " ") + fmt.Println(string(serialized)) + p.printInUse() +} + +func (p *Pool) Clone(slice []{{.ElementType}}) []{{.ElementType}} { + res := p.Make(len(slice)) + copy(res, slice) + return res } +{{end}} \ No newline at end of file diff --git a/internal/generator/sumcheck/generate.go b/internal/generator/sumcheck/generate.go new file mode 100644 index 000000000..868fb2313 --- /dev/null +++ b/internal/generator/sumcheck/generate.go @@ -0,0 +1,15 @@ +package sumcheck + +import ( + "github.com/consensys/bavard" + "github.com/consensys/gnark-crypto/internal/generator/config" + "path/filepath" +) + +func Generate(conf config.FieldDependency, baseDir string, bgen *bavard.BatchGenerator) error { + entries := []bavard.Entry{ + {File: filepath.Join(baseDir, "sumcheck.go"), Templates: []string{"sumcheck.go.tmpl"}}, + {File: filepath.Join(baseDir, "sumcheck_test.go"), Templates: []string{"sumcheck.test.go.tmpl"}}, + } + return bgen.Generate(conf, "sumcheck", "./sumcheck/template/", entries...) +} diff --git a/internal/generator/sumcheck/template/sumcheck.go.tmpl b/internal/generator/sumcheck/template/sumcheck.go.tmpl new file mode 100644 index 000000000..80881bdcc --- /dev/null +++ b/internal/generator/sumcheck/template/sumcheck.go.tmpl @@ -0,0 +1,163 @@ +import ( + "fmt" + "{{.FieldPackagePath}}" + "{{.FieldPackagePath}}/polynomial" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a {{.ElementType}}) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next({{.ElementType}}) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []{{.ElementType}}) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a {{.ElementType}}) {{.ElementType}} // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = &transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []{{.ElementType}}, remainingChallengeNames *[]string) ({{.ElementType}}, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return {{.ElementType}}{}, err + } + } + var res {{.ElementType}} + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff {{.ElementType}} + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []{{.ElementType}}{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]{{.ElementType}}, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff {{.ElementType}} + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []{{.ElementType}}{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]{{.ElementType}}, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return fmt.Errorf("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/generator/sumcheck/template/sumcheck.test.go.tmpl b/internal/generator/sumcheck/template/sumcheck.test.go.tmpl new file mode 100644 index 000000000..b50c31092 --- /dev/null +++ b/internal/generator/sumcheck/template/sumcheck.test.go.tmpl @@ -0,0 +1,143 @@ +import ( + "fmt" + "{{.FieldPackagePath}}" + "{{.FieldPackagePath}}/polynomial" + "{{.FieldPackagePath}}/test_vector_utils" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []{{.ElementType}}) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []{{.ElementType}}{sum} +} + +func (c singleMultilinClaim) Combine({{.ElementType}}) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r {{.ElementType}}) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum {{.ElementType}} +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []{{.ElementType}}, combinationCoeff {{.ElementType}}, purportedValue {{.ElementType}}, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs {{.ElementType}}) {{.ElementType}} { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/generator/sumcheck/test_vectors/hash.json b/internal/generator/sumcheck/test_vectors/hash.json new file mode 100644 index 000000000..a3f147a26 --- /dev/null +++ b/internal/generator/sumcheck/test_vectors/hash.json @@ -0,0 +1,13 @@ +{ + "-14,3":-4, + "-9,1":-4, + "-4,-2":3, + "-2,-4":1, + "3,-3":-2, + "4,-3":-3, + "26,-3":-2, + "27,-3":-2, + "482434100784":-3, + "482434100785":-4, + "482434100786":-2 +} \ No newline at end of file diff --git a/internal/generator/sumcheck/test_vectors/main.go b/internal/generator/sumcheck/test_vectors/main.go new file mode 100644 index 000000000..591042039 --- /dev/null +++ b/internal/generator/sumcheck/test_vectors/main.go @@ -0,0 +1,204 @@ +package main + +import ( + "encoding/json" + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils" + "math/bits" + "os" + "path/filepath" +) + +func runMultilin(dir string, testCaseInfo *TestCaseInfo) error { + + var poly polynomial.MultiLin + if v, err := test_vector_utils.SliceToElementSlice(testCaseInfo.Values); err == nil { + poly = v + } else { + return err + } + + var mp *test_vector_utils.ElementMap + var err error + if mp, err = test_vector_utils.ElementMapFromFile(filepath.Join(dir, testCaseInfo.Hash)); err != nil { + return err + } + + proof, err := sumcheck.Prove( + &singleMultilinClaim{poly}, fiatshamir.WithHash(&test_vector_utils.MapHash{Map: mp})) + if err != nil { + return err + } + testCaseInfo.Proof = toPrintableProof(proof) + + // Verification + if v, _err := test_vector_utils.SliceToElementSlice(testCaseInfo.Values); _err == nil { + poly = v + } else { + return _err + } + var claimedSum small_rational.SmallRational + if _, err = claimedSum.SetInterface(testCaseInfo.ClaimedSum); err != nil { + return err + } + + if err = sumcheck.Verify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, fiatshamir.WithHash(&test_vector_utils.MapHash{Map: mp})); err != nil { + return fmt.Errorf("proof rejected: %v", err) + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + if err = sumcheck.Verify(singleMultilinLazyClaim{g: poly, claimedSum: claimedSum}, proof, fiatshamir.WithHash(&test_vector_utils.MapHash{Map: mp})); err == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func run(dir string, testCaseInfo *TestCaseInfo) error { + switch testCaseInfo.Type { + case "multilin": + return runMultilin(dir, testCaseInfo) + default: + return fmt.Errorf("type \"%s\" unrecognized", testCaseInfo.Type) + } +} + +func runAll(relPath string) error { + var filename string + var err error + if filename, err = filepath.Abs(relPath); err != nil { + return err + } + + dir := filepath.Dir(filename) + + var bytes []byte + + if bytes, err = os.ReadFile(filename); err != nil { + return err + } + + var testCasesInfo TestCasesInfo + if err = json.Unmarshal(bytes, &testCasesInfo); err != nil { + return err + } + + failed := false + for name, testCase := range testCasesInfo { + if err = run(dir, testCase); err != nil { + fmt.Println(name, ":", err) + failed = true + } + } + + if failed { + return fmt.Errorf("test case failed") + } + + if bytes, err = json.MarshalIndent(testCasesInfo, "", "\t"); err != nil { + return err + } + + if err = test_vector_utils.SaveUsedHashEntries(); err != nil { + return err + } + + return os.WriteFile(filename, bytes, 0) +} + +func main() { + if err := runAll("sumcheck/test_vectors/vectors.json"); err != nil { + fmt.Println(err) + os.Exit(-1) + } +} + +type TestCasesInfo map[string]*TestCaseInfo + +type TestCaseInfo struct { + Type string `json:"type"` + Hash string `json:"hash"` + Values []interface{} `json:"values"` + Description string `json:"description"` + Proof PrintableProof `json:"proof"` + ClaimedSum interface{} `json:"claimedSum"` +} + +type PrintableProof struct { + PartialSumPolys [][]interface{} `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` +} + +func toPrintableProof(proof sumcheck.Proof) (printable PrintableProof) { + if proof.FinalEvalProof != nil { + panic("null expected") + } + printable.FinalEvalProof = struct{}{} + printable.PartialSumPolys = test_vector_utils.ElementSliceSliceToInterfaceSliceSlice(proof.PartialSumPolys) + return +} + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval([]small_rational.SmallRational) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []small_rational.SmallRational{sum} +} + +func (c singleMultilinClaim) Combine(small_rational.SmallRational) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r small_rational.SmallRational) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum small_rational.SmallRational +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []small_rational.SmallRational, _ small_rational.SmallRational, purportedValue small_rational.SmallRational, _ interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(small_rational.SmallRational) small_rational.SmallRational { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} diff --git a/internal/generator/sumcheck/test_vectors/vectors.json b/internal/generator/sumcheck/test_vectors/vectors.json new file mode 100644 index 000000000..8b1820e17 --- /dev/null +++ b/internal/generator/sumcheck/test_vectors/vectors.json @@ -0,0 +1,50 @@ +{ + "linear_univariate_single_claim": { + "type": "multilin", + "hash": "hash.json", + "values": [ + 1, + 3 + ], + "description": "X ↦ 2X + 1", + "proof": { + "partialSumPolys": [ + [ + 3 + ] + ], + "finalEvalProof": {} + }, + "claimedSum": 4 + }, + "trilinear_single_claim": { + "type": "multilin", + "hash": "hash.json", + "values": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8 + ], + "description": "X₁, X₂, X₃ ↦ 1 + 4X₁ + 2X₂ + X₃", + "proof": { + "partialSumPolys": [ + [ + 26 + ], + [ + -9 + ], + [ + -14 + ] + ], + "finalEvalProof": {} + }, + "claimedSum": 36 + } +} \ No newline at end of file diff --git a/internal/generator/test_vector_utils/generate.go b/internal/generator/test_vector_utils/generate.go new file mode 100644 index 000000000..60a8fc9a9 --- /dev/null +++ b/internal/generator/test_vector_utils/generate.go @@ -0,0 +1,57 @@ +package test_vector_utils + +import ( + "github.com/consensys/bavard" + "github.com/consensys/gnark-crypto/internal/generator/config" + "github.com/consensys/gnark-crypto/internal/generator/gkr" + "github.com/consensys/gnark-crypto/internal/generator/polynomial" + "github.com/consensys/gnark-crypto/internal/generator/sumcheck" + "path/filepath" +) + +type Config struct { + config.FieldDependency + RandomizeMissingHashEntries bool +} + +func GenerateRationals(bgen *bavard.BatchGenerator) error { + gkrConf := gkr.Config{ + FieldDependency: config.FieldDependency{ + FieldPackagePath: "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational", + FieldPackageName: "small_rational", + ElementType: "small_rational.SmallRational", + }, + GenerateTests: false, + RetainTestCaseRawInfo: true, + TestVectorsRelativePath: "../../../gkr/test_vectors", + } + + baseDir := "./test_vector_utils/small_rational/" + if err := polynomial.Generate(gkrConf.FieldDependency, baseDir+"polynomial", false, bgen); err != nil { + return err + } + if err := sumcheck.Generate(gkrConf.FieldDependency, baseDir+"sumcheck", bgen); err != nil { + return err + } + if err := gkr.Generate(gkrConf, baseDir+"gkr", bgen); err != nil { + return err + } + if err := Generate(Config{gkrConf.FieldDependency, true}, baseDir+"test_vector_utils", bgen); err != nil { + return err + } + + // generate gkr test vector generator for rationals + gkrConf.OutsideGkrPackage = true + return bgen.Generate(gkrConf, "main", "./gkr/template", bavard.Entry{ + File: filepath.Join("gkr", "test_vectors", "main.go"), Templates: []string{"gkr.test.vectors.gen.go.tmpl", "gkr.test.vectors.go.tmpl"}, + }) + +} + +func Generate(conf Config, baseDir string, bgen *bavard.BatchGenerator) error { + entry := bavard.Entry{ + File: filepath.Join(baseDir, "test_vector_utils.go"), Templates: []string{"test_vector_utils.go.tmpl"}, + } + + return bgen.Generate(conf, "test_vector_utils", "./test_vector_utils/template/", entry) +} diff --git a/internal/generator/test_vector_utils/small_rational/gkr/gkr.go b/internal/generator/test_vector_utils/small_rational/gkr/gkr.go new file mode 100644 index 000000000..5d981324a --- /dev/null +++ b/internal/generator/test_vector_utils/small_rational/gkr/gkr.go @@ -0,0 +1,774 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package gkr + +import ( + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/sumcheck" + "strconv" + "sync" +) + +// The goal is to prove/verify evaluations of many instances of the same circuit + +// Gate must be a low-degree polynomial +type Gate interface { + Evaluate(...small_rational.SmallRational) small_rational.SmallRational + Degree() int +} + +type Wire struct { + Gate Gate + Inputs []*Wire // if there are no Inputs, the wire is assumed an input wire + nbUniqueOutputs int // number of other wires using it as input, not counting duplicates (i.e. providing two inputs to the same gate counts as one) +} + +type Circuit []Wire + +func (w Wire) IsInput() bool { + return len(w.Inputs) == 0 +} + +func (w Wire) IsOutput() bool { + return w.nbUniqueOutputs == 0 +} + +func (w Wire) NbClaims() int { + if w.IsOutput() { + return 1 + } + return w.nbUniqueOutputs +} + +func (w Wire) noProof() bool { + return w.IsInput() && w.NbClaims() == 1 +} + +// WireAssignment is assignment of values to the same wire across many instances of the circuit +type WireAssignment map[*Wire]polynomial.MultiLin + +type Proof []sumcheck.Proof // for each layer, for each wire, a sumcheck (for each variable, a polynomial) + +type eqTimesGateEvalSumcheckLazyClaims struct { + wire *Wire + evaluationPoints [][]small_rational.SmallRational + claimedEvaluations []small_rational.SmallRational + manager *claimsManager // WARNING: Circular references +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) ClaimsNum() int { + return len(e.evaluationPoints) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VarsNum() int { + return len(e.evaluationPoints[0]) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) CombinedSum(a small_rational.SmallRational) small_rational.SmallRational { + evalsAsPoly := polynomial.Polynomial(e.claimedEvaluations) + return evalsAsPoly.Eval(&a) +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) Degree(int) int { + return 1 + e.wire.Gate.Degree() +} + +func (e *eqTimesGateEvalSumcheckLazyClaims) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error { + inputEvaluationsNoRedundancy := proof.([]small_rational.SmallRational) + + // the eq terms + numClaims := len(e.evaluationPoints) + evaluation := polynomial.EvalEq(e.evaluationPoints[numClaims-1], r) + for i := numClaims - 2; i >= 0; i-- { + evaluation.Mul(&evaluation, &combinationCoeff) + eq := polynomial.EvalEq(e.evaluationPoints[i], r) + evaluation.Add(&evaluation, &eq) + } + + // the g(...) term + var gateEvaluation small_rational.SmallRational + if e.wire.IsInput() { + gateEvaluation = e.manager.assignment[e.wire].Evaluate(r, e.manager.memPool) + } else { + inputEvaluations := make([]small_rational.SmallRational, len(e.wire.Inputs)) + indexesInProof := make(map[*Wire]int, len(inputEvaluationsNoRedundancy)) + + proofI := 0 + for inI, in := range e.wire.Inputs { + indexInProof, found := indexesInProof[in] + if !found { + indexInProof = proofI + indexesInProof[in] = indexInProof + + // defer verification, store new claim + e.manager.add(in, r, inputEvaluationsNoRedundancy[indexInProof]) + proofI++ + } + inputEvaluations[inI] = inputEvaluationsNoRedundancy[indexInProof] + } + if proofI != len(inputEvaluationsNoRedundancy) { + return fmt.Errorf("%d input wire evaluations given, %d expected", len(inputEvaluationsNoRedundancy), proofI) + } + gateEvaluation = e.wire.Gate.Evaluate(inputEvaluations...) + } + + evaluation.Mul(&evaluation, &gateEvaluation) + + if evaluation.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("incompatible evaluations") +} + +type eqTimesGateEvalSumcheckClaims struct { + wire *Wire + evaluationPoints [][]small_rational.SmallRational // x in the paper + claimedEvaluations []small_rational.SmallRational // y in the paper + manager *claimsManager + + inputPreprocessors []polynomial.MultiLin // P_u in the paper, so that we don't need to pass along all the circuit's evaluations + + eq polynomial.MultiLin // ∑_i τ_i eq(x_i, -) +} + +func (c *eqTimesGateEvalSumcheckClaims) Combine(combinationCoeff small_rational.SmallRational) polynomial.Polynomial { + varsNum := c.VarsNum() + eqLength := 1 << varsNum + claimsNum := c.ClaimsNum() + // initialize the eq tables + c.eq = c.manager.memPool.Make(eqLength) + + c.eq[0].SetOne() + c.eq.Eq(c.evaluationPoints[0]) + + newEq := polynomial.MultiLin(c.manager.memPool.Make(eqLength)) + aI := combinationCoeff + + for k := 1; k < claimsNum; k++ { //TODO: parallelizable? + // define eq_k = aᵏ eq(x_k1, ..., x_kn, *, ..., *) where x_ki are the evaluation points + newEq[0].Set(&aI) + newEq.Eq(c.evaluationPoints[k]) + + eqAsPoly := polynomial.Polynomial(c.eq) //just semantics + eqAsPoly.Add(eqAsPoly, polynomial.Polynomial(newEq)) + + if k+1 < claimsNum { + aI.Mul(&aI, &combinationCoeff) + } + } + + c.manager.memPool.Dump(newEq) + + // from this point on the claim is a rather simple one: g = E(h) × R_v (P_u0(h), ...) where E and the P_u are multilinear and R_v is of low-degree + + return c.computeGJ() +} + +// computeValAndStep returns val : i ↦ m(1, i...) and step : i ↦ m(1, i...) - m(0, i...) +func computeValAndStep(m polynomial.MultiLin, p *polynomial.Pool) (val polynomial.MultiLin, step polynomial.MultiLin) { + val = p.Clone(m[len(m)/2:]) + step = p.Clone(m[:len(m)/2]) + + valAsPoly, stepAsPoly := polynomial.Polynomial(val), polynomial.Polynomial(step) + + stepAsPoly.Sub(valAsPoly, stepAsPoly) + return +} + +// computeGJ: gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) = ∑_{0≤i<2ⁿ⁻ʲ} E(r₁, ..., X_j, i...) R_v( P_u0(r₁, ..., X_j, i...), ... ) where E = ∑ eq_k +// the polynomial is represented by the evaluations g_j(1), g_j(2), ..., g_j(deg(g_j)). +// The value g_j(0) is inferred from the equation g_j(0) + g_j(1) = g_{j-1}(r_{j-1}). By convention, g_0 is a constant polynomial equal to the claimed sum. +func (c *eqTimesGateEvalSumcheckClaims) computeGJ() (gJ polynomial.Polynomial) { + + // Let f ∈ { E(r₁, ..., X_j, d...) } ∪ {P_ul(r₁, ..., X_j, d...) }. It is linear in X_j, so f(m) = m×(f(1) - f(0)) + f(0), and f(0), f(1) are easily computed from the bookkeeping tables + EVal, EStep := computeValAndStep(c.eq, c.manager.memPool) + + puVal := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO: Make a two-dimensional array struct, and index it i-first rather than inputI first: would result in scanning memory access in the "d" loop and obviate the gateInput variable + puStep := make([]polynomial.MultiLin, len(c.inputPreprocessors)) //TODO, ctd: the greater degGJ, the more this would matter + + for i, puI := range c.inputPreprocessors { + puVal[i], puStep[i] = computeValAndStep(puI, c.manager.memPool) + } + + degGJ := 1 + c.wire.Gate.Degree() // guaranteed to be no smaller than the actual deg(g_j) + gJ = make([]small_rational.SmallRational, degGJ) + + parallel := len(EVal) >= 1024 //TODO: Experiment with threshold + + var gateInput [][]small_rational.SmallRational + + if parallel { + gateInput = [][]small_rational.SmallRational{c.manager.memPool.Make(len(c.inputPreprocessors)), + c.manager.memPool.Make(len(c.inputPreprocessors))} + } else { + gateInput = [][]small_rational.SmallRational{c.manager.memPool.Make(len(c.inputPreprocessors))} + } + + var wg sync.WaitGroup + + for d := 0; d < degGJ; d++ { + + notLastIteration := d+1 < degGJ + + sumOverI := func(res *small_rational.SmallRational, gateInput []small_rational.SmallRational, start, end int) { + for i := start; i < end; i++ { + + for inputI := range puVal { + gateInput[inputI].Set(&puVal[inputI][i]) + if notLastIteration { + puVal[inputI][i].Add(&puVal[inputI][i], &puStep[inputI][i]) + } + } + + // gJAtDI = gJ(d, i...) + gJAtDI := c.wire.Gate.Evaluate(gateInput...) + gJAtDI.Mul(&gJAtDI, &EVal[i]) + + res.Add(res, &gJAtDI) + + if notLastIteration { + EVal[i].Add(&EVal[i], &EStep[i]) + } + } + wg.Done() + } + + if parallel { + var firstHalf, secondHalf small_rational.SmallRational + wg.Add(2) + go sumOverI(&secondHalf, gateInput[1], len(EVal)/2, len(EVal)) + go sumOverI(&firstHalf, gateInput[0], 0, len(EVal)/2) + wg.Wait() + gJ[d].Add(&firstHalf, &secondHalf) + } else { + wg.Add(1) // formalities + sumOverI(&gJ[d], gateInput[0], 0, len(EVal)) + } + } + + c.manager.memPool.Dump(gateInput...) + c.manager.memPool.Dump(EVal, EStep) + + for inputI := range puVal { + c.manager.memPool.Dump(puVal[inputI], puStep[inputI]) + } + + return +} + +// Next first folds the "preprocessing" and "eq" polynomials then compute the new g_j +func (c *eqTimesGateEvalSumcheckClaims) Next(element small_rational.SmallRational) polynomial.Polynomial { + c.eq.Fold(element) + for i := 0; i < len(c.inputPreprocessors); i++ { + c.inputPreprocessors[i].Fold(element) + } + return c.computeGJ() +} + +func (c *eqTimesGateEvalSumcheckClaims) VarsNum() int { + return len(c.evaluationPoints[0]) +} + +func (c *eqTimesGateEvalSumcheckClaims) ClaimsNum() int { + return len(c.claimedEvaluations) +} + +func (c *eqTimesGateEvalSumcheckClaims) ProveFinalEval(r []small_rational.SmallRational) interface{} { + + //defer the proof, return list of claims + evaluations := make([]small_rational.SmallRational, 0, len(c.wire.Inputs)) + noMoreClaimsAllowed := make(map[*Wire]struct{}, len(c.inputPreprocessors)) + noMoreClaimsAllowed[c.wire] = struct{}{} + + for inI, in := range c.wire.Inputs { + puI := c.inputPreprocessors[inI] + if _, found := noMoreClaimsAllowed[in]; !found { + noMoreClaimsAllowed[in] = struct{}{} + puI.Fold(r[len(r)-1]) + c.manager.add(in, r, puI[0]) + evaluations = append(evaluations, puI[0]) + } + c.manager.memPool.Dump(puI) + } + + c.manager.memPool.Dump(c.claimedEvaluations, c.eq) + + return evaluations +} + +type claimsManager struct { + claimsMap map[*Wire]*eqTimesGateEvalSumcheckLazyClaims + assignment WireAssignment + memPool *polynomial.Pool +} + +func newClaimsManager(c Circuit, assignment WireAssignment, pool *polynomial.Pool) (claims claimsManager) { + claims.assignment = assignment + claims.claimsMap = make(map[*Wire]*eqTimesGateEvalSumcheckLazyClaims, len(c)) + claims.memPool = pool + + for i := range c { + wire := &c[i] + + claims.claimsMap[wire] = &eqTimesGateEvalSumcheckLazyClaims{ + wire: wire, + evaluationPoints: make([][]small_rational.SmallRational, 0, wire.NbClaims()), + claimedEvaluations: claims.memPool.Make(wire.NbClaims()), + manager: &claims, + } + } + return +} + +func (m *claimsManager) add(wire *Wire, evaluationPoint []small_rational.SmallRational, evaluation small_rational.SmallRational) { + claim := m.claimsMap[wire] + i := len(claim.evaluationPoints) + claim.claimedEvaluations[i] = evaluation + claim.evaluationPoints = append(claim.evaluationPoints, evaluationPoint) +} + +func (m *claimsManager) getLazyClaim(wire *Wire) *eqTimesGateEvalSumcheckLazyClaims { + return m.claimsMap[wire] +} + +func (m *claimsManager) getClaim(wire *Wire) *eqTimesGateEvalSumcheckClaims { + lazy := m.claimsMap[wire] + res := &eqTimesGateEvalSumcheckClaims{ + wire: wire, + evaluationPoints: lazy.evaluationPoints, + claimedEvaluations: lazy.claimedEvaluations, + manager: m, + } + + if wire.IsInput() { + res.inputPreprocessors = []polynomial.MultiLin{m.memPool.Clone(m.assignment[wire])} + } else { + res.inputPreprocessors = make([]polynomial.MultiLin, len(wire.Inputs)) + + for inputI, inputW := range wire.Inputs { + res.inputPreprocessors[inputI] = m.memPool.Clone(m.assignment[inputW]) //will be edited later, so must be deep copied + } + } + return res +} + +func (m *claimsManager) deleteClaim(wire *Wire) { + delete(m.claimsMap, wire) +} + +type settings struct { + pool *polynomial.Pool + sorted []*Wire + transcript *fiatshamir.Transcript + transcriptPrefix string + nbVars int +} + +type Option func(*settings) + +func WithPool(pool *polynomial.Pool) Option { + return func(options *settings) { + options.pool = pool + } +} + +func WithSortedCircuit(sorted []*Wire) Option { + return func(options *settings) { + options.sorted = sorted + } +} + +func setup(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (settings, error) { + var o settings + var err error + for _, option := range options { + option(&o) + } + + o.nbVars = assignment.NumVars() + nbInstances := assignment.NumInstances() + if 1< b { + return a + } + return b +} + +func ChallengeNames(sorted []*Wire, logNbInstances int, prefix string) []string { + + // Pre-compute the size TODO: Consider not doing this and just grow the list by appending + size := logNbInstances // first challenge + + for _, w := range sorted { + if w.noProof() { // no proof, no challenge + continue + } + if w.NbClaims() > 1 { //combine the claims + size++ + } + size += logNbInstances // full run of sumcheck on logNbInstances variables + } + + nums := make([]string, max(len(sorted), logNbInstances)) + for i := range nums { + nums[i] = strconv.Itoa(i) + } + + challenges := make([]string, size) + + // output wire claims + firstChallengePrefix := prefix + "fC." + for j := 0; j < logNbInstances; j++ { + challenges[j] = firstChallengePrefix + nums[j] + } + j := logNbInstances + for i := len(sorted) - 1; i >= 0; i-- { + if sorted[i].noProof() { + continue + } + wirePrefix := prefix + "w" + nums[i] + "." + + if sorted[i].NbClaims() > 1 { + challenges[j] = wirePrefix + "comb" + j++ + } + + partialSumPrefix := wirePrefix + "pSP." + for k := 0; k < logNbInstances; k++ { + challenges[j] = partialSumPrefix + nums[k] + j++ + } + } + return challenges +} + +func getFirstChallengeNames(logNbInstances int, prefix string) []string { + res := make([]string, logNbInstances) + firstChallengePrefix := prefix + "fC." + for i := 0; i < logNbInstances; i++ { + res[i] = firstChallengePrefix + strconv.Itoa(i) + } + return res +} + +func getChallenges(transcript *fiatshamir.Transcript, names []string) ([]small_rational.SmallRational, error) { + res := make([]small_rational.SmallRational, len(names)) + for i, name := range names { + if bytes, err := transcript.ComputeChallenge(name); err == nil { + res[i].SetBytes(bytes) + } else { + return nil, err + } + } + return res, nil +} + +// Prove consistency of the claimed assignment +func Prove(c Circuit, assignment WireAssignment, transcriptSettings fiatshamir.Settings, options ...Option) (Proof, error) { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return nil, err + } + + claims := newClaimsManager(c, assignment, o.pool) + + proof := make(Proof, len(c)) + // firstChallenge called rho in the paper + var firstChallenge []small_rational.SmallRational + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return nil, err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + claim := claims.getClaim(wire) + if wire.noProof() { // input wires with one claim only + proof[i] = sumcheck.Proof{ + PartialSumPolys: []polynomial.Polynomial{}, + FinalEvalProof: []small_rational.SmallRational{}, + } + } else { + if proof[i], err = sumcheck.Prove( + claim, fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err != nil { + return proof, err + } + + finalEvalProof := proof[i].FinalEvalProof.([]small_rational.SmallRational) + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } + // the verifier checks a single claim about input wires itself + claims.deleteClaim(wire) + } + + return proof, nil +} + +// Verify the consistency of the claimed output with the claimed input +// Unlike in Prove, the assignment argument need not be complete +func Verify(c Circuit, assignment WireAssignment, proof Proof, transcriptSettings fiatshamir.Settings, options ...Option) error { + o, err := setup(c, assignment, transcriptSettings, options...) + if err != nil { + return err + } + + claims := newClaimsManager(c, assignment, o.pool) + + var firstChallenge []small_rational.SmallRational + firstChallenge, err = getChallenges(o.transcript, getFirstChallengeNames(o.nbVars, o.transcriptPrefix)) + if err != nil { + return err + } + + wirePrefix := o.transcriptPrefix + "w" + var baseChallenge [][]byte + for i := len(c) - 1; i >= 0; i-- { + wire := o.sorted[i] + + if wire.IsOutput() { + claims.add(wire, firstChallenge, assignment[wire].Evaluate(firstChallenge, claims.memPool)) + } + + proofW := proof[i] + finalEvalProof := proofW.FinalEvalProof.([]small_rational.SmallRational) + claim := claims.getLazyClaim(wire) + if wire.noProof() { // input wires with one claim only + // make sure the proof is empty + if len(finalEvalProof) != 0 || len(proofW.PartialSumPolys) != 0 { + return fmt.Errorf("no proof allowed for input wire with a single claim") + } + + if wire.NbClaims() == 1 { // input wire + // simply evaluate and see if it matches + evaluation := assignment[wire].Evaluate(claim.evaluationPoints[0], claims.memPool) + if !claim.claimedEvaluations[0].Equal(&evaluation) { + return fmt.Errorf("incorrect input wire claim") + } + } + } else if err = sumcheck.Verify( + claim, proof[i], fiatshamir.WithTranscript(o.transcript, wirePrefix+strconv.Itoa(i)+".", baseChallenge...), + ); err == nil { + baseChallenge = make([][]byte, len(finalEvalProof)) + for j := range finalEvalProof { + bytes := finalEvalProof[j].Bytes() + baseChallenge[j] = bytes[:] + } + } else { + return fmt.Errorf("sumcheck proof rejected: %v", err) //TODO: Any polynomials to dump? + } + claims.deleteClaim(wire) + } + return nil +} + +type IdentityGate struct{} + +func (IdentityGate) Evaluate(input ...small_rational.SmallRational) small_rational.SmallRational { + return input[0] +} + +func (IdentityGate) Degree() int { + return 1 +} + +// outputsList also sets the nbUniqueOutputs fields. It also sets the wire metadata. +func outputsList(c Circuit, indexes map[*Wire]int) [][]int { + res := make([][]int, len(c)) + for i := range c { + res[i] = make([]int, 0) + c[i].nbUniqueOutputs = 0 + if c[i].IsInput() { + c[i].Gate = IdentityGate{} + } + } + ins := make(map[int]struct{}, len(c)) + for i := range c { + for k := range ins { // clear map + delete(ins, k) + } + for _, in := range c[i].Inputs { + inI := indexes[in] + res[inI] = append(res[inI], i) + if _, ok := ins[inI]; !ok { + in.nbUniqueOutputs++ + ins[inI] = struct{}{} + } + } + } + return res +} + +type topSortData struct { + outputs [][]int + status []int // status > 0 indicates number of inputs left to be ready. status = 0 means ready. status = -1 means done + index map[*Wire]int + leastReady int +} + +func (d *topSortData) markDone(i int) { + + d.status[i] = -1 + + for _, outI := range d.outputs[i] { + d.status[outI]-- + if d.status[outI] == 0 && outI < d.leastReady { + d.leastReady = outI + } + } + + for d.leastReady < len(d.status) && d.status[d.leastReady] != 0 { + d.leastReady++ + } +} + +func indexMap(c Circuit) map[*Wire]int { + res := make(map[*Wire]int, len(c)) + for i := range c { + res[&c[i]] = i + } + return res +} + +func statusList(c Circuit) []int { + res := make([]int, len(c)) + for i := range c { + res[i] = len(c[i].Inputs) + } + return res +} + +// TopologicalSort sorts the wires in order of dependence. Such that for any wire, any one it depends on +// occurs before it. It tries to stick to the input order as much as possible. An already sorted list will remain unchanged. +// It also sets the nbOutput flags, and a dummy IdentityGate for input wires. +// Worst-case inefficient O(n^2), but that probably won't matter since the circuits are small. +// Furthermore, it is efficient with already-close-to-sorted lists, which are the expected input +func TopologicalSort(c Circuit) []*Wire { + var data topSortData + data.index = indexMap(c) + data.outputs = outputsList(c, data.index) + data.status = statusList(c) + sorted := make([]*Wire, len(c)) + + for data.leastReady = 0; data.status[data.leastReady] != 0; data.leastReady++ { + } + + for i := range c { + sorted[i] = &c[data.leastReady] + data.markDone(data.leastReady) + } + + return sorted +} + +// Complete the circuit evaluation from input values +func (a WireAssignment) Complete(c Circuit) WireAssignment { + + sortedWires := TopologicalSort(c) + + numEvaluations := 0 + + for _, w := range sortedWires { + if !w.IsInput() { + if numEvaluations == 0 { + numEvaluations = len(a[w.Inputs[0]]) + } + evals := make([]small_rational.SmallRational, numEvaluations) + ins := make([]small_rational.SmallRational, len(w.Inputs)) + for k := 0; k < numEvaluations; k++ { + for inI, in := range w.Inputs { + ins[inI] = a[in][k] + } + evals[k] = w.Gate.Evaluate(ins...) + } + a[w] = evals + } + } + return a +} + +func (a WireAssignment) NumInstances() int { + for _, aW := range a { + return len(aW) + } + panic("empty assignment") +} + +func (a WireAssignment) NumVars() int { + for _, aW := range a { + return aW.NumVars() + } + panic("empty assignment") +} diff --git a/internal/generator/test_vector_utils/small_rational/gkr/gkr_test.go b/internal/generator/test_vector_utils/small_rational/gkr/gkr_test.go new file mode 100644 index 000000000..cc6202aad --- /dev/null +++ b/internal/generator/test_vector_utils/small_rational/gkr/gkr_test.go @@ -0,0 +1,48 @@ +package gkr + +import ( + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils" + "github.com/stretchr/testify/assert" + "testing" +) + +var two = test_vector_utils.ToElement(2) +var three = test_vector_utils.ToElement(3) +var four = test_vector_utils.ToElement(4) + +func TestSingleMulGateTwoInstances(t *testing.T) { + testSingleMulGate(t, []small_rational.SmallRational{*four, *three}, []small_rational.SmallRational{*two, *three}) +} + +func testSingleMulGate(t *testing.T, inputAssignments ...[]small_rational.SmallRational) { + + c := make(Circuit, 3) + c[2] = Wire{ + Gate: mulGate{}, + Inputs: []*Wire{&c[0], &c[1]}, + } + + assignment := WireAssignment{&c[0]: inputAssignments[0], &c[1]: inputAssignments[1]}.Complete(c) + + proof, err := Prove(c, assignment, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err) + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(1, 1))) + assert.NoError(t, err, "proof rejected") + + err = Verify(c, assignment, proof, fiatshamir.WithHash(test_vector_utils.NewMessageCounter(0, 1))) + assert.NotNil(t, err, "bad proof accepted") +} + +type mulGate struct{} + +func (g mulGate) Evaluate(element ...small_rational.SmallRational) (result small_rational.SmallRational) { + result.Mul(&element[0], &element[1]) + return +} + +func (g mulGate) Degree() int { + return 2 +} diff --git a/internal/generator/test_vector_utils/small_rational/polynomial/doc.go b/internal/generator/test_vector_utils/small_rational/polynomial/doc.go new file mode 100644 index 000000000..83479b058 --- /dev/null +++ b/internal/generator/test_vector_utils/small_rational/polynomial/doc.go @@ -0,0 +1,18 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +// Package polynomial provides polynomial methods and commitment schemes. +package polynomial diff --git a/internal/generator/test_vector_utils/small_rational/polynomial/multilin.go b/internal/generator/test_vector_utils/small_rational/polynomial/multilin.go new file mode 100644 index 000000000..9535d2783 --- /dev/null +++ b/internal/generator/test_vector_utils/small_rational/polynomial/multilin.go @@ -0,0 +1,271 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package polynomial + +import ( + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" + "math/bits" +) + +// MultiLin tracks the values of a (dense i.e. not sparse) multilinear polynomial +// The variables are X₁ through Xₙ where n = log(len(.)) +// .[∑ᵢ 2ⁱ⁻¹ bₙ₋ᵢ] = the polynomial evaluated at (b₁, b₂, ..., bₙ) +// It is understood that any hypercube evaluation can be extrapolated to a multilinear polynomial +type MultiLin []small_rational.SmallRational + +// Fold is partial evaluation function k[X₁, X₂, ..., Xₙ] → k[X₂, ..., Xₙ] by setting X₁=r +func (m *MultiLin) Fold(r small_rational.SmallRational) { + mid := len(*m) / 2 + + bottom, top := (*m)[:mid], (*m)[mid:] + + // updating bookkeeping table + // knowing that the polynomial f ∈ (k[X₂, ..., Xₙ])[X₁] is linear, we would get f(r) = f(0) + r(f(1) - f(0)) + // the following loop computes the evaluations of f(r) accordingly: + // f(r, b₂, ..., bₙ) = f(0, b₂, ..., bₙ) + r(f(1, b₂, ..., bₙ) - f(0, b₂, ..., bₙ)) + for i := 0; i < mid; i++ { + // table[i] ← table[i] + r (table[i + mid] - table[i]) + top[i].Sub(&top[i], &bottom[i]) + top[i].Mul(&top[i], &r) + bottom[i].Add(&bottom[i], &top[i]) + } + + *m = (*m)[:mid] +} + +func (m MultiLin) Sum() small_rational.SmallRational { + s := m[0] + for i := 1; i < len(m); i++ { + s.Add(&s, &m[i]) + } + return s +} + +func _clone(m MultiLin, p *Pool) MultiLin { + if p == nil { + return m.Clone() + } else { + return p.Clone(m) + } +} + +func _dump(m MultiLin, p *Pool) { + if p != nil { + p.Dump(m) + } +} + +// Evaluate extrapolate the value of the multilinear polynomial corresponding to m +// on the given coordinates +func (m MultiLin) Evaluate(coordinates []small_rational.SmallRational, p *Pool) small_rational.SmallRational { + // Folding is a mutating operation + bkCopy := _clone(m, p) + + // Evaluate step by step through repeated folding (i.e. evaluation at the first remaining variable) + for _, r := range coordinates { + bkCopy.Fold(r) + } + + result := bkCopy[0] + + _dump(bkCopy, p) + return result +} + +// Clone creates a deep copy of a bookkeeping table. +// Both multilinear interpolation and sumcheck require folding an underlying +// array, but folding changes the array. To do both one requires a deep copy +// of the bookkeeping table. +func (m MultiLin) Clone() MultiLin { + res := make(MultiLin, len(m)) + copy(res, m) + return res +} + +// Add two bookKeepingTables +func (m *MultiLin) Add(left, right MultiLin) { + size := len(left) + // Check that left and right have the same size + if len(right) != size || len(*m) != size { + panic("left, right and destination must have the right size") + } + + // Add elementwise + for i := 0; i < size; i++ { + (*m)[i].Add(&left[i], &right[i]) + } +} + +// EvalEq computes Eq(q₁, ... , qₙ, h₁, ... , hₙ) = Π₁ⁿ Eq(qᵢ, hᵢ) +// where Eq(x,y) = xy + (1-x)(1-y) = 1 - x - y + xy + xy interpolates +// +// _________________ +// | | | +// | 0 | 1 | +// |_______|_______| +// y | | | +// | 1 | 0 | +// |_______|_______| +// +// x +// +// In other words the polynomial evaluated here is the multilinear extrapolation of +// one that evaluates to q' == h' for vectors q', h' of binary values +func EvalEq(q, h []small_rational.SmallRational) small_rational.SmallRational { + var res, nxt, one, sum small_rational.SmallRational + one.SetOne() + for i := 0; i < len(q); i++ { + nxt.Mul(&q[i], &h[i]) // nxt <- qᵢ * hᵢ + nxt.Double(&nxt) // nxt <- 2 * qᵢ * hᵢ + nxt.Add(&nxt, &one) // nxt <- 1 + 2 * qᵢ * hᵢ + sum.Add(&q[i], &h[i]) // sum <- qᵢ + hᵢ TODO: Why not subtract one by one from nxt? More parallel? + + if i == 0 { + res.Sub(&nxt, &sum) // nxt <- 1 + 2 * qᵢ * hᵢ - qᵢ - hᵢ + } else { + nxt.Sub(&nxt, &sum) // nxt <- 1 + 2 * qᵢ * hᵢ - qᵢ - hᵢ + res.Mul(&res, &nxt) // res <- res * nxt + } + } + return res +} + +// Eq sets m to the representation of the polynomial Eq(q₁, ..., qₙ, *, ..., *) × m[0] +func (m *MultiLin) Eq(q []small_rational.SmallRational) { + n := len(q) + + if len(*m) != 1<= 0; i-- { + res.Mul(&res, v) + res.Add(&res, &(*p)[i]) + } + + return res +} + +// Clone returns a copy of the polynomial +func (p *Polynomial) Clone() Polynomial { + _p := make(Polynomial, len(*p)) + copy(_p, *p) + return _p +} + +// Set to another polynomial +func (p *Polynomial) Set(p1 Polynomial) { + if len(*p) != len(p1) { + *p = p1.Clone() + return + } + + for i := 0; i < len(p1); i++ { + (*p)[i].Set(&p1[i]) + } +} + +// AddConstantInPlace adds a constant to the polynomial, modifying p +func (p *Polynomial) AddConstantInPlace(c *small_rational.SmallRational) { + for i := 0; i < len(*p); i++ { + (*p)[i].Add(&(*p)[i], c) + } +} + +// SubConstantInPlace subs a constant to the polynomial, modifying p +func (p *Polynomial) SubConstantInPlace(c *small_rational.SmallRational) { + for i := 0; i < len(*p); i++ { + (*p)[i].Sub(&(*p)[i], c) + } +} + +// ScaleInPlace multiplies p by v, modifying p +func (p *Polynomial) ScaleInPlace(c *small_rational.SmallRational) { + for i := 0; i < len(*p); i++ { + (*p)[i].Mul(&(*p)[i], c) + } +} + +// Scale multiplies p0 by v, storing the result in p +func (p *Polynomial) Scale(c *small_rational.SmallRational, p0 Polynomial) { + if len(*p) != len(p0) { + *p = make(Polynomial, len(p0)) + } + for i := 0; i < len(p0); i++ { + (*p)[i].Mul(c, &p0[i]) + } +} + +// Add adds p1 to p2 +// This function allocates a new slice unless p == p1 or p == p2 +func (p *Polynomial) Add(p1, p2 Polynomial) *Polynomial { + + bigger := p1 + smaller := p2 + if len(bigger) < len(smaller) { + bigger, smaller = smaller, bigger + } + + if len(*p) == len(bigger) && (&(*p)[0] == &bigger[0]) { + for i := 0; i < len(smaller); i++ { + (*p)[i].Add(&(*p)[i], &smaller[i]) + } + return p + } + + if len(*p) == len(smaller) && (&(*p)[0] == &smaller[0]) { + for i := 0; i < len(smaller); i++ { + (*p)[i].Add(&(*p)[i], &bigger[i]) + } + *p = append(*p, bigger[len(smaller):]...) + return p + } + + res := make(Polynomial, len(bigger)) + copy(res, bigger) + for i := 0; i < len(smaller); i++ { + res[i].Add(&res[i], &smaller[i]) + } + *p = res + return p +} + +// Sub subtracts p2 from p1 +// TODO make interface more consistent with Add +func (p *Polynomial) Sub(p1, p2 Polynomial) *Polynomial { + if len(p1) != len(p2) || len(p2) != len(*p) { + return nil + } + for i := 0; i < len(*p); i++ { + (*p)[i].Sub(&p1[i], &p2[i]) + } + return p +} + +// Equal checks equality between two polynomials +func (p *Polynomial) Equal(p1 Polynomial) bool { + if (*p == nil) != (p1 == nil) { + return false + } + + if len(*p) != len(p1) { + return false + } + + for i := range p1 { + if !(*p)[i].Equal(&p1[i]) { + return false + } + } + + return true +} + +func (p Polynomial) SetZero() { + for i := 0; i < len(p); i++ { + p[i].SetZero() + } +} + +func (p Polynomial) Text(base int) string { + + var builder strings.Builder + + first := true + for d := len(p) - 1; d >= 0; d-- { + if p[d].IsZero() { + continue + } + + pD := p[d] + pDText := pD.Text(base) + + initialLen := builder.Len() + + if pDText[0] == '-' { + pDText = pDText[1:] + if first { + builder.WriteString("-") + } else { + builder.WriteString(" - ") + } + } else if !first { + builder.WriteString(" + ") + } + + first = false + + if !pD.IsOne() || d == 0 { + builder.WriteString(pDText) + } + + if builder.Len()-initialLen > 10 { + builder.WriteString("×") + } + + if d != 0 { + builder.WriteString("X") + } + if d > 1 { + builder.WriteString( + utils.ToSuperscript(strconv.Itoa(d)), + ) + } + + } + + if first { + return "0" + } + + return builder.String() +} diff --git a/ecc/bls12-381/fr/element_ops_noasm.go b/internal/generator/test_vector_utils/small_rational/polynomial/pool.go similarity index 50% rename from ecc/bls12-381/fr/element_ops_noasm.go rename to internal/generator/test_vector_utils/small_rational/polynomial/pool.go index dcff149fd..737f4f404 100644 --- a/ecc/bls12-381/fr/element_ops_noasm.go +++ b/internal/generator/test_vector_utils/small_rational/polynomial/pool.go @@ -1,6 +1,3 @@ -//go:build !amd64 -// +build !amd64 - // Copyright 2020 ConsenSys Software Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,45 +14,29 @@ // Code generated by consensys/gnark-crypto DO NOT EDIT -package fr +package polynomial -// MulBy3 x *= 3 (mod q) -func MulBy3(x *Element) { - _x := *x - x.Double(x).Add(x, &_x) -} +import ( + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" +) -// MulBy5 x *= 5 (mod q) -func MulBy5(x *Element) { - _x := *x - x.Double(x).Double(x).Add(x, &_x) +// Do as little as possible to instantiate the interface +type Pool struct { } -// MulBy13 x *= 13 (mod q) -func MulBy13(x *Element) { - var y = Element{ - 120259084260, - 15510977298029211676, - 7326335280343703402, - 5909200893219589146, - } - x.Mul(x, &y) +func NewPool(...int) (pool Pool) { + return Pool{} } -// Butterfly sets -// a = a + b (mod q) -// b = a - b (mod q) -func Butterfly(a, b *Element) { - _butterflyGeneric(a, b) -} -func mul(z, x, y *Element) { - _mulGeneric(z, x, y) +func (p *Pool) Make(n int) []small_rational.SmallRational { + return make([]small_rational.SmallRational, n) } -func fromMont(z *Element) { - _fromMontGeneric(z) +func (p *Pool) Dump(...[]small_rational.SmallRational) { } -func reduce(z *Element) { - _reduceGeneric(z) +func (p *Pool) Clone(slice []small_rational.SmallRational) []small_rational.SmallRational { + res := p.Make(len(slice)) + copy(res, slice) + return res } diff --git a/internal/generator/test_vector_utils/small_rational/small-rational.go b/internal/generator/test_vector_utils/small_rational/small-rational.go new file mode 100644 index 000000000..352cc0f9d --- /dev/null +++ b/internal/generator/test_vector_utils/small_rational/small-rational.go @@ -0,0 +1,390 @@ +package small_rational + +import ( + "crypto/rand" + "fmt" + "math/big" + "strconv" + "strings" +) + +const Bytes = 64 + +type SmallRational struct { + text string //For debugging purposes + numerator big.Int + denominator big.Int // By convention, denominator == 0 also indicates zero +} + +var smallPrimes = []*big.Int{ + big.NewInt(2), big.NewInt(3), big.NewInt(5), + big.NewInt(7), big.NewInt(11), big.NewInt(13), +} + +func bigDivides(p, a *big.Int) bool { + var remainder big.Int + remainder.Mod(a, p) + return remainder.BitLen() == 0 +} + +func (z *SmallRational) UpdateText() { + z.text = z.Text(10) +} + +func (z *SmallRational) simplify() { + + if z.numerator.BitLen() == 0 || z.denominator.BitLen() == 0 { + return + } + + var num, den big.Int + + num.Set(&z.numerator) + den.Set(&z.denominator) + + for _, p := range smallPrimes { + for bigDivides(p, &num) && bigDivides(p, &den) { + num.Div(&num, p) + den.Div(&den, p) + } + } + + z.numerator = num + z.denominator = den + +} +func (z *SmallRational) Square(x *SmallRational) *SmallRational { + var num, den big.Int + num.Mul(&x.numerator, &x.numerator) + den.Mul(&x.denominator, &x.denominator) + + z.numerator = num + z.denominator = den + + z.UpdateText() + + return z +} + +func (z *SmallRational) String() string { + z.text = z.Text(10) + return z.text +} + +func (z *SmallRational) Add(x, y *SmallRational) *SmallRational { + if x.denominator.BitLen() == 0 { + *z = *y + } else if y.denominator.BitLen() == 0 { + *z = *x + } else { + //TODO: Exploit cases where one denom divides the other + var numDen, denNum big.Int + numDen.Mul(&x.numerator, &y.denominator) + denNum.Mul(&x.denominator, &y.numerator) + + numDen.Add(&denNum, &numDen) + z.numerator = numDen //to avoid shallow copy problems + + denNum.Mul(&x.denominator, &y.denominator) + z.denominator = denNum + z.simplify() + } + + z.UpdateText() + + return z +} + +func (z *SmallRational) IsZero() bool { + return z.numerator.BitLen() == 0 || z.denominator.BitLen() == 0 +} + +func (z *SmallRational) Inverse(x *SmallRational) *SmallRational { + if x.IsZero() { + *z = *x + } else { + *z = SmallRational{numerator: x.denominator, denominator: x.numerator} + z.UpdateText() + } + + return z +} + +func (z *SmallRational) Neg(x *SmallRational) *SmallRational { + z.numerator.Neg(&x.numerator) + z.denominator = x.denominator + + if x.text == "" { + x.UpdateText() + } + + if x.text[0] == '-' { + z.text = x.text[1:] + } else { + z.text = "-" + x.text + } + + return z +} + +func (z *SmallRational) Double(x *SmallRational) *SmallRational { + + var y big.Int + + if x.denominator.Bit(0) == 0 { + z.numerator = x.numerator + y.Rsh(&x.denominator, 1) + z.denominator = y + } else { + y.Lsh(&x.numerator, 1) + z.numerator = y + z.denominator = x.denominator + } + + z.UpdateText() + + return z +} + +func (z *SmallRational) Sign() int { + return z.numerator.Sign() * z.denominator.Sign() +} + +func (z *SmallRational) MarshalJSON() ([]byte, error) { + return []byte(z.String()), nil +} + +func (z *SmallRational) UnmarshalJson(data []byte) error { + _, err := z.SetInterface(string(data)) + return err +} + +func (z *SmallRational) Equal(x *SmallRational) bool { + return z.Cmp(x) == 0 +} + +func (z *SmallRational) Sub(x, y *SmallRational) *SmallRational { + var yNeg SmallRational + yNeg.Neg(y) + z.Add(x, &yNeg) + + z.UpdateText() + return z +} + +func (z *SmallRational) Cmp(x *SmallRational) int { + zSign, xSign := z.Sign(), x.Sign() + + if zSign > xSign { + return 1 + } + if zSign < xSign { + return -1 + } + + var Z, X big.Int + Z.Mul(&z.numerator, &x.denominator) + X.Mul(&x.numerator, &z.denominator) + + Z.Abs(&Z) + X.Abs(&X) + + return Z.Cmp(&X) * zSign + +} + +func BatchInvert(a []SmallRational) []SmallRational { + res := make([]SmallRational, len(a)) + for i := range a { + res[i].Inverse(&a[i]) + } + return res +} + +func (z *SmallRational) Mul(x, y *SmallRational) *SmallRational { + var num, den big.Int + + num.Mul(&x.numerator, &y.numerator) + den.Mul(&x.denominator, &y.denominator) + + z.numerator = num + z.denominator = den + + z.simplify() + z.UpdateText() + return z +} + +func (z *SmallRational) SetOne() *SmallRational { + return z.SetInt64(1) +} + +func (z *SmallRational) SetZero() *SmallRational { + return z.SetInt64(0) +} + +func (z *SmallRational) SetInt64(i int64) *SmallRational { + z.numerator = *big.NewInt(i) + z.denominator = *big.NewInt(1) + z.text = strconv.FormatInt(i, 10) + return z +} + +func (z *SmallRational) SetRandom() (*SmallRational, error) { + + bytes := make([]byte, 1) + n, err := rand.Read(bytes) + if err != nil { + return nil, err + } + if n != len(bytes) { + return nil, fmt.Errorf("%d bytes read instead of %d", n, len(bytes)) + } + + z.numerator = *big.NewInt(int64(bytes[0]%16) - 8) + z.denominator = *big.NewInt(int64((bytes[0]) / 16)) + + z.simplify() + z.UpdateText() + + return z, nil +} + +func (z *SmallRational) SetUint64(i uint64) { + var num big.Int + num.SetUint64(i) + z.numerator = num + z.denominator = *big.NewInt(1) + z.text = strconv.FormatUint(i, 10) +} + +func (z *SmallRational) IsOne() bool { + return z.numerator.Cmp(&z.denominator) == 0 && z.denominator.BitLen() != 0 +} + +func (z *SmallRational) Text(base int) string { + + if z.denominator.BitLen() == 0 { + return "0" + } + + if z.denominator.Sign() < 0 { + var num, den big.Int + num.Neg(&z.numerator) + den.Neg(&z.denominator) + z.numerator = num + z.denominator = den + } + + if bigDivides(&z.denominator, &z.numerator) { + var num big.Int + num.Div(&z.numerator, &z.denominator) + z.numerator = num + z.denominator = *big.NewInt(1) + } + + numerator := z.numerator.Text(base) + + if z.denominator.IsInt64() && z.denominator.Int64() == 1 { + return numerator + } + + return numerator + "/" + z.denominator.Text(base) +} + +func (z *SmallRational) Set(x *SmallRational) *SmallRational { + *z = *x // shallow copy is safe because ops are never in place + return z +} + +func (z *SmallRational) SetInterface(x interface{}) (*SmallRational, error) { + + switch v := x.(type) { + case *SmallRational: + *z = *v + case SmallRational: + *z = v + case int64: + z.SetInt64(v) + case int: + z.SetInt64(int64(v)) + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + case string: + z.text = v + sep := strings.Split(v, "/") + switch len(sep) { + case 1: + if asInt, err := strconv.Atoi(sep[0]); err == nil { + z.SetInt64(int64(asInt)) + } else { + return nil, err + } + case 2: + var err error + var num, denom int + num, err = strconv.Atoi(sep[0]) + if err != nil { + return nil, err + } + denom, err = strconv.Atoi(sep[1]) + if err != nil { + return nil, err + } + z.numerator = *big.NewInt(int64(num)) + z.denominator = *big.NewInt(int64(denom)) + default: + return nil, fmt.Errorf("cannot parse \"%s\"", v) + } + default: + return nil, fmt.Errorf("cannot parse %T", x) + } + + return z, nil +} + +func bigIntToBytesSigned(dst []byte, src big.Int) { + src.FillBytes(dst[1:]) + dst[0] = 0 + if src.Sign() < 0 { + dst[0] = 255 + } +} + +func (z *SmallRational) Bytes() [Bytes]byte { + var res [Bytes]byte + bigIntToBytesSigned(res[:Bytes/2], z.numerator) + bigIntToBytesSigned(res[Bytes/2:], z.denominator) + return res +} + +func bytesToBigIntSigned(src []byte) big.Int { + var res big.Int + res.SetBytes(src[1:]) + if src[0] != 0 { + res.Neg(&res) + } + return res +} + +func (z *SmallRational) SetBytes(b []byte) { + if len(b) > Bytes/2 { + z.numerator = bytesToBigIntSigned(b[:Bytes/2]) + z.denominator = bytesToBigIntSigned(b[Bytes/2:]) + } else { + z.numerator.SetBytes(b) + z.denominator.SetInt64(1) + } + z.simplify() + z.UpdateText() +} + +func Modulus() *big.Int { + res := big.NewInt(1) + res.Lsh(res, 64) + return res +} diff --git a/internal/generator/test_vector_utils/small_rational/small_rational_test.go b/internal/generator/test_vector_utils/small_rational/small_rational_test.go new file mode 100644 index 000000000..4a91687fa --- /dev/null +++ b/internal/generator/test_vector_utils/small_rational/small_rational_test.go @@ -0,0 +1,109 @@ +package small_rational + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestCmp(t *testing.T) { + + cases := make([]SmallRational, 36) + + for i := int64(0); i < 9; i++ { + if i%2 == 0 { + cases[4*i].numerator.SetInt64((i - 4) / 2) + cases[4*i].denominator.SetInt64(1) + } else { + cases[4*i].numerator.SetInt64(i - 4) + cases[4*i].denominator.SetInt64(2) + } + + cases[4*i+1].numerator.Neg(&cases[4*i].numerator) + cases[4*i+1].denominator.Neg(&cases[4*i].denominator) + + cases[4*i+2].numerator.Lsh(&cases[4*i].numerator, 1) + cases[4*i+2].denominator.Lsh(&cases[4*i].denominator, 1) + + cases[4*i+3].numerator.Neg(&cases[4*i+2].numerator) + cases[4*i+3].denominator.Neg(&cases[4*i+2].denominator) + } + + for i := range cases { + for j := range cases { + I, J := i/4, j/4 + var expectedCmp int + cmp := cases[i].Cmp(&cases[j]) + if I < J { + expectedCmp = -1 + } else if I == J { + expectedCmp = 0 + } else { + expectedCmp = 1 + } + assert.Equal(t, expectedCmp, cmp, "comparing index %d, index %d", i, j) + } + } + + zeroIndex := len(cases) / 8 + var weirdZero SmallRational + for i := range cases { + I := i / 4 + var expectedCmp int + cmp := cases[i].Cmp(&weirdZero) + cmpNeg := weirdZero.Cmp(&cases[i]) + if I < zeroIndex { + expectedCmp = -1 + } else if I == zeroIndex { + expectedCmp = 0 + } else { + expectedCmp = 1 + } + + assert.Equal(t, expectedCmp, cmp, "comparing index %d, 0/0", i) + assert.Equal(t, -expectedCmp, cmpNeg, "comparing 0/0, index %d", i) + } +} + +func TestDouble(t *testing.T) { + values := []interface{}{1, 2, 3, 4, 5, "2/3", "3/2", "-3/-2"} + valsDoubled := []interface{}{2, 4, 6, 8, 10, "-4/-3", 3, 3} + + for i := range values { + var v, vDoubled, vDoubledExpected SmallRational + _, err := v.SetInterface(values[i]) + assert.NoError(t, err) + _, err = vDoubledExpected.SetInterface(valsDoubled[i]) + assert.NoError(t, err) + vDoubled.Double(&v) + assert.True(t, vDoubled.Equal(&vDoubledExpected), + "mismatch at %d: expected 2×%s = %s, saw %s", i, v.text, vDoubledExpected.text, vDoubled.text) + + } +} + +func TestOperandConstancy(t *testing.T) { + var p0, p, pPure SmallRational + p0.SetInt64(1) + p.SetInt64(-3) + pPure.SetInt64(-3) + + res := p + res.Add(&res, &p0) + assert.True(t, p.Equal(&pPure)) +} + +func TestSquare(t *testing.T) { + var two, four, x SmallRational + two.SetInt64(2) + four.SetInt64(4) + + x.Square(&two) + + assert.True(t, x.Equal(&four), "expected 4, saw %s", x.Text(10)) +} + +func TestSetBytes(t *testing.T) { + var c SmallRational + c.SetBytes([]byte("firstChallenge.0")) + +} diff --git a/internal/generator/test_vector_utils/small_rational/sumcheck/sumcheck.go b/internal/generator/test_vector_utils/small_rational/sumcheck/sumcheck.go new file mode 100644 index 000000000..940a8a27c --- /dev/null +++ b/internal/generator/test_vector_utils/small_rational/sumcheck/sumcheck.go @@ -0,0 +1,181 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" + "strconv" +) + +// This does not make use of parallelism and represents polynomials as lists of coefficients +// It is currently geared towards arithmetic hashes. Once we have a more unified hash function interface, this can be generified. + +// Claims to a multi-sumcheck statement. i.e. one of the form ∑_{0≤i<2ⁿ} fⱼ(i) = cⱼ for 1 ≤ j ≤ m. +// Later evolving into a claim of the form gⱼ = ∑_{0≤i<2ⁿ⁻ʲ} g(r₁, r₂, ..., rⱼ₋₁, Xⱼ, i...) +type Claims interface { + Combine(a small_rational.SmallRational) polynomial.Polynomial // Combine into the 0ᵗʰ sumcheck subclaim. Create g := ∑_{1≤j≤m} aʲ⁻¹fⱼ for which now we seek to prove ∑_{0≤i<2ⁿ} g(i) = c := ∑_{1≤j≤m} aʲ⁻¹cⱼ. Return g₁. + Next(small_rational.SmallRational) polynomial.Polynomial // Return the evaluations gⱼ(k) for 1 ≤ k < degⱼ(g). Update the claim to gⱼ₊₁ for the input value as rⱼ + VarsNum() int //number of variables + ClaimsNum() int //number of claims + ProveFinalEval(r []small_rational.SmallRational) interface{} //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +// LazyClaims is the Claims data structure on the verifier side. It is "lazy" in that it has to compute fewer things. +type LazyClaims interface { + ClaimsNum() int // ClaimsNum = m + VarsNum() int // VarsNum = n + CombinedSum(a small_rational.SmallRational) small_rational.SmallRational // CombinedSum returns c = ∑_{1≤j≤m} aʲ⁻¹cⱼ + Degree(i int) int //Degree of the total claim in the i'th variable + VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error +} + +// Proof of a multi-sumcheck statement. +type Proof struct { + PartialSumPolys []polynomial.Polynomial `json:"partialSumPolys"` + FinalEvalProof interface{} `json:"finalEvalProof"` //in case it is difficult for the verifier to compute g(r₁, ..., rₙ) on its own, the prover can provide the value and a proof +} + +func setupTranscript(claimsNum int, varsNum int, settings *fiatshamir.Settings) (challengeNames []string, err error) { + numChallenges := varsNum + if claimsNum >= 2 { + numChallenges++ + } + challengeNames = make([]string, numChallenges) + if claimsNum >= 2 { + challengeNames[0] = settings.Prefix + "comb" + } + prefix := settings.Prefix + "pSP." + for i := 0; i < varsNum; i++ { + challengeNames[i+numChallenges-varsNum] = prefix + strconv.Itoa(i) + } + if settings.Transcript == nil { + transcript := fiatshamir.NewTranscript(settings.Hash, challengeNames...) + settings.Transcript = &transcript + } + + for i := range settings.BaseChallenges { + if err = settings.Transcript.Bind(challengeNames[0], settings.BaseChallenges[i]); err != nil { + return + } + } + return +} + +func next(transcript *fiatshamir.Transcript, bindings []small_rational.SmallRational, remainingChallengeNames *[]string) (small_rational.SmallRational, error) { + challengeName := (*remainingChallengeNames)[0] + for i := range bindings { + bytes := bindings[i].Bytes() + if err := transcript.Bind(challengeName, bytes[:]); err != nil { + return small_rational.SmallRational{}, err + } + } + var res small_rational.SmallRational + bytes, err := transcript.ComputeChallenge(challengeName) + res.SetBytes(bytes) + + *remainingChallengeNames = (*remainingChallengeNames)[1:] + + return res, err +} + +// Prove create a non-interactive sumcheck proof +func Prove(claims Claims, transcriptSettings fiatshamir.Settings) (Proof, error) { + + var proof Proof + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return proof, err + } + + var combinationCoeff small_rational.SmallRational + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []small_rational.SmallRational{}, &remainingChallengeNames); err != nil { + return proof, err + } + } + + varsNum := claims.VarsNum() + proof.PartialSumPolys = make([]polynomial.Polynomial, varsNum) + proof.PartialSumPolys[0] = claims.Combine(combinationCoeff) + challenges := make([]small_rational.SmallRational, varsNum) + + for j := 0; j+1 < varsNum; j++ { + if challenges[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return proof, err + } + proof.PartialSumPolys[j+1] = claims.Next(challenges[j]) + } + + if challenges[varsNum-1], err = next(transcript, proof.PartialSumPolys[varsNum-1], &remainingChallengeNames); err != nil { + return proof, err + } + + proof.FinalEvalProof = claims.ProveFinalEval(challenges) + + return proof, nil +} + +func Verify(claims LazyClaims, proof Proof, transcriptSettings fiatshamir.Settings) error { + remainingChallengeNames, err := setupTranscript(claims.ClaimsNum(), claims.VarsNum(), &transcriptSettings) + transcript := transcriptSettings.Transcript + if err != nil { + return err + } + + var combinationCoeff small_rational.SmallRational + + if claims.ClaimsNum() >= 2 { + if combinationCoeff, err = next(transcript, []small_rational.SmallRational{}, &remainingChallengeNames); err != nil { + return err + } + } + + r := make([]small_rational.SmallRational, claims.VarsNum()) + + // Just so that there is enough room for gJ to be reused + maxDegree := claims.Degree(0) + for j := 1; j < claims.VarsNum(); j++ { + if d := claims.Degree(j); d > maxDegree { + maxDegree = d + } + } + gJ := make(polynomial.Polynomial, maxDegree+1) //At the end of iteration j, gJ = ∑_{i < 2ⁿ⁻ʲ⁻¹} g(X₁, ..., Xⱼ₊₁, i...) NOTE: n is shorthand for claims.VarsNum() + gJR := claims.CombinedSum(combinationCoeff) // At the beginning of iteration j, gJR = ∑_{i < 2ⁿ⁻ʲ} g(r₁, ..., rⱼ, i...) + + for j := 0; j < claims.VarsNum(); j++ { + if len(proof.PartialSumPolys[j]) != claims.Degree(j) { + return fmt.Errorf("malformed proof") + } + copy(gJ[1:], proof.PartialSumPolys[j]) + gJ[0].Sub(&gJR, &proof.PartialSumPolys[j][0]) // Requirement that gⱼ(0) + gⱼ(1) = gⱼ₋₁(r) + // gJ is ready + + //Prepare for the next iteration + if r[j], err = next(transcript, proof.PartialSumPolys[j], &remainingChallengeNames); err != nil { + return err + } + // This is an extremely inefficient way of interpolating. TODO: Interpolate without symbolically computing a polynomial + gJCoeffs := polynomial.InterpolateOnRange(gJ[:(claims.Degree(j) + 1)]) + gJR = gJCoeffs.Eval(&r[j]) + } + + return claims.VerifyFinalEval(r, combinationCoeff, gJR, proof.FinalEvalProof) +} diff --git a/internal/generator/test_vector_utils/small_rational/sumcheck/sumcheck_test.go b/internal/generator/test_vector_utils/small_rational/sumcheck/sumcheck_test.go new file mode 100644 index 000000000..e30e37b4d --- /dev/null +++ b/internal/generator/test_vector_utils/small_rational/sumcheck/sumcheck_test.go @@ -0,0 +1,161 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package sumcheck + +import ( + "fmt" + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/test_vector_utils" + "github.com/stretchr/testify/assert" + "hash" + "math/bits" + "strings" + "testing" +) + +type singleMultilinClaim struct { + g polynomial.MultiLin +} + +func (c singleMultilinClaim) ProveFinalEval(r []small_rational.SmallRational) interface{} { + return nil // verifier can compute the final eval itself +} + +func (c singleMultilinClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func (c singleMultilinClaim) ClaimsNum() int { + return 1 +} + +func sumForX1One(g polynomial.MultiLin) polynomial.Polynomial { + sum := g[len(g)/2] + for i := len(g)/2 + 1; i < len(g); i++ { + sum.Add(&sum, &g[i]) + } + return []small_rational.SmallRational{sum} +} + +func (c singleMultilinClaim) Combine(small_rational.SmallRational) polynomial.Polynomial { + return sumForX1One(c.g) +} + +func (c *singleMultilinClaim) Next(r small_rational.SmallRational) polynomial.Polynomial { + c.g.Fold(r) + return sumForX1One(c.g) +} + +type singleMultilinLazyClaim struct { + g polynomial.MultiLin + claimedSum small_rational.SmallRational +} + +func (c singleMultilinLazyClaim) VerifyFinalEval(r []small_rational.SmallRational, combinationCoeff small_rational.SmallRational, purportedValue small_rational.SmallRational, proof interface{}) error { + val := c.g.Evaluate(r, nil) + if val.Equal(&purportedValue) { + return nil + } + return fmt.Errorf("mismatch") +} + +func (c singleMultilinLazyClaim) CombinedSum(combinationCoeffs small_rational.SmallRational) small_rational.SmallRational { + return c.claimedSum +} + +func (c singleMultilinLazyClaim) Degree(i int) int { + return 1 +} + +func (c singleMultilinLazyClaim) ClaimsNum() int { + return 1 +} + +func (c singleMultilinLazyClaim) VarsNum() int { + return bits.TrailingZeros(uint(len(c.g))) +} + +func testSumcheckSingleClaimMultilin(polyInt []uint64, hashGenerator func() hash.Hash) error { + poly := make(polynomial.MultiLin, len(polyInt)) + for i, n := range polyInt { + poly[i].SetUint64(n) + } + + claim := singleMultilinClaim{g: poly.Clone()} + + proof, err := Prove(&claim, fiatshamir.WithHash(hashGenerator())) + if err != nil { + return err + } + + var sb strings.Builder + for _, p := range proof.PartialSumPolys { + + sb.WriteString("\t{") + for i := 0; i < len(p); i++ { + sb.WriteString(p[i].String()) + if i+1 < len(p) { + sb.WriteString(", ") + } + } + sb.WriteString("}\n") + } + + lazyClaim := singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if err = Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())); err != nil { + return err + } + + proof.PartialSumPolys[0][0].Add(&proof.PartialSumPolys[0][0], test_vector_utils.ToElement(1)) + lazyClaim = singleMultilinLazyClaim{g: poly, claimedSum: poly.Sum()} + if Verify(lazyClaim, proof, fiatshamir.WithHash(hashGenerator())) == nil { + return fmt.Errorf("bad proof accepted") + } + return nil +} + +func TestSumcheckDeterministicHashSingleClaimMultilin(t *testing.T) { + //printMsws(36) + + polys := [][]uint64{ + {1, 2, 3, 4}, // 1 + 2X₁ + X₂ + {1, 2, 3, 4, 5, 6, 7, 8}, // 1 + 4X₁ + 2X₂ + X₃ + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 1 + 8X₁ + 4X₂ + 2X₃ + X₄ + } + + const MaxStep = 4 + const MaxStart = 4 + hashGens := make([]func() hash.Hash, 0, MaxStart*MaxStep) + + for step := 0; step < MaxStep; step++ { + for startState := 0; startState < MaxStart; startState++ { + if step == 0 && startState == 1 { // unlucky case where a bad proof would be accepted + continue + } + hashGens = append(hashGens, test_vector_utils.NewMessageCounterGenerator(startState, step)) + } + } + + for _, poly := range polys { + for _, hashGen := range hashGens { + assert.NoError(t, testSumcheckSingleClaimMultilin(poly, hashGen), + "failed with poly %v and hashGen %v", poly, hashGen()) + } + } +} diff --git a/internal/generator/test_vector_utils/small_rational/test_vector_utils/test_vector_utils.go b/internal/generator/test_vector_utils/small_rational/test_vector_utils/test_vector_utils.go new file mode 100644 index 000000000..00eb39b48 --- /dev/null +++ b/internal/generator/test_vector_utils/small_rational/test_vector_utils/test_vector_utils.go @@ -0,0 +1,444 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by consensys/gnark-crypto DO NOT EDIT + +package test_vector_utils + +import ( + "encoding/json" + "fmt" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational/polynomial" + "hash" + "math/rand" + "os" + "path/filepath" + "reflect" + "sort" + "strconv" + "strings" +) + +type ElementTriplet struct { + key1 small_rational.SmallRational + key2 small_rational.SmallRational + key2Present bool + value small_rational.SmallRational + used bool +} + +func (t *ElementTriplet) CmpKey(o *ElementTriplet) int { + if cmp1 := t.key1.Cmp(&o.key1); cmp1 != 0 { + return cmp1 + } + + if t.key2Present { + if o.key2Present { + return t.key2.Cmp(&o.key2) + } + return 1 + } else { + if o.key2Present { + return -1 + } + return 0 + } +} + +var MapCache = make(map[string]*ElementMap) + +func ElementMapFromFile(path string) (*ElementMap, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if h, ok := MapCache[path]; ok { + return h, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var asMap map[string]interface{} + if err = json.Unmarshal(bytes, &asMap); err != nil { + return nil, err + } + + var h ElementMap + if h, err = CreateElementMap(asMap); err == nil { + MapCache[path] = &h + } + + return &h, err + + } else { + return nil, err + } +} + +func CreateElementMap(rawMap map[string]interface{}) (ElementMap, error) { + res := make(ElementMap, 0, len(rawMap)) + + for k, v := range rawMap { + var entry ElementTriplet + if _, err := entry.value.SetInterface(v); err != nil { + return nil, err + } + + key := strings.Split(k, ",") + switch len(key) { + case 1: + entry.key2Present = false + case 2: + entry.key2Present = true + if _, err := entry.key2.SetInterface(key[1]); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("cannot parse %T as one or two field elements", v) + } + if _, err := entry.key1.SetInterface(key[0]); err != nil { + return nil, err + } + + res = append(res, &entry) + } + + res.sort() + return res, nil +} + +type ElementMap []*ElementTriplet + +type MapHash struct { + Map *ElementMap + state small_rational.SmallRational + stateValid bool +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func (m *MapHash) Write(p []byte) (n int, err error) { + var x small_rational.SmallRational + for i := 0; i < len(p); i += small_rational.Bytes { + x.SetBytes(p[i:min(len(p), i+small_rational.Bytes)]) + if err = m.write(x); err != nil { + return + } + } + n = len(p) + return +} + +func (m *MapHash) Sum(b []byte) []byte { + mP := *m + if _, err := mP.Write(b); err != nil { + panic(err) + } + bytes := mP.state.Bytes() + return bytes[:] +} + +func (m *MapHash) Reset() { + m.stateValid = false +} + +func (m *MapHash) Size() int { + return small_rational.Bytes +} + +func (m *MapHash) BlockSize() int { + return small_rational.Bytes +} + +func (m *MapHash) write(x small_rational.SmallRational) error { + X := &x + Y := &m.state + if !m.stateValid { + Y = nil + } + var err error + if m.state, err = m.Map.FindPair(X, Y); err == nil { + m.stateValid = true + } + return err +} + +func (t *ElementTriplet) writeKey(sb *strings.Builder) { + sb.WriteRune('"') + sb.WriteString(t.key1.String()) + if t.key2Present { + sb.WriteRune(',') + sb.WriteString(t.key2.String()) + } + sb.WriteRune('"') +} + +func SaveUsedHashEntries() error { + for path, hash := range MapCache { + if err := hash.SaveUsedEntries(path); err != nil { + return err + } + } + return nil +} + +func (t *ElementTriplet) writeKeyValue(sb *strings.Builder) error { + t.writeKey(sb) + sb.WriteRune(':') + + if valueBytes, err := json.Marshal(ElementToInterface(&t.value)); err == nil { + sb.WriteString(string(valueBytes)) + return nil + } else { + return err + } +} + +func (m *ElementMap) serializedUsedEntries() (string, error) { + var sb strings.Builder + sb.WriteRune('{') + + first := true + + for _, element := range *m { + if !element.used { + continue + } + if !first { + sb.WriteRune(',') + } + first = false + sb.WriteString("\n\t") + if err := element.writeKeyValue(&sb); err != nil { + return "", err + } + } + + sb.WriteString("\n}") + + return sb.String(), nil +} + +func (m *ElementMap) SaveUsedEntries(path string) error { + + if s, err := m.serializedUsedEntries(); err != nil { + return err + } else { + return os.WriteFile(path, []byte(s), 0) + } +} + +func (m *ElementMap) sort() { + sort.Slice(*m, func(i, j int) bool { + return (*m)[i].CmpKey((*m)[j]) <= 0 + }) +} + +func (m *ElementMap) find(toFind *ElementTriplet) (small_rational.SmallRational, error) { + i := sort.Search(len(*m), func(i int) bool { return (*m)[i].CmpKey(toFind) >= 0 }) + + if i < len(*m) && (*m)[i].CmpKey(toFind) == 0 { + (*m)[i].used = true + return (*m)[i].value, nil + } + // if not found, add it: + if _, err := toFind.value.SetInterface(rand.Int63n(11) - 5); err != nil { + panic(err.Error()) + } + toFind.used = true + *m = append(*m, toFind) + m.sort() //Inefficient, but it's okay. This is only run when a new test case is introduced + + return toFind.value, nil +} + +func (m *ElementMap) FindPair(x *small_rational.SmallRational, y *small_rational.SmallRational) (small_rational.SmallRational, error) { + + toFind := ElementTriplet{ + key1: *x, + key2Present: y != nil, + } + + if y != nil { + toFind.key2 = *y + } + + return m.find(&toFind) +} + +func ToElement(i int64) *small_rational.SmallRational { + var res small_rational.SmallRational + res.SetInt64(i) + return &res +} + +type MessageCounter struct { + startState uint64 + state uint64 + step uint64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/small_rational.Bytes + 1 + m.state += uint64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/small_rational.Bytes + 1 + resI := m.state + uint64(inputBlockSize)*m.step + var res small_rational.SmallRational + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return small_rational.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return small_rational.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: uint64(startState), state: uint64(startState), step: uint64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []small_rational.SmallRational + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return small_rational.Bytes +} + +func (h *ListHash) BlockSize() int { + return small_rational.Bytes +} + +func SliceToElementSlice[T any](slice []T) ([]small_rational.SmallRational, error) { + elementSlice := make([]small_rational.SmallRational, len(slice)) + for i, v := range slice { + if _, err := elementSlice[i].SetInterface(v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []small_rational.SmallRational, b []small_rational.SmallRational) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]small_rational.SmallRational, b [][]small_rational.SmallRational) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i], b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *small_rational.SmallRational) interface{} { + text := x.Text(10) + if len(text) < 10 && !strings.Contains(text, "/") { + if i, err := strconv.Atoi(text); err != nil { + panic(err.Error()) + } else { + return i + } + } + return text +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().(small_rational.SmallRational) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/generator/test_vector_utils/small_rational/test_vector_utils/test_vector_utils_test.go b/internal/generator/test_vector_utils/small_rational/test_vector_utils/test_vector_utils_test.go new file mode 100644 index 000000000..9fb77e99a --- /dev/null +++ b/internal/generator/test_vector_utils/small_rational/test_vector_utils/test_vector_utils_test.go @@ -0,0 +1,76 @@ +package test_vector_utils + +import ( + fiatshamir "github.com/consensys/gnark-crypto/fiat-shamir" + "github.com/consensys/gnark-crypto/internal/generator/test_vector_utils/small_rational" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestHashNewElementSaved(t *testing.T) { + var hash ElementMap + + var fortyFour small_rational.SmallRational + fortyFour.SetInt64(44) + + expected, err := hash.FindPair(&fortyFour, nil) + assert.NoError(t, err) + for i := 0; i < 10; i++ { + seen, err := hash.FindPair(&fortyFour, nil) + assert.NoError(t, err) + if !expected.Equal(&seen) { + t.Errorf("expected %s saw %s", expected.String(), seen.String()) + } + } +} + +func TestHashConsistency(t *testing.T) { + var one small_rational.SmallRational + var mp ElementMap + one.SetOne() + bytes := one.Bytes() + + t1 := fiatshamir.NewTranscript(&MapHash{Map: &mp}, "0") + assert.NoError(t, t1.Bind("0", bytes[:])) + c1, err := t1.ComputeChallenge("0") + assert.NoError(t, err) + + t2 := fiatshamir.NewTranscript(&MapHash{Map: &mp}, "0") + assert.NoError(t, t2.Bind("0", bytes[:])) + c2, err := t2.ComputeChallenge("0") + assert.NoError(t, err) + + assert.Equal(t, c1, c2) +} + +func TestSaveHash(t *testing.T) { + + var one, two, three small_rational.SmallRational + one.SetInt64(1) + two.SetInt64(2) + three.SetInt64(3) + + hash := ElementMap{{ + key1: one, + key2: small_rational.SmallRational{}, + key2Present: false, + value: two, + used: true, + }, { + key1: one, + key2: one, + key2Present: true, + value: three, + used: true, + }, { + key1: two, + key2: one, + key2Present: true, + value: two, + used: false, + }} + + serialized, err := hash.serializedUsedEntries() + assert.NoError(t, err) + assert.Equal(t, "{\n\t\"1\":2,\n\t\"1,1\":3\n}", serialized) +} diff --git a/internal/generator/test_vector_utils/template/test_vector_utils.go.tmpl b/internal/generator/test_vector_utils/template/test_vector_utils.go.tmpl new file mode 100644 index 000000000..12b63bfb9 --- /dev/null +++ b/internal/generator/test_vector_utils/template/test_vector_utils.go.tmpl @@ -0,0 +1,494 @@ +import ( + "encoding/json" + "fmt" + "{{.FieldPackagePath}}" + "{{.FieldPackagePath}}/polynomial" + "hash" + {{if .RandomizeMissingHashEntries}}"math/rand"{{end}} + "os" + "path/filepath" + "sort" + "strings" + "strconv" + "reflect" +) + +type ElementTriplet struct { + key1 {{.ElementType}} + key2 {{.ElementType}} + key2Present bool + value {{.ElementType}} + used bool +} + +func (t *ElementTriplet) CmpKey(o *ElementTriplet) int { + if cmp1 := t.key1.Cmp(&o.key1); cmp1 != 0 { + return cmp1 + } + + if t.key2Present { + if o.key2Present { + return t.key2.Cmp(&o.key2) + } + return 1 + } else { + if o.key2Present { + return -1 + } + return 0 + } +} + +var MapCache = make(map[string]*ElementMap) + +func ElementMapFromFile(path string) (*ElementMap, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if h, ok := MapCache[path]; ok { + return h, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var asMap map[string]interface{} + if err = json.Unmarshal(bytes, &asMap); err != nil { + return nil, err + } + + var h ElementMap + if h, err = CreateElementMap(asMap); err == nil { + MapCache[path] = &h + } + + return &h, err + + } else { + return nil, err + } +} + +func CreateElementMap(rawMap map[string]interface{}) (ElementMap, error) { + res := make(ElementMap, 0, len(rawMap)) + + for k, v := range rawMap { + var entry ElementTriplet + if _, err := {{setElement "entry.value" "v" .ElementType }}; err != nil { + return nil, err + } + + key := strings.Split(k, ",") + switch len(key) { + case 1: + entry.key2Present = false + case 2: + entry.key2Present = true + if _, err := {{setElement "entry.key2" "key[1]" .ElementType}}; err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("cannot parse %T as one or two field elements", v) + } + if _, err := {{setElement "entry.key1" "key[0]" .ElementType }}; err != nil { + return nil, err + } + + res = append(res, &entry) + } + + res.sort() + return res, nil +} + +type ElementMap []*ElementTriplet + +type MapHash struct { + Map *ElementMap + state {{.ElementType}} + stateValid bool +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func (m *MapHash) Write(p []byte) (n int, err error) { + var x {{.ElementType}} + for i := 0; i < len(p); i += {{.FieldPackageName}}.Bytes { + x.SetBytes(p[i:min(len(p), i+{{.FieldPackageName}}.Bytes)]) + if err = m.write(x); err != nil { + return + } + } + n = len(p) + return +} + +func (m *MapHash) Sum(b []byte) []byte { + mP := *m + if _, err := mP.Write(b); err != nil { + panic(err) + } + bytes := mP.state.Bytes() + return bytes[:] +} + +func (m *MapHash) Reset() { + m.stateValid = false +} + +func (m *MapHash) Size() int { + return {{.FieldPackageName}}.Bytes +} + +func (m *MapHash) BlockSize() int { + return {{.FieldPackageName}}.Bytes +} + +func (m *MapHash) write(x {{.ElementType}}) error { + X := &x + Y := &m.state + if !m.stateValid { + Y = nil + } + var err error + if m.state, err = m.Map.FindPair(X, Y); err == nil { + m.stateValid = true + } + return err +} + +func (t *ElementTriplet) writeKey(sb *strings.Builder) { + sb.WriteRune('"') + sb.WriteString(t.key1.String()) + if t.key2Present { + sb.WriteRune(',') + sb.WriteString(t.key2.String()) + } + sb.WriteRune('"') +} + +{{- if .RandomizeMissingHashEntries}} + +func SaveUsedHashEntries() error { + for path, hash := range MapCache { + if err := hash.SaveUsedEntries(path); err != nil { + return err + } + } + return nil +} + +func (t *ElementTriplet) writeKeyValue(sb *strings.Builder) error { + t.writeKey(sb) + sb.WriteRune(':') + + if valueBytes, err := json.Marshal(ElementToInterface(&t.value)); err == nil { + sb.WriteString(string(valueBytes)) + return nil + } else { + return err + } +} + +func (m *ElementMap) serializedUsedEntries() (string, error) { + var sb strings.Builder + sb.WriteRune('{') + + first := true + + for _, element := range *m { + if !element.used { + continue + } + if !first { + sb.WriteRune(',') + } + first = false + sb.WriteString("\n\t") + if err := element.writeKeyValue(&sb); err != nil { + return "", err + } + } + + sb.WriteString("\n}") + + return sb.String(), nil +} + +func (m *ElementMap) SaveUsedEntries(path string) error { + + if s, err := m.serializedUsedEntries(); err != nil { + return err + } else { + return os.WriteFile(path, []byte(s), 0) + } +} +{{- else}} +func (m *ElementMap) UnusedEntries() []interface{} { + unused := make([]interface{}, 0) + for _, v := range *m { + if !v.used { + var vInterface interface{} + if v.key2Present { + vInterface = []interface{}{ElementToInterface(&v.key1), ElementToInterface(&v.key2)} + } else { + vInterface = ElementToInterface(&v.key1) + } + unused = append(unused, vInterface) + } + } + return unused +} +{{- end}} + +func (m *ElementMap) sort() { + sort.Slice(*m, func(i, j int) bool { + return (*m)[i].CmpKey((*m)[j]) <= 0 + }) +} + +func (m *ElementMap) find(toFind *ElementTriplet) ({{.ElementType}}, error) { + i := sort.Search(len(*m), func(i int) bool { return (*m)[i].CmpKey(toFind) >= 0 }) + + if i < len(*m) && (*m)[i].CmpKey(toFind) == 0 { + (*m)[i].used = true + return (*m)[i].value, nil + } + + {{- if .RandomizeMissingHashEntries}} + // if not found, add it: + if _, err := toFind.value.SetInterface(rand.Int63n(11) - 5); err != nil { + panic(err.Error()) + } + toFind.used = true + *m = append(*m, toFind) + m.sort() //Inefficient, but it's okay. This is only run when a new test case is introduced + + return toFind.value, nil + {{- else}} + var sb strings.Builder + sb.WriteString("no value available for input ") + toFind.writeKey(&sb) + return {{.ElementType}}{}, fmt.Errorf(sb.String()) + {{- end}} +} + +func (m *ElementMap) FindPair(x *{{.ElementType}}, y *{{.ElementType}}) ({{.ElementType}}, error) { + + toFind := ElementTriplet{ + key1: *x, + key2Present: y != nil, + } + + if y != nil { + toFind.key2 = *y + } + + return m.find(&toFind) +} + +func ToElement(i int64) *{{.ElementType}} { + var res {{.ElementType}} + res.SetInt64(i) + return &res +} + +type MessageCounter struct { + startState uint64 + state uint64 + step uint64 +} + +func (m *MessageCounter) Write(p []byte) (n int, err error) { + inputBlockSize := (len(p)-1)/{{.FieldPackageName}}.Bytes + 1 + m.state += uint64(inputBlockSize) * m.step + return len(p), nil +} + +func (m *MessageCounter) Sum(b []byte) []byte { + inputBlockSize := (len(b)-1)/{{.FieldPackageName}}.Bytes + 1 + resI := m.state + uint64(inputBlockSize)*m.step + var res {{.ElementType}} + res.SetInt64(int64(resI)) + resBytes := res.Bytes() + return resBytes[:] +} + +func (m *MessageCounter) Reset() { + m.state = m.startState +} + +func (m *MessageCounter) Size() int { + return {{.FieldPackageName}}.Bytes +} + +func (m *MessageCounter) BlockSize() int { + return {{.FieldPackageName}}.Bytes +} + +func NewMessageCounter(startState, step int) hash.Hash { + transcript := &MessageCounter{startState: uint64(startState), state: uint64(startState), step: uint64(step)} + return transcript +} + +func NewMessageCounterGenerator(startState, step int) func() hash.Hash { + return func() hash.Hash { + return NewMessageCounter(startState, step) + } +} + +type ListHash []{{.ElementType}} + +func (h *ListHash) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (h *ListHash) Sum(b []byte) []byte { + res := (*h)[0].Bytes() + *h = (*h)[1:] + return res[:] +} + +func (h *ListHash) Reset() { +} + +func (h *ListHash) Size() int { + return {{.FieldPackageName}}.Bytes +} + +func (h *ListHash) BlockSize() int { +return {{.FieldPackageName}}.Bytes +} + +{{- if eq .ElementType "fr.Element"}} +func SetElement(z *fr.Element, value interface{}) (*fr.Element, error) { + + // TODO: Put this in element.SetString? + switch v := value.(type) { + case string: + + if sep := strings.Split(v, "/"); len(sep) == 2 { + var denom fr.Element + if _, err := z.SetString(sep[0]); err != nil { + return nil, err + } + if _, err := denom.SetString(sep[1]); err != nil { + return nil, err + } + denom.Inverse(&denom) + z.Mul(z, &denom) + return z, nil + } + + case float64: + asInt := int64(v) + if float64(asInt) != v { + return nil, fmt.Errorf("cannot currently parse float") + } + z.SetInt64(asInt) + return z, nil + } + + return z.SetInterface(value) +} +{{- end}} + +{{- define "setElement element value elementType"}} +{{- if eq .elementType "fr.Element"}} SetElement(&{{.element}}, {{.value}}) +{{- else if eq .elementType "small_rational.SmallRational"}} {{.element}}.SetInterface({{.value}}) +{{- else}} + {{print "\"UNEXPECTED TYPE" .elementType "\""}} +{{- end}} +{{- end}} + +func SliceToElementSlice[T any](slice []T) ([]{{.ElementType}}, error) { + elementSlice := make([]{{.ElementType}}, len(slice)) + for i, v := range slice { + if _, err := {{setElement "elementSlice[i]" "v" .ElementType}}; err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []{{.ElementType}}, b []{{.ElementType}}) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +func SliceSliceEquals(a [][]{{.ElementType}}, b [][]{{.ElementType}}) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i],b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func PolynomialSliceEquals(a []polynomial.Polynomial, b []polynomial.Polynomial) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if err := SliceEquals(a[i],b[i]); err != nil { + return fmt.Errorf("at index %d: %w", i, err) + } + } + return nil +} + +func ElementToInterface(x *{{.ElementType}}) interface{} { + text := x.Text(10) + if len(text) < 10 && !strings.Contains(text, "/") { + if i, err := strconv.Atoi(text); err != nil { + panic(err.Error()) + } else { + return i + } + } + return text +} + +func ElementSliceToInterfaceSlice(x interface{}) []interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([]interface{}, X.Len()) + for i := range res { + xI := X.Index(i).Interface().({{.ElementType}}) + res[i] = ElementToInterface(&xI) + } + return res +} + +func ElementSliceSliceToInterfaceSliceSlice(x interface{}) [][]interface{} { + if x == nil { + return nil + } + + X := reflect.ValueOf(x) + + res := make([][]interface{}, X.Len()) + for i := range res { + res[i] = ElementSliceToInterfaceSlice(X.Index(i).Interface()) + } + + return res +} diff --git a/internal/generator/test_vector_utils/utils.go b/internal/generator/test_vector_utils/utils.go new file mode 100644 index 000000000..47b05c646 --- /dev/null +++ b/internal/generator/test_vector_utils/utils.go @@ -0,0 +1,248 @@ +package test_vector_utils + +/* +var hashCache = make(map[string]HashMap) + +func GetHash(path string) (HashMap, error) { + path, err := filepath.Abs(path) + if err != nil { + return nil, err + } + if h, ok := hashCache[path]; ok { + return h, nil + } + var bytes []byte + if bytes, err = os.ReadFile(path); err == nil { + var asMap map[string]interface{} + if err = json.Unmarshal(bytes, &asMap); err != nil { + return nil, err + } + + res := make(HashMap, 0, len(asMap)) + + for k, v := range asMap { + var entry RationalTriplet + if _, err = entry.value.SetInterface(v); err != nil { + return nil, err + } + + key := strings.Split(k, ",") + + switch len(key) { + case 1: + entry.key2Present = false + case 2: + entry.key2Present = true + if _, err = entry.key2.SetInterface(key[1]); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("cannot parse %T as one or two field elements", v) + } + if _, err = entry.key1.SetInterface(key[0]); err != nil { + return nil, err + } + + res = append(res, &entry) + } + + res.sort() + + hashCache[path] = res + + return res, nil + + } else { + return nil, err + } +} + +func (m *HashMap) SaveUsedEntries(path string) error { + + var sb strings.Builder + sb.WriteRune('[') + + first := true + + for _, element := range *m { + if !element.used { + continue + } + if !first { + sb.WriteRune(',') + } + first = false + sb.WriteString("\n\t") + element.WriteKeyValue(&sb) + } + + if !first { + sb.WriteRune(',') + } + + sb.WriteString("\n]") + + return os.WriteFile(path, []byte(sb.String()), 0) +} + +type HashMap []*RationalTriplet + +type RationalTriplet struct { + key1 small_rational.SmallRational + key2 small_rational.SmallRational + key2Present bool + value small_rational.SmallRational + used bool +} + +func (t *RationalTriplet) WriteKeyValue(sb *strings.Builder) { + sb.WriteString("\t\"") + sb.WriteString(t.key1.String()) + if t.key2Present { + sb.WriteRune(',') + sb.WriteString(t.key2.String()) + } + sb.WriteString("\":") + if valueBytes, err := json.Marshal(ElementToInterface(&t.value)); err == nil { + sb.WriteString(string(valueBytes)) + } else { + panic(err.Error()) + } +} + +func (m *HashMap) sort() { + sort.Slice(*m, func(i, j int) bool { + return (*m)[i].CmpKey((*m)[j]) <= 0 + }) +} + +func (m *HashMap) find(toFind *RationalTriplet) small_rational.SmallRational { + i := sort.Search(len(*m), func(i int) bool { return (*m)[i].CmpKey(toFind) >= 0 }) + + if i < len(*m) && (*m)[i].CmpKey(toFind) == 0 { + (*m)[i].used = true + return (*m)[i].value + } + + // if not found, add it: + if _, err := toFind.value.SetInterface(rand.Int63n(11) - 5); err != nil { + panic(err.Error()) + } + toFind.used = true + *m = append(*m, toFind) + m.sort() //Inefficient, but it's okay. This is only run when a new test case is introduced + + return toFind.value +} + +func (t *RationalTriplet) CmpKey(o *RationalTriplet) int { + if cmp1 := t.key1.Cmp(&o.key1); cmp1 != 0 { + return cmp1 + } + + if t.key2Present { + if o.key2Present { + return t.key2.Cmp(&o.key2) + } + return 1 + } else { + if o.key2Present { + return -1 + } + return 0 + } +} + +type MapHashTranscript struct { + HashMap HashMap + stateValid bool + resultAvailable bool + state small_rational.SmallRational +} + +func (m *HashMap) Hash(x *small_rational.SmallRational, y *small_rational.SmallRational) small_rational.SmallRational { + + toFind := RationalTriplet{ + key1: *x, + key2Present: y != nil, + } + + if y != nil { + toFind.key2 = *y + } + + return m.find(&toFind) +} + +func (m *MapHashTranscript) Update(i ...interface{}) { + if len(i) > 0 { + for _, x := range i { + + var xElement small_rational.SmallRational + if _, err := xElement.SetInterface(x); err != nil { + panic(err.Error()) + } + if m.stateValid { + m.state = m.HashMap.Hash(&xElement, &m.state) + } else { + m.state = m.HashMap.Hash(&xElement, nil) + } + + m.stateValid = true + } + } else { //just hash the state itself + if !m.stateValid { + panic("nothing to hash") + } + m.state = m.HashMap.Hash(&m.state, nil) + } + m.resultAvailable = true +} + +func (m *MapHashTranscript) Next(i ...interface{}) small_rational.SmallRational { + + if len(i) > 0 || !m.resultAvailable { + m.Update(i...) + } + m.resultAvailable = false + return m.state +} + +func (m *MapHashTranscript) NextN(N int, i ...interface{}) []small_rational.SmallRational { + + if len(i) > 0 { + m.Update(i...) + } + + res := make([]small_rational.SmallRational, N) + + for n := range res { + res[n] = m.Next() + } + + return res +} + +func SliceToElementSlice(slice []interface{}) ([]small_rational.SmallRational, error) { + elementSlice := make([]small_rational.SmallRational, len(slice)) + for i, v := range slice { + if _, err := elementSlice[i].SetInterface(v); err != nil { + return nil, err + } + } + return elementSlice, nil +} + +func SliceEquals(a []small_rational.SmallRational, b []small_rational.SmallRational) error { + if len(a) != len(b) { + return fmt.Errorf("length mismatch %d≠%d", len(a), len(b)) + } + for i := range a { + if !a[i].Equal(&b[i]) { + return fmt.Errorf("at index %d: %s ≠ %s", i, a[i].String(), b[i].String()) + } + } + return nil +} + +*/ diff --git a/internal/generator/tower/asm/amd64/e2.go b/internal/generator/tower/asm/amd64/e2.go index e375508ab..5fb3ea969 100644 --- a/internal/generator/tower/asm/amd64/e2.go +++ b/internal/generator/tower/asm/amd64/e2.go @@ -19,8 +19,8 @@ import ( "github.com/consensys/bavard" ramd64 "github.com/consensys/bavard/amd64" - "github.com/consensys/gnark-crypto/internal/field" - "github.com/consensys/gnark-crypto/internal/field/asm/amd64" + "github.com/consensys/gnark-crypto/field/generator/asm/amd64" + field "github.com/consensys/gnark-crypto/field/generator/config" "github.com/consensys/gnark-crypto/internal/generator/config" ) diff --git a/internal/generator/tower/asm/amd64/e2_bn254.go b/internal/generator/tower/asm/amd64/e2_bn254.go index d56274902..821338109 100644 --- a/internal/generator/tower/asm/amd64/e2_bn254.go +++ b/internal/generator/tower/asm/amd64/e2_bn254.go @@ -21,7 +21,7 @@ import ( "strings" "github.com/consensys/bavard/amd64" - gamd64 "github.com/consensys/gnark-crypto/internal/field/asm/amd64" + gamd64 "github.com/consensys/gnark-crypto/field/generator/asm/amd64" ) func (fq2 *Fq2Amd64) generateMulByNonResidueE2BN254() { diff --git a/internal/generator/tower/generate.go b/internal/generator/tower/generate.go index 5ee50d644..ba6919473 100644 --- a/internal/generator/tower/generate.go +++ b/internal/generator/tower/generate.go @@ -2,7 +2,6 @@ package tower import ( "fmt" - "io" "os" "path/filepath" @@ -13,7 +12,7 @@ import ( // Generate generates a tower 2->6->12 over fp func Generate(conf config.Curve, baseDir string, bgen *bavard.BatchGenerator) error { - if conf.Equal(config.BW6_756) || conf.Equal(config.BW6_761) || conf.Equal(config.BW6_633) || conf.Equal(config.BLS24_315) || conf.Equal(config.BLS24_317) { + if conf.Equal(config.BW6_756) || conf.Equal(config.BW6_761) || conf.Equal(config.BW6_633) || conf.Equal(config.BLS24_315) || conf.Equal(config.BLS24_317) || conf.Equal(config.SECP256K1) { return nil } @@ -78,9 +77,6 @@ func Generate(conf config.Curve, baseDir string, bgen *bavard.BatchGenerator) er return err } - if conf.Equal(config.BN254) || conf.Equal(config.BLS12_381) { - _, _ = io.WriteString(f, "// +build !amd64_adx\n") - } Fq2Amd64 := amd64.NewFq2Amd64(f, conf.Fp, conf) if err := Fq2Amd64.Generate(true); err != nil { _ = f.Close() @@ -94,19 +90,7 @@ func Generate(conf config.Curve, baseDir string, bgen *bavard.BatchGenerator) er { // fq2 assembly fName := filepath.Join(baseDir, "e2_adx_amd64.s") - f, err := os.Create(fName) - if err != nil { - return err - } - - _, _ = io.WriteString(f, "// +build amd64_adx\n") - Fq2Amd64 := amd64.NewFq2Amd64(f, conf.Fp, conf) - if err := Fq2Amd64.Generate(false); err != nil { - _ = f.Close() - return err - } - _ = f.Close() - + os.Remove(fName) } } diff --git a/internal/generator/tower/template/fq12over6over2/fq12.go.tmpl b/internal/generator/tower/template/fq12over6over2/fq12.go.tmpl index c37fa9796..a2853c37f 100644 --- a/internal/generator/tower/template/fq12over6over2/fq12.go.tmpl +++ b/internal/generator/tower/template/fq12over6over2/fq12.go.tmpl @@ -1,6 +1,5 @@ import ( "math/big" - "encoding/binary" "errors" "sync" "github.com/consensys/gnark-crypto/ecc" @@ -50,20 +49,6 @@ func (z *E12) SetOne() *E12 { return z } -// ToMont converts to Mont form -func (z *E12) ToMont() *E12 { - z.C0.ToMont() - z.C1.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E12) FromMont() *E12 { - z.C0.FromMont() - z.C1.FromMont() - return z -} - // Add set z=x+y in E12 and return z func (z *E12) Add(x, y *E12) *E12 { z.C0.Add(&x.C0, &y.C0) @@ -101,6 +86,10 @@ func (z *E12) IsZero() bool { return z.C0.IsZero() && z.C1.IsZero() } +func (z *E12) IsOne() bool { + return z.C0.IsOne() && z.C1.IsZero() +} + // Mul set z=x*y in E12 and return z func (z *E12) Mul(x, y *E12) *E12 { var a, b, c E6 @@ -585,8 +574,8 @@ func (z *E12) ExpGLV(x E12, k *big.Int) *E12 { table[14].Mul(&table[11], &table[2]) // bounds on the lattice base vectors guarantee that s1, s2 are len(r)/2 bits long max - s1.SetBigInt(&s[0]).FromMont() - s2.SetBigInt(&s[1]).FromMont() + s1 = s1.SetBigInt(&s[0]).Bits() + s2 = s2.SetBigInt(&s[1]).Bits() // loop starts from len(s1)/2 due to the bounds for i := len(s1) / 2; i >= 0; i-- { @@ -638,46 +627,43 @@ func (z *E12) Unmarshal(buf []byte) error { // Bytes returns the regular (non montgomery) value // of z as a big-endian byte array. -// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... +// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... func (z *E12) Bytes() (r [SizeOfGT]byte) { - _z := *z - _z.FromMont() - {{- $offset := mul $sizeOfFp 11}} - {{- template "putFp" dict "all" . "OffSet" $offset "From" "_z.C0.B0.A0"}} + {{- template "putFp" dict "all" . "OffSet" $offset "From" "z.C0.B0.A0"}} {{- $offset := mul $sizeOfFp 10}} - {{- template "putFp" dict "all" . "OffSet" $offset "From" "_z.C0.B0.A1"}} + {{- template "putFp" dict "all" . "OffSet" $offset "From" "z.C0.B0.A1"}} {{- $offset := mul $sizeOfFp 9}} - {{- template "putFp" dict "all" . "OffSet" $offset "From" "_z.C0.B1.A0"}} + {{- template "putFp" dict "all" . "OffSet" $offset "From" "z.C0.B1.A0"}} {{- $offset := mul $sizeOfFp 8}} - {{- template "putFp" dict "all" . "OffSet" $offset "From" "_z.C0.B1.A1"}} + {{- template "putFp" dict "all" . "OffSet" $offset "From" "z.C0.B1.A1"}} {{- $offset := mul $sizeOfFp 7}} - {{- template "putFp" dict "all" . "OffSet" $offset "From" "_z.C0.B2.A0"}} + {{- template "putFp" dict "all" . "OffSet" $offset "From" "z.C0.B2.A0"}} {{- $offset := mul $sizeOfFp 6}} - {{- template "putFp" dict "all" . "OffSet" $offset "From" "_z.C0.B2.A1"}} + {{- template "putFp" dict "all" . "OffSet" $offset "From" "z.C0.B2.A1"}} {{- $offset := mul $sizeOfFp 5}} - {{- template "putFp" dict "all" . "OffSet" $offset "From" "_z.C1.B0.A0"}} + {{- template "putFp" dict "all" . "OffSet" $offset "From" "z.C1.B0.A0"}} {{- $offset := mul $sizeOfFp 4}} - {{- template "putFp" dict "all" . "OffSet" $offset "From" "_z.C1.B0.A1"}} + {{- template "putFp" dict "all" . "OffSet" $offset "From" "z.C1.B0.A1"}} {{- $offset := mul $sizeOfFp 3}} - {{- template "putFp" dict "all" . "OffSet" $offset "From" "_z.C1.B1.A0"}} + {{- template "putFp" dict "all" . "OffSet" $offset "From" "z.C1.B1.A0"}} {{- $offset := mul $sizeOfFp 2}} - {{- template "putFp" dict "all" . "OffSet" $offset "From" "_z.C1.B1.A1"}} + {{- template "putFp" dict "all" . "OffSet" $offset "From" "z.C1.B1.A1"}} {{- $offset := mul $sizeOfFp 1}} - {{- template "putFp" dict "all" . "OffSet" $offset "From" "_z.C1.B2.A0"}} + {{- template "putFp" dict "all" . "OffSet" $offset "From" "z.C1.B2.A0"}} {{- $offset := mul $sizeOfFp 0}} - {{- template "putFp" dict "all" . "OffSet" $offset "From" "_z.C1.B2.A1"}} + {{- template "putFp" dict "all" . "OffSet" $offset "From" "z.C1.B2.A1"}} return } @@ -686,7 +672,7 @@ func (z *E12) Bytes() (r [SizeOfGT]byte) { // SetBytes interprets e as the bytes of a big-endian GT // sets z to that value (in Montgomery form), and returns z. // size(e) == {{ $sizeOfFp }} * 12 -// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... +// z.C1.B2.A1 | z.C1.B2.A0 | z.C1.B1.A1 | ... func (z *E12) SetBytes(e []byte) error { if len(e) != SizeOfGT { return errors.New("invalid buffer size") @@ -761,20 +747,15 @@ func (z *E12) IsInSubGroup() bool { return a.Equal(&b) } -{{define "putFp"}} - {{- range $i := reverse .all.Curve.Fp.NbWordsIndexesFull}} - {{- $j := mul $i 8}} - {{- $j := add $j $.OffSet}} - {{- $k := sub $.all.Curve.Fp.NbWords 1}} - {{- $k := sub $k $i}} - {{- $jj := add $j 8}} - binary.BigEndian.PutUint64(r[{{$j}}:{{$jj}}], {{$.From}}[{{$k}}]) - {{- end}} -{{end}} - -{{define "readFp"}} - {{$.To}}.SetBytes(e[{{$.OffSet}}:{{$.OffSet}} + fp.Bytes]) -{{end}} +{{- define "putFp"}} + fp.BigEndian.PutElement((*[fp.Bytes]byte)( r[{{$.OffSet}}:{{$.OffSet}} + fp.Bytes]), {{$.From}}) +{{- end}} + +{{- define "readFp"}} + if err := {{$.To}}.SetBytesCanonical(e[{{$.OffSet}}:{{$.OffSet}} + fp.Bytes]); err != nil { + return err + } +{{- end}} // CompressTorus GT/E12 element to half its size // z must be in the cyclotomic subgroup diff --git a/internal/generator/tower/template/fq12over6over2/fq2.go.tmpl b/internal/generator/tower/template/fq12over6over2/fq2.go.tmpl index 471d6af8f..c79042021 100644 --- a/internal/generator/tower/template/fq12over6over2/fq2.go.tmpl +++ b/internal/generator/tower/template/fq12over6over2/fq2.go.tmpl @@ -14,6 +14,16 @@ func (z *E2) Equal(x *E2) bool { return z.A0.Equal(&x.A0) && z.A1.Equal(&x.A1) } +// Bits +// TODO @gbotrel fixme this shouldn't return a E2 +func (z *E2) Bits() E2 { + r := E2 {} + r.A0 = z.A0.Bits() + r.A1 = z.A1.Bits() + return r +} + + // Cmp compares (lexicographic order) z and x and returns: // // -1 if z < x @@ -81,6 +91,10 @@ func (z *E2) IsZero() bool { return z.A0.IsZero() && z.A1.IsZero() } +func (z *E2) IsOne() bool { + return z.A0.IsOne() && z.A1.IsZero() +} + // Add adds two elements of E2 func (z *E2) Add(x, y *E2) *E2 { addE2(z, x, y) @@ -112,20 +126,6 @@ func (z *E2) String() string { return z.A0.String() + "+" + z.A1.String() + "*u" } -// ToMont converts to mont form -func (z *E2) ToMont() *E2 { - z.A0.ToMont() - z.A1.ToMont() - return z -} - -// FromMont converts from mont form -func (z *E2) FromMont() *E2 { - z.A0.FromMont() - z.A1.FromMont() - return z -} - // MulByElement multiplies an element in E2 by an element in fp func (z *E2) MulByElement(x *E2, y *fp.Element) *E2 { var yCopy fp.Element diff --git a/internal/generator/tower/template/fq12over6over2/fq6.go.tmpl b/internal/generator/tower/template/fq12over6over2/fq6.go.tmpl index 71337e9f6..c48ffe575 100644 --- a/internal/generator/tower/template/fq12over6over2/fq6.go.tmpl +++ b/internal/generator/tower/template/fq12over6over2/fq6.go.tmpl @@ -45,25 +45,13 @@ func (z *E6) SetRandom() (*E6, error) { return z, nil } -// IsZero returns true if the two elements are equal, fasle otherwise +// IsZero returns true if the two elements are equal, false otherwise func (z *E6) IsZero() bool { return z.B0.IsZero() && z.B1.IsZero() && z.B2.IsZero() } -// ToMont converts to Mont form -func (z *E6) ToMont() *E6 { - z.B0.ToMont() - z.B1.ToMont() - z.B2.ToMont() - return z -} - -// FromMont converts from Mont form -func (z *E6) FromMont() *E6 { - z.B0.FromMont() - z.B1.FromMont() - z.B2.FromMont() - return z +func (z *E6) IsOne() bool { + return z.B0.IsOne() && z.B1.IsZero() && z.B2.IsZero() } // Add adds two elements of E6 diff --git a/utils/decompose.go b/utils/decompose.go new file mode 100644 index 000000000..dc847b696 --- /dev/null +++ b/utils/decompose.go @@ -0,0 +1,38 @@ +// Copyright 2020 ConsenSys Software Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License.s + +package utils + +import "math/big" + +// Decompose interpret rawBytes as a bigInt x in big endian, +// and returns the digits of x (from LSB to MSB) when x is written +// in basis modulo. +func Decompose(rawBytes []byte, modulo *big.Int) (decomposed []byte) { + raw := big.NewInt(0).SetBytes(rawBytes) + + var chunk [32]byte + decomposed = make([]byte, 0, len(rawBytes)) + for raw.Cmp(modulo) >= 0 { + mod := big.NewInt(0).Mod(raw, modulo) + mod.FillBytes(chunk[:]) + decomposed = append(decomposed, chunk[:]...) + + raw.Div(raw, modulo) + } + + raw.FillBytes(chunk[:]) + decomposed = append(decomposed, chunk[:]...) + return decomposed +}