diff --git a/.gitignore b/.gitignore index c3e412fb64..bb95194a5f 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,6 @@ field/generator/addchain/** # compiled sage -> python code *.sage.py + +.DS_Store +.idea diff --git a/ecc/bls12-377/fp/element.go b/ecc/bls12-377/fp/element.go index 7862068837..c488b0e4ae 100644 --- a/ecc/bls12-377/fp/element.go +++ b/ecc/bls12-377/fp/element.go @@ -63,15 +63,25 @@ func Modulus() *big.Int { } // q (modulus) +const qElementWord0 uint64 = 9586122913090633729 +const qElementWord1 uint64 = 1660523435060625408 +const qElementWord2 uint64 = 2230234197602682880 +const qElementWord3 uint64 = 1883307231910630287 +const qElementWord4 uint64 = 14284016967150029115 +const qElementWord5 uint64 = 121098312706494698 + var qElement = Element{ - 9586122913090633729, - 1660523435060625408, - 2230234197602682880, - 1883307231910630287, - 14284016967150029115, - 121098312706494698, + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, + qElementWord4, + qElementWord5, } +// Used for Montgomery reduction. (qInvNeg) q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +const qInvNegLsw uint64 = 9586122913090633727 + // rSquare var rSquare = Element{ 13224372171368877346, @@ -197,7 +207,7 @@ func (z *Element) IsZero() bool { return (z[5] | z[4] | z[3] | z[2] | z[1] | z[0]) == 0 } -// IsUint64 returns true if z[0] >= 0 and all other words are 0 +// IsUint64 returns true if z[0] ⩾ 0 and all other words are 0 func (z *Element) IsUint64() bool { return (z[5] | z[4] | z[3] | z[2] | z[1]) == 0 } @@ -281,7 +291,7 @@ func (z *Element) SetRandom() (*Element, error) { z[5] = binary.BigEndian.Uint64(bytes[40:48]) z[5] %= 121098312706494698 - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[5] < 121098312706494698 || (z[5] == 121098312706494698 && (z[4] < 14284016967150029115 || (z[4] == 14284016967150029115 && (z[3] < 1883307231910630287 || (z[3] == 1883307231910630287 && (z[2] < 2230234197602682880 || (z[2] == 2230234197602682880 && (z[1] < 1660523435060625408 || (z[1] == 1660523435060625408 && (z[0] < 9586122913090633729))))))))))) { var b uint64 @@ -485,7 +495,90 @@ func _mulGeneric(z, x, y *Element) { z[5], z[4] = madd3(m, 121098312706494698, c[0], c[2], c[1]) } - // if z > q --> z -= q + // if z > q → z -= q + // note: this is NOT constant time + if !(z[5] < 121098312706494698 || (z[5] == 121098312706494698 && (z[4] < 14284016967150029115 || (z[4] == 14284016967150029115 && (z[3] < 1883307231910630287 || (z[3] == 1883307231910630287 && (z[2] < 2230234197602682880 || (z[2] == 2230234197602682880 && (z[1] < 1660523435060625408 || (z[1] == 1660523435060625408 && (z[0] < 9586122913090633729))))))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 9586122913090633729, 0) + z[1], b = bits.Sub64(z[1], 1660523435060625408, b) + z[2], b = bits.Sub64(z[2], 2230234197602682880, b) + z[3], b = bits.Sub64(z[3], 1883307231910630287, b) + z[4], b = bits.Sub64(z[4], 14284016967150029115, b) + z[5], _ = bits.Sub64(z[5], 121098312706494698, b) + } +} + +func _mulWGeneric(z, x *Element, y uint64) { + + var t [6]uint64 + { + // round 0 + c1, c0 := bits.Mul64(y, x[0]) + m := c0 * 9586122913090633727 + c2 := madd0(m, 9586122913090633729, c0) + c1, c0 = madd1(y, x[1], c1) + c2, t[0] = madd2(m, 1660523435060625408, c2, c0) + c1, c0 = madd1(y, x[2], c1) + c2, t[1] = madd2(m, 2230234197602682880, c2, c0) + c1, c0 = madd1(y, x[3], c1) + c2, t[2] = madd2(m, 1883307231910630287, c2, c0) + c1, c0 = madd1(y, x[4], c1) + c2, t[3] = madd2(m, 14284016967150029115, c2, c0) + c1, c0 = madd1(y, x[5], c1) + t[5], t[4] = madd3(m, 121098312706494698, c0, c2, c1) + } + { + // round 1 + m := t[0] * 9586122913090633727 + c2 := madd0(m, 9586122913090633729, t[0]) + c2, t[0] = madd2(m, 1660523435060625408, c2, t[1]) + c2, t[1] = madd2(m, 2230234197602682880, c2, t[2]) + c2, t[2] = madd2(m, 1883307231910630287, c2, t[3]) + c2, t[3] = madd2(m, 14284016967150029115, c2, t[4]) + t[5], t[4] = madd2(m, 121098312706494698, t[5], c2) + } + { + // round 2 + m := t[0] * 9586122913090633727 + c2 := madd0(m, 9586122913090633729, t[0]) + c2, t[0] = madd2(m, 1660523435060625408, c2, t[1]) + c2, t[1] = madd2(m, 2230234197602682880, c2, t[2]) + c2, t[2] = madd2(m, 1883307231910630287, c2, t[3]) + c2, t[3] = madd2(m, 14284016967150029115, c2, t[4]) + t[5], t[4] = madd2(m, 121098312706494698, t[5], c2) + } + { + // round 3 + m := t[0] * 9586122913090633727 + c2 := madd0(m, 9586122913090633729, t[0]) + c2, t[0] = madd2(m, 1660523435060625408, c2, t[1]) + c2, t[1] = madd2(m, 2230234197602682880, c2, t[2]) + c2, t[2] = madd2(m, 1883307231910630287, c2, t[3]) + c2, t[3] = madd2(m, 14284016967150029115, c2, t[4]) + t[5], t[4] = madd2(m, 121098312706494698, t[5], c2) + } + { + // round 4 + m := t[0] * 9586122913090633727 + c2 := madd0(m, 9586122913090633729, t[0]) + c2, t[0] = madd2(m, 1660523435060625408, c2, t[1]) + c2, t[1] = madd2(m, 2230234197602682880, c2, t[2]) + c2, t[2] = madd2(m, 1883307231910630287, c2, t[3]) + c2, t[3] = madd2(m, 14284016967150029115, c2, t[4]) + t[5], t[4] = madd2(m, 121098312706494698, t[5], c2) + } + { + // round 5 + m := t[0] * 9586122913090633727 + c2 := madd0(m, 9586122913090633729, t[0]) + c2, z[0] = madd2(m, 1660523435060625408, c2, t[1]) + c2, z[1] = madd2(m, 2230234197602682880, c2, t[2]) + c2, z[2] = madd2(m, 1883307231910630287, c2, t[3]) + c2, z[3] = madd2(m, 14284016967150029115, c2, t[4]) + z[5], z[4] = madd2(m, 121098312706494698, t[5], c2) + } + + // if z > q → z -= q // note: this is NOT constant time if !(z[5] < 121098312706494698 || (z[5] == 121098312706494698 && (z[4] < 14284016967150029115 || (z[4] == 14284016967150029115 && (z[3] < 1883307231910630287 || (z[3] == 1883307231910630287 && (z[2] < 2230234197602682880 || (z[2] == 2230234197602682880 && (z[1] < 1660523435060625408 || (z[1] == 1660523435060625408 && (z[0] < 9586122913090633729))))))))))) { var b uint64 @@ -568,7 +661,7 @@ func _fromMontGeneric(z *Element) { z[5] = C } - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[5] < 121098312706494698 || (z[5] == 121098312706494698 && (z[4] < 14284016967150029115 || (z[4] == 14284016967150029115 && (z[3] < 1883307231910630287 || (z[3] == 1883307231910630287 && (z[2] < 2230234197602682880 || (z[2] == 2230234197602682880 && (z[1] < 1660523435060625408 || (z[1] == 1660523435060625408 && (z[0] < 9586122913090633729))))))))))) { var b uint64 @@ -591,7 +684,7 @@ func _addGeneric(z, x, y *Element) { z[4], carry = bits.Add64(x[4], y[4], carry) z[5], _ = bits.Add64(x[5], y[5], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[5] < 121098312706494698 || (z[5] == 121098312706494698 && (z[4] < 14284016967150029115 || (z[4] == 14284016967150029115 && (z[3] < 1883307231910630287 || (z[3] == 1883307231910630287 && (z[2] < 2230234197602682880 || (z[2] == 2230234197602682880 && (z[1] < 1660523435060625408 || (z[1] == 1660523435060625408 && (z[0] < 9586122913090633729))))))))))) { var b uint64 @@ -614,7 +707,7 @@ func _doubleGeneric(z, x *Element) { z[4], carry = bits.Add64(x[4], x[4], carry) z[5], _ = bits.Add64(x[5], x[5], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[5] < 121098312706494698 || (z[5] == 121098312706494698 && (z[4] < 14284016967150029115 || (z[4] == 14284016967150029115 && (z[3] < 1883307231910630287 || (z[3] == 1883307231910630287 && (z[2] < 2230234197602682880 || (z[2] == 2230234197602682880 && (z[1] < 1660523435060625408 || (z[1] == 1660523435060625408 && (z[0] < 9586122913090633729))))))))))) { var b uint64 @@ -662,7 +755,7 @@ func _negGeneric(z, x *Element) { func _reduceGeneric(z *Element) { - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[5] < 121098312706494698 || (z[5] == 121098312706494698 && (z[4] < 14284016967150029115 || (z[4] == 14284016967150029115 && (z[3] < 1883307231910630287 || (z[3] == 1883307231910630287 && (z[2] < 2230234197602682880 || (z[2] == 2230234197602682880 && (z[1] < 1660523435060625408 || (z[1] == 1660523435060625408 && (z[0] < 9586122913090633729))))))))))) { var b uint64 @@ -778,7 +871,7 @@ func (z *Element) Exp(x Element, exponent *big.Int) *Element { } // ToMont converts z to Montgomery form -// sets and returns z = z * r^2 +// sets and returns z = z * r² func (z *Element) ToMont() *Element { return z.Mul(z, &rSquare) } @@ -912,7 +1005,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { return z } -// setBigInt assumes 0 <= v < q +// setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() @@ -1100,181 +1193,496 @@ func (z *Element) Sqrt(x *Element) *Element { } } -// Inverse z = x^-1 mod q -// Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" -// if x == 0, sets and returns z = x +func max(a int, b int) int { + if a > b { + return a + } + return b +} + +func min(a int, b int) int { + if a < b { + return a + } + return b +} + +const updateFactorsConversionBias int64 = 0x7fffffff7fffffff // (2³¹ - 1)(2³² + 1) +const updateFactorIdentityMatrixRow0 = 1 +const updateFactorIdentityMatrixRow1 = 1 << 32 + +func updateFactorsDecompose(c int64) (int64, int64) { + c += updateFactorsConversionBias + const low32BitsFilter int64 = 0xFFFFFFFF + f := c&low32BitsFilter - 0x7FFFFFFF + g := c>>32&low32BitsFilter - 0x7FFFFFFF + return f, g +} + +const k = 32 // word size / 2 +const signBitSelector = uint64(1) << 63 +const approxLowBitsN = k - 1 +const approxHighBitsN = k + 1 +const inversionCorrectionFactorWord0 = 16386826051656692015 +const inversionCorrectionFactorWord1 = 8373462824848618879 +const inversionCorrectionFactorWord2 = 7553521018781888459 +const inversionCorrectionFactorWord3 = 595240760696852504 +const inversionCorrectionFactorWord4 = 16794241053652767540 +const inversionCorrectionFactorWord5 = 43911691917702151 + +const invIterationsN = 26 + +// Inverse z = x⁻¹ mod q +// Implements "Optimized Binary GCD for Modular Inversion" +// https://github.com/pornin/bingcd/blob/main/doc/bingcd.pdf func (z *Element) Inverse(x *Element) *Element { if x.IsZero() { z.SetZero() return z } - // initialize u = q - var u = Element{ - 9586122913090633729, - 1660523435060625408, - 2230234197602682880, - 1883307231910630287, - 14284016967150029115, - 121098312706494698, + a := *x + b := Element{ + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, + qElementWord4, + qElementWord5, + } // b := q + + u := Element{1} + + // Update factors: we get [u; v]:= [f0 g0; f1 g1] [u; v] + // c_i = f_i + 2³¹ - 1 + 2³² * (g_i + 2³¹ - 1) + var c0, c1 int64 + + // Saved update factors to reduce the number of field multiplications + var pf0, pf1, pg0, pg1 int64 + + var i uint + + var v, s Element + + // Since u,v are updated every other iteration, we must make sure we terminate after evenly many iterations + // This also lets us get away with half as many updates to u,v + // To make this constant-time-ish, replace the condition with i < invIterationsN + for i = 0; i&1 == 1 || !a.IsZero(); i++ { + n := max(a.BitLen(), b.BitLen()) + aApprox, bApprox := approximate(&a, n), approximate(&b, n) + + // After 0 iterations, we have f₀ ≤ 2⁰ and f₁ < 2⁰ + // f0, g0, f1, g1 = 1, 0, 0, 1 + c0, c1 = updateFactorIdentityMatrixRow0, updateFactorIdentityMatrixRow1 + + for j := 0; j < approxLowBitsN; j++ { + + if aApprox&1 == 0 { + aApprox /= 2 + } else { + s, borrow := bits.Sub64(aApprox, bApprox, 0) + if borrow == 1 { + s = bApprox - aApprox + bApprox = aApprox + c0, c1 = c1, c0 + } + + aApprox = s / 2 + c0 = c0 - c1 + + // Now |f₀| < 2ʲ + 2ʲ = 2ʲ⁺¹ + // |f₁| ≤ 2ʲ still + } + + c1 *= 2 + // |f₁| ≤ 2ʲ⁺¹ + } + + s = a + + var g0 int64 + // from this point on c0 aliases for f0 + c0, g0 = updateFactorsDecompose(c0) + aHi := a.linearCombNonModular(&s, c0, &b, g0) + if aHi&signBitSelector != 0 { + // if aHi < 0 + c0, g0 = -c0, -g0 + aHi = a.neg(&a, aHi) + } + // right-shift a by k-1 bits + a[0] = (a[0] >> approxLowBitsN) | ((a[1]) << approxHighBitsN) + a[1] = (a[1] >> approxLowBitsN) | ((a[2]) << approxHighBitsN) + a[2] = (a[2] >> approxLowBitsN) | ((a[3]) << approxHighBitsN) + a[3] = (a[3] >> approxLowBitsN) | ((a[4]) << approxHighBitsN) + a[4] = (a[4] >> approxLowBitsN) | ((a[5]) << approxHighBitsN) + a[5] = (a[5] >> approxLowBitsN) | (aHi << approxHighBitsN) + + var f1 int64 + // from this point on c1 aliases for g0 + f1, c1 = updateFactorsDecompose(c1) + bHi := b.linearCombNonModular(&s, f1, &b, c1) + if bHi&signBitSelector != 0 { + // if bHi < 0 + f1, c1 = -f1, -c1 + bHi = b.neg(&b, bHi) + } + // right-shift b by k-1 bits + b[0] = (b[0] >> approxLowBitsN) | ((b[1]) << approxHighBitsN) + b[1] = (b[1] >> approxLowBitsN) | ((b[2]) << approxHighBitsN) + b[2] = (b[2] >> approxLowBitsN) | ((b[3]) << approxHighBitsN) + b[3] = (b[3] >> approxLowBitsN) | ((b[4]) << approxHighBitsN) + b[4] = (b[4] >> approxLowBitsN) | ((b[5]) << approxHighBitsN) + b[5] = (b[5] >> approxLowBitsN) | (bHi << approxHighBitsN) + + if i&1 == 1 { + // Combine current update factors with previously stored ones + // [f₀, g₀; f₁, g₁] ← [f₀, g₀; f₁, g₀] [pf₀, pg₀; pf₀, pg₀] + // We have |f₀|, |g₀|, |pf₀|, |pf₁| ≤ 2ᵏ⁻¹, and that |pf_i| < 2ᵏ⁻¹ for i ∈ {0, 1} + // Then for the new value we get |f₀| < 2ᵏ⁻¹ × 2ᵏ⁻¹ + 2ᵏ⁻¹ × 2ᵏ⁻¹ = 2²ᵏ⁻¹ + // Which leaves us with an extra bit for the sign + + // c0 aliases f0, c1 aliases g1 + c0, g0, f1, c1 = c0*pf0+g0*pf1, + c0*pg0+g0*pg1, + f1*pf0+c1*pf1, + f1*pg0+c1*pg1 + + s = u + u.linearCombSosSigned(&u, c0, &v, g0) + v.linearCombSosSigned(&s, f1, &v, c1) + + } else { + // Save update factors + pf0, pg0, pf1, pg1 = c0, g0, f1, c1 + } } - // initialize s = r^2 - var s = Element{ - 13224372171368877346, - 227991066186625457, - 2496666625421784173, - 13825906835078366124, - 9475172226622360569, - 30958721782860680, + // For every iteration that we miss, v is not being multiplied by 2²ᵏ⁻² + const pSq int64 = 1 << (2 * (k - 1)) + // If the function is constant-time ish, this loop will not run (probably no need to take it out explicitly) + for ; i < invIterationsN; i += 2 { + v.mulWSigned(&v, pSq) } - // r = 0 - r := Element{} + z.Mul(&v, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + inversionCorrectionFactorWord5, + }) + return z +} - v := *x +// approximate a big number x into a single 64 bit word using its uppermost and lowermost bits +// if x fits in a word as is, no approximation necessary +func approximate(x *Element, nBits int) uint64 { - var carry, borrow uint64 - var bigger bool + if nBits <= 64 { + return x[0] + } - for { - for v[0]&1 == 0 { + const mask = (uint64(1) << (k - 1)) - 1 // k-1 ones + lo := mask & x[0] - // v = v >> 1 + hiWordIndex := (nBits - 1) / 64 - v[0] = v[0]>>1 | v[1]<<63 - v[1] = v[1]>>1 | v[2]<<63 - v[2] = v[2]>>1 | v[3]<<63 - v[3] = v[3]>>1 | v[4]<<63 - v[4] = v[4]>>1 | v[5]<<63 - v[5] >>= 1 + hiWordBitsAvailable := nBits - hiWordIndex*64 + hiWordBitsUsed := min(hiWordBitsAvailable, approxHighBitsN) - if s[0]&1 == 1 { + mask_ := uint64(^((1 << (hiWordBitsAvailable - hiWordBitsUsed)) - 1)) + hi := (x[hiWordIndex] & mask_) << (64 - hiWordBitsAvailable) - // s = s + q - s[0], carry = bits.Add64(s[0], 9586122913090633729, 0) - s[1], carry = bits.Add64(s[1], 1660523435060625408, carry) - s[2], carry = bits.Add64(s[2], 2230234197602682880, carry) - s[3], carry = bits.Add64(s[3], 1883307231910630287, carry) - s[4], carry = bits.Add64(s[4], 14284016967150029115, carry) - s[5], _ = bits.Add64(s[5], 121098312706494698, carry) + mask_ = ^(1<<(approxLowBitsN+hiWordBitsUsed) - 1) + mid := (mask_ & x[hiWordIndex-1]) >> hiWordBitsUsed - } + return lo | mid | hi +} - // s = s >> 1 +func (z *Element) linearCombSosSigned(x *Element, xC int64, y *Element, yC int64) { + hi := z.linearCombNonModular(x, xC, y, yC) + z.montReduceSigned(z, hi) +} + +// montReduceSigned SOS algorithm; xHi must be at most 63 bits long. Last bit of xHi may be used as a sign bit +func (z *Element) montReduceSigned(x *Element, xHi uint64) { + + const signBitRemover = ^signBitSelector + neg := xHi&signBitSelector != 0 + // the SOS implementation requires that most significant bit is 0 + // Let X be xHi*r + x + // note that if X is negative we would have initially stored it as 2⁶⁴ r + X + xHi &= signBitRemover + // with this a negative X is now represented as 2⁶³ r + X + + var t [2*Limbs - 1]uint64 + var C uint64 + + m := x[0] * qInvNegLsw + + C = madd0(m, qElementWord0, x[0]) + C, t[1] = madd2(m, qElementWord1, x[1], C) + C, t[2] = madd2(m, qElementWord2, x[2], C) + C, t[3] = madd2(m, qElementWord3, x[3], C) + C, t[4] = madd2(m, qElementWord4, x[4], C) + C, t[5] = madd2(m, qElementWord5, x[5], C) + + // the high word of m * qElement[5] is at most 62 bits + // x[5] + C is at most 65 bits (high word at most 1 bit) + // Thus the resulting C will be at most 63 bits + t[6] = xHi + C + // xHi and C are 63 bits, therefore no overflow + + { + const i = 1 + m = t[i] * qInvNegLsw - s[0] = s[0]>>1 | s[1]<<63 - s[1] = s[1]>>1 | s[2]<<63 - s[2] = s[2]>>1 | s[3]<<63 - s[3] = s[3]>>1 | s[4]<<63 - s[4] = s[4]>>1 | s[5]<<63 - s[5] >>= 1 + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + t[i+Limbs] += C + } + { + const i = 2 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + + t[i+Limbs] += C + } + { + const i = 3 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + + t[i+Limbs] += C + } + { + const i = 4 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + + t[i+Limbs] += C + } + { + const i = 5 + m := t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, z[0] = madd2(m, qElementWord1, t[i+1], C) + C, z[1] = madd2(m, qElementWord2, t[i+2], C) + C, z[2] = madd2(m, qElementWord3, t[i+3], C) + C, z[3] = madd2(m, qElementWord4, t[i+4], C) + z[5], z[4] = madd2(m, qElementWord5, t[i+5], C) + } + + // if z > q → z -= q + // note: this is NOT constant time + if !(z[5] < 121098312706494698 || (z[5] == 121098312706494698 && (z[4] < 14284016967150029115 || (z[4] == 14284016967150029115 && (z[3] < 1883307231910630287 || (z[3] == 1883307231910630287 && (z[2] < 2230234197602682880 || (z[2] == 2230234197602682880 && (z[1] < 1660523435060625408 || (z[1] == 1660523435060625408 && (z[0] < 9586122913090633729))))))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 9586122913090633729, 0) + z[1], b = bits.Sub64(z[1], 1660523435060625408, b) + z[2], b = bits.Sub64(z[2], 2230234197602682880, b) + z[3], b = bits.Sub64(z[3], 1883307231910630287, b) + z[4], b = bits.Sub64(z[4], 14284016967150029115, b) + z[5], _ = bits.Sub64(z[5], 121098312706494698, b) + } + if neg { + // We have computed ( 2⁶³ r + X ) r⁻¹ = 2⁶³ + X r⁻¹ instead + var b uint64 + z[0], b = bits.Sub64(z[0], signBitSelector, 0) + z[1], b = bits.Sub64(z[1], 0, b) + z[2], b = bits.Sub64(z[2], 0, b) + z[3], b = bits.Sub64(z[3], 0, b) + z[4], b = bits.Sub64(z[4], 0, b) + z[5], b = bits.Sub64(z[5], 0, b) + + // Occurs iff x == 0 && xHi < 0, i.e. X = rX' for -2⁶³ ≤ X' < 0 + if b != 0 { + // z[5] = -1 + // negative: add q + const neg1 = 0xFFFFFFFFFFFFFFFF + + b = 0 + z[0], b = bits.Add64(z[0], qElementWord0, b) + z[1], b = bits.Add64(z[1], qElementWord1, b) + z[2], b = bits.Add64(z[2], qElementWord2, b) + z[3], b = bits.Add64(z[3], qElementWord3, b) + z[4], b = bits.Add64(z[4], qElementWord4, b) + z[5], _ = bits.Add64(neg1, qElementWord5, b) } - for u[0]&1 == 0 { + } +} - // u = u >> 1 +// mulWSigned mul word signed (w/ montgomery reduction) +func (z *Element) mulWSigned(x *Element, y int64) { + m := y >> 63 + _mulWGeneric(z, x, uint64((y^m)-m)) + // multiply by abs(y) + if y < 0 { + z.Neg(z) + } +} - u[0] = u[0]>>1 | u[1]<<63 - u[1] = u[1]>>1 | u[2]<<63 - u[2] = u[2]>>1 | u[3]<<63 - u[3] = u[3]>>1 | u[4]<<63 - u[4] = u[4]>>1 | u[5]<<63 - u[5] >>= 1 +func (z *Element) neg(x *Element, xHi uint64) uint64 { + var b uint64 - if r[0]&1 == 1 { + z[0], b = bits.Sub64(0, x[0], 0) + z[1], b = bits.Sub64(0, x[1], b) + z[2], b = bits.Sub64(0, x[2], b) + z[3], b = bits.Sub64(0, x[3], b) + z[4], b = bits.Sub64(0, x[4], b) + z[5], b = bits.Sub64(0, x[5], b) + xHi, _ = bits.Sub64(0, xHi, b) - // r = r + q - r[0], carry = bits.Add64(r[0], 9586122913090633729, 0) - r[1], carry = bits.Add64(r[1], 1660523435060625408, carry) - r[2], carry = bits.Add64(r[2], 2230234197602682880, carry) - r[3], carry = bits.Add64(r[3], 1883307231910630287, carry) - r[4], carry = bits.Add64(r[4], 14284016967150029115, carry) - r[5], _ = bits.Add64(r[5], 121098312706494698, carry) + return xHi +} - } +// regular multiplication by one word regular (non montgomery) +// Fewer additions than the branch-free for positive y. Could be faster on some architectures +func (z *Element) mulWRegular(x *Element, y int64) uint64 { + + // w := abs(y) + m := y >> 63 + w := uint64((y ^ m) - m) + + var c uint64 + c, z[0] = bits.Mul64(x[0], w) + c, z[1] = madd1(x[1], w, c) + c, z[2] = madd1(x[2], w, c) + c, z[3] = madd1(x[3], w, c) + c, z[4] = madd1(x[4], w, c) + c, z[5] = madd1(x[5], w, c) + + if y < 0 { + c = z.neg(z, c) + } + + return c +} + +/* +Removed: seems slower +// mulWRegular branch-free regular multiplication by one word (non montgomery) +func (z *Element) mulWRegularBf(x *Element, y int64) uint64 { + + w := uint64(y) + allNeg := uint64(y >> 63) // -1 if y < 0, 0 o.w - // r = r >> 1 + // s[0], s[1] so results are not stored immediately in z. + // x[i] will be needed in the i+1 th iteration. We don't want to overwrite it in case x = z + var s [2]uint64 + var h [2]uint64 - r[0] = r[0]>>1 | r[1]<<63 - r[1] = r[1]>>1 | r[2]<<63 - r[2] = r[2]>>1 | r[3]<<63 - r[3] = r[3]>>1 | r[4]<<63 - r[4] = r[4]>>1 | r[5]<<63 - r[5] >>= 1 + h[0], s[0] = bits.Mul64(x[0], w) + c := uint64(0) + b := uint64(0) + + { + const curI = 1 % 2 + const prevI = 1 - curI + const iMinusOne = 1 - 1 + + h[curI], s[curI] = bits.Mul64(x[1], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - // v >= u - bigger = !(v[5] < u[5] || (v[5] == u[5] && (v[4] < u[4] || (v[4] == u[4] && (v[3] < u[3] || (v[3] == u[3] && (v[2] < u[2] || (v[2] == u[2] && (v[1] < u[1] || (v[1] == u[1] && (v[0] < u[0]))))))))))) - - if bigger { - - // v = v - u - v[0], borrow = bits.Sub64(v[0], u[0], 0) - v[1], borrow = bits.Sub64(v[1], u[1], borrow) - v[2], borrow = bits.Sub64(v[2], u[2], borrow) - v[3], borrow = bits.Sub64(v[3], u[3], borrow) - v[4], borrow = bits.Sub64(v[4], u[4], borrow) - v[5], _ = bits.Sub64(v[5], u[5], borrow) - - // s = s - r - s[0], borrow = bits.Sub64(s[0], r[0], 0) - s[1], borrow = bits.Sub64(s[1], r[1], borrow) - s[2], borrow = bits.Sub64(s[2], r[2], borrow) - s[3], borrow = bits.Sub64(s[3], r[3], borrow) - s[4], borrow = bits.Sub64(s[4], r[4], borrow) - s[5], borrow = bits.Sub64(s[5], r[5], borrow) - - if borrow == 1 { - - // s = s + q - s[0], carry = bits.Add64(s[0], 9586122913090633729, 0) - s[1], carry = bits.Add64(s[1], 1660523435060625408, carry) - s[2], carry = bits.Add64(s[2], 2230234197602682880, carry) - s[3], carry = bits.Add64(s[3], 1883307231910630287, carry) - s[4], carry = bits.Add64(s[4], 14284016967150029115, carry) - s[5], _ = bits.Add64(s[5], 121098312706494698, carry) + { + const curI = 2 % 2 + const prevI = 1 - curI + const iMinusOne = 2 - 1 - } - } else { + h[curI], s[curI] = bits.Mul64(x[2], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } - // u = u - v - u[0], borrow = bits.Sub64(u[0], v[0], 0) - u[1], borrow = bits.Sub64(u[1], v[1], borrow) - u[2], borrow = bits.Sub64(u[2], v[2], borrow) - u[3], borrow = bits.Sub64(u[3], v[3], borrow) - u[4], borrow = bits.Sub64(u[4], v[4], borrow) - u[5], _ = bits.Sub64(u[5], v[5], borrow) - - // r = r - s - r[0], borrow = bits.Sub64(r[0], s[0], 0) - r[1], borrow = bits.Sub64(r[1], s[1], borrow) - r[2], borrow = bits.Sub64(r[2], s[2], borrow) - r[3], borrow = bits.Sub64(r[3], s[3], borrow) - r[4], borrow = bits.Sub64(r[4], s[4], borrow) - r[5], borrow = bits.Sub64(r[5], s[5], borrow) - - if borrow == 1 { - - // r = r + q - r[0], carry = bits.Add64(r[0], 9586122913090633729, 0) - r[1], carry = bits.Add64(r[1], 1660523435060625408, carry) - r[2], carry = bits.Add64(r[2], 2230234197602682880, carry) - r[3], carry = bits.Add64(r[3], 1883307231910630287, carry) - r[4], carry = bits.Add64(r[4], 14284016967150029115, carry) - r[5], _ = bits.Add64(r[5], 121098312706494698, carry) + { + const curI = 3 % 2 + const prevI = 1 - curI + const iMinusOne = 3 - 1 - } + h[curI], s[curI] = bits.Mul64(x[3], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (u[0] == 1) && (u[5]|u[4]|u[3]|u[2]|u[1]) == 0 { - z.Set(&r) - return z + + { + const curI = 4 % 2 + const prevI = 1 - curI + const iMinusOne = 4 - 1 + + h[curI], s[curI] = bits.Mul64(x[4], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (v[0] == 1) && (v[5]|v[4]|v[3]|v[2]|v[1]) == 0 { - z.Set(&s) - return z + + { + const curI = 5 % 2 + const prevI = 1 - curI + const iMinusOne = 5 - 1 + + h[curI], s[curI] = bits.Mul64(x[5], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } + { + const curI = 6 % 2 + const prevI = 1 - curI + const iMinusOne = 5 + + s[curI], _ = bits.Sub64(h[prevI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + + return s[curI] + c } +}*/ + +// Requires NoCarry +func (z *Element) linearCombNonModular(x *Element, xC int64, y *Element, yC int64) uint64 { + var yTimes Element + + yHi := yTimes.mulWRegular(y, yC) + xHi := z.mulWRegular(x, xC) + + carry := uint64(0) + z[0], carry = bits.Add64(z[0], yTimes[0], carry) + z[1], carry = bits.Add64(z[1], yTimes[1], carry) + z[2], carry = bits.Add64(z[2], yTimes[2], carry) + z[3], carry = bits.Add64(z[3], yTimes[3], carry) + z[4], carry = bits.Add64(z[4], yTimes[4], carry) + z[5], carry = bits.Add64(z[5], yTimes[5], carry) + + yHi, _ = bits.Add64(xHi, yHi, carry) + return yHi } diff --git a/ecc/bls12-377/fp/element_test.go b/ecc/bls12-377/fp/element_test.go index fc5cd40a7e..dbcd6a36bf 100644 --- a/ecc/bls12-377/fp/element_test.go +++ b/ecc/bls12-377/fp/element_test.go @@ -22,6 +22,7 @@ import ( "fmt" "math/big" "math/bits" + mrand "math/rand" "testing" "github.com/leanovate/gopter" @@ -275,7 +276,7 @@ var staticTestValues []Element func init() { staticTestValues = append(staticTestValues, Element{}) // zero staticTestValues = append(staticTestValues, One()) // one - staticTestValues = append(staticTestValues, rSquare) // r^2 + staticTestValues = append(staticTestValues, rSquare) // r² var e, one Element one.SetOne() e.Sub(&qElement, &one) @@ -1990,3 +1991,504 @@ func genFull() gopter.Gen { return genResult } } + +func TestElementInversionApproximation(t *testing.T) { + var x Element + for i := 0; i < 1000; i++ { + x.SetRandom() + + // Normally small elements are unlikely. Here we give them a higher chance + xZeros := mrand.Int() % Limbs + for j := 1; j < xZeros; j++ { + x[Limbs-j] = 0 + } + + a := approximate(&x, x.BitLen()) + aRef := approximateRef(&x) + + if a != aRef { + t.Error("Approximation mismatch") + } + } +} + +func TestElementInversionCorrectionFactorFormula(t *testing.T) { + const kLimbs = k * Limbs + const power = kLimbs*6 + invIterationsN*(kLimbs-k+1) + factorInt := big.NewInt(1) + factorInt.Lsh(factorInt, power) + factorInt.Mod(factorInt, Modulus()) + + var refFactorInt big.Int + inversionCorrectionFactor := Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + inversionCorrectionFactorWord5, + } + inversionCorrectionFactor.ToBigInt(&refFactorInt) + + if refFactorInt.Cmp(factorInt) != 0 { + t.Error("mismatch") + } +} + +func TestElementLinearComb(t *testing.T) { + var x Element + var y Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + y.SetRandom() + testLinearComb(t, &x, mrand.Int63(), &y, mrand.Int63()) + } +} + +// Probably unnecessary post-dev. In case the output of inv is wrong, this checks whether it's only off by a constant factor. +func TestElementInversionCorrectionFactor(t *testing.T) { + + // (1/x)/inv(x) = (1/1)/inv(1) ⇔ inv(1) = x inv(x) + + var one Element + var oneInv Element + one.SetOne() + oneInv.Inverse(&one) + + for i := 0; i < 100; i++ { + var x Element + var xInv Element + x.SetRandom() + xInv.Inverse(&x) + + x.Mul(&x, &xInv) + if !x.Equal(&oneInv) { + t.Error("Correction factor is inconsistent") + } + } + + if !oneInv.Equal(&one) { + var i big.Int + oneInv.ToBigIntRegular(&i) // no montgomery + i.ModInverse(&i, Modulus()) + var fac Element + fac.setBigInt(&i) // back to montgomery + + var facTimesFac Element + facTimesFac.Mul(&fac, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + inversionCorrectionFactorWord5, + }) + + t.Error("Correction factor is consistently off by", fac, "Should be", facTimesFac) + } +} + +func TestElementBigNumNeg(t *testing.T) { + var a Element + aHi := a.neg(&a, 0) + if !a.IsZero() || aHi != 0 { + t.Error("-0 != 0") + } +} + +func TestElementBigNumWMul(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + w := mrand.Int63() + testBigNumWMul(t, &x, w) + } +} + +func TestElementVeryBigIntConversion(t *testing.T) { + xHi := mrand.Uint64() + var x Element + x.SetRandom() + var xInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + x.assertMatchVeryBigInt(t, xHi, &xInt) +} + +func TestElementMontReducePos(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64() & ^signBitSelector) + } +} + +func TestElementMontReduceNeg(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64()|signBitSelector) + } +} + +func TestElementMontNegMultipleOfR(t *testing.T) { + var zero Element + + for i := 0; i < 1000; i++ { + testMontReduceSigned(t, &zero, mrand.Uint64()|signBitSelector) + } +} + +//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +func TestUpdateFactorSubtraction(t *testing.T) { + for i := 0; i < 1000; i++ { + + f0, g0 := randomizeUpdateFactors() + f1, g1 := randomizeUpdateFactors() + + for f0-f1 > 1<<31 || f0-f1 <= -1<<31 { + f1 /= 2 + } + + for g0-g1 > 1<<31 || g0-g1 <= -1<<31 { + g1 /= 2 + } + + c0 := updateFactorsCompose(f0, g0) + c1 := updateFactorsCompose(f1, g1) + + cRes := c0 - c1 + fRes, gRes := updateFactorsDecompose(cRes) + + if fRes != f0-f1 || gRes != g0-g1 { + t.Error(i) + } + } +} + +func TestUpdateFactorsDouble(t *testing.T) { + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f > 1<<30 || f < (-1<<31+1)/2 { + f /= 2 + if g <= 1<<29 && g >= (-1<<31+1)/4 { + g *= 2 //g was kept small on f's account. Now that we're halving f, we can double g + } + } + + if g > 1<<30 || g < (-1<<31+1)/2 { + g /= 2 + + if f <= 1<<29 && f >= (-1<<31+1)/4 { + f *= 2 //f was kept small on g's account. Now that we're halving g, we can double f + } + } + + c := updateFactorsCompose(f, g) + cD := c * 2 + fD, gD := updateFactorsDecompose(cD) + + if fD != 2*f || gD != 2*g { + t.Error(i) + } + } +} + +func TestUpdateFactorsNeg(t *testing.T) { + var fMistake bool + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f == 0x80000000 || g == 0x80000000 { + // Update factors this large can only have been obtained after 31 iterations and will therefore never be negated + // We don't have capacity to store -2³¹ + // Repeat this iteration + i-- + continue + } + + c := updateFactorsCompose(f, g) + nc := -c + nf, ng := updateFactorsDecompose(nc) + fMistake = fMistake || nf != -f + if nf != -f || ng != -g { + t.Errorf("Mismatch iteration #%d:\n%d, %d ->\n %d -> %d ->\n %d, %d\n Inputs in hex: %X, %X", + i, f, g, c, nc, nf, ng, f, g) + } + } + if fMistake { + t.Error("Mistake with f detected") + } else { + t.Log("All good with f") + } +} + +func TestUpdateFactorsNeg0(t *testing.T) { + c := updateFactorsCompose(0, 0) + t.Logf("c(0,0) = %X", c) + cn := -c + + if c != cn { + t.Error("Negation of zero update factors should yield the same result.") + } +} + +func TestUpdateFactorDecomposition(t *testing.T) { + var negSeen bool + + for i := 0; i < 1000; i++ { + + f, g := randomizeUpdateFactors() + + if f <= -(1<<31) || f > 1<<31 { + t.Fatal("f out of range") + } + + negSeen = negSeen || f < 0 + + c := updateFactorsCompose(f, g) + + fBack, gBack := updateFactorsDecompose(c) + + if f != fBack || g != gBack { + t.Errorf("(%d, %d) -> %d -> (%d, %d)\n", f, g, c, fBack, gBack) + } + } + + if !negSeen { + t.Fatal("No negative f factors") + } +} + +func TestUpdateFactorInitialValues(t *testing.T) { + + f0, g0 := updateFactorsDecompose(updateFactorIdentityMatrixRow0) + f1, g1 := updateFactorsDecompose(updateFactorIdentityMatrixRow1) + + if f0 != 1 || g0 != 0 || f1 != 0 || g1 != 1 { + t.Error("Update factor initial value constants are incorrect") + } +} + +func TestUpdateFactorsRandomization(t *testing.T) { + var maxLen int + + //t.Log("|f| + |g| is not to exceed", 1 << 31) + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + lf, lg := abs64T32(f), abs64T32(g) + absSum := lf + lg + if absSum >= 1<<31 { + + if absSum == 1<<31 { + maxLen++ + } else { + t.Error(i, "Sum of absolute values too large, f =", f, ",g =", g, ",|f| + |g| =", absSum) + } + } + } + + if maxLen == 0 { + t.Error("max len not observed") + } else { + t.Log(maxLen, "maxLens observed") + } +} + +func randomizeUpdateFactor(absLimit uint32) int64 { + const maxSizeLikelihood = 10 + maxSize := mrand.Intn(maxSizeLikelihood) + + absLimit64 := int64(absLimit) + var f int64 + switch maxSize { + case 0: + f = absLimit64 + case 1: + f = -absLimit64 + default: + f = int64(mrand.Uint64()%(2*uint64(absLimit64)+1)) - absLimit64 + } + + if f > 1<<31 { + return 1 << 31 + } else if f < -1<<31+1 { + return -1<<31 + 1 + } + + return f +} + +func abs64T32(f int64) uint32 { + if f >= 1<<32 || f < -1<<32 { + panic("f out of range") + } + + if f < 0 { + return uint32(-f) + } + return uint32(f) +} + +func randomizeUpdateFactors() (int64, int64) { + var f [2]int64 + b := mrand.Int() % 2 + + f[b] = randomizeUpdateFactor(1 << 31) + + //As per the paper, |f| + |g| \le 2³¹. + f[1-b] = randomizeUpdateFactor(1<<31 - abs64T32(f[b])) + + //Patching another edge case + if f[0]+f[1] == -1<<31 { + b = mrand.Int() % 2 + f[b]++ + } + + return f[0], f[1] +} + +func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { + + var p1 big.Int + x.ToBigInt(&p1) + p1.Mul(&p1, big.NewInt(xC)) + + var p2 big.Int + y.ToBigInt(&p2) + p2.Mul(&p2, big.NewInt(yC)) + + p1.Add(&p1, &p2) + p1.Mod(&p1, Modulus()) + montReduce(&p1, &p1) + + var z Element + z.linearCombSosSigned(x, xC, y, yC) + z.assertMatchVeryBigInt(t, 0, &p1) +} + +func testBigNumWMul(t *testing.T, a *Element, c int64) { + var aHi uint64 + var aTimes Element + aHi = aTimes.mulWRegular(a, c) + + assertMulProduct(t, a, c, &aTimes, aHi) +} + +func testMontReduceSigned(t *testing.T, x *Element, xHi uint64) { + var res Element + var xInt big.Int + var resInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + res.montReduceSigned(x, xHi) + montReduce(&resInt, &xInt) + res.assertMatchVeryBigInt(t, 0, &resInt) +} + +func updateFactorsCompose(f int64, g int64) int64 { + return f + g<<32 +} + +var rInv big.Int + +func montReduce(res *big.Int, x *big.Int) { + if rInv.BitLen() == 0 { // initialization + rInv.SetUint64(1) + rInv.Lsh(&rInv, Limbs*64) + rInv.ModInverse(&rInv, Modulus()) + } + res.Mul(x, &rInv) + res.Mod(res, Modulus()) +} + +func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { + z.ToBigInt(i) + var upperWord big.Int + upperWord.SetUint64(xHi) + upperWord.Lsh(&upperWord, Limbs*64) + i.Add(&upperWord, i) +} + +func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { + z.toVeryBigIntUnsigned(i, xHi) + if signBitSelector&xHi != 0 { + twosCompModulus := big.NewInt(1) + twosCompModulus.Lsh(twosCompModulus, (Limbs+1)*64) + i.Sub(i, twosCompModulus) + } +} + +func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { + var xInt big.Int + x.ToBigInt(&xInt) + + xInt.Mul(&xInt, big.NewInt(c)) + + result.assertMatchVeryBigInt(t, resultHi, &xInt) + return xInt +} + +func assertMatch(t *testing.T, w []big.Word, a uint64, index int) { + + var wI big.Word + + if index < len(w) { + wI = w[index] + } + + const filter uint64 = 0xFFFFFFFFFFFFFFFF >> (64 - bits.UintSize) + + a = a >> ((index * bits.UintSize) % 64) + a &= filter + + if uint64(wI) != a { + t.Error("Bignum mismatch: disagreement on word", index) + } +} + +func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { + + var modulus big.Int + var aIntMod big.Int + modulus.SetInt64(1) + modulus.Lsh(&modulus, (Limbs+1)*64) + aIntMod.Mod(aInt, &modulus) + + words := aIntMod.Bits() + + const steps = 64 / bits.UintSize + for i := 0; i < Limbs*steps; i++ { + assertMatch(t, words, z[i/steps], i) + } + + for i := 0; i < steps; i++ { + assertMatch(t, words, aHi, Limbs*steps+i) + } +} + +func approximateRef(x *Element) uint64 { + + var asInt big.Int + x.ToBigInt(&asInt) + n := x.BitLen() + + if n <= 64 { + return asInt.Uint64() + } + + modulus := big.NewInt(1 << 31) + var lo big.Int + lo.Mod(&asInt, modulus) + + modulus.Lsh(modulus, uint(n-64)) + var hi big.Int + hi.Div(&asInt, modulus) + hi.Lsh(&hi, 31) + + hi.Add(&hi, &lo) + return hi.Uint64() +} diff --git a/ecc/bls12-377/fr/element.go b/ecc/bls12-377/fr/element.go index 641041b072..bb791d7673 100644 --- a/ecc/bls12-377/fr/element.go +++ b/ecc/bls12-377/fr/element.go @@ -63,13 +63,21 @@ func Modulus() *big.Int { } // q (modulus) +const qElementWord0 uint64 = 725501752471715841 +const qElementWord1 uint64 = 6461107452199829505 +const qElementWord2 uint64 = 6968279316240510977 +const qElementWord3 uint64 = 1345280370688173398 + var qElement = Element{ - 725501752471715841, - 6461107452199829505, - 6968279316240510977, - 1345280370688173398, + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, } +// Used for Montgomery reduction. (qInvNeg) q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +const qInvNegLsw uint64 = 725501752471715839 + // rSquare var rSquare = Element{ 2726216793283724667, @@ -187,7 +195,7 @@ func (z *Element) IsZero() bool { return (z[3] | z[2] | z[1] | z[0]) == 0 } -// IsUint64 returns true if z[0] >= 0 and all other words are 0 +// IsUint64 returns true if z[0] ⩾ 0 and all other words are 0 func (z *Element) IsUint64() bool { return (z[3] | z[2] | z[1]) == 0 } @@ -257,7 +265,7 @@ func (z *Element) SetRandom() (*Element, error) { z[3] = binary.BigEndian.Uint64(bytes[24:32]) z[3] %= 1345280370688173398 - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 1345280370688173398 || (z[3] == 1345280370688173398 && (z[2] < 6968279316240510977 || (z[2] == 6968279316240510977 && (z[1] < 6461107452199829505 || (z[1] == 6461107452199829505 && (z[0] < 725501752471715841))))))) { var b uint64 @@ -405,7 +413,58 @@ func _mulGeneric(z, x, y *Element) { z[3], z[2] = madd3(m, 1345280370688173398, c[0], c[2], c[1]) } - // if z > q --> z -= q + // if z > q → z -= q + // note: this is NOT constant time + if !(z[3] < 1345280370688173398 || (z[3] == 1345280370688173398 && (z[2] < 6968279316240510977 || (z[2] == 6968279316240510977 && (z[1] < 6461107452199829505 || (z[1] == 6461107452199829505 && (z[0] < 725501752471715841))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 725501752471715841, 0) + z[1], b = bits.Sub64(z[1], 6461107452199829505, b) + z[2], b = bits.Sub64(z[2], 6968279316240510977, b) + z[3], _ = bits.Sub64(z[3], 1345280370688173398, b) + } +} + +func _mulWGeneric(z, x *Element, y uint64) { + + var t [4]uint64 + { + // round 0 + c1, c0 := bits.Mul64(y, x[0]) + m := c0 * 725501752471715839 + c2 := madd0(m, 725501752471715841, c0) + c1, c0 = madd1(y, x[1], c1) + c2, t[0] = madd2(m, 6461107452199829505, c2, c0) + c1, c0 = madd1(y, x[2], c1) + c2, t[1] = madd2(m, 6968279316240510977, c2, c0) + c1, c0 = madd1(y, x[3], c1) + t[3], t[2] = madd3(m, 1345280370688173398, c0, c2, c1) + } + { + // round 1 + m := t[0] * 725501752471715839 + c2 := madd0(m, 725501752471715841, t[0]) + c2, t[0] = madd2(m, 6461107452199829505, c2, t[1]) + c2, t[1] = madd2(m, 6968279316240510977, c2, t[2]) + t[3], t[2] = madd2(m, 1345280370688173398, t[3], c2) + } + { + // round 2 + m := t[0] * 725501752471715839 + c2 := madd0(m, 725501752471715841, t[0]) + c2, t[0] = madd2(m, 6461107452199829505, c2, t[1]) + c2, t[1] = madd2(m, 6968279316240510977, c2, t[2]) + t[3], t[2] = madd2(m, 1345280370688173398, t[3], c2) + } + { + // round 3 + m := t[0] * 725501752471715839 + c2 := madd0(m, 725501752471715841, t[0]) + c2, z[0] = madd2(m, 6461107452199829505, c2, t[1]) + c2, z[1] = madd2(m, 6968279316240510977, c2, t[2]) + z[3], z[2] = madd2(m, 1345280370688173398, t[3], c2) + } + + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 1345280370688173398 || (z[3] == 1345280370688173398 && (z[2] < 6968279316240510977 || (z[2] == 6968279316240510977 && (z[1] < 6461107452199829505 || (z[1] == 6461107452199829505 && (z[0] < 725501752471715841))))))) { var b uint64 @@ -456,7 +515,7 @@ func _fromMontGeneric(z *Element) { z[3] = C } - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 1345280370688173398 || (z[3] == 1345280370688173398 && (z[2] < 6968279316240510977 || (z[2] == 6968279316240510977 && (z[1] < 6461107452199829505 || (z[1] == 6461107452199829505 && (z[0] < 725501752471715841))))))) { var b uint64 @@ -475,7 +534,7 @@ func _addGeneric(z, x, y *Element) { z[2], carry = bits.Add64(x[2], y[2], carry) z[3], _ = bits.Add64(x[3], y[3], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 1345280370688173398 || (z[3] == 1345280370688173398 && (z[2] < 6968279316240510977 || (z[2] == 6968279316240510977 && (z[1] < 6461107452199829505 || (z[1] == 6461107452199829505 && (z[0] < 725501752471715841))))))) { var b uint64 @@ -494,7 +553,7 @@ func _doubleGeneric(z, x *Element) { z[2], carry = bits.Add64(x[2], x[2], carry) z[3], _ = bits.Add64(x[3], x[3], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 1345280370688173398 || (z[3] == 1345280370688173398 && (z[2] < 6968279316240510977 || (z[2] == 6968279316240510977 && (z[1] < 6461107452199829505 || (z[1] == 6461107452199829505 && (z[0] < 725501752471715841))))))) { var b uint64 @@ -534,7 +593,7 @@ func _negGeneric(z, x *Element) { func _reduceGeneric(z *Element) { - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 1345280370688173398 || (z[3] == 1345280370688173398 && (z[2] < 6968279316240510977 || (z[2] == 6968279316240510977 && (z[1] < 6461107452199829505 || (z[1] == 6461107452199829505 && (z[0] < 725501752471715841))))))) { var b uint64 @@ -642,7 +701,7 @@ func (z *Element) Exp(x Element, exponent *big.Int) *Element { } // ToMont converts z to Montgomery form -// sets and returns z = z * r^2 +// sets and returns z = z * r² func (z *Element) ToMont() *Element { return z.Mul(z, &rSquare) } @@ -772,7 +831,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { return z } -// setBigInt assumes 0 <= v < q +// setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() @@ -958,153 +1017,418 @@ func (z *Element) Sqrt(x *Element) *Element { } } -// Inverse z = x^-1 mod q -// Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" -// if x == 0, sets and returns z = x +func max(a int, b int) int { + if a > b { + return a + } + return b +} + +func min(a int, b int) int { + if a < b { + return a + } + return b +} + +const updateFactorsConversionBias int64 = 0x7fffffff7fffffff // (2³¹ - 1)(2³² + 1) +const updateFactorIdentityMatrixRow0 = 1 +const updateFactorIdentityMatrixRow1 = 1 << 32 + +func updateFactorsDecompose(c int64) (int64, int64) { + c += updateFactorsConversionBias + const low32BitsFilter int64 = 0xFFFFFFFF + f := c&low32BitsFilter - 0x7FFFFFFF + g := c>>32&low32BitsFilter - 0x7FFFFFFF + return f, g +} + +const k = 32 // word size / 2 +const signBitSelector = uint64(1) << 63 +const approxLowBitsN = k - 1 +const approxHighBitsN = k + 1 +const inversionCorrectionFactorWord0 = 10532255328610800637 +const inversionCorrectionFactorWord1 = 12368705355905441295 +const inversionCorrectionFactorWord2 = 4360254653162267807 +const inversionCorrectionFactorWord3 = 835375988631323795 + +const invIterationsN = 18 + +// Inverse z = x⁻¹ mod q +// Implements "Optimized Binary GCD for Modular Inversion" +// https://github.com/pornin/bingcd/blob/main/doc/bingcd.pdf func (z *Element) Inverse(x *Element) *Element { if x.IsZero() { z.SetZero() return z } - // initialize u = q - var u = Element{ - 725501752471715841, - 6461107452199829505, - 6968279316240510977, - 1345280370688173398, + a := *x + b := Element{ + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, + } // b := q + + u := Element{1} + + // Update factors: we get [u; v]:= [f0 g0; f1 g1] [u; v] + // c_i = f_i + 2³¹ - 1 + 2³² * (g_i + 2³¹ - 1) + var c0, c1 int64 + + // Saved update factors to reduce the number of field multiplications + var pf0, pf1, pg0, pg1 int64 + + var i uint + + var v, s Element + + // Since u,v are updated every other iteration, we must make sure we terminate after evenly many iterations + // This also lets us get away with half as many updates to u,v + // To make this constant-time-ish, replace the condition with i < invIterationsN + for i = 0; i&1 == 1 || !a.IsZero(); i++ { + n := max(a.BitLen(), b.BitLen()) + aApprox, bApprox := approximate(&a, n), approximate(&b, n) + + // After 0 iterations, we have f₀ ≤ 2⁰ and f₁ < 2⁰ + // f0, g0, f1, g1 = 1, 0, 0, 1 + c0, c1 = updateFactorIdentityMatrixRow0, updateFactorIdentityMatrixRow1 + + for j := 0; j < approxLowBitsN; j++ { + + if aApprox&1 == 0 { + aApprox /= 2 + } else { + s, borrow := bits.Sub64(aApprox, bApprox, 0) + if borrow == 1 { + s = bApprox - aApprox + bApprox = aApprox + c0, c1 = c1, c0 + } + + aApprox = s / 2 + c0 = c0 - c1 + + // Now |f₀| < 2ʲ + 2ʲ = 2ʲ⁺¹ + // |f₁| ≤ 2ʲ still + } + + c1 *= 2 + // |f₁| ≤ 2ʲ⁺¹ + } + + s = a + + var g0 int64 + // from this point on c0 aliases for f0 + c0, g0 = updateFactorsDecompose(c0) + aHi := a.linearCombNonModular(&s, c0, &b, g0) + if aHi&signBitSelector != 0 { + // if aHi < 0 + c0, g0 = -c0, -g0 + aHi = a.neg(&a, aHi) + } + // right-shift a by k-1 bits + a[0] = (a[0] >> approxLowBitsN) | ((a[1]) << approxHighBitsN) + a[1] = (a[1] >> approxLowBitsN) | ((a[2]) << approxHighBitsN) + a[2] = (a[2] >> approxLowBitsN) | ((a[3]) << approxHighBitsN) + a[3] = (a[3] >> approxLowBitsN) | (aHi << approxHighBitsN) + + var f1 int64 + // from this point on c1 aliases for g0 + f1, c1 = updateFactorsDecompose(c1) + bHi := b.linearCombNonModular(&s, f1, &b, c1) + if bHi&signBitSelector != 0 { + // if bHi < 0 + f1, c1 = -f1, -c1 + bHi = b.neg(&b, bHi) + } + // right-shift b by k-1 bits + b[0] = (b[0] >> approxLowBitsN) | ((b[1]) << approxHighBitsN) + b[1] = (b[1] >> approxLowBitsN) | ((b[2]) << approxHighBitsN) + b[2] = (b[2] >> approxLowBitsN) | ((b[3]) << approxHighBitsN) + b[3] = (b[3] >> approxLowBitsN) | (bHi << approxHighBitsN) + + if i&1 == 1 { + // Combine current update factors with previously stored ones + // [f₀, g₀; f₁, g₁] ← [f₀, g₀; f₁, g₀] [pf₀, pg₀; pf₀, pg₀] + // We have |f₀|, |g₀|, |pf₀|, |pf₁| ≤ 2ᵏ⁻¹, and that |pf_i| < 2ᵏ⁻¹ for i ∈ {0, 1} + // Then for the new value we get |f₀| < 2ᵏ⁻¹ × 2ᵏ⁻¹ + 2ᵏ⁻¹ × 2ᵏ⁻¹ = 2²ᵏ⁻¹ + // Which leaves us with an extra bit for the sign + + // c0 aliases f0, c1 aliases g1 + c0, g0, f1, c1 = c0*pf0+g0*pf1, + c0*pg0+g0*pg1, + f1*pf0+c1*pf1, + f1*pg0+c1*pg1 + + s = u + u.linearCombSosSigned(&u, c0, &v, g0) + v.linearCombSosSigned(&s, f1, &v, c1) + + } else { + // Save update factors + pf0, pg0, pf1, pg1 = c0, g0, f1, c1 + } } - // initialize s = r^2 - var s = Element{ - 2726216793283724667, - 14712177743343147295, - 12091039717619697043, - 81024008013859129, + // For every iteration that we miss, v is not being multiplied by 2²ᵏ⁻² + const pSq int64 = 1 << (2 * (k - 1)) + // If the function is constant-time ish, this loop will not run (probably no need to take it out explicitly) + for ; i < invIterationsN; i += 2 { + v.mulWSigned(&v, pSq) } - // r = 0 - r := Element{} + z.Mul(&v, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + }) + return z +} - v := *x +// approximate a big number x into a single 64 bit word using its uppermost and lowermost bits +// if x fits in a word as is, no approximation necessary +func approximate(x *Element, nBits int) uint64 { - var carry, borrow uint64 - var bigger bool + if nBits <= 64 { + return x[0] + } - for { - for v[0]&1 == 0 { + const mask = (uint64(1) << (k - 1)) - 1 // k-1 ones + lo := mask & x[0] - // v = v >> 1 + hiWordIndex := (nBits - 1) / 64 - v[0] = v[0]>>1 | v[1]<<63 - v[1] = v[1]>>1 | v[2]<<63 - v[2] = v[2]>>1 | v[3]<<63 - v[3] >>= 1 + hiWordBitsAvailable := nBits - hiWordIndex*64 + hiWordBitsUsed := min(hiWordBitsAvailable, approxHighBitsN) - if s[0]&1 == 1 { + mask_ := uint64(^((1 << (hiWordBitsAvailable - hiWordBitsUsed)) - 1)) + hi := (x[hiWordIndex] & mask_) << (64 - hiWordBitsAvailable) - // s = s + q - s[0], carry = bits.Add64(s[0], 725501752471715841, 0) - s[1], carry = bits.Add64(s[1], 6461107452199829505, carry) - s[2], carry = bits.Add64(s[2], 6968279316240510977, carry) - s[3], _ = bits.Add64(s[3], 1345280370688173398, carry) + mask_ = ^(1<<(approxLowBitsN+hiWordBitsUsed) - 1) + mid := (mask_ & x[hiWordIndex-1]) >> hiWordBitsUsed - } + return lo | mid | hi +} - // s = s >> 1 +func (z *Element) linearCombSosSigned(x *Element, xC int64, y *Element, yC int64) { + hi := z.linearCombNonModular(x, xC, y, yC) + z.montReduceSigned(z, hi) +} - s[0] = s[0]>>1 | s[1]<<63 - s[1] = s[1]>>1 | s[2]<<63 - s[2] = s[2]>>1 | s[3]<<63 - s[3] >>= 1 +// montReduceSigned SOS algorithm; xHi must be at most 63 bits long. Last bit of xHi may be used as a sign bit +func (z *Element) montReduceSigned(x *Element, xHi uint64) { - } - for u[0]&1 == 0 { + const signBitRemover = ^signBitSelector + neg := xHi&signBitSelector != 0 + // the SOS implementation requires that most significant bit is 0 + // Let X be xHi*r + x + // note that if X is negative we would have initially stored it as 2⁶⁴ r + X + xHi &= signBitRemover + // with this a negative X is now represented as 2⁶³ r + X - // u = u >> 1 + var t [2*Limbs - 1]uint64 + var C uint64 - u[0] = u[0]>>1 | u[1]<<63 - u[1] = u[1]>>1 | u[2]<<63 - u[2] = u[2]>>1 | u[3]<<63 - u[3] >>= 1 + m := x[0] * qInvNegLsw - if r[0]&1 == 1 { + C = madd0(m, qElementWord0, x[0]) + C, t[1] = madd2(m, qElementWord1, x[1], C) + C, t[2] = madd2(m, qElementWord2, x[2], C) + C, t[3] = madd2(m, qElementWord3, x[3], C) - // r = r + q - r[0], carry = bits.Add64(r[0], 725501752471715841, 0) - r[1], carry = bits.Add64(r[1], 6461107452199829505, carry) - r[2], carry = bits.Add64(r[2], 6968279316240510977, carry) - r[3], _ = bits.Add64(r[3], 1345280370688173398, carry) + // the high word of m * qElement[3] is at most 62 bits + // x[3] + C is at most 65 bits (high word at most 1 bit) + // Thus the resulting C will be at most 63 bits + t[4] = xHi + C + // xHi and C are 63 bits, therefore no overflow - } + { + const i = 1 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + + t[i+Limbs] += C + } + { + const i = 2 + m = t[i] * qInvNegLsw - // r = r >> 1 + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) - r[0] = r[0]>>1 | r[1]<<63 - r[1] = r[1]>>1 | r[2]<<63 - r[2] = r[2]>>1 | r[3]<<63 - r[3] >>= 1 + t[i+Limbs] += C + } + { + const i = 3 + m := t[i] * qInvNegLsw + C = madd0(m, qElementWord0, t[i+0]) + C, z[0] = madd2(m, qElementWord1, t[i+1], C) + C, z[1] = madd2(m, qElementWord2, t[i+2], C) + z[3], z[2] = madd2(m, qElementWord3, t[i+3], C) + } + + // if z > q → z -= q + // note: this is NOT constant time + if !(z[3] < 1345280370688173398 || (z[3] == 1345280370688173398 && (z[2] < 6968279316240510977 || (z[2] == 6968279316240510977 && (z[1] < 6461107452199829505 || (z[1] == 6461107452199829505 && (z[0] < 725501752471715841))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 725501752471715841, 0) + z[1], b = bits.Sub64(z[1], 6461107452199829505, b) + z[2], b = bits.Sub64(z[2], 6968279316240510977, b) + z[3], _ = bits.Sub64(z[3], 1345280370688173398, b) + } + if neg { + // We have computed ( 2⁶³ r + X ) r⁻¹ = 2⁶³ + X r⁻¹ instead + var b uint64 + z[0], b = bits.Sub64(z[0], signBitSelector, 0) + z[1], b = bits.Sub64(z[1], 0, b) + z[2], b = bits.Sub64(z[2], 0, b) + z[3], b = bits.Sub64(z[3], 0, b) + + // Occurs iff x == 0 && xHi < 0, i.e. X = rX' for -2⁶³ ≤ X' < 0 + if b != 0 { + // z[3] = -1 + // negative: add q + const neg1 = 0xFFFFFFFFFFFFFFFF + + b = 0 + z[0], b = bits.Add64(z[0], qElementWord0, b) + z[1], b = bits.Add64(z[1], qElementWord1, b) + z[2], b = bits.Add64(z[2], qElementWord2, b) + z[3], _ = bits.Add64(neg1, qElementWord3, b) } + } +} + +// mulWSigned mul word signed (w/ montgomery reduction) +func (z *Element) mulWSigned(x *Element, y int64) { + m := y >> 63 + _mulWGeneric(z, x, uint64((y^m)-m)) + // multiply by abs(y) + if y < 0 { + z.Neg(z) + } +} + +func (z *Element) neg(x *Element, xHi uint64) uint64 { + var b uint64 - // v >= u - bigger = !(v[3] < u[3] || (v[3] == u[3] && (v[2] < u[2] || (v[2] == u[2] && (v[1] < u[1] || (v[1] == u[1] && (v[0] < u[0]))))))) + z[0], b = bits.Sub64(0, x[0], 0) + z[1], b = bits.Sub64(0, x[1], b) + z[2], b = bits.Sub64(0, x[2], b) + z[3], b = bits.Sub64(0, x[3], b) + xHi, _ = bits.Sub64(0, xHi, b) - if bigger { + return xHi +} - // v = v - u - v[0], borrow = bits.Sub64(v[0], u[0], 0) - v[1], borrow = bits.Sub64(v[1], u[1], borrow) - v[2], borrow = bits.Sub64(v[2], u[2], borrow) - v[3], _ = bits.Sub64(v[3], u[3], borrow) +// regular multiplication by one word regular (non montgomery) +// Fewer additions than the branch-free for positive y. Could be faster on some architectures +func (z *Element) mulWRegular(x *Element, y int64) uint64 { - // s = s - r - s[0], borrow = bits.Sub64(s[0], r[0], 0) - s[1], borrow = bits.Sub64(s[1], r[1], borrow) - s[2], borrow = bits.Sub64(s[2], r[2], borrow) - s[3], borrow = bits.Sub64(s[3], r[3], borrow) + // w := abs(y) + m := y >> 63 + w := uint64((y ^ m) - m) - if borrow == 1 { + var c uint64 + c, z[0] = bits.Mul64(x[0], w) + c, z[1] = madd1(x[1], w, c) + c, z[2] = madd1(x[2], w, c) + c, z[3] = madd1(x[3], w, c) - // s = s + q - s[0], carry = bits.Add64(s[0], 725501752471715841, 0) - s[1], carry = bits.Add64(s[1], 6461107452199829505, carry) - s[2], carry = bits.Add64(s[2], 6968279316240510977, carry) - s[3], _ = bits.Add64(s[3], 1345280370688173398, carry) + if y < 0 { + c = z.neg(z, c) + } - } - } else { + return c +} - // u = u - v - u[0], borrow = bits.Sub64(u[0], v[0], 0) - u[1], borrow = bits.Sub64(u[1], v[1], borrow) - u[2], borrow = bits.Sub64(u[2], v[2], borrow) - u[3], _ = bits.Sub64(u[3], v[3], borrow) +/* +Removed: seems slower +// mulWRegular branch-free regular multiplication by one word (non montgomery) +func (z *Element) mulWRegularBf(x *Element, y int64) uint64 { - // r = r - s - r[0], borrow = bits.Sub64(r[0], s[0], 0) - r[1], borrow = bits.Sub64(r[1], s[1], borrow) - r[2], borrow = bits.Sub64(r[2], s[2], borrow) - r[3], borrow = bits.Sub64(r[3], s[3], borrow) + w := uint64(y) + allNeg := uint64(y >> 63) // -1 if y < 0, 0 o.w - if borrow == 1 { + // s[0], s[1] so results are not stored immediately in z. + // x[i] will be needed in the i+1 th iteration. We don't want to overwrite it in case x = z + var s [2]uint64 + var h [2]uint64 - // r = r + q - r[0], carry = bits.Add64(r[0], 725501752471715841, 0) - r[1], carry = bits.Add64(r[1], 6461107452199829505, carry) - r[2], carry = bits.Add64(r[2], 6968279316240510977, carry) - r[3], _ = bits.Add64(r[3], 1345280370688173398, carry) + h[0], s[0] = bits.Mul64(x[0], w) - } + c := uint64(0) + b := uint64(0) + + { + const curI = 1 % 2 + const prevI = 1 - curI + const iMinusOne = 1 - 1 + + h[curI], s[curI] = bits.Mul64(x[1], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (u[0] == 1) && (u[3]|u[2]|u[1]) == 0 { - z.Set(&r) - return z + + { + const curI = 2 % 2 + const prevI = 1 - curI + const iMinusOne = 2 - 1 + + h[curI], s[curI] = bits.Mul64(x[2], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (v[0] == 1) && (v[3]|v[2]|v[1]) == 0 { - z.Set(&s) - return z + + { + const curI = 3 % 2 + const prevI = 1 - curI + const iMinusOne = 3 - 1 + + h[curI], s[curI] = bits.Mul64(x[3], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } + { + const curI = 4 % 2 + const prevI = 1 - curI + const iMinusOne = 3 + + s[curI], _ = bits.Sub64(h[prevI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + + return s[curI] + c } +}*/ + +// Requires NoCarry +func (z *Element) linearCombNonModular(x *Element, xC int64, y *Element, yC int64) uint64 { + var yTimes Element + + yHi := yTimes.mulWRegular(y, yC) + xHi := z.mulWRegular(x, xC) + + carry := uint64(0) + z[0], carry = bits.Add64(z[0], yTimes[0], carry) + z[1], carry = bits.Add64(z[1], yTimes[1], carry) + z[2], carry = bits.Add64(z[2], yTimes[2], carry) + z[3], carry = bits.Add64(z[3], yTimes[3], carry) + + yHi, _ = bits.Add64(xHi, yHi, carry) + return yHi } diff --git a/ecc/bls12-377/fr/element_test.go b/ecc/bls12-377/fr/element_test.go index 5e7e532473..4debe45cb5 100644 --- a/ecc/bls12-377/fr/element_test.go +++ b/ecc/bls12-377/fr/element_test.go @@ -22,6 +22,7 @@ import ( "fmt" "math/big" "math/bits" + mrand "math/rand" "testing" "github.com/leanovate/gopter" @@ -271,7 +272,7 @@ var staticTestValues []Element func init() { staticTestValues = append(staticTestValues, Element{}) // zero staticTestValues = append(staticTestValues, One()) // one - staticTestValues = append(staticTestValues, rSquare) // r^2 + staticTestValues = append(staticTestValues, rSquare) // r² var e, one Element one.SetOne() e.Sub(&qElement, &one) @@ -1962,3 +1963,500 @@ func genFull() gopter.Gen { return genResult } } + +func TestElementInversionApproximation(t *testing.T) { + var x Element + for i := 0; i < 1000; i++ { + x.SetRandom() + + // Normally small elements are unlikely. Here we give them a higher chance + xZeros := mrand.Int() % Limbs + for j := 1; j < xZeros; j++ { + x[Limbs-j] = 0 + } + + a := approximate(&x, x.BitLen()) + aRef := approximateRef(&x) + + if a != aRef { + t.Error("Approximation mismatch") + } + } +} + +func TestElementInversionCorrectionFactorFormula(t *testing.T) { + const kLimbs = k * Limbs + const power = kLimbs*6 + invIterationsN*(kLimbs-k+1) + factorInt := big.NewInt(1) + factorInt.Lsh(factorInt, power) + factorInt.Mod(factorInt, Modulus()) + + var refFactorInt big.Int + inversionCorrectionFactor := Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + } + inversionCorrectionFactor.ToBigInt(&refFactorInt) + + if refFactorInt.Cmp(factorInt) != 0 { + t.Error("mismatch") + } +} + +func TestElementLinearComb(t *testing.T) { + var x Element + var y Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + y.SetRandom() + testLinearComb(t, &x, mrand.Int63(), &y, mrand.Int63()) + } +} + +// Probably unnecessary post-dev. In case the output of inv is wrong, this checks whether it's only off by a constant factor. +func TestElementInversionCorrectionFactor(t *testing.T) { + + // (1/x)/inv(x) = (1/1)/inv(1) ⇔ inv(1) = x inv(x) + + var one Element + var oneInv Element + one.SetOne() + oneInv.Inverse(&one) + + for i := 0; i < 100; i++ { + var x Element + var xInv Element + x.SetRandom() + xInv.Inverse(&x) + + x.Mul(&x, &xInv) + if !x.Equal(&oneInv) { + t.Error("Correction factor is inconsistent") + } + } + + if !oneInv.Equal(&one) { + var i big.Int + oneInv.ToBigIntRegular(&i) // no montgomery + i.ModInverse(&i, Modulus()) + var fac Element + fac.setBigInt(&i) // back to montgomery + + var facTimesFac Element + facTimesFac.Mul(&fac, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + }) + + t.Error("Correction factor is consistently off by", fac, "Should be", facTimesFac) + } +} + +func TestElementBigNumNeg(t *testing.T) { + var a Element + aHi := a.neg(&a, 0) + if !a.IsZero() || aHi != 0 { + t.Error("-0 != 0") + } +} + +func TestElementBigNumWMul(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + w := mrand.Int63() + testBigNumWMul(t, &x, w) + } +} + +func TestElementVeryBigIntConversion(t *testing.T) { + xHi := mrand.Uint64() + var x Element + x.SetRandom() + var xInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + x.assertMatchVeryBigInt(t, xHi, &xInt) +} + +func TestElementMontReducePos(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64() & ^signBitSelector) + } +} + +func TestElementMontReduceNeg(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64()|signBitSelector) + } +} + +func TestElementMontNegMultipleOfR(t *testing.T) { + var zero Element + + for i := 0; i < 1000; i++ { + testMontReduceSigned(t, &zero, mrand.Uint64()|signBitSelector) + } +} + +//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +func TestUpdateFactorSubtraction(t *testing.T) { + for i := 0; i < 1000; i++ { + + f0, g0 := randomizeUpdateFactors() + f1, g1 := randomizeUpdateFactors() + + for f0-f1 > 1<<31 || f0-f1 <= -1<<31 { + f1 /= 2 + } + + for g0-g1 > 1<<31 || g0-g1 <= -1<<31 { + g1 /= 2 + } + + c0 := updateFactorsCompose(f0, g0) + c1 := updateFactorsCompose(f1, g1) + + cRes := c0 - c1 + fRes, gRes := updateFactorsDecompose(cRes) + + if fRes != f0-f1 || gRes != g0-g1 { + t.Error(i) + } + } +} + +func TestUpdateFactorsDouble(t *testing.T) { + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f > 1<<30 || f < (-1<<31+1)/2 { + f /= 2 + if g <= 1<<29 && g >= (-1<<31+1)/4 { + g *= 2 //g was kept small on f's account. Now that we're halving f, we can double g + } + } + + if g > 1<<30 || g < (-1<<31+1)/2 { + g /= 2 + + if f <= 1<<29 && f >= (-1<<31+1)/4 { + f *= 2 //f was kept small on g's account. Now that we're halving g, we can double f + } + } + + c := updateFactorsCompose(f, g) + cD := c * 2 + fD, gD := updateFactorsDecompose(cD) + + if fD != 2*f || gD != 2*g { + t.Error(i) + } + } +} + +func TestUpdateFactorsNeg(t *testing.T) { + var fMistake bool + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f == 0x80000000 || g == 0x80000000 { + // Update factors this large can only have been obtained after 31 iterations and will therefore never be negated + // We don't have capacity to store -2³¹ + // Repeat this iteration + i-- + continue + } + + c := updateFactorsCompose(f, g) + nc := -c + nf, ng := updateFactorsDecompose(nc) + fMistake = fMistake || nf != -f + if nf != -f || ng != -g { + t.Errorf("Mismatch iteration #%d:\n%d, %d ->\n %d -> %d ->\n %d, %d\n Inputs in hex: %X, %X", + i, f, g, c, nc, nf, ng, f, g) + } + } + if fMistake { + t.Error("Mistake with f detected") + } else { + t.Log("All good with f") + } +} + +func TestUpdateFactorsNeg0(t *testing.T) { + c := updateFactorsCompose(0, 0) + t.Logf("c(0,0) = %X", c) + cn := -c + + if c != cn { + t.Error("Negation of zero update factors should yield the same result.") + } +} + +func TestUpdateFactorDecomposition(t *testing.T) { + var negSeen bool + + for i := 0; i < 1000; i++ { + + f, g := randomizeUpdateFactors() + + if f <= -(1<<31) || f > 1<<31 { + t.Fatal("f out of range") + } + + negSeen = negSeen || f < 0 + + c := updateFactorsCompose(f, g) + + fBack, gBack := updateFactorsDecompose(c) + + if f != fBack || g != gBack { + t.Errorf("(%d, %d) -> %d -> (%d, %d)\n", f, g, c, fBack, gBack) + } + } + + if !negSeen { + t.Fatal("No negative f factors") + } +} + +func TestUpdateFactorInitialValues(t *testing.T) { + + f0, g0 := updateFactorsDecompose(updateFactorIdentityMatrixRow0) + f1, g1 := updateFactorsDecompose(updateFactorIdentityMatrixRow1) + + if f0 != 1 || g0 != 0 || f1 != 0 || g1 != 1 { + t.Error("Update factor initial value constants are incorrect") + } +} + +func TestUpdateFactorsRandomization(t *testing.T) { + var maxLen int + + //t.Log("|f| + |g| is not to exceed", 1 << 31) + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + lf, lg := abs64T32(f), abs64T32(g) + absSum := lf + lg + if absSum >= 1<<31 { + + if absSum == 1<<31 { + maxLen++ + } else { + t.Error(i, "Sum of absolute values too large, f =", f, ",g =", g, ",|f| + |g| =", absSum) + } + } + } + + if maxLen == 0 { + t.Error("max len not observed") + } else { + t.Log(maxLen, "maxLens observed") + } +} + +func randomizeUpdateFactor(absLimit uint32) int64 { + const maxSizeLikelihood = 10 + maxSize := mrand.Intn(maxSizeLikelihood) + + absLimit64 := int64(absLimit) + var f int64 + switch maxSize { + case 0: + f = absLimit64 + case 1: + f = -absLimit64 + default: + f = int64(mrand.Uint64()%(2*uint64(absLimit64)+1)) - absLimit64 + } + + if f > 1<<31 { + return 1 << 31 + } else if f < -1<<31+1 { + return -1<<31 + 1 + } + + return f +} + +func abs64T32(f int64) uint32 { + if f >= 1<<32 || f < -1<<32 { + panic("f out of range") + } + + if f < 0 { + return uint32(-f) + } + return uint32(f) +} + +func randomizeUpdateFactors() (int64, int64) { + var f [2]int64 + b := mrand.Int() % 2 + + f[b] = randomizeUpdateFactor(1 << 31) + + //As per the paper, |f| + |g| \le 2³¹. + f[1-b] = randomizeUpdateFactor(1<<31 - abs64T32(f[b])) + + //Patching another edge case + if f[0]+f[1] == -1<<31 { + b = mrand.Int() % 2 + f[b]++ + } + + return f[0], f[1] +} + +func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { + + var p1 big.Int + x.ToBigInt(&p1) + p1.Mul(&p1, big.NewInt(xC)) + + var p2 big.Int + y.ToBigInt(&p2) + p2.Mul(&p2, big.NewInt(yC)) + + p1.Add(&p1, &p2) + p1.Mod(&p1, Modulus()) + montReduce(&p1, &p1) + + var z Element + z.linearCombSosSigned(x, xC, y, yC) + z.assertMatchVeryBigInt(t, 0, &p1) +} + +func testBigNumWMul(t *testing.T, a *Element, c int64) { + var aHi uint64 + var aTimes Element + aHi = aTimes.mulWRegular(a, c) + + assertMulProduct(t, a, c, &aTimes, aHi) +} + +func testMontReduceSigned(t *testing.T, x *Element, xHi uint64) { + var res Element + var xInt big.Int + var resInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + res.montReduceSigned(x, xHi) + montReduce(&resInt, &xInt) + res.assertMatchVeryBigInt(t, 0, &resInt) +} + +func updateFactorsCompose(f int64, g int64) int64 { + return f + g<<32 +} + +var rInv big.Int + +func montReduce(res *big.Int, x *big.Int) { + if rInv.BitLen() == 0 { // initialization + rInv.SetUint64(1) + rInv.Lsh(&rInv, Limbs*64) + rInv.ModInverse(&rInv, Modulus()) + } + res.Mul(x, &rInv) + res.Mod(res, Modulus()) +} + +func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { + z.ToBigInt(i) + var upperWord big.Int + upperWord.SetUint64(xHi) + upperWord.Lsh(&upperWord, Limbs*64) + i.Add(&upperWord, i) +} + +func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { + z.toVeryBigIntUnsigned(i, xHi) + if signBitSelector&xHi != 0 { + twosCompModulus := big.NewInt(1) + twosCompModulus.Lsh(twosCompModulus, (Limbs+1)*64) + i.Sub(i, twosCompModulus) + } +} + +func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { + var xInt big.Int + x.ToBigInt(&xInt) + + xInt.Mul(&xInt, big.NewInt(c)) + + result.assertMatchVeryBigInt(t, resultHi, &xInt) + return xInt +} + +func assertMatch(t *testing.T, w []big.Word, a uint64, index int) { + + var wI big.Word + + if index < len(w) { + wI = w[index] + } + + const filter uint64 = 0xFFFFFFFFFFFFFFFF >> (64 - bits.UintSize) + + a = a >> ((index * bits.UintSize) % 64) + a &= filter + + if uint64(wI) != a { + t.Error("Bignum mismatch: disagreement on word", index) + } +} + +func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { + + var modulus big.Int + var aIntMod big.Int + modulus.SetInt64(1) + modulus.Lsh(&modulus, (Limbs+1)*64) + aIntMod.Mod(aInt, &modulus) + + words := aIntMod.Bits() + + const steps = 64 / bits.UintSize + for i := 0; i < Limbs*steps; i++ { + assertMatch(t, words, z[i/steps], i) + } + + for i := 0; i < steps; i++ { + assertMatch(t, words, aHi, Limbs*steps+i) + } +} + +func approximateRef(x *Element) uint64 { + + var asInt big.Int + x.ToBigInt(&asInt) + n := x.BitLen() + + if n <= 64 { + return asInt.Uint64() + } + + modulus := big.NewInt(1 << 31) + var lo big.Int + lo.Mod(&asInt, modulus) + + modulus.Lsh(modulus, uint(n-64)) + var hi big.Int + hi.Div(&asInt, modulus) + hi.Lsh(&hi, 31) + + hi.Add(&hi, &lo) + return hi.Uint64() +} diff --git a/ecc/bls12-381/fp/element.go b/ecc/bls12-381/fp/element.go index 1ec8b3b506..edf8e18411 100644 --- a/ecc/bls12-381/fp/element.go +++ b/ecc/bls12-381/fp/element.go @@ -63,15 +63,25 @@ func Modulus() *big.Int { } // q (modulus) +const qElementWord0 uint64 = 13402431016077863595 +const qElementWord1 uint64 = 2210141511517208575 +const qElementWord2 uint64 = 7435674573564081700 +const qElementWord3 uint64 = 7239337960414712511 +const qElementWord4 uint64 = 5412103778470702295 +const qElementWord5 uint64 = 1873798617647539866 + var qElement = Element{ - 13402431016077863595, - 2210141511517208575, - 7435674573564081700, - 7239337960414712511, - 5412103778470702295, - 1873798617647539866, + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, + qElementWord4, + qElementWord5, } +// Used for Montgomery reduction. (qInvNeg) q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +const qInvNegLsw uint64 = 9940570264628428797 + // rSquare var rSquare = Element{ 17644856173732828998, @@ -197,7 +207,7 @@ func (z *Element) IsZero() bool { return (z[5] | z[4] | z[3] | z[2] | z[1] | z[0]) == 0 } -// IsUint64 returns true if z[0] >= 0 and all other words are 0 +// IsUint64 returns true if z[0] ⩾ 0 and all other words are 0 func (z *Element) IsUint64() bool { return (z[5] | z[4] | z[3] | z[2] | z[1]) == 0 } @@ -281,7 +291,7 @@ func (z *Element) SetRandom() (*Element, error) { z[5] = binary.BigEndian.Uint64(bytes[40:48]) z[5] %= 1873798617647539866 - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[5] < 1873798617647539866 || (z[5] == 1873798617647539866 && (z[4] < 5412103778470702295 || (z[4] == 5412103778470702295 && (z[3] < 7239337960414712511 || (z[3] == 7239337960414712511 && (z[2] < 7435674573564081700 || (z[2] == 7435674573564081700 && (z[1] < 2210141511517208575 || (z[1] == 2210141511517208575 && (z[0] < 13402431016077863595))))))))))) { var b uint64 @@ -485,7 +495,90 @@ func _mulGeneric(z, x, y *Element) { z[5], z[4] = madd3(m, 1873798617647539866, c[0], c[2], c[1]) } - // if z > q --> z -= q + // if z > q → z -= q + // note: this is NOT constant time + if !(z[5] < 1873798617647539866 || (z[5] == 1873798617647539866 && (z[4] < 5412103778470702295 || (z[4] == 5412103778470702295 && (z[3] < 7239337960414712511 || (z[3] == 7239337960414712511 && (z[2] < 7435674573564081700 || (z[2] == 7435674573564081700 && (z[1] < 2210141511517208575 || (z[1] == 2210141511517208575 && (z[0] < 13402431016077863595))))))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 13402431016077863595, 0) + z[1], b = bits.Sub64(z[1], 2210141511517208575, b) + z[2], b = bits.Sub64(z[2], 7435674573564081700, b) + z[3], b = bits.Sub64(z[3], 7239337960414712511, b) + z[4], b = bits.Sub64(z[4], 5412103778470702295, b) + z[5], _ = bits.Sub64(z[5], 1873798617647539866, b) + } +} + +func _mulWGeneric(z, x *Element, y uint64) { + + var t [6]uint64 + { + // round 0 + c1, c0 := bits.Mul64(y, x[0]) + m := c0 * 9940570264628428797 + c2 := madd0(m, 13402431016077863595, c0) + c1, c0 = madd1(y, x[1], c1) + c2, t[0] = madd2(m, 2210141511517208575, c2, c0) + c1, c0 = madd1(y, x[2], c1) + c2, t[1] = madd2(m, 7435674573564081700, c2, c0) + c1, c0 = madd1(y, x[3], c1) + c2, t[2] = madd2(m, 7239337960414712511, c2, c0) + c1, c0 = madd1(y, x[4], c1) + c2, t[3] = madd2(m, 5412103778470702295, c2, c0) + c1, c0 = madd1(y, x[5], c1) + t[5], t[4] = madd3(m, 1873798617647539866, c0, c2, c1) + } + { + // round 1 + m := t[0] * 9940570264628428797 + c2 := madd0(m, 13402431016077863595, t[0]) + c2, t[0] = madd2(m, 2210141511517208575, c2, t[1]) + c2, t[1] = madd2(m, 7435674573564081700, c2, t[2]) + c2, t[2] = madd2(m, 7239337960414712511, c2, t[3]) + c2, t[3] = madd2(m, 5412103778470702295, c2, t[4]) + t[5], t[4] = madd2(m, 1873798617647539866, t[5], c2) + } + { + // round 2 + m := t[0] * 9940570264628428797 + c2 := madd0(m, 13402431016077863595, t[0]) + c2, t[0] = madd2(m, 2210141511517208575, c2, t[1]) + c2, t[1] = madd2(m, 7435674573564081700, c2, t[2]) + c2, t[2] = madd2(m, 7239337960414712511, c2, t[3]) + c2, t[3] = madd2(m, 5412103778470702295, c2, t[4]) + t[5], t[4] = madd2(m, 1873798617647539866, t[5], c2) + } + { + // round 3 + m := t[0] * 9940570264628428797 + c2 := madd0(m, 13402431016077863595, t[0]) + c2, t[0] = madd2(m, 2210141511517208575, c2, t[1]) + c2, t[1] = madd2(m, 7435674573564081700, c2, t[2]) + c2, t[2] = madd2(m, 7239337960414712511, c2, t[3]) + c2, t[3] = madd2(m, 5412103778470702295, c2, t[4]) + t[5], t[4] = madd2(m, 1873798617647539866, t[5], c2) + } + { + // round 4 + m := t[0] * 9940570264628428797 + c2 := madd0(m, 13402431016077863595, t[0]) + c2, t[0] = madd2(m, 2210141511517208575, c2, t[1]) + c2, t[1] = madd2(m, 7435674573564081700, c2, t[2]) + c2, t[2] = madd2(m, 7239337960414712511, c2, t[3]) + c2, t[3] = madd2(m, 5412103778470702295, c2, t[4]) + t[5], t[4] = madd2(m, 1873798617647539866, t[5], c2) + } + { + // round 5 + m := t[0] * 9940570264628428797 + c2 := madd0(m, 13402431016077863595, t[0]) + c2, z[0] = madd2(m, 2210141511517208575, c2, t[1]) + c2, z[1] = madd2(m, 7435674573564081700, c2, t[2]) + c2, z[2] = madd2(m, 7239337960414712511, c2, t[3]) + c2, z[3] = madd2(m, 5412103778470702295, c2, t[4]) + z[5], z[4] = madd2(m, 1873798617647539866, t[5], c2) + } + + // if z > q → z -= q // note: this is NOT constant time if !(z[5] < 1873798617647539866 || (z[5] == 1873798617647539866 && (z[4] < 5412103778470702295 || (z[4] == 5412103778470702295 && (z[3] < 7239337960414712511 || (z[3] == 7239337960414712511 && (z[2] < 7435674573564081700 || (z[2] == 7435674573564081700 && (z[1] < 2210141511517208575 || (z[1] == 2210141511517208575 && (z[0] < 13402431016077863595))))))))))) { var b uint64 @@ -568,7 +661,7 @@ func _fromMontGeneric(z *Element) { z[5] = C } - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[5] < 1873798617647539866 || (z[5] == 1873798617647539866 && (z[4] < 5412103778470702295 || (z[4] == 5412103778470702295 && (z[3] < 7239337960414712511 || (z[3] == 7239337960414712511 && (z[2] < 7435674573564081700 || (z[2] == 7435674573564081700 && (z[1] < 2210141511517208575 || (z[1] == 2210141511517208575 && (z[0] < 13402431016077863595))))))))))) { var b uint64 @@ -591,7 +684,7 @@ func _addGeneric(z, x, y *Element) { z[4], carry = bits.Add64(x[4], y[4], carry) z[5], _ = bits.Add64(x[5], y[5], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[5] < 1873798617647539866 || (z[5] == 1873798617647539866 && (z[4] < 5412103778470702295 || (z[4] == 5412103778470702295 && (z[3] < 7239337960414712511 || (z[3] == 7239337960414712511 && (z[2] < 7435674573564081700 || (z[2] == 7435674573564081700 && (z[1] < 2210141511517208575 || (z[1] == 2210141511517208575 && (z[0] < 13402431016077863595))))))))))) { var b uint64 @@ -614,7 +707,7 @@ func _doubleGeneric(z, x *Element) { z[4], carry = bits.Add64(x[4], x[4], carry) z[5], _ = bits.Add64(x[5], x[5], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[5] < 1873798617647539866 || (z[5] == 1873798617647539866 && (z[4] < 5412103778470702295 || (z[4] == 5412103778470702295 && (z[3] < 7239337960414712511 || (z[3] == 7239337960414712511 && (z[2] < 7435674573564081700 || (z[2] == 7435674573564081700 && (z[1] < 2210141511517208575 || (z[1] == 2210141511517208575 && (z[0] < 13402431016077863595))))))))))) { var b uint64 @@ -662,7 +755,7 @@ func _negGeneric(z, x *Element) { func _reduceGeneric(z *Element) { - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[5] < 1873798617647539866 || (z[5] == 1873798617647539866 && (z[4] < 5412103778470702295 || (z[4] == 5412103778470702295 && (z[3] < 7239337960414712511 || (z[3] == 7239337960414712511 && (z[2] < 7435674573564081700 || (z[2] == 7435674573564081700 && (z[1] < 2210141511517208575 || (z[1] == 2210141511517208575 && (z[0] < 13402431016077863595))))))))))) { var b uint64 @@ -778,7 +871,7 @@ func (z *Element) Exp(x Element, exponent *big.Int) *Element { } // ToMont converts z to Montgomery form -// sets and returns z = z * r^2 +// sets and returns z = z * r² func (z *Element) ToMont() *Element { return z.Mul(z, &rSquare) } @@ -912,7 +1005,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { return z } -// setBigInt assumes 0 <= v < q +// setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() @@ -1046,181 +1139,496 @@ func (z *Element) Sqrt(x *Element) *Element { return nil } -// Inverse z = x^-1 mod q -// Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" -// if x == 0, sets and returns z = x +func max(a int, b int) int { + if a > b { + return a + } + return b +} + +func min(a int, b int) int { + if a < b { + return a + } + return b +} + +const updateFactorsConversionBias int64 = 0x7fffffff7fffffff // (2³¹ - 1)(2³² + 1) +const updateFactorIdentityMatrixRow0 = 1 +const updateFactorIdentityMatrixRow1 = 1 << 32 + +func updateFactorsDecompose(c int64) (int64, int64) { + c += updateFactorsConversionBias + const low32BitsFilter int64 = 0xFFFFFFFF + f := c&low32BitsFilter - 0x7FFFFFFF + g := c>>32&low32BitsFilter - 0x7FFFFFFF + return f, g +} + +const k = 32 // word size / 2 +const signBitSelector = uint64(1) << 63 +const approxLowBitsN = k - 1 +const approxHighBitsN = k + 1 +const inversionCorrectionFactorWord0 = 8737414717120368535 +const inversionCorrectionFactorWord1 = 10094300570241649429 +const inversionCorrectionFactorWord2 = 6339946188669102923 +const inversionCorrectionFactorWord3 = 10492640117780001228 +const inversionCorrectionFactorWord4 = 12201317704601795701 +const inversionCorrectionFactorWord5 = 1158882751927031822 + +const invIterationsN = 26 + +// Inverse z = x⁻¹ mod q +// Implements "Optimized Binary GCD for Modular Inversion" +// https://github.com/pornin/bingcd/blob/main/doc/bingcd.pdf func (z *Element) Inverse(x *Element) *Element { if x.IsZero() { z.SetZero() return z } - // initialize u = q - var u = Element{ - 13402431016077863595, - 2210141511517208575, - 7435674573564081700, - 7239337960414712511, - 5412103778470702295, - 1873798617647539866, + a := *x + b := Element{ + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, + qElementWord4, + qElementWord5, + } // b := q + + u := Element{1} + + // Update factors: we get [u; v]:= [f0 g0; f1 g1] [u; v] + // c_i = f_i + 2³¹ - 1 + 2³² * (g_i + 2³¹ - 1) + var c0, c1 int64 + + // Saved update factors to reduce the number of field multiplications + var pf0, pf1, pg0, pg1 int64 + + var i uint + + var v, s Element + + // Since u,v are updated every other iteration, we must make sure we terminate after evenly many iterations + // This also lets us get away with half as many updates to u,v + // To make this constant-time-ish, replace the condition with i < invIterationsN + for i = 0; i&1 == 1 || !a.IsZero(); i++ { + n := max(a.BitLen(), b.BitLen()) + aApprox, bApprox := approximate(&a, n), approximate(&b, n) + + // After 0 iterations, we have f₀ ≤ 2⁰ and f₁ < 2⁰ + // f0, g0, f1, g1 = 1, 0, 0, 1 + c0, c1 = updateFactorIdentityMatrixRow0, updateFactorIdentityMatrixRow1 + + for j := 0; j < approxLowBitsN; j++ { + + if aApprox&1 == 0 { + aApprox /= 2 + } else { + s, borrow := bits.Sub64(aApprox, bApprox, 0) + if borrow == 1 { + s = bApprox - aApprox + bApprox = aApprox + c0, c1 = c1, c0 + } + + aApprox = s / 2 + c0 = c0 - c1 + + // Now |f₀| < 2ʲ + 2ʲ = 2ʲ⁺¹ + // |f₁| ≤ 2ʲ still + } + + c1 *= 2 + // |f₁| ≤ 2ʲ⁺¹ + } + + s = a + + var g0 int64 + // from this point on c0 aliases for f0 + c0, g0 = updateFactorsDecompose(c0) + aHi := a.linearCombNonModular(&s, c0, &b, g0) + if aHi&signBitSelector != 0 { + // if aHi < 0 + c0, g0 = -c0, -g0 + aHi = a.neg(&a, aHi) + } + // right-shift a by k-1 bits + a[0] = (a[0] >> approxLowBitsN) | ((a[1]) << approxHighBitsN) + a[1] = (a[1] >> approxLowBitsN) | ((a[2]) << approxHighBitsN) + a[2] = (a[2] >> approxLowBitsN) | ((a[3]) << approxHighBitsN) + a[3] = (a[3] >> approxLowBitsN) | ((a[4]) << approxHighBitsN) + a[4] = (a[4] >> approxLowBitsN) | ((a[5]) << approxHighBitsN) + a[5] = (a[5] >> approxLowBitsN) | (aHi << approxHighBitsN) + + var f1 int64 + // from this point on c1 aliases for g0 + f1, c1 = updateFactorsDecompose(c1) + bHi := b.linearCombNonModular(&s, f1, &b, c1) + if bHi&signBitSelector != 0 { + // if bHi < 0 + f1, c1 = -f1, -c1 + bHi = b.neg(&b, bHi) + } + // right-shift b by k-1 bits + b[0] = (b[0] >> approxLowBitsN) | ((b[1]) << approxHighBitsN) + b[1] = (b[1] >> approxLowBitsN) | ((b[2]) << approxHighBitsN) + b[2] = (b[2] >> approxLowBitsN) | ((b[3]) << approxHighBitsN) + b[3] = (b[3] >> approxLowBitsN) | ((b[4]) << approxHighBitsN) + b[4] = (b[4] >> approxLowBitsN) | ((b[5]) << approxHighBitsN) + b[5] = (b[5] >> approxLowBitsN) | (bHi << approxHighBitsN) + + if i&1 == 1 { + // Combine current update factors with previously stored ones + // [f₀, g₀; f₁, g₁] ← [f₀, g₀; f₁, g₀] [pf₀, pg₀; pf₀, pg₀] + // We have |f₀|, |g₀|, |pf₀|, |pf₁| ≤ 2ᵏ⁻¹, and that |pf_i| < 2ᵏ⁻¹ for i ∈ {0, 1} + // Then for the new value we get |f₀| < 2ᵏ⁻¹ × 2ᵏ⁻¹ + 2ᵏ⁻¹ × 2ᵏ⁻¹ = 2²ᵏ⁻¹ + // Which leaves us with an extra bit for the sign + + // c0 aliases f0, c1 aliases g1 + c0, g0, f1, c1 = c0*pf0+g0*pf1, + c0*pg0+g0*pg1, + f1*pf0+c1*pf1, + f1*pg0+c1*pg1 + + s = u + u.linearCombSosSigned(&u, c0, &v, g0) + v.linearCombSosSigned(&s, f1, &v, c1) + + } else { + // Save update factors + pf0, pg0, pf1, pg1 = c0, g0, f1, c1 + } } - // initialize s = r^2 - var s = Element{ - 17644856173732828998, - 754043588434789617, - 10224657059481499349, - 7488229067341005760, - 11130996698012816685, - 1267921511277847466, + // For every iteration that we miss, v is not being multiplied by 2²ᵏ⁻² + const pSq int64 = 1 << (2 * (k - 1)) + // If the function is constant-time ish, this loop will not run (probably no need to take it out explicitly) + for ; i < invIterationsN; i += 2 { + v.mulWSigned(&v, pSq) } - // r = 0 - r := Element{} + z.Mul(&v, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + inversionCorrectionFactorWord5, + }) + return z +} - v := *x +// approximate a big number x into a single 64 bit word using its uppermost and lowermost bits +// if x fits in a word as is, no approximation necessary +func approximate(x *Element, nBits int) uint64 { - var carry, borrow uint64 - var bigger bool + if nBits <= 64 { + return x[0] + } - for { - for v[0]&1 == 0 { + const mask = (uint64(1) << (k - 1)) - 1 // k-1 ones + lo := mask & x[0] - // v = v >> 1 + hiWordIndex := (nBits - 1) / 64 - v[0] = v[0]>>1 | v[1]<<63 - v[1] = v[1]>>1 | v[2]<<63 - v[2] = v[2]>>1 | v[3]<<63 - v[3] = v[3]>>1 | v[4]<<63 - v[4] = v[4]>>1 | v[5]<<63 - v[5] >>= 1 + hiWordBitsAvailable := nBits - hiWordIndex*64 + hiWordBitsUsed := min(hiWordBitsAvailable, approxHighBitsN) - if s[0]&1 == 1 { + mask_ := uint64(^((1 << (hiWordBitsAvailable - hiWordBitsUsed)) - 1)) + hi := (x[hiWordIndex] & mask_) << (64 - hiWordBitsAvailable) - // s = s + q - s[0], carry = bits.Add64(s[0], 13402431016077863595, 0) - s[1], carry = bits.Add64(s[1], 2210141511517208575, carry) - s[2], carry = bits.Add64(s[2], 7435674573564081700, carry) - s[3], carry = bits.Add64(s[3], 7239337960414712511, carry) - s[4], carry = bits.Add64(s[4], 5412103778470702295, carry) - s[5], _ = bits.Add64(s[5], 1873798617647539866, carry) + mask_ = ^(1<<(approxLowBitsN+hiWordBitsUsed) - 1) + mid := (mask_ & x[hiWordIndex-1]) >> hiWordBitsUsed - } + return lo | mid | hi +} - // s = s >> 1 +func (z *Element) linearCombSosSigned(x *Element, xC int64, y *Element, yC int64) { + hi := z.linearCombNonModular(x, xC, y, yC) + z.montReduceSigned(z, hi) +} + +// montReduceSigned SOS algorithm; xHi must be at most 63 bits long. Last bit of xHi may be used as a sign bit +func (z *Element) montReduceSigned(x *Element, xHi uint64) { + + const signBitRemover = ^signBitSelector + neg := xHi&signBitSelector != 0 + // the SOS implementation requires that most significant bit is 0 + // Let X be xHi*r + x + // note that if X is negative we would have initially stored it as 2⁶⁴ r + X + xHi &= signBitRemover + // with this a negative X is now represented as 2⁶³ r + X + + var t [2*Limbs - 1]uint64 + var C uint64 + + m := x[0] * qInvNegLsw + + C = madd0(m, qElementWord0, x[0]) + C, t[1] = madd2(m, qElementWord1, x[1], C) + C, t[2] = madd2(m, qElementWord2, x[2], C) + C, t[3] = madd2(m, qElementWord3, x[3], C) + C, t[4] = madd2(m, qElementWord4, x[4], C) + C, t[5] = madd2(m, qElementWord5, x[5], C) + + // the high word of m * qElement[5] is at most 62 bits + // x[5] + C is at most 65 bits (high word at most 1 bit) + // Thus the resulting C will be at most 63 bits + t[6] = xHi + C + // xHi and C are 63 bits, therefore no overflow + + { + const i = 1 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + + t[i+Limbs] += C + } + { + const i = 2 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + + t[i+Limbs] += C + } + { + const i = 3 + m = t[i] * qInvNegLsw - s[0] = s[0]>>1 | s[1]<<63 - s[1] = s[1]>>1 | s[2]<<63 - s[2] = s[2]>>1 | s[3]<<63 - s[3] = s[3]>>1 | s[4]<<63 - s[4] = s[4]>>1 | s[5]<<63 - s[5] >>= 1 + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + t[i+Limbs] += C + } + { + const i = 4 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + + t[i+Limbs] += C + } + { + const i = 5 + m := t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, z[0] = madd2(m, qElementWord1, t[i+1], C) + C, z[1] = madd2(m, qElementWord2, t[i+2], C) + C, z[2] = madd2(m, qElementWord3, t[i+3], C) + C, z[3] = madd2(m, qElementWord4, t[i+4], C) + z[5], z[4] = madd2(m, qElementWord5, t[i+5], C) + } + + // if z > q → z -= q + // note: this is NOT constant time + if !(z[5] < 1873798617647539866 || (z[5] == 1873798617647539866 && (z[4] < 5412103778470702295 || (z[4] == 5412103778470702295 && (z[3] < 7239337960414712511 || (z[3] == 7239337960414712511 && (z[2] < 7435674573564081700 || (z[2] == 7435674573564081700 && (z[1] < 2210141511517208575 || (z[1] == 2210141511517208575 && (z[0] < 13402431016077863595))))))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 13402431016077863595, 0) + z[1], b = bits.Sub64(z[1], 2210141511517208575, b) + z[2], b = bits.Sub64(z[2], 7435674573564081700, b) + z[3], b = bits.Sub64(z[3], 7239337960414712511, b) + z[4], b = bits.Sub64(z[4], 5412103778470702295, b) + z[5], _ = bits.Sub64(z[5], 1873798617647539866, b) + } + if neg { + // We have computed ( 2⁶³ r + X ) r⁻¹ = 2⁶³ + X r⁻¹ instead + var b uint64 + z[0], b = bits.Sub64(z[0], signBitSelector, 0) + z[1], b = bits.Sub64(z[1], 0, b) + z[2], b = bits.Sub64(z[2], 0, b) + z[3], b = bits.Sub64(z[3], 0, b) + z[4], b = bits.Sub64(z[4], 0, b) + z[5], b = bits.Sub64(z[5], 0, b) + + // Occurs iff x == 0 && xHi < 0, i.e. X = rX' for -2⁶³ ≤ X' < 0 + if b != 0 { + // z[5] = -1 + // negative: add q + const neg1 = 0xFFFFFFFFFFFFFFFF + + b = 0 + z[0], b = bits.Add64(z[0], qElementWord0, b) + z[1], b = bits.Add64(z[1], qElementWord1, b) + z[2], b = bits.Add64(z[2], qElementWord2, b) + z[3], b = bits.Add64(z[3], qElementWord3, b) + z[4], b = bits.Add64(z[4], qElementWord4, b) + z[5], _ = bits.Add64(neg1, qElementWord5, b) } - for u[0]&1 == 0 { + } +} - // u = u >> 1 +// mulWSigned mul word signed (w/ montgomery reduction) +func (z *Element) mulWSigned(x *Element, y int64) { + m := y >> 63 + _mulWGeneric(z, x, uint64((y^m)-m)) + // multiply by abs(y) + if y < 0 { + z.Neg(z) + } +} - u[0] = u[0]>>1 | u[1]<<63 - u[1] = u[1]>>1 | u[2]<<63 - u[2] = u[2]>>1 | u[3]<<63 - u[3] = u[3]>>1 | u[4]<<63 - u[4] = u[4]>>1 | u[5]<<63 - u[5] >>= 1 +func (z *Element) neg(x *Element, xHi uint64) uint64 { + var b uint64 - if r[0]&1 == 1 { + z[0], b = bits.Sub64(0, x[0], 0) + z[1], b = bits.Sub64(0, x[1], b) + z[2], b = bits.Sub64(0, x[2], b) + z[3], b = bits.Sub64(0, x[3], b) + z[4], b = bits.Sub64(0, x[4], b) + z[5], b = bits.Sub64(0, x[5], b) + xHi, _ = bits.Sub64(0, xHi, b) - // r = r + q - r[0], carry = bits.Add64(r[0], 13402431016077863595, 0) - r[1], carry = bits.Add64(r[1], 2210141511517208575, carry) - r[2], carry = bits.Add64(r[2], 7435674573564081700, carry) - r[3], carry = bits.Add64(r[3], 7239337960414712511, carry) - r[4], carry = bits.Add64(r[4], 5412103778470702295, carry) - r[5], _ = bits.Add64(r[5], 1873798617647539866, carry) + return xHi +} - } +// regular multiplication by one word regular (non montgomery) +// Fewer additions than the branch-free for positive y. Could be faster on some architectures +func (z *Element) mulWRegular(x *Element, y int64) uint64 { + + // w := abs(y) + m := y >> 63 + w := uint64((y ^ m) - m) + + var c uint64 + c, z[0] = bits.Mul64(x[0], w) + c, z[1] = madd1(x[1], w, c) + c, z[2] = madd1(x[2], w, c) + c, z[3] = madd1(x[3], w, c) + c, z[4] = madd1(x[4], w, c) + c, z[5] = madd1(x[5], w, c) + + if y < 0 { + c = z.neg(z, c) + } + + return c +} + +/* +Removed: seems slower +// mulWRegular branch-free regular multiplication by one word (non montgomery) +func (z *Element) mulWRegularBf(x *Element, y int64) uint64 { + + w := uint64(y) + allNeg := uint64(y >> 63) // -1 if y < 0, 0 o.w - // r = r >> 1 + // s[0], s[1] so results are not stored immediately in z. + // x[i] will be needed in the i+1 th iteration. We don't want to overwrite it in case x = z + var s [2]uint64 + var h [2]uint64 - r[0] = r[0]>>1 | r[1]<<63 - r[1] = r[1]>>1 | r[2]<<63 - r[2] = r[2]>>1 | r[3]<<63 - r[3] = r[3]>>1 | r[4]<<63 - r[4] = r[4]>>1 | r[5]<<63 - r[5] >>= 1 + h[0], s[0] = bits.Mul64(x[0], w) + c := uint64(0) + b := uint64(0) + + { + const curI = 1 % 2 + const prevI = 1 - curI + const iMinusOne = 1 - 1 + + h[curI], s[curI] = bits.Mul64(x[1], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - // v >= u - bigger = !(v[5] < u[5] || (v[5] == u[5] && (v[4] < u[4] || (v[4] == u[4] && (v[3] < u[3] || (v[3] == u[3] && (v[2] < u[2] || (v[2] == u[2] && (v[1] < u[1] || (v[1] == u[1] && (v[0] < u[0]))))))))))) - - if bigger { - - // v = v - u - v[0], borrow = bits.Sub64(v[0], u[0], 0) - v[1], borrow = bits.Sub64(v[1], u[1], borrow) - v[2], borrow = bits.Sub64(v[2], u[2], borrow) - v[3], borrow = bits.Sub64(v[3], u[3], borrow) - v[4], borrow = bits.Sub64(v[4], u[4], borrow) - v[5], _ = bits.Sub64(v[5], u[5], borrow) - - // s = s - r - s[0], borrow = bits.Sub64(s[0], r[0], 0) - s[1], borrow = bits.Sub64(s[1], r[1], borrow) - s[2], borrow = bits.Sub64(s[2], r[2], borrow) - s[3], borrow = bits.Sub64(s[3], r[3], borrow) - s[4], borrow = bits.Sub64(s[4], r[4], borrow) - s[5], borrow = bits.Sub64(s[5], r[5], borrow) - - if borrow == 1 { - - // s = s + q - s[0], carry = bits.Add64(s[0], 13402431016077863595, 0) - s[1], carry = bits.Add64(s[1], 2210141511517208575, carry) - s[2], carry = bits.Add64(s[2], 7435674573564081700, carry) - s[3], carry = bits.Add64(s[3], 7239337960414712511, carry) - s[4], carry = bits.Add64(s[4], 5412103778470702295, carry) - s[5], _ = bits.Add64(s[5], 1873798617647539866, carry) + { + const curI = 2 % 2 + const prevI = 1 - curI + const iMinusOne = 2 - 1 - } - } else { + h[curI], s[curI] = bits.Mul64(x[2], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } - // u = u - v - u[0], borrow = bits.Sub64(u[0], v[0], 0) - u[1], borrow = bits.Sub64(u[1], v[1], borrow) - u[2], borrow = bits.Sub64(u[2], v[2], borrow) - u[3], borrow = bits.Sub64(u[3], v[3], borrow) - u[4], borrow = bits.Sub64(u[4], v[4], borrow) - u[5], _ = bits.Sub64(u[5], v[5], borrow) - - // r = r - s - r[0], borrow = bits.Sub64(r[0], s[0], 0) - r[1], borrow = bits.Sub64(r[1], s[1], borrow) - r[2], borrow = bits.Sub64(r[2], s[2], borrow) - r[3], borrow = bits.Sub64(r[3], s[3], borrow) - r[4], borrow = bits.Sub64(r[4], s[4], borrow) - r[5], borrow = bits.Sub64(r[5], s[5], borrow) - - if borrow == 1 { - - // r = r + q - r[0], carry = bits.Add64(r[0], 13402431016077863595, 0) - r[1], carry = bits.Add64(r[1], 2210141511517208575, carry) - r[2], carry = bits.Add64(r[2], 7435674573564081700, carry) - r[3], carry = bits.Add64(r[3], 7239337960414712511, carry) - r[4], carry = bits.Add64(r[4], 5412103778470702295, carry) - r[5], _ = bits.Add64(r[5], 1873798617647539866, carry) + { + const curI = 3 % 2 + const prevI = 1 - curI + const iMinusOne = 3 - 1 - } + h[curI], s[curI] = bits.Mul64(x[3], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (u[0] == 1) && (u[5]|u[4]|u[3]|u[2]|u[1]) == 0 { - z.Set(&r) - return z + + { + const curI = 4 % 2 + const prevI = 1 - curI + const iMinusOne = 4 - 1 + + h[curI], s[curI] = bits.Mul64(x[4], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (v[0] == 1) && (v[5]|v[4]|v[3]|v[2]|v[1]) == 0 { - z.Set(&s) - return z + + { + const curI = 5 % 2 + const prevI = 1 - curI + const iMinusOne = 5 - 1 + + h[curI], s[curI] = bits.Mul64(x[5], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } + { + const curI = 6 % 2 + const prevI = 1 - curI + const iMinusOne = 5 + + s[curI], _ = bits.Sub64(h[prevI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + + return s[curI] + c } +}*/ + +// Requires NoCarry +func (z *Element) linearCombNonModular(x *Element, xC int64, y *Element, yC int64) uint64 { + var yTimes Element + + yHi := yTimes.mulWRegular(y, yC) + xHi := z.mulWRegular(x, xC) + + carry := uint64(0) + z[0], carry = bits.Add64(z[0], yTimes[0], carry) + z[1], carry = bits.Add64(z[1], yTimes[1], carry) + z[2], carry = bits.Add64(z[2], yTimes[2], carry) + z[3], carry = bits.Add64(z[3], yTimes[3], carry) + z[4], carry = bits.Add64(z[4], yTimes[4], carry) + z[5], carry = bits.Add64(z[5], yTimes[5], carry) + + yHi, _ = bits.Add64(xHi, yHi, carry) + return yHi } diff --git a/ecc/bls12-381/fp/element_test.go b/ecc/bls12-381/fp/element_test.go index dffe1800b5..b87e1f9303 100644 --- a/ecc/bls12-381/fp/element_test.go +++ b/ecc/bls12-381/fp/element_test.go @@ -22,6 +22,7 @@ import ( "fmt" "math/big" "math/bits" + mrand "math/rand" "testing" "github.com/leanovate/gopter" @@ -275,7 +276,7 @@ var staticTestValues []Element func init() { staticTestValues = append(staticTestValues, Element{}) // zero staticTestValues = append(staticTestValues, One()) // one - staticTestValues = append(staticTestValues, rSquare) // r^2 + staticTestValues = append(staticTestValues, rSquare) // r² var e, one Element one.SetOne() e.Sub(&qElement, &one) @@ -1990,3 +1991,504 @@ func genFull() gopter.Gen { return genResult } } + +func TestElementInversionApproximation(t *testing.T) { + var x Element + for i := 0; i < 1000; i++ { + x.SetRandom() + + // Normally small elements are unlikely. Here we give them a higher chance + xZeros := mrand.Int() % Limbs + for j := 1; j < xZeros; j++ { + x[Limbs-j] = 0 + } + + a := approximate(&x, x.BitLen()) + aRef := approximateRef(&x) + + if a != aRef { + t.Error("Approximation mismatch") + } + } +} + +func TestElementInversionCorrectionFactorFormula(t *testing.T) { + const kLimbs = k * Limbs + const power = kLimbs*6 + invIterationsN*(kLimbs-k+1) + factorInt := big.NewInt(1) + factorInt.Lsh(factorInt, power) + factorInt.Mod(factorInt, Modulus()) + + var refFactorInt big.Int + inversionCorrectionFactor := Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + inversionCorrectionFactorWord5, + } + inversionCorrectionFactor.ToBigInt(&refFactorInt) + + if refFactorInt.Cmp(factorInt) != 0 { + t.Error("mismatch") + } +} + +func TestElementLinearComb(t *testing.T) { + var x Element + var y Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + y.SetRandom() + testLinearComb(t, &x, mrand.Int63(), &y, mrand.Int63()) + } +} + +// Probably unnecessary post-dev. In case the output of inv is wrong, this checks whether it's only off by a constant factor. +func TestElementInversionCorrectionFactor(t *testing.T) { + + // (1/x)/inv(x) = (1/1)/inv(1) ⇔ inv(1) = x inv(x) + + var one Element + var oneInv Element + one.SetOne() + oneInv.Inverse(&one) + + for i := 0; i < 100; i++ { + var x Element + var xInv Element + x.SetRandom() + xInv.Inverse(&x) + + x.Mul(&x, &xInv) + if !x.Equal(&oneInv) { + t.Error("Correction factor is inconsistent") + } + } + + if !oneInv.Equal(&one) { + var i big.Int + oneInv.ToBigIntRegular(&i) // no montgomery + i.ModInverse(&i, Modulus()) + var fac Element + fac.setBigInt(&i) // back to montgomery + + var facTimesFac Element + facTimesFac.Mul(&fac, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + inversionCorrectionFactorWord5, + }) + + t.Error("Correction factor is consistently off by", fac, "Should be", facTimesFac) + } +} + +func TestElementBigNumNeg(t *testing.T) { + var a Element + aHi := a.neg(&a, 0) + if !a.IsZero() || aHi != 0 { + t.Error("-0 != 0") + } +} + +func TestElementBigNumWMul(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + w := mrand.Int63() + testBigNumWMul(t, &x, w) + } +} + +func TestElementVeryBigIntConversion(t *testing.T) { + xHi := mrand.Uint64() + var x Element + x.SetRandom() + var xInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + x.assertMatchVeryBigInt(t, xHi, &xInt) +} + +func TestElementMontReducePos(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64() & ^signBitSelector) + } +} + +func TestElementMontReduceNeg(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64()|signBitSelector) + } +} + +func TestElementMontNegMultipleOfR(t *testing.T) { + var zero Element + + for i := 0; i < 1000; i++ { + testMontReduceSigned(t, &zero, mrand.Uint64()|signBitSelector) + } +} + +//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +func TestUpdateFactorSubtraction(t *testing.T) { + for i := 0; i < 1000; i++ { + + f0, g0 := randomizeUpdateFactors() + f1, g1 := randomizeUpdateFactors() + + for f0-f1 > 1<<31 || f0-f1 <= -1<<31 { + f1 /= 2 + } + + for g0-g1 > 1<<31 || g0-g1 <= -1<<31 { + g1 /= 2 + } + + c0 := updateFactorsCompose(f0, g0) + c1 := updateFactorsCompose(f1, g1) + + cRes := c0 - c1 + fRes, gRes := updateFactorsDecompose(cRes) + + if fRes != f0-f1 || gRes != g0-g1 { + t.Error(i) + } + } +} + +func TestUpdateFactorsDouble(t *testing.T) { + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f > 1<<30 || f < (-1<<31+1)/2 { + f /= 2 + if g <= 1<<29 && g >= (-1<<31+1)/4 { + g *= 2 //g was kept small on f's account. Now that we're halving f, we can double g + } + } + + if g > 1<<30 || g < (-1<<31+1)/2 { + g /= 2 + + if f <= 1<<29 && f >= (-1<<31+1)/4 { + f *= 2 //f was kept small on g's account. Now that we're halving g, we can double f + } + } + + c := updateFactorsCompose(f, g) + cD := c * 2 + fD, gD := updateFactorsDecompose(cD) + + if fD != 2*f || gD != 2*g { + t.Error(i) + } + } +} + +func TestUpdateFactorsNeg(t *testing.T) { + var fMistake bool + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f == 0x80000000 || g == 0x80000000 { + // Update factors this large can only have been obtained after 31 iterations and will therefore never be negated + // We don't have capacity to store -2³¹ + // Repeat this iteration + i-- + continue + } + + c := updateFactorsCompose(f, g) + nc := -c + nf, ng := updateFactorsDecompose(nc) + fMistake = fMistake || nf != -f + if nf != -f || ng != -g { + t.Errorf("Mismatch iteration #%d:\n%d, %d ->\n %d -> %d ->\n %d, %d\n Inputs in hex: %X, %X", + i, f, g, c, nc, nf, ng, f, g) + } + } + if fMistake { + t.Error("Mistake with f detected") + } else { + t.Log("All good with f") + } +} + +func TestUpdateFactorsNeg0(t *testing.T) { + c := updateFactorsCompose(0, 0) + t.Logf("c(0,0) = %X", c) + cn := -c + + if c != cn { + t.Error("Negation of zero update factors should yield the same result.") + } +} + +func TestUpdateFactorDecomposition(t *testing.T) { + var negSeen bool + + for i := 0; i < 1000; i++ { + + f, g := randomizeUpdateFactors() + + if f <= -(1<<31) || f > 1<<31 { + t.Fatal("f out of range") + } + + negSeen = negSeen || f < 0 + + c := updateFactorsCompose(f, g) + + fBack, gBack := updateFactorsDecompose(c) + + if f != fBack || g != gBack { + t.Errorf("(%d, %d) -> %d -> (%d, %d)\n", f, g, c, fBack, gBack) + } + } + + if !negSeen { + t.Fatal("No negative f factors") + } +} + +func TestUpdateFactorInitialValues(t *testing.T) { + + f0, g0 := updateFactorsDecompose(updateFactorIdentityMatrixRow0) + f1, g1 := updateFactorsDecompose(updateFactorIdentityMatrixRow1) + + if f0 != 1 || g0 != 0 || f1 != 0 || g1 != 1 { + t.Error("Update factor initial value constants are incorrect") + } +} + +func TestUpdateFactorsRandomization(t *testing.T) { + var maxLen int + + //t.Log("|f| + |g| is not to exceed", 1 << 31) + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + lf, lg := abs64T32(f), abs64T32(g) + absSum := lf + lg + if absSum >= 1<<31 { + + if absSum == 1<<31 { + maxLen++ + } else { + t.Error(i, "Sum of absolute values too large, f =", f, ",g =", g, ",|f| + |g| =", absSum) + } + } + } + + if maxLen == 0 { + t.Error("max len not observed") + } else { + t.Log(maxLen, "maxLens observed") + } +} + +func randomizeUpdateFactor(absLimit uint32) int64 { + const maxSizeLikelihood = 10 + maxSize := mrand.Intn(maxSizeLikelihood) + + absLimit64 := int64(absLimit) + var f int64 + switch maxSize { + case 0: + f = absLimit64 + case 1: + f = -absLimit64 + default: + f = int64(mrand.Uint64()%(2*uint64(absLimit64)+1)) - absLimit64 + } + + if f > 1<<31 { + return 1 << 31 + } else if f < -1<<31+1 { + return -1<<31 + 1 + } + + return f +} + +func abs64T32(f int64) uint32 { + if f >= 1<<32 || f < -1<<32 { + panic("f out of range") + } + + if f < 0 { + return uint32(-f) + } + return uint32(f) +} + +func randomizeUpdateFactors() (int64, int64) { + var f [2]int64 + b := mrand.Int() % 2 + + f[b] = randomizeUpdateFactor(1 << 31) + + //As per the paper, |f| + |g| \le 2³¹. + f[1-b] = randomizeUpdateFactor(1<<31 - abs64T32(f[b])) + + //Patching another edge case + if f[0]+f[1] == -1<<31 { + b = mrand.Int() % 2 + f[b]++ + } + + return f[0], f[1] +} + +func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { + + var p1 big.Int + x.ToBigInt(&p1) + p1.Mul(&p1, big.NewInt(xC)) + + var p2 big.Int + y.ToBigInt(&p2) + p2.Mul(&p2, big.NewInt(yC)) + + p1.Add(&p1, &p2) + p1.Mod(&p1, Modulus()) + montReduce(&p1, &p1) + + var z Element + z.linearCombSosSigned(x, xC, y, yC) + z.assertMatchVeryBigInt(t, 0, &p1) +} + +func testBigNumWMul(t *testing.T, a *Element, c int64) { + var aHi uint64 + var aTimes Element + aHi = aTimes.mulWRegular(a, c) + + assertMulProduct(t, a, c, &aTimes, aHi) +} + +func testMontReduceSigned(t *testing.T, x *Element, xHi uint64) { + var res Element + var xInt big.Int + var resInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + res.montReduceSigned(x, xHi) + montReduce(&resInt, &xInt) + res.assertMatchVeryBigInt(t, 0, &resInt) +} + +func updateFactorsCompose(f int64, g int64) int64 { + return f + g<<32 +} + +var rInv big.Int + +func montReduce(res *big.Int, x *big.Int) { + if rInv.BitLen() == 0 { // initialization + rInv.SetUint64(1) + rInv.Lsh(&rInv, Limbs*64) + rInv.ModInverse(&rInv, Modulus()) + } + res.Mul(x, &rInv) + res.Mod(res, Modulus()) +} + +func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { + z.ToBigInt(i) + var upperWord big.Int + upperWord.SetUint64(xHi) + upperWord.Lsh(&upperWord, Limbs*64) + i.Add(&upperWord, i) +} + +func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { + z.toVeryBigIntUnsigned(i, xHi) + if signBitSelector&xHi != 0 { + twosCompModulus := big.NewInt(1) + twosCompModulus.Lsh(twosCompModulus, (Limbs+1)*64) + i.Sub(i, twosCompModulus) + } +} + +func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { + var xInt big.Int + x.ToBigInt(&xInt) + + xInt.Mul(&xInt, big.NewInt(c)) + + result.assertMatchVeryBigInt(t, resultHi, &xInt) + return xInt +} + +func assertMatch(t *testing.T, w []big.Word, a uint64, index int) { + + var wI big.Word + + if index < len(w) { + wI = w[index] + } + + const filter uint64 = 0xFFFFFFFFFFFFFFFF >> (64 - bits.UintSize) + + a = a >> ((index * bits.UintSize) % 64) + a &= filter + + if uint64(wI) != a { + t.Error("Bignum mismatch: disagreement on word", index) + } +} + +func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { + + var modulus big.Int + var aIntMod big.Int + modulus.SetInt64(1) + modulus.Lsh(&modulus, (Limbs+1)*64) + aIntMod.Mod(aInt, &modulus) + + words := aIntMod.Bits() + + const steps = 64 / bits.UintSize + for i := 0; i < Limbs*steps; i++ { + assertMatch(t, words, z[i/steps], i) + } + + for i := 0; i < steps; i++ { + assertMatch(t, words, aHi, Limbs*steps+i) + } +} + +func approximateRef(x *Element) uint64 { + + var asInt big.Int + x.ToBigInt(&asInt) + n := x.BitLen() + + if n <= 64 { + return asInt.Uint64() + } + + modulus := big.NewInt(1 << 31) + var lo big.Int + lo.Mod(&asInt, modulus) + + modulus.Lsh(modulus, uint(n-64)) + var hi big.Int + hi.Div(&asInt, modulus) + hi.Lsh(&hi, 31) + + hi.Add(&hi, &lo) + return hi.Uint64() +} diff --git a/ecc/bls12-381/fr/element.go b/ecc/bls12-381/fr/element.go index f02768b82d..4e5f901d13 100644 --- a/ecc/bls12-381/fr/element.go +++ b/ecc/bls12-381/fr/element.go @@ -63,13 +63,21 @@ func Modulus() *big.Int { } // q (modulus) +const qElementWord0 uint64 = 18446744069414584321 +const qElementWord1 uint64 = 6034159408538082302 +const qElementWord2 uint64 = 3691218898639771653 +const qElementWord3 uint64 = 8353516859464449352 + var qElement = Element{ - 18446744069414584321, - 6034159408538082302, - 3691218898639771653, - 8353516859464449352, + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, } +// Used for Montgomery reduction. (qInvNeg) q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +const qInvNegLsw uint64 = 18446744069414584319 + // rSquare var rSquare = Element{ 14526898881837571181, @@ -187,7 +195,7 @@ func (z *Element) IsZero() bool { return (z[3] | z[2] | z[1] | z[0]) == 0 } -// IsUint64 returns true if z[0] >= 0 and all other words are 0 +// IsUint64 returns true if z[0] ⩾ 0 and all other words are 0 func (z *Element) IsUint64() bool { return (z[3] | z[2] | z[1]) == 0 } @@ -257,7 +265,7 @@ func (z *Element) SetRandom() (*Element, error) { z[3] = binary.BigEndian.Uint64(bytes[24:32]) z[3] %= 8353516859464449352 - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 8353516859464449352 || (z[3] == 8353516859464449352 && (z[2] < 3691218898639771653 || (z[2] == 3691218898639771653 && (z[1] < 6034159408538082302 || (z[1] == 6034159408538082302 && (z[0] < 18446744069414584321))))))) { var b uint64 @@ -405,7 +413,58 @@ func _mulGeneric(z, x, y *Element) { z[3], z[2] = madd3(m, 8353516859464449352, c[0], c[2], c[1]) } - // if z > q --> z -= q + // if z > q → z -= q + // note: this is NOT constant time + if !(z[3] < 8353516859464449352 || (z[3] == 8353516859464449352 && (z[2] < 3691218898639771653 || (z[2] == 3691218898639771653 && (z[1] < 6034159408538082302 || (z[1] == 6034159408538082302 && (z[0] < 18446744069414584321))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 18446744069414584321, 0) + z[1], b = bits.Sub64(z[1], 6034159408538082302, b) + z[2], b = bits.Sub64(z[2], 3691218898639771653, b) + z[3], _ = bits.Sub64(z[3], 8353516859464449352, b) + } +} + +func _mulWGeneric(z, x *Element, y uint64) { + + var t [4]uint64 + { + // round 0 + c1, c0 := bits.Mul64(y, x[0]) + m := c0 * 18446744069414584319 + c2 := madd0(m, 18446744069414584321, c0) + c1, c0 = madd1(y, x[1], c1) + c2, t[0] = madd2(m, 6034159408538082302, c2, c0) + c1, c0 = madd1(y, x[2], c1) + c2, t[1] = madd2(m, 3691218898639771653, c2, c0) + c1, c0 = madd1(y, x[3], c1) + t[3], t[2] = madd3(m, 8353516859464449352, c0, c2, c1) + } + { + // round 1 + m := t[0] * 18446744069414584319 + c2 := madd0(m, 18446744069414584321, t[0]) + c2, t[0] = madd2(m, 6034159408538082302, c2, t[1]) + c2, t[1] = madd2(m, 3691218898639771653, c2, t[2]) + t[3], t[2] = madd2(m, 8353516859464449352, t[3], c2) + } + { + // round 2 + m := t[0] * 18446744069414584319 + c2 := madd0(m, 18446744069414584321, t[0]) + c2, t[0] = madd2(m, 6034159408538082302, c2, t[1]) + c2, t[1] = madd2(m, 3691218898639771653, c2, t[2]) + t[3], t[2] = madd2(m, 8353516859464449352, t[3], c2) + } + { + // round 3 + m := t[0] * 18446744069414584319 + c2 := madd0(m, 18446744069414584321, t[0]) + c2, z[0] = madd2(m, 6034159408538082302, c2, t[1]) + c2, z[1] = madd2(m, 3691218898639771653, c2, t[2]) + z[3], z[2] = madd2(m, 8353516859464449352, t[3], c2) + } + + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 8353516859464449352 || (z[3] == 8353516859464449352 && (z[2] < 3691218898639771653 || (z[2] == 3691218898639771653 && (z[1] < 6034159408538082302 || (z[1] == 6034159408538082302 && (z[0] < 18446744069414584321))))))) { var b uint64 @@ -456,7 +515,7 @@ func _fromMontGeneric(z *Element) { z[3] = C } - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 8353516859464449352 || (z[3] == 8353516859464449352 && (z[2] < 3691218898639771653 || (z[2] == 3691218898639771653 && (z[1] < 6034159408538082302 || (z[1] == 6034159408538082302 && (z[0] < 18446744069414584321))))))) { var b uint64 @@ -475,7 +534,7 @@ func _addGeneric(z, x, y *Element) { z[2], carry = bits.Add64(x[2], y[2], carry) z[3], _ = bits.Add64(x[3], y[3], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 8353516859464449352 || (z[3] == 8353516859464449352 && (z[2] < 3691218898639771653 || (z[2] == 3691218898639771653 && (z[1] < 6034159408538082302 || (z[1] == 6034159408538082302 && (z[0] < 18446744069414584321))))))) { var b uint64 @@ -494,7 +553,7 @@ func _doubleGeneric(z, x *Element) { z[2], carry = bits.Add64(x[2], x[2], carry) z[3], _ = bits.Add64(x[3], x[3], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 8353516859464449352 || (z[3] == 8353516859464449352 && (z[2] < 3691218898639771653 || (z[2] == 3691218898639771653 && (z[1] < 6034159408538082302 || (z[1] == 6034159408538082302 && (z[0] < 18446744069414584321))))))) { var b uint64 @@ -534,7 +593,7 @@ func _negGeneric(z, x *Element) { func _reduceGeneric(z *Element) { - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 8353516859464449352 || (z[3] == 8353516859464449352 && (z[2] < 3691218898639771653 || (z[2] == 3691218898639771653 && (z[1] < 6034159408538082302 || (z[1] == 6034159408538082302 && (z[0] < 18446744069414584321))))))) { var b uint64 @@ -642,7 +701,7 @@ func (z *Element) Exp(x Element, exponent *big.Int) *Element { } // ToMont converts z to Montgomery form -// sets and returns z = z * r^2 +// sets and returns z = z * r² func (z *Element) ToMont() *Element { return z.Mul(z, &rSquare) } @@ -772,7 +831,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { return z } -// setBigInt assumes 0 <= v < q +// setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() @@ -958,153 +1017,418 @@ func (z *Element) Sqrt(x *Element) *Element { } } -// Inverse z = x^-1 mod q -// Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" -// if x == 0, sets and returns z = x +func max(a int, b int) int { + if a > b { + return a + } + return b +} + +func min(a int, b int) int { + if a < b { + return a + } + return b +} + +const updateFactorsConversionBias int64 = 0x7fffffff7fffffff // (2³¹ - 1)(2³² + 1) +const updateFactorIdentityMatrixRow0 = 1 +const updateFactorIdentityMatrixRow1 = 1 << 32 + +func updateFactorsDecompose(c int64) (int64, int64) { + c += updateFactorsConversionBias + const low32BitsFilter int64 = 0xFFFFFFFF + f := c&low32BitsFilter - 0x7FFFFFFF + g := c>>32&low32BitsFilter - 0x7FFFFFFF + return f, g +} + +const k = 32 // word size / 2 +const signBitSelector = uint64(1) << 63 +const approxLowBitsN = k - 1 +const approxHighBitsN = k + 1 +const inversionCorrectionFactorWord0 = 10120633560485349752 +const inversionCorrectionFactorWord1 = 6708885176490223342 +const inversionCorrectionFactorWord2 = 15589610060228208133 +const inversionCorrectionFactorWord3 = 1857276366933877101 + +const invIterationsN = 18 + +// Inverse z = x⁻¹ mod q +// Implements "Optimized Binary GCD for Modular Inversion" +// https://github.com/pornin/bingcd/blob/main/doc/bingcd.pdf func (z *Element) Inverse(x *Element) *Element { if x.IsZero() { z.SetZero() return z } - // initialize u = q - var u = Element{ - 18446744069414584321, - 6034159408538082302, - 3691218898639771653, - 8353516859464449352, + a := *x + b := Element{ + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, + } // b := q + + u := Element{1} + + // Update factors: we get [u; v]:= [f0 g0; f1 g1] [u; v] + // c_i = f_i + 2³¹ - 1 + 2³² * (g_i + 2³¹ - 1) + var c0, c1 int64 + + // Saved update factors to reduce the number of field multiplications + var pf0, pf1, pg0, pg1 int64 + + var i uint + + var v, s Element + + // Since u,v are updated every other iteration, we must make sure we terminate after evenly many iterations + // This also lets us get away with half as many updates to u,v + // To make this constant-time-ish, replace the condition with i < invIterationsN + for i = 0; i&1 == 1 || !a.IsZero(); i++ { + n := max(a.BitLen(), b.BitLen()) + aApprox, bApprox := approximate(&a, n), approximate(&b, n) + + // After 0 iterations, we have f₀ ≤ 2⁰ and f₁ < 2⁰ + // f0, g0, f1, g1 = 1, 0, 0, 1 + c0, c1 = updateFactorIdentityMatrixRow0, updateFactorIdentityMatrixRow1 + + for j := 0; j < approxLowBitsN; j++ { + + if aApprox&1 == 0 { + aApprox /= 2 + } else { + s, borrow := bits.Sub64(aApprox, bApprox, 0) + if borrow == 1 { + s = bApprox - aApprox + bApprox = aApprox + c0, c1 = c1, c0 + } + + aApprox = s / 2 + c0 = c0 - c1 + + // Now |f₀| < 2ʲ + 2ʲ = 2ʲ⁺¹ + // |f₁| ≤ 2ʲ still + } + + c1 *= 2 + // |f₁| ≤ 2ʲ⁺¹ + } + + s = a + + var g0 int64 + // from this point on c0 aliases for f0 + c0, g0 = updateFactorsDecompose(c0) + aHi := a.linearCombNonModular(&s, c0, &b, g0) + if aHi&signBitSelector != 0 { + // if aHi < 0 + c0, g0 = -c0, -g0 + aHi = a.neg(&a, aHi) + } + // right-shift a by k-1 bits + a[0] = (a[0] >> approxLowBitsN) | ((a[1]) << approxHighBitsN) + a[1] = (a[1] >> approxLowBitsN) | ((a[2]) << approxHighBitsN) + a[2] = (a[2] >> approxLowBitsN) | ((a[3]) << approxHighBitsN) + a[3] = (a[3] >> approxLowBitsN) | (aHi << approxHighBitsN) + + var f1 int64 + // from this point on c1 aliases for g0 + f1, c1 = updateFactorsDecompose(c1) + bHi := b.linearCombNonModular(&s, f1, &b, c1) + if bHi&signBitSelector != 0 { + // if bHi < 0 + f1, c1 = -f1, -c1 + bHi = b.neg(&b, bHi) + } + // right-shift b by k-1 bits + b[0] = (b[0] >> approxLowBitsN) | ((b[1]) << approxHighBitsN) + b[1] = (b[1] >> approxLowBitsN) | ((b[2]) << approxHighBitsN) + b[2] = (b[2] >> approxLowBitsN) | ((b[3]) << approxHighBitsN) + b[3] = (b[3] >> approxLowBitsN) | (bHi << approxHighBitsN) + + if i&1 == 1 { + // Combine current update factors with previously stored ones + // [f₀, g₀; f₁, g₁] ← [f₀, g₀; f₁, g₀] [pf₀, pg₀; pf₀, pg₀] + // We have |f₀|, |g₀|, |pf₀|, |pf₁| ≤ 2ᵏ⁻¹, and that |pf_i| < 2ᵏ⁻¹ for i ∈ {0, 1} + // Then for the new value we get |f₀| < 2ᵏ⁻¹ × 2ᵏ⁻¹ + 2ᵏ⁻¹ × 2ᵏ⁻¹ = 2²ᵏ⁻¹ + // Which leaves us with an extra bit for the sign + + // c0 aliases f0, c1 aliases g1 + c0, g0, f1, c1 = c0*pf0+g0*pf1, + c0*pg0+g0*pg1, + f1*pf0+c1*pf1, + f1*pg0+c1*pg1 + + s = u + u.linearCombSosSigned(&u, c0, &v, g0) + v.linearCombSosSigned(&s, f1, &v, c1) + + } else { + // Save update factors + pf0, pg0, pf1, pg1 = c0, g0, f1, c1 + } } - // initialize s = r^2 - var s = Element{ - 14526898881837571181, - 3129137299524312099, - 419701826671360399, - 524908885293268753, + // For every iteration that we miss, v is not being multiplied by 2²ᵏ⁻² + const pSq int64 = 1 << (2 * (k - 1)) + // If the function is constant-time ish, this loop will not run (probably no need to take it out explicitly) + for ; i < invIterationsN; i += 2 { + v.mulWSigned(&v, pSq) } - // r = 0 - r := Element{} + z.Mul(&v, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + }) + return z +} - v := *x +// approximate a big number x into a single 64 bit word using its uppermost and lowermost bits +// if x fits in a word as is, no approximation necessary +func approximate(x *Element, nBits int) uint64 { - var carry, borrow uint64 - var bigger bool + if nBits <= 64 { + return x[0] + } - for { - for v[0]&1 == 0 { + const mask = (uint64(1) << (k - 1)) - 1 // k-1 ones + lo := mask & x[0] - // v = v >> 1 + hiWordIndex := (nBits - 1) / 64 - v[0] = v[0]>>1 | v[1]<<63 - v[1] = v[1]>>1 | v[2]<<63 - v[2] = v[2]>>1 | v[3]<<63 - v[3] >>= 1 + hiWordBitsAvailable := nBits - hiWordIndex*64 + hiWordBitsUsed := min(hiWordBitsAvailable, approxHighBitsN) - if s[0]&1 == 1 { + mask_ := uint64(^((1 << (hiWordBitsAvailable - hiWordBitsUsed)) - 1)) + hi := (x[hiWordIndex] & mask_) << (64 - hiWordBitsAvailable) - // s = s + q - s[0], carry = bits.Add64(s[0], 18446744069414584321, 0) - s[1], carry = bits.Add64(s[1], 6034159408538082302, carry) - s[2], carry = bits.Add64(s[2], 3691218898639771653, carry) - s[3], _ = bits.Add64(s[3], 8353516859464449352, carry) + mask_ = ^(1<<(approxLowBitsN+hiWordBitsUsed) - 1) + mid := (mask_ & x[hiWordIndex-1]) >> hiWordBitsUsed - } + return lo | mid | hi +} - // s = s >> 1 +func (z *Element) linearCombSosSigned(x *Element, xC int64, y *Element, yC int64) { + hi := z.linearCombNonModular(x, xC, y, yC) + z.montReduceSigned(z, hi) +} - s[0] = s[0]>>1 | s[1]<<63 - s[1] = s[1]>>1 | s[2]<<63 - s[2] = s[2]>>1 | s[3]<<63 - s[3] >>= 1 +// montReduceSigned SOS algorithm; xHi must be at most 63 bits long. Last bit of xHi may be used as a sign bit +func (z *Element) montReduceSigned(x *Element, xHi uint64) { - } - for u[0]&1 == 0 { + const signBitRemover = ^signBitSelector + neg := xHi&signBitSelector != 0 + // the SOS implementation requires that most significant bit is 0 + // Let X be xHi*r + x + // note that if X is negative we would have initially stored it as 2⁶⁴ r + X + xHi &= signBitRemover + // with this a negative X is now represented as 2⁶³ r + X - // u = u >> 1 + var t [2*Limbs - 1]uint64 + var C uint64 - u[0] = u[0]>>1 | u[1]<<63 - u[1] = u[1]>>1 | u[2]<<63 - u[2] = u[2]>>1 | u[3]<<63 - u[3] >>= 1 + m := x[0] * qInvNegLsw - if r[0]&1 == 1 { + C = madd0(m, qElementWord0, x[0]) + C, t[1] = madd2(m, qElementWord1, x[1], C) + C, t[2] = madd2(m, qElementWord2, x[2], C) + C, t[3] = madd2(m, qElementWord3, x[3], C) - // r = r + q - r[0], carry = bits.Add64(r[0], 18446744069414584321, 0) - r[1], carry = bits.Add64(r[1], 6034159408538082302, carry) - r[2], carry = bits.Add64(r[2], 3691218898639771653, carry) - r[3], _ = bits.Add64(r[3], 8353516859464449352, carry) + // the high word of m * qElement[3] is at most 62 bits + // x[3] + C is at most 65 bits (high word at most 1 bit) + // Thus the resulting C will be at most 63 bits + t[4] = xHi + C + // xHi and C are 63 bits, therefore no overflow - } + { + const i = 1 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + + t[i+Limbs] += C + } + { + const i = 2 + m = t[i] * qInvNegLsw - // r = r >> 1 + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) - r[0] = r[0]>>1 | r[1]<<63 - r[1] = r[1]>>1 | r[2]<<63 - r[2] = r[2]>>1 | r[3]<<63 - r[3] >>= 1 + t[i+Limbs] += C + } + { + const i = 3 + m := t[i] * qInvNegLsw + C = madd0(m, qElementWord0, t[i+0]) + C, z[0] = madd2(m, qElementWord1, t[i+1], C) + C, z[1] = madd2(m, qElementWord2, t[i+2], C) + z[3], z[2] = madd2(m, qElementWord3, t[i+3], C) + } + + // if z > q → z -= q + // note: this is NOT constant time + if !(z[3] < 8353516859464449352 || (z[3] == 8353516859464449352 && (z[2] < 3691218898639771653 || (z[2] == 3691218898639771653 && (z[1] < 6034159408538082302 || (z[1] == 6034159408538082302 && (z[0] < 18446744069414584321))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 18446744069414584321, 0) + z[1], b = bits.Sub64(z[1], 6034159408538082302, b) + z[2], b = bits.Sub64(z[2], 3691218898639771653, b) + z[3], _ = bits.Sub64(z[3], 8353516859464449352, b) + } + if neg { + // We have computed ( 2⁶³ r + X ) r⁻¹ = 2⁶³ + X r⁻¹ instead + var b uint64 + z[0], b = bits.Sub64(z[0], signBitSelector, 0) + z[1], b = bits.Sub64(z[1], 0, b) + z[2], b = bits.Sub64(z[2], 0, b) + z[3], b = bits.Sub64(z[3], 0, b) + + // Occurs iff x == 0 && xHi < 0, i.e. X = rX' for -2⁶³ ≤ X' < 0 + if b != 0 { + // z[3] = -1 + // negative: add q + const neg1 = 0xFFFFFFFFFFFFFFFF + + b = 0 + z[0], b = bits.Add64(z[0], qElementWord0, b) + z[1], b = bits.Add64(z[1], qElementWord1, b) + z[2], b = bits.Add64(z[2], qElementWord2, b) + z[3], _ = bits.Add64(neg1, qElementWord3, b) } + } +} + +// mulWSigned mul word signed (w/ montgomery reduction) +func (z *Element) mulWSigned(x *Element, y int64) { + m := y >> 63 + _mulWGeneric(z, x, uint64((y^m)-m)) + // multiply by abs(y) + if y < 0 { + z.Neg(z) + } +} + +func (z *Element) neg(x *Element, xHi uint64) uint64 { + var b uint64 - // v >= u - bigger = !(v[3] < u[3] || (v[3] == u[3] && (v[2] < u[2] || (v[2] == u[2] && (v[1] < u[1] || (v[1] == u[1] && (v[0] < u[0]))))))) + z[0], b = bits.Sub64(0, x[0], 0) + z[1], b = bits.Sub64(0, x[1], b) + z[2], b = bits.Sub64(0, x[2], b) + z[3], b = bits.Sub64(0, x[3], b) + xHi, _ = bits.Sub64(0, xHi, b) - if bigger { + return xHi +} - // v = v - u - v[0], borrow = bits.Sub64(v[0], u[0], 0) - v[1], borrow = bits.Sub64(v[1], u[1], borrow) - v[2], borrow = bits.Sub64(v[2], u[2], borrow) - v[3], _ = bits.Sub64(v[3], u[3], borrow) +// regular multiplication by one word regular (non montgomery) +// Fewer additions than the branch-free for positive y. Could be faster on some architectures +func (z *Element) mulWRegular(x *Element, y int64) uint64 { - // s = s - r - s[0], borrow = bits.Sub64(s[0], r[0], 0) - s[1], borrow = bits.Sub64(s[1], r[1], borrow) - s[2], borrow = bits.Sub64(s[2], r[2], borrow) - s[3], borrow = bits.Sub64(s[3], r[3], borrow) + // w := abs(y) + m := y >> 63 + w := uint64((y ^ m) - m) - if borrow == 1 { + var c uint64 + c, z[0] = bits.Mul64(x[0], w) + c, z[1] = madd1(x[1], w, c) + c, z[2] = madd1(x[2], w, c) + c, z[3] = madd1(x[3], w, c) - // s = s + q - s[0], carry = bits.Add64(s[0], 18446744069414584321, 0) - s[1], carry = bits.Add64(s[1], 6034159408538082302, carry) - s[2], carry = bits.Add64(s[2], 3691218898639771653, carry) - s[3], _ = bits.Add64(s[3], 8353516859464449352, carry) + if y < 0 { + c = z.neg(z, c) + } - } - } else { + return c +} - // u = u - v - u[0], borrow = bits.Sub64(u[0], v[0], 0) - u[1], borrow = bits.Sub64(u[1], v[1], borrow) - u[2], borrow = bits.Sub64(u[2], v[2], borrow) - u[3], _ = bits.Sub64(u[3], v[3], borrow) +/* +Removed: seems slower +// mulWRegular branch-free regular multiplication by one word (non montgomery) +func (z *Element) mulWRegularBf(x *Element, y int64) uint64 { - // r = r - s - r[0], borrow = bits.Sub64(r[0], s[0], 0) - r[1], borrow = bits.Sub64(r[1], s[1], borrow) - r[2], borrow = bits.Sub64(r[2], s[2], borrow) - r[3], borrow = bits.Sub64(r[3], s[3], borrow) + w := uint64(y) + allNeg := uint64(y >> 63) // -1 if y < 0, 0 o.w - if borrow == 1 { + // s[0], s[1] so results are not stored immediately in z. + // x[i] will be needed in the i+1 th iteration. We don't want to overwrite it in case x = z + var s [2]uint64 + var h [2]uint64 - // r = r + q - r[0], carry = bits.Add64(r[0], 18446744069414584321, 0) - r[1], carry = bits.Add64(r[1], 6034159408538082302, carry) - r[2], carry = bits.Add64(r[2], 3691218898639771653, carry) - r[3], _ = bits.Add64(r[3], 8353516859464449352, carry) + h[0], s[0] = bits.Mul64(x[0], w) - } + c := uint64(0) + b := uint64(0) + + { + const curI = 1 % 2 + const prevI = 1 - curI + const iMinusOne = 1 - 1 + + h[curI], s[curI] = bits.Mul64(x[1], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (u[0] == 1) && (u[3]|u[2]|u[1]) == 0 { - z.Set(&r) - return z + + { + const curI = 2 % 2 + const prevI = 1 - curI + const iMinusOne = 2 - 1 + + h[curI], s[curI] = bits.Mul64(x[2], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (v[0] == 1) && (v[3]|v[2]|v[1]) == 0 { - z.Set(&s) - return z + + { + const curI = 3 % 2 + const prevI = 1 - curI + const iMinusOne = 3 - 1 + + h[curI], s[curI] = bits.Mul64(x[3], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } + { + const curI = 4 % 2 + const prevI = 1 - curI + const iMinusOne = 3 + + s[curI], _ = bits.Sub64(h[prevI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + + return s[curI] + c } +}*/ + +// Requires NoCarry +func (z *Element) linearCombNonModular(x *Element, xC int64, y *Element, yC int64) uint64 { + var yTimes Element + + yHi := yTimes.mulWRegular(y, yC) + xHi := z.mulWRegular(x, xC) + + carry := uint64(0) + z[0], carry = bits.Add64(z[0], yTimes[0], carry) + z[1], carry = bits.Add64(z[1], yTimes[1], carry) + z[2], carry = bits.Add64(z[2], yTimes[2], carry) + z[3], carry = bits.Add64(z[3], yTimes[3], carry) + + yHi, _ = bits.Add64(xHi, yHi, carry) + return yHi } diff --git a/ecc/bls12-381/fr/element_test.go b/ecc/bls12-381/fr/element_test.go index a5f26098d0..916508a19f 100644 --- a/ecc/bls12-381/fr/element_test.go +++ b/ecc/bls12-381/fr/element_test.go @@ -22,6 +22,7 @@ import ( "fmt" "math/big" "math/bits" + mrand "math/rand" "testing" "github.com/leanovate/gopter" @@ -271,7 +272,7 @@ var staticTestValues []Element func init() { staticTestValues = append(staticTestValues, Element{}) // zero staticTestValues = append(staticTestValues, One()) // one - staticTestValues = append(staticTestValues, rSquare) // r^2 + staticTestValues = append(staticTestValues, rSquare) // r² var e, one Element one.SetOne() e.Sub(&qElement, &one) @@ -1962,3 +1963,500 @@ func genFull() gopter.Gen { return genResult } } + +func TestElementInversionApproximation(t *testing.T) { + var x Element + for i := 0; i < 1000; i++ { + x.SetRandom() + + // Normally small elements are unlikely. Here we give them a higher chance + xZeros := mrand.Int() % Limbs + for j := 1; j < xZeros; j++ { + x[Limbs-j] = 0 + } + + a := approximate(&x, x.BitLen()) + aRef := approximateRef(&x) + + if a != aRef { + t.Error("Approximation mismatch") + } + } +} + +func TestElementInversionCorrectionFactorFormula(t *testing.T) { + const kLimbs = k * Limbs + const power = kLimbs*6 + invIterationsN*(kLimbs-k+1) + factorInt := big.NewInt(1) + factorInt.Lsh(factorInt, power) + factorInt.Mod(factorInt, Modulus()) + + var refFactorInt big.Int + inversionCorrectionFactor := Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + } + inversionCorrectionFactor.ToBigInt(&refFactorInt) + + if refFactorInt.Cmp(factorInt) != 0 { + t.Error("mismatch") + } +} + +func TestElementLinearComb(t *testing.T) { + var x Element + var y Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + y.SetRandom() + testLinearComb(t, &x, mrand.Int63(), &y, mrand.Int63()) + } +} + +// Probably unnecessary post-dev. In case the output of inv is wrong, this checks whether it's only off by a constant factor. +func TestElementInversionCorrectionFactor(t *testing.T) { + + // (1/x)/inv(x) = (1/1)/inv(1) ⇔ inv(1) = x inv(x) + + var one Element + var oneInv Element + one.SetOne() + oneInv.Inverse(&one) + + for i := 0; i < 100; i++ { + var x Element + var xInv Element + x.SetRandom() + xInv.Inverse(&x) + + x.Mul(&x, &xInv) + if !x.Equal(&oneInv) { + t.Error("Correction factor is inconsistent") + } + } + + if !oneInv.Equal(&one) { + var i big.Int + oneInv.ToBigIntRegular(&i) // no montgomery + i.ModInverse(&i, Modulus()) + var fac Element + fac.setBigInt(&i) // back to montgomery + + var facTimesFac Element + facTimesFac.Mul(&fac, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + }) + + t.Error("Correction factor is consistently off by", fac, "Should be", facTimesFac) + } +} + +func TestElementBigNumNeg(t *testing.T) { + var a Element + aHi := a.neg(&a, 0) + if !a.IsZero() || aHi != 0 { + t.Error("-0 != 0") + } +} + +func TestElementBigNumWMul(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + w := mrand.Int63() + testBigNumWMul(t, &x, w) + } +} + +func TestElementVeryBigIntConversion(t *testing.T) { + xHi := mrand.Uint64() + var x Element + x.SetRandom() + var xInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + x.assertMatchVeryBigInt(t, xHi, &xInt) +} + +func TestElementMontReducePos(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64() & ^signBitSelector) + } +} + +func TestElementMontReduceNeg(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64()|signBitSelector) + } +} + +func TestElementMontNegMultipleOfR(t *testing.T) { + var zero Element + + for i := 0; i < 1000; i++ { + testMontReduceSigned(t, &zero, mrand.Uint64()|signBitSelector) + } +} + +//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +func TestUpdateFactorSubtraction(t *testing.T) { + for i := 0; i < 1000; i++ { + + f0, g0 := randomizeUpdateFactors() + f1, g1 := randomizeUpdateFactors() + + for f0-f1 > 1<<31 || f0-f1 <= -1<<31 { + f1 /= 2 + } + + for g0-g1 > 1<<31 || g0-g1 <= -1<<31 { + g1 /= 2 + } + + c0 := updateFactorsCompose(f0, g0) + c1 := updateFactorsCompose(f1, g1) + + cRes := c0 - c1 + fRes, gRes := updateFactorsDecompose(cRes) + + if fRes != f0-f1 || gRes != g0-g1 { + t.Error(i) + } + } +} + +func TestUpdateFactorsDouble(t *testing.T) { + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f > 1<<30 || f < (-1<<31+1)/2 { + f /= 2 + if g <= 1<<29 && g >= (-1<<31+1)/4 { + g *= 2 //g was kept small on f's account. Now that we're halving f, we can double g + } + } + + if g > 1<<30 || g < (-1<<31+1)/2 { + g /= 2 + + if f <= 1<<29 && f >= (-1<<31+1)/4 { + f *= 2 //f was kept small on g's account. Now that we're halving g, we can double f + } + } + + c := updateFactorsCompose(f, g) + cD := c * 2 + fD, gD := updateFactorsDecompose(cD) + + if fD != 2*f || gD != 2*g { + t.Error(i) + } + } +} + +func TestUpdateFactorsNeg(t *testing.T) { + var fMistake bool + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f == 0x80000000 || g == 0x80000000 { + // Update factors this large can only have been obtained after 31 iterations and will therefore never be negated + // We don't have capacity to store -2³¹ + // Repeat this iteration + i-- + continue + } + + c := updateFactorsCompose(f, g) + nc := -c + nf, ng := updateFactorsDecompose(nc) + fMistake = fMistake || nf != -f + if nf != -f || ng != -g { + t.Errorf("Mismatch iteration #%d:\n%d, %d ->\n %d -> %d ->\n %d, %d\n Inputs in hex: %X, %X", + i, f, g, c, nc, nf, ng, f, g) + } + } + if fMistake { + t.Error("Mistake with f detected") + } else { + t.Log("All good with f") + } +} + +func TestUpdateFactorsNeg0(t *testing.T) { + c := updateFactorsCompose(0, 0) + t.Logf("c(0,0) = %X", c) + cn := -c + + if c != cn { + t.Error("Negation of zero update factors should yield the same result.") + } +} + +func TestUpdateFactorDecomposition(t *testing.T) { + var negSeen bool + + for i := 0; i < 1000; i++ { + + f, g := randomizeUpdateFactors() + + if f <= -(1<<31) || f > 1<<31 { + t.Fatal("f out of range") + } + + negSeen = negSeen || f < 0 + + c := updateFactorsCompose(f, g) + + fBack, gBack := updateFactorsDecompose(c) + + if f != fBack || g != gBack { + t.Errorf("(%d, %d) -> %d -> (%d, %d)\n", f, g, c, fBack, gBack) + } + } + + if !negSeen { + t.Fatal("No negative f factors") + } +} + +func TestUpdateFactorInitialValues(t *testing.T) { + + f0, g0 := updateFactorsDecompose(updateFactorIdentityMatrixRow0) + f1, g1 := updateFactorsDecompose(updateFactorIdentityMatrixRow1) + + if f0 != 1 || g0 != 0 || f1 != 0 || g1 != 1 { + t.Error("Update factor initial value constants are incorrect") + } +} + +func TestUpdateFactorsRandomization(t *testing.T) { + var maxLen int + + //t.Log("|f| + |g| is not to exceed", 1 << 31) + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + lf, lg := abs64T32(f), abs64T32(g) + absSum := lf + lg + if absSum >= 1<<31 { + + if absSum == 1<<31 { + maxLen++ + } else { + t.Error(i, "Sum of absolute values too large, f =", f, ",g =", g, ",|f| + |g| =", absSum) + } + } + } + + if maxLen == 0 { + t.Error("max len not observed") + } else { + t.Log(maxLen, "maxLens observed") + } +} + +func randomizeUpdateFactor(absLimit uint32) int64 { + const maxSizeLikelihood = 10 + maxSize := mrand.Intn(maxSizeLikelihood) + + absLimit64 := int64(absLimit) + var f int64 + switch maxSize { + case 0: + f = absLimit64 + case 1: + f = -absLimit64 + default: + f = int64(mrand.Uint64()%(2*uint64(absLimit64)+1)) - absLimit64 + } + + if f > 1<<31 { + return 1 << 31 + } else if f < -1<<31+1 { + return -1<<31 + 1 + } + + return f +} + +func abs64T32(f int64) uint32 { + if f >= 1<<32 || f < -1<<32 { + panic("f out of range") + } + + if f < 0 { + return uint32(-f) + } + return uint32(f) +} + +func randomizeUpdateFactors() (int64, int64) { + var f [2]int64 + b := mrand.Int() % 2 + + f[b] = randomizeUpdateFactor(1 << 31) + + //As per the paper, |f| + |g| \le 2³¹. + f[1-b] = randomizeUpdateFactor(1<<31 - abs64T32(f[b])) + + //Patching another edge case + if f[0]+f[1] == -1<<31 { + b = mrand.Int() % 2 + f[b]++ + } + + return f[0], f[1] +} + +func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { + + var p1 big.Int + x.ToBigInt(&p1) + p1.Mul(&p1, big.NewInt(xC)) + + var p2 big.Int + y.ToBigInt(&p2) + p2.Mul(&p2, big.NewInt(yC)) + + p1.Add(&p1, &p2) + p1.Mod(&p1, Modulus()) + montReduce(&p1, &p1) + + var z Element + z.linearCombSosSigned(x, xC, y, yC) + z.assertMatchVeryBigInt(t, 0, &p1) +} + +func testBigNumWMul(t *testing.T, a *Element, c int64) { + var aHi uint64 + var aTimes Element + aHi = aTimes.mulWRegular(a, c) + + assertMulProduct(t, a, c, &aTimes, aHi) +} + +func testMontReduceSigned(t *testing.T, x *Element, xHi uint64) { + var res Element + var xInt big.Int + var resInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + res.montReduceSigned(x, xHi) + montReduce(&resInt, &xInt) + res.assertMatchVeryBigInt(t, 0, &resInt) +} + +func updateFactorsCompose(f int64, g int64) int64 { + return f + g<<32 +} + +var rInv big.Int + +func montReduce(res *big.Int, x *big.Int) { + if rInv.BitLen() == 0 { // initialization + rInv.SetUint64(1) + rInv.Lsh(&rInv, Limbs*64) + rInv.ModInverse(&rInv, Modulus()) + } + res.Mul(x, &rInv) + res.Mod(res, Modulus()) +} + +func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { + z.ToBigInt(i) + var upperWord big.Int + upperWord.SetUint64(xHi) + upperWord.Lsh(&upperWord, Limbs*64) + i.Add(&upperWord, i) +} + +func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { + z.toVeryBigIntUnsigned(i, xHi) + if signBitSelector&xHi != 0 { + twosCompModulus := big.NewInt(1) + twosCompModulus.Lsh(twosCompModulus, (Limbs+1)*64) + i.Sub(i, twosCompModulus) + } +} + +func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { + var xInt big.Int + x.ToBigInt(&xInt) + + xInt.Mul(&xInt, big.NewInt(c)) + + result.assertMatchVeryBigInt(t, resultHi, &xInt) + return xInt +} + +func assertMatch(t *testing.T, w []big.Word, a uint64, index int) { + + var wI big.Word + + if index < len(w) { + wI = w[index] + } + + const filter uint64 = 0xFFFFFFFFFFFFFFFF >> (64 - bits.UintSize) + + a = a >> ((index * bits.UintSize) % 64) + a &= filter + + if uint64(wI) != a { + t.Error("Bignum mismatch: disagreement on word", index) + } +} + +func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { + + var modulus big.Int + var aIntMod big.Int + modulus.SetInt64(1) + modulus.Lsh(&modulus, (Limbs+1)*64) + aIntMod.Mod(aInt, &modulus) + + words := aIntMod.Bits() + + const steps = 64 / bits.UintSize + for i := 0; i < Limbs*steps; i++ { + assertMatch(t, words, z[i/steps], i) + } + + for i := 0; i < steps; i++ { + assertMatch(t, words, aHi, Limbs*steps+i) + } +} + +func approximateRef(x *Element) uint64 { + + var asInt big.Int + x.ToBigInt(&asInt) + n := x.BitLen() + + if n <= 64 { + return asInt.Uint64() + } + + modulus := big.NewInt(1 << 31) + var lo big.Int + lo.Mod(&asInt, modulus) + + modulus.Lsh(modulus, uint(n-64)) + var hi big.Int + hi.Div(&asInt, modulus) + hi.Lsh(&hi, 31) + + hi.Add(&hi, &lo) + return hi.Uint64() +} diff --git a/ecc/bls24-315/fp/element.go b/ecc/bls24-315/fp/element.go index 8599cb423a..7c38b06d06 100644 --- a/ecc/bls24-315/fp/element.go +++ b/ecc/bls24-315/fp/element.go @@ -63,14 +63,23 @@ func Modulus() *big.Int { } // q (modulus) +const qElementWord0 uint64 = 8063698428123676673 +const qElementWord1 uint64 = 4764498181658371330 +const qElementWord2 uint64 = 16051339359738796768 +const qElementWord3 uint64 = 15273757526516850351 +const qElementWord4 uint64 = 342900304943437392 + var qElement = Element{ - 8063698428123676673, - 4764498181658371330, - 16051339359738796768, - 15273757526516850351, - 342900304943437392, + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, + qElementWord4, } +// Used for Montgomery reduction. (qInvNeg) q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +const qInvNegLsw uint64 = 8083954730842193919 + // rSquare var rSquare = Element{ 7746605402484284438, @@ -192,7 +201,7 @@ func (z *Element) IsZero() bool { return (z[4] | z[3] | z[2] | z[1] | z[0]) == 0 } -// IsUint64 returns true if z[0] >= 0 and all other words are 0 +// IsUint64 returns true if z[0] ⩾ 0 and all other words are 0 func (z *Element) IsUint64() bool { return (z[4] | z[3] | z[2] | z[1]) == 0 } @@ -269,7 +278,7 @@ func (z *Element) SetRandom() (*Element, error) { z[4] = binary.BigEndian.Uint64(bytes[32:40]) z[4] %= 342900304943437392 - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[4] < 342900304943437392 || (z[4] == 342900304943437392 && (z[3] < 15273757526516850351 || (z[3] == 15273757526516850351 && (z[2] < 16051339359738796768 || (z[2] == 16051339359738796768 && (z[1] < 4764498181658371330 || (z[1] == 4764498181658371330 && (z[0] < 8063698428123676673))))))))) { var b uint64 @@ -443,7 +452,73 @@ func _mulGeneric(z, x, y *Element) { z[4], z[3] = madd3(m, 342900304943437392, c[0], c[2], c[1]) } - // if z > q --> z -= q + // if z > q → z -= q + // note: this is NOT constant time + if !(z[4] < 342900304943437392 || (z[4] == 342900304943437392 && (z[3] < 15273757526516850351 || (z[3] == 15273757526516850351 && (z[2] < 16051339359738796768 || (z[2] == 16051339359738796768 && (z[1] < 4764498181658371330 || (z[1] == 4764498181658371330 && (z[0] < 8063698428123676673))))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 8063698428123676673, 0) + z[1], b = bits.Sub64(z[1], 4764498181658371330, b) + z[2], b = bits.Sub64(z[2], 16051339359738796768, b) + z[3], b = bits.Sub64(z[3], 15273757526516850351, b) + z[4], _ = bits.Sub64(z[4], 342900304943437392, b) + } +} + +func _mulWGeneric(z, x *Element, y uint64) { + + var t [5]uint64 + { + // round 0 + c1, c0 := bits.Mul64(y, x[0]) + m := c0 * 8083954730842193919 + c2 := madd0(m, 8063698428123676673, c0) + c1, c0 = madd1(y, x[1], c1) + c2, t[0] = madd2(m, 4764498181658371330, c2, c0) + c1, c0 = madd1(y, x[2], c1) + c2, t[1] = madd2(m, 16051339359738796768, c2, c0) + c1, c0 = madd1(y, x[3], c1) + c2, t[2] = madd2(m, 15273757526516850351, c2, c0) + c1, c0 = madd1(y, x[4], c1) + t[4], t[3] = madd3(m, 342900304943437392, c0, c2, c1) + } + { + // round 1 + m := t[0] * 8083954730842193919 + c2 := madd0(m, 8063698428123676673, t[0]) + c2, t[0] = madd2(m, 4764498181658371330, c2, t[1]) + c2, t[1] = madd2(m, 16051339359738796768, c2, t[2]) + c2, t[2] = madd2(m, 15273757526516850351, c2, t[3]) + t[4], t[3] = madd2(m, 342900304943437392, t[4], c2) + } + { + // round 2 + m := t[0] * 8083954730842193919 + c2 := madd0(m, 8063698428123676673, t[0]) + c2, t[0] = madd2(m, 4764498181658371330, c2, t[1]) + c2, t[1] = madd2(m, 16051339359738796768, c2, t[2]) + c2, t[2] = madd2(m, 15273757526516850351, c2, t[3]) + t[4], t[3] = madd2(m, 342900304943437392, t[4], c2) + } + { + // round 3 + m := t[0] * 8083954730842193919 + c2 := madd0(m, 8063698428123676673, t[0]) + c2, t[0] = madd2(m, 4764498181658371330, c2, t[1]) + c2, t[1] = madd2(m, 16051339359738796768, c2, t[2]) + c2, t[2] = madd2(m, 15273757526516850351, c2, t[3]) + t[4], t[3] = madd2(m, 342900304943437392, t[4], c2) + } + { + // round 4 + m := t[0] * 8083954730842193919 + c2 := madd0(m, 8063698428123676673, t[0]) + c2, z[0] = madd2(m, 4764498181658371330, c2, t[1]) + c2, z[1] = madd2(m, 16051339359738796768, c2, t[2]) + c2, z[2] = madd2(m, 15273757526516850351, c2, t[3]) + z[4], z[3] = madd2(m, 342900304943437392, t[4], c2) + } + + // if z > q → z -= q // note: this is NOT constant time if !(z[4] < 342900304943437392 || (z[4] == 342900304943437392 && (z[3] < 15273757526516850351 || (z[3] == 15273757526516850351 && (z[2] < 16051339359738796768 || (z[2] == 16051339359738796768 && (z[1] < 4764498181658371330 || (z[1] == 4764498181658371330 && (z[0] < 8063698428123676673))))))))) { var b uint64 @@ -509,7 +584,7 @@ func _fromMontGeneric(z *Element) { z[4] = C } - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[4] < 342900304943437392 || (z[4] == 342900304943437392 && (z[3] < 15273757526516850351 || (z[3] == 15273757526516850351 && (z[2] < 16051339359738796768 || (z[2] == 16051339359738796768 && (z[1] < 4764498181658371330 || (z[1] == 4764498181658371330 && (z[0] < 8063698428123676673))))))))) { var b uint64 @@ -530,7 +605,7 @@ func _addGeneric(z, x, y *Element) { z[3], carry = bits.Add64(x[3], y[3], carry) z[4], _ = bits.Add64(x[4], y[4], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[4] < 342900304943437392 || (z[4] == 342900304943437392 && (z[3] < 15273757526516850351 || (z[3] == 15273757526516850351 && (z[2] < 16051339359738796768 || (z[2] == 16051339359738796768 && (z[1] < 4764498181658371330 || (z[1] == 4764498181658371330 && (z[0] < 8063698428123676673))))))))) { var b uint64 @@ -551,7 +626,7 @@ func _doubleGeneric(z, x *Element) { z[3], carry = bits.Add64(x[3], x[3], carry) z[4], _ = bits.Add64(x[4], x[4], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[4] < 342900304943437392 || (z[4] == 342900304943437392 && (z[3] < 15273757526516850351 || (z[3] == 15273757526516850351 && (z[2] < 16051339359738796768 || (z[2] == 16051339359738796768 && (z[1] < 4764498181658371330 || (z[1] == 4764498181658371330 && (z[0] < 8063698428123676673))))))))) { var b uint64 @@ -595,7 +670,7 @@ func _negGeneric(z, x *Element) { func _reduceGeneric(z *Element) { - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[4] < 342900304943437392 || (z[4] == 342900304943437392 && (z[3] < 15273757526516850351 || (z[3] == 15273757526516850351 && (z[2] < 16051339359738796768 || (z[2] == 16051339359738796768 && (z[1] < 4764498181658371330 || (z[1] == 4764498181658371330 && (z[0] < 8063698428123676673))))))))) { var b uint64 @@ -707,7 +782,7 @@ func (z *Element) Exp(x Element, exponent *big.Int) *Element { } // ToMont converts z to Montgomery form -// sets and returns z = z * r^2 +// sets and returns z = z * r² func (z *Element) ToMont() *Element { return z.Mul(z, &rSquare) } @@ -839,7 +914,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { return z } -// setBigInt assumes 0 <= v < q +// setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() @@ -1026,167 +1101,456 @@ func (z *Element) Sqrt(x *Element) *Element { } } -// Inverse z = x^-1 mod q -// Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" -// if x == 0, sets and returns z = x +func max(a int, b int) int { + if a > b { + return a + } + return b +} + +func min(a int, b int) int { + if a < b { + return a + } + return b +} + +const updateFactorsConversionBias int64 = 0x7fffffff7fffffff // (2³¹ - 1)(2³² + 1) +const updateFactorIdentityMatrixRow0 = 1 +const updateFactorIdentityMatrixRow1 = 1 << 32 + +func updateFactorsDecompose(c int64) (int64, int64) { + c += updateFactorsConversionBias + const low32BitsFilter int64 = 0xFFFFFFFF + f := c&low32BitsFilter - 0x7FFFFFFF + g := c>>32&low32BitsFilter - 0x7FFFFFFF + return f, g +} + +const k = 32 // word size / 2 +const signBitSelector = uint64(1) << 63 +const approxLowBitsN = k - 1 +const approxHighBitsN = k + 1 +const inversionCorrectionFactorWord0 = 13359241550610159594 +const inversionCorrectionFactorWord1 = 7624632887220174691 +const inversionCorrectionFactorWord2 = 6412344873752403825 +const inversionCorrectionFactorWord3 = 11214014560053792263 +const inversionCorrectionFactorWord4 = 75428258669939399 + +const invIterationsN = 22 + +// Inverse z = x⁻¹ mod q +// Implements "Optimized Binary GCD for Modular Inversion" +// https://github.com/pornin/bingcd/blob/main/doc/bingcd.pdf func (z *Element) Inverse(x *Element) *Element { if x.IsZero() { z.SetZero() return z } - // initialize u = q - var u = Element{ - 8063698428123676673, - 4764498181658371330, - 16051339359738796768, - 15273757526516850351, - 342900304943437392, + a := *x + b := Element{ + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, + qElementWord4, + } // b := q + + u := Element{1} + + // Update factors: we get [u; v]:= [f0 g0; f1 g1] [u; v] + // c_i = f_i + 2³¹ - 1 + 2³² * (g_i + 2³¹ - 1) + var c0, c1 int64 + + // Saved update factors to reduce the number of field multiplications + var pf0, pf1, pg0, pg1 int64 + + var i uint + + var v, s Element + + // Since u,v are updated every other iteration, we must make sure we terminate after evenly many iterations + // This also lets us get away with half as many updates to u,v + // To make this constant-time-ish, replace the condition with i < invIterationsN + for i = 0; i&1 == 1 || !a.IsZero(); i++ { + n := max(a.BitLen(), b.BitLen()) + aApprox, bApprox := approximate(&a, n), approximate(&b, n) + + // After 0 iterations, we have f₀ ≤ 2⁰ and f₁ < 2⁰ + // f0, g0, f1, g1 = 1, 0, 0, 1 + c0, c1 = updateFactorIdentityMatrixRow0, updateFactorIdentityMatrixRow1 + + for j := 0; j < approxLowBitsN; j++ { + + if aApprox&1 == 0 { + aApprox /= 2 + } else { + s, borrow := bits.Sub64(aApprox, bApprox, 0) + if borrow == 1 { + s = bApprox - aApprox + bApprox = aApprox + c0, c1 = c1, c0 + } + + aApprox = s / 2 + c0 = c0 - c1 + + // Now |f₀| < 2ʲ + 2ʲ = 2ʲ⁺¹ + // |f₁| ≤ 2ʲ still + } + + c1 *= 2 + // |f₁| ≤ 2ʲ⁺¹ + } + + s = a + + var g0 int64 + // from this point on c0 aliases for f0 + c0, g0 = updateFactorsDecompose(c0) + aHi := a.linearCombNonModular(&s, c0, &b, g0) + if aHi&signBitSelector != 0 { + // if aHi < 0 + c0, g0 = -c0, -g0 + aHi = a.neg(&a, aHi) + } + // right-shift a by k-1 bits + a[0] = (a[0] >> approxLowBitsN) | ((a[1]) << approxHighBitsN) + a[1] = (a[1] >> approxLowBitsN) | ((a[2]) << approxHighBitsN) + a[2] = (a[2] >> approxLowBitsN) | ((a[3]) << approxHighBitsN) + a[3] = (a[3] >> approxLowBitsN) | ((a[4]) << approxHighBitsN) + a[4] = (a[4] >> approxLowBitsN) | (aHi << approxHighBitsN) + + var f1 int64 + // from this point on c1 aliases for g0 + f1, c1 = updateFactorsDecompose(c1) + bHi := b.linearCombNonModular(&s, f1, &b, c1) + if bHi&signBitSelector != 0 { + // if bHi < 0 + f1, c1 = -f1, -c1 + bHi = b.neg(&b, bHi) + } + // right-shift b by k-1 bits + b[0] = (b[0] >> approxLowBitsN) | ((b[1]) << approxHighBitsN) + b[1] = (b[1] >> approxLowBitsN) | ((b[2]) << approxHighBitsN) + b[2] = (b[2] >> approxLowBitsN) | ((b[3]) << approxHighBitsN) + b[3] = (b[3] >> approxLowBitsN) | ((b[4]) << approxHighBitsN) + b[4] = (b[4] >> approxLowBitsN) | (bHi << approxHighBitsN) + + if i&1 == 1 { + // Combine current update factors with previously stored ones + // [f₀, g₀; f₁, g₁] ← [f₀, g₀; f₁, g₀] [pf₀, pg₀; pf₀, pg₀] + // We have |f₀|, |g₀|, |pf₀|, |pf₁| ≤ 2ᵏ⁻¹, and that |pf_i| < 2ᵏ⁻¹ for i ∈ {0, 1} + // Then for the new value we get |f₀| < 2ᵏ⁻¹ × 2ᵏ⁻¹ + 2ᵏ⁻¹ × 2ᵏ⁻¹ = 2²ᵏ⁻¹ + // Which leaves us with an extra bit for the sign + + // c0 aliases f0, c1 aliases g1 + c0, g0, f1, c1 = c0*pf0+g0*pf1, + c0*pg0+g0*pg1, + f1*pf0+c1*pf1, + f1*pg0+c1*pg1 + + s = u + u.linearCombSosSigned(&u, c0, &v, g0) + v.linearCombSosSigned(&s, f1, &v, c1) + + } else { + // Save update factors + pf0, pg0, pf1, pg1 = c0, g0, f1, c1 + } } - // initialize s = r^2 - var s = Element{ - 7746605402484284438, - 6457291528853138485, - 14067144135019420374, - 14705958577488011058, - 150264569250089173, + // For every iteration that we miss, v is not being multiplied by 2²ᵏ⁻² + const pSq int64 = 1 << (2 * (k - 1)) + // If the function is constant-time ish, this loop will not run (probably no need to take it out explicitly) + for ; i < invIterationsN; i += 2 { + v.mulWSigned(&v, pSq) } - // r = 0 - r := Element{} + z.Mul(&v, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + }) + return z +} - v := *x +// approximate a big number x into a single 64 bit word using its uppermost and lowermost bits +// if x fits in a word as is, no approximation necessary +func approximate(x *Element, nBits int) uint64 { - var carry, borrow uint64 - var bigger bool + if nBits <= 64 { + return x[0] + } - for { - for v[0]&1 == 0 { + const mask = (uint64(1) << (k - 1)) - 1 // k-1 ones + lo := mask & x[0] - // v = v >> 1 + hiWordIndex := (nBits - 1) / 64 - v[0] = v[0]>>1 | v[1]<<63 - v[1] = v[1]>>1 | v[2]<<63 - v[2] = v[2]>>1 | v[3]<<63 - v[3] = v[3]>>1 | v[4]<<63 - v[4] >>= 1 + hiWordBitsAvailable := nBits - hiWordIndex*64 + hiWordBitsUsed := min(hiWordBitsAvailable, approxHighBitsN) - if s[0]&1 == 1 { + mask_ := uint64(^((1 << (hiWordBitsAvailable - hiWordBitsUsed)) - 1)) + hi := (x[hiWordIndex] & mask_) << (64 - hiWordBitsAvailable) - // s = s + q - s[0], carry = bits.Add64(s[0], 8063698428123676673, 0) - s[1], carry = bits.Add64(s[1], 4764498181658371330, carry) - s[2], carry = bits.Add64(s[2], 16051339359738796768, carry) - s[3], carry = bits.Add64(s[3], 15273757526516850351, carry) - s[4], _ = bits.Add64(s[4], 342900304943437392, carry) + mask_ = ^(1<<(approxLowBitsN+hiWordBitsUsed) - 1) + mid := (mask_ & x[hiWordIndex-1]) >> hiWordBitsUsed - } + return lo | mid | hi +} - // s = s >> 1 +func (z *Element) linearCombSosSigned(x *Element, xC int64, y *Element, yC int64) { + hi := z.linearCombNonModular(x, xC, y, yC) + z.montReduceSigned(z, hi) +} - s[0] = s[0]>>1 | s[1]<<63 - s[1] = s[1]>>1 | s[2]<<63 - s[2] = s[2]>>1 | s[3]<<63 - s[3] = s[3]>>1 | s[4]<<63 - s[4] >>= 1 +// montReduceSigned SOS algorithm; xHi must be at most 63 bits long. Last bit of xHi may be used as a sign bit +func (z *Element) montReduceSigned(x *Element, xHi uint64) { - } - for u[0]&1 == 0 { + const signBitRemover = ^signBitSelector + neg := xHi&signBitSelector != 0 + // the SOS implementation requires that most significant bit is 0 + // Let X be xHi*r + x + // note that if X is negative we would have initially stored it as 2⁶⁴ r + X + xHi &= signBitRemover + // with this a negative X is now represented as 2⁶³ r + X - // u = u >> 1 + var t [2*Limbs - 1]uint64 + var C uint64 - u[0] = u[0]>>1 | u[1]<<63 - u[1] = u[1]>>1 | u[2]<<63 - u[2] = u[2]>>1 | u[3]<<63 - u[3] = u[3]>>1 | u[4]<<63 - u[4] >>= 1 + m := x[0] * qInvNegLsw - if r[0]&1 == 1 { + C = madd0(m, qElementWord0, x[0]) + C, t[1] = madd2(m, qElementWord1, x[1], C) + C, t[2] = madd2(m, qElementWord2, x[2], C) + C, t[3] = madd2(m, qElementWord3, x[3], C) + C, t[4] = madd2(m, qElementWord4, x[4], C) - // r = r + q - r[0], carry = bits.Add64(r[0], 8063698428123676673, 0) - r[1], carry = bits.Add64(r[1], 4764498181658371330, carry) - r[2], carry = bits.Add64(r[2], 16051339359738796768, carry) - r[3], carry = bits.Add64(r[3], 15273757526516850351, carry) - r[4], _ = bits.Add64(r[4], 342900304943437392, carry) + // the high word of m * qElement[4] is at most 62 bits + // x[4] + C is at most 65 bits (high word at most 1 bit) + // Thus the resulting C will be at most 63 bits + t[5] = xHi + C + // xHi and C are 63 bits, therefore no overflow - } + { + const i = 1 + m = t[i] * qInvNegLsw - // r = r >> 1 + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) - r[0] = r[0]>>1 | r[1]<<63 - r[1] = r[1]>>1 | r[2]<<63 - r[2] = r[2]>>1 | r[3]<<63 - r[3] = r[3]>>1 | r[4]<<63 - r[4] >>= 1 + t[i+Limbs] += C + } + { + const i = 2 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + + t[i+Limbs] += C + } + { + const i = 3 + m = t[i] * qInvNegLsw + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + + t[i+Limbs] += C + } + { + const i = 4 + m := t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, z[0] = madd2(m, qElementWord1, t[i+1], C) + C, z[1] = madd2(m, qElementWord2, t[i+2], C) + C, z[2] = madd2(m, qElementWord3, t[i+3], C) + z[4], z[3] = madd2(m, qElementWord4, t[i+4], C) + } + + // if z > q → z -= q + // note: this is NOT constant time + if !(z[4] < 342900304943437392 || (z[4] == 342900304943437392 && (z[3] < 15273757526516850351 || (z[3] == 15273757526516850351 && (z[2] < 16051339359738796768 || (z[2] == 16051339359738796768 && (z[1] < 4764498181658371330 || (z[1] == 4764498181658371330 && (z[0] < 8063698428123676673))))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 8063698428123676673, 0) + z[1], b = bits.Sub64(z[1], 4764498181658371330, b) + z[2], b = bits.Sub64(z[2], 16051339359738796768, b) + z[3], b = bits.Sub64(z[3], 15273757526516850351, b) + z[4], _ = bits.Sub64(z[4], 342900304943437392, b) + } + if neg { + // We have computed ( 2⁶³ r + X ) r⁻¹ = 2⁶³ + X r⁻¹ instead + var b uint64 + z[0], b = bits.Sub64(z[0], signBitSelector, 0) + z[1], b = bits.Sub64(z[1], 0, b) + z[2], b = bits.Sub64(z[2], 0, b) + z[3], b = bits.Sub64(z[3], 0, b) + z[4], b = bits.Sub64(z[4], 0, b) + + // Occurs iff x == 0 && xHi < 0, i.e. X = rX' for -2⁶³ ≤ X' < 0 + if b != 0 { + // z[4] = -1 + // negative: add q + const neg1 = 0xFFFFFFFFFFFFFFFF + + b = 0 + z[0], b = bits.Add64(z[0], qElementWord0, b) + z[1], b = bits.Add64(z[1], qElementWord1, b) + z[2], b = bits.Add64(z[2], qElementWord2, b) + z[3], b = bits.Add64(z[3], qElementWord3, b) + z[4], _ = bits.Add64(neg1, qElementWord4, b) } + } +} - // v >= u - bigger = !(v[4] < u[4] || (v[4] == u[4] && (v[3] < u[3] || (v[3] == u[3] && (v[2] < u[2] || (v[2] == u[2] && (v[1] < u[1] || (v[1] == u[1] && (v[0] < u[0]))))))))) +// mulWSigned mul word signed (w/ montgomery reduction) +func (z *Element) mulWSigned(x *Element, y int64) { + m := y >> 63 + _mulWGeneric(z, x, uint64((y^m)-m)) + // multiply by abs(y) + if y < 0 { + z.Neg(z) + } +} - if bigger { +func (z *Element) neg(x *Element, xHi uint64) uint64 { + var b uint64 - // v = v - u - v[0], borrow = bits.Sub64(v[0], u[0], 0) - v[1], borrow = bits.Sub64(v[1], u[1], borrow) - v[2], borrow = bits.Sub64(v[2], u[2], borrow) - v[3], borrow = bits.Sub64(v[3], u[3], borrow) - v[4], _ = bits.Sub64(v[4], u[4], borrow) + z[0], b = bits.Sub64(0, x[0], 0) + z[1], b = bits.Sub64(0, x[1], b) + z[2], b = bits.Sub64(0, x[2], b) + z[3], b = bits.Sub64(0, x[3], b) + z[4], b = bits.Sub64(0, x[4], b) + xHi, _ = bits.Sub64(0, xHi, b) - // s = s - r - s[0], borrow = bits.Sub64(s[0], r[0], 0) - s[1], borrow = bits.Sub64(s[1], r[1], borrow) - s[2], borrow = bits.Sub64(s[2], r[2], borrow) - s[3], borrow = bits.Sub64(s[3], r[3], borrow) - s[4], borrow = bits.Sub64(s[4], r[4], borrow) + return xHi +} - if borrow == 1 { +// regular multiplication by one word regular (non montgomery) +// Fewer additions than the branch-free for positive y. Could be faster on some architectures +func (z *Element) mulWRegular(x *Element, y int64) uint64 { - // s = s + q - s[0], carry = bits.Add64(s[0], 8063698428123676673, 0) - s[1], carry = bits.Add64(s[1], 4764498181658371330, carry) - s[2], carry = bits.Add64(s[2], 16051339359738796768, carry) - s[3], carry = bits.Add64(s[3], 15273757526516850351, carry) - s[4], _ = bits.Add64(s[4], 342900304943437392, carry) + // w := abs(y) + m := y >> 63 + w := uint64((y ^ m) - m) - } - } else { + var c uint64 + c, z[0] = bits.Mul64(x[0], w) + c, z[1] = madd1(x[1], w, c) + c, z[2] = madd1(x[2], w, c) + c, z[3] = madd1(x[3], w, c) + c, z[4] = madd1(x[4], w, c) - // u = u - v - u[0], borrow = bits.Sub64(u[0], v[0], 0) - u[1], borrow = bits.Sub64(u[1], v[1], borrow) - u[2], borrow = bits.Sub64(u[2], v[2], borrow) - u[3], borrow = bits.Sub64(u[3], v[3], borrow) - u[4], _ = bits.Sub64(u[4], v[4], borrow) - - // r = r - s - r[0], borrow = bits.Sub64(r[0], s[0], 0) - r[1], borrow = bits.Sub64(r[1], s[1], borrow) - r[2], borrow = bits.Sub64(r[2], s[2], borrow) - r[3], borrow = bits.Sub64(r[3], s[3], borrow) - r[4], borrow = bits.Sub64(r[4], s[4], borrow) - - if borrow == 1 { - - // r = r + q - r[0], carry = bits.Add64(r[0], 8063698428123676673, 0) - r[1], carry = bits.Add64(r[1], 4764498181658371330, carry) - r[2], carry = bits.Add64(r[2], 16051339359738796768, carry) - r[3], carry = bits.Add64(r[3], 15273757526516850351, carry) - r[4], _ = bits.Add64(r[4], 342900304943437392, carry) + if y < 0 { + c = z.neg(z, c) + } - } + return c +} + +/* +Removed: seems slower +// mulWRegular branch-free regular multiplication by one word (non montgomery) +func (z *Element) mulWRegularBf(x *Element, y int64) uint64 { + + w := uint64(y) + allNeg := uint64(y >> 63) // -1 if y < 0, 0 o.w + + // s[0], s[1] so results are not stored immediately in z. + // x[i] will be needed in the i+1 th iteration. We don't want to overwrite it in case x = z + var s [2]uint64 + var h [2]uint64 + + h[0], s[0] = bits.Mul64(x[0], w) + + c := uint64(0) + b := uint64(0) + + { + const curI = 1 % 2 + const prevI = 1 - curI + const iMinusOne = 1 - 1 + + h[curI], s[curI] = bits.Mul64(x[1], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } + + { + const curI = 2 % 2 + const prevI = 1 - curI + const iMinusOne = 2 - 1 + + h[curI], s[curI] = bits.Mul64(x[2], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (u[0] == 1) && (u[4]|u[3]|u[2]|u[1]) == 0 { - z.Set(&r) - return z + + { + const curI = 3 % 2 + const prevI = 1 - curI + const iMinusOne = 3 - 1 + + h[curI], s[curI] = bits.Mul64(x[3], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (v[0] == 1) && (v[4]|v[3]|v[2]|v[1]) == 0 { - z.Set(&s) - return z + + { + const curI = 4 % 2 + const prevI = 1 - curI + const iMinusOne = 4 - 1 + + h[curI], s[curI] = bits.Mul64(x[4], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } + { + const curI = 5 % 2 + const prevI = 1 - curI + const iMinusOne = 4 + + s[curI], _ = bits.Sub64(h[prevI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + + return s[curI] + c } +}*/ + +// Requires NoCarry +func (z *Element) linearCombNonModular(x *Element, xC int64, y *Element, yC int64) uint64 { + var yTimes Element + + yHi := yTimes.mulWRegular(y, yC) + xHi := z.mulWRegular(x, xC) + + carry := uint64(0) + z[0], carry = bits.Add64(z[0], yTimes[0], carry) + z[1], carry = bits.Add64(z[1], yTimes[1], carry) + z[2], carry = bits.Add64(z[2], yTimes[2], carry) + z[3], carry = bits.Add64(z[3], yTimes[3], carry) + z[4], carry = bits.Add64(z[4], yTimes[4], carry) + + yHi, _ = bits.Add64(xHi, yHi, carry) + return yHi } diff --git a/ecc/bls24-315/fp/element_test.go b/ecc/bls24-315/fp/element_test.go index bd589a54e2..03a68cc848 100644 --- a/ecc/bls24-315/fp/element_test.go +++ b/ecc/bls24-315/fp/element_test.go @@ -22,6 +22,7 @@ import ( "fmt" "math/big" "math/bits" + mrand "math/rand" "testing" "github.com/leanovate/gopter" @@ -273,7 +274,7 @@ var staticTestValues []Element func init() { staticTestValues = append(staticTestValues, Element{}) // zero staticTestValues = append(staticTestValues, One()) // one - staticTestValues = append(staticTestValues, rSquare) // r^2 + staticTestValues = append(staticTestValues, rSquare) // r² var e, one Element one.SetOne() e.Sub(&qElement, &one) @@ -1976,3 +1977,502 @@ func genFull() gopter.Gen { return genResult } } + +func TestElementInversionApproximation(t *testing.T) { + var x Element + for i := 0; i < 1000; i++ { + x.SetRandom() + + // Normally small elements are unlikely. Here we give them a higher chance + xZeros := mrand.Int() % Limbs + for j := 1; j < xZeros; j++ { + x[Limbs-j] = 0 + } + + a := approximate(&x, x.BitLen()) + aRef := approximateRef(&x) + + if a != aRef { + t.Error("Approximation mismatch") + } + } +} + +func TestElementInversionCorrectionFactorFormula(t *testing.T) { + const kLimbs = k * Limbs + const power = kLimbs*6 + invIterationsN*(kLimbs-k+1) + factorInt := big.NewInt(1) + factorInt.Lsh(factorInt, power) + factorInt.Mod(factorInt, Modulus()) + + var refFactorInt big.Int + inversionCorrectionFactor := Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + } + inversionCorrectionFactor.ToBigInt(&refFactorInt) + + if refFactorInt.Cmp(factorInt) != 0 { + t.Error("mismatch") + } +} + +func TestElementLinearComb(t *testing.T) { + var x Element + var y Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + y.SetRandom() + testLinearComb(t, &x, mrand.Int63(), &y, mrand.Int63()) + } +} + +// Probably unnecessary post-dev. In case the output of inv is wrong, this checks whether it's only off by a constant factor. +func TestElementInversionCorrectionFactor(t *testing.T) { + + // (1/x)/inv(x) = (1/1)/inv(1) ⇔ inv(1) = x inv(x) + + var one Element + var oneInv Element + one.SetOne() + oneInv.Inverse(&one) + + for i := 0; i < 100; i++ { + var x Element + var xInv Element + x.SetRandom() + xInv.Inverse(&x) + + x.Mul(&x, &xInv) + if !x.Equal(&oneInv) { + t.Error("Correction factor is inconsistent") + } + } + + if !oneInv.Equal(&one) { + var i big.Int + oneInv.ToBigIntRegular(&i) // no montgomery + i.ModInverse(&i, Modulus()) + var fac Element + fac.setBigInt(&i) // back to montgomery + + var facTimesFac Element + facTimesFac.Mul(&fac, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + }) + + t.Error("Correction factor is consistently off by", fac, "Should be", facTimesFac) + } +} + +func TestElementBigNumNeg(t *testing.T) { + var a Element + aHi := a.neg(&a, 0) + if !a.IsZero() || aHi != 0 { + t.Error("-0 != 0") + } +} + +func TestElementBigNumWMul(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + w := mrand.Int63() + testBigNumWMul(t, &x, w) + } +} + +func TestElementVeryBigIntConversion(t *testing.T) { + xHi := mrand.Uint64() + var x Element + x.SetRandom() + var xInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + x.assertMatchVeryBigInt(t, xHi, &xInt) +} + +func TestElementMontReducePos(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64() & ^signBitSelector) + } +} + +func TestElementMontReduceNeg(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64()|signBitSelector) + } +} + +func TestElementMontNegMultipleOfR(t *testing.T) { + var zero Element + + for i := 0; i < 1000; i++ { + testMontReduceSigned(t, &zero, mrand.Uint64()|signBitSelector) + } +} + +//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +func TestUpdateFactorSubtraction(t *testing.T) { + for i := 0; i < 1000; i++ { + + f0, g0 := randomizeUpdateFactors() + f1, g1 := randomizeUpdateFactors() + + for f0-f1 > 1<<31 || f0-f1 <= -1<<31 { + f1 /= 2 + } + + for g0-g1 > 1<<31 || g0-g1 <= -1<<31 { + g1 /= 2 + } + + c0 := updateFactorsCompose(f0, g0) + c1 := updateFactorsCompose(f1, g1) + + cRes := c0 - c1 + fRes, gRes := updateFactorsDecompose(cRes) + + if fRes != f0-f1 || gRes != g0-g1 { + t.Error(i) + } + } +} + +func TestUpdateFactorsDouble(t *testing.T) { + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f > 1<<30 || f < (-1<<31+1)/2 { + f /= 2 + if g <= 1<<29 && g >= (-1<<31+1)/4 { + g *= 2 //g was kept small on f's account. Now that we're halving f, we can double g + } + } + + if g > 1<<30 || g < (-1<<31+1)/2 { + g /= 2 + + if f <= 1<<29 && f >= (-1<<31+1)/4 { + f *= 2 //f was kept small on g's account. Now that we're halving g, we can double f + } + } + + c := updateFactorsCompose(f, g) + cD := c * 2 + fD, gD := updateFactorsDecompose(cD) + + if fD != 2*f || gD != 2*g { + t.Error(i) + } + } +} + +func TestUpdateFactorsNeg(t *testing.T) { + var fMistake bool + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f == 0x80000000 || g == 0x80000000 { + // Update factors this large can only have been obtained after 31 iterations and will therefore never be negated + // We don't have capacity to store -2³¹ + // Repeat this iteration + i-- + continue + } + + c := updateFactorsCompose(f, g) + nc := -c + nf, ng := updateFactorsDecompose(nc) + fMistake = fMistake || nf != -f + if nf != -f || ng != -g { + t.Errorf("Mismatch iteration #%d:\n%d, %d ->\n %d -> %d ->\n %d, %d\n Inputs in hex: %X, %X", + i, f, g, c, nc, nf, ng, f, g) + } + } + if fMistake { + t.Error("Mistake with f detected") + } else { + t.Log("All good with f") + } +} + +func TestUpdateFactorsNeg0(t *testing.T) { + c := updateFactorsCompose(0, 0) + t.Logf("c(0,0) = %X", c) + cn := -c + + if c != cn { + t.Error("Negation of zero update factors should yield the same result.") + } +} + +func TestUpdateFactorDecomposition(t *testing.T) { + var negSeen bool + + for i := 0; i < 1000; i++ { + + f, g := randomizeUpdateFactors() + + if f <= -(1<<31) || f > 1<<31 { + t.Fatal("f out of range") + } + + negSeen = negSeen || f < 0 + + c := updateFactorsCompose(f, g) + + fBack, gBack := updateFactorsDecompose(c) + + if f != fBack || g != gBack { + t.Errorf("(%d, %d) -> %d -> (%d, %d)\n", f, g, c, fBack, gBack) + } + } + + if !negSeen { + t.Fatal("No negative f factors") + } +} + +func TestUpdateFactorInitialValues(t *testing.T) { + + f0, g0 := updateFactorsDecompose(updateFactorIdentityMatrixRow0) + f1, g1 := updateFactorsDecompose(updateFactorIdentityMatrixRow1) + + if f0 != 1 || g0 != 0 || f1 != 0 || g1 != 1 { + t.Error("Update factor initial value constants are incorrect") + } +} + +func TestUpdateFactorsRandomization(t *testing.T) { + var maxLen int + + //t.Log("|f| + |g| is not to exceed", 1 << 31) + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + lf, lg := abs64T32(f), abs64T32(g) + absSum := lf + lg + if absSum >= 1<<31 { + + if absSum == 1<<31 { + maxLen++ + } else { + t.Error(i, "Sum of absolute values too large, f =", f, ",g =", g, ",|f| + |g| =", absSum) + } + } + } + + if maxLen == 0 { + t.Error("max len not observed") + } else { + t.Log(maxLen, "maxLens observed") + } +} + +func randomizeUpdateFactor(absLimit uint32) int64 { + const maxSizeLikelihood = 10 + maxSize := mrand.Intn(maxSizeLikelihood) + + absLimit64 := int64(absLimit) + var f int64 + switch maxSize { + case 0: + f = absLimit64 + case 1: + f = -absLimit64 + default: + f = int64(mrand.Uint64()%(2*uint64(absLimit64)+1)) - absLimit64 + } + + if f > 1<<31 { + return 1 << 31 + } else if f < -1<<31+1 { + return -1<<31 + 1 + } + + return f +} + +func abs64T32(f int64) uint32 { + if f >= 1<<32 || f < -1<<32 { + panic("f out of range") + } + + if f < 0 { + return uint32(-f) + } + return uint32(f) +} + +func randomizeUpdateFactors() (int64, int64) { + var f [2]int64 + b := mrand.Int() % 2 + + f[b] = randomizeUpdateFactor(1 << 31) + + //As per the paper, |f| + |g| \le 2³¹. + f[1-b] = randomizeUpdateFactor(1<<31 - abs64T32(f[b])) + + //Patching another edge case + if f[0]+f[1] == -1<<31 { + b = mrand.Int() % 2 + f[b]++ + } + + return f[0], f[1] +} + +func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { + + var p1 big.Int + x.ToBigInt(&p1) + p1.Mul(&p1, big.NewInt(xC)) + + var p2 big.Int + y.ToBigInt(&p2) + p2.Mul(&p2, big.NewInt(yC)) + + p1.Add(&p1, &p2) + p1.Mod(&p1, Modulus()) + montReduce(&p1, &p1) + + var z Element + z.linearCombSosSigned(x, xC, y, yC) + z.assertMatchVeryBigInt(t, 0, &p1) +} + +func testBigNumWMul(t *testing.T, a *Element, c int64) { + var aHi uint64 + var aTimes Element + aHi = aTimes.mulWRegular(a, c) + + assertMulProduct(t, a, c, &aTimes, aHi) +} + +func testMontReduceSigned(t *testing.T, x *Element, xHi uint64) { + var res Element + var xInt big.Int + var resInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + res.montReduceSigned(x, xHi) + montReduce(&resInt, &xInt) + res.assertMatchVeryBigInt(t, 0, &resInt) +} + +func updateFactorsCompose(f int64, g int64) int64 { + return f + g<<32 +} + +var rInv big.Int + +func montReduce(res *big.Int, x *big.Int) { + if rInv.BitLen() == 0 { // initialization + rInv.SetUint64(1) + rInv.Lsh(&rInv, Limbs*64) + rInv.ModInverse(&rInv, Modulus()) + } + res.Mul(x, &rInv) + res.Mod(res, Modulus()) +} + +func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { + z.ToBigInt(i) + var upperWord big.Int + upperWord.SetUint64(xHi) + upperWord.Lsh(&upperWord, Limbs*64) + i.Add(&upperWord, i) +} + +func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { + z.toVeryBigIntUnsigned(i, xHi) + if signBitSelector&xHi != 0 { + twosCompModulus := big.NewInt(1) + twosCompModulus.Lsh(twosCompModulus, (Limbs+1)*64) + i.Sub(i, twosCompModulus) + } +} + +func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { + var xInt big.Int + x.ToBigInt(&xInt) + + xInt.Mul(&xInt, big.NewInt(c)) + + result.assertMatchVeryBigInt(t, resultHi, &xInt) + return xInt +} + +func assertMatch(t *testing.T, w []big.Word, a uint64, index int) { + + var wI big.Word + + if index < len(w) { + wI = w[index] + } + + const filter uint64 = 0xFFFFFFFFFFFFFFFF >> (64 - bits.UintSize) + + a = a >> ((index * bits.UintSize) % 64) + a &= filter + + if uint64(wI) != a { + t.Error("Bignum mismatch: disagreement on word", index) + } +} + +func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { + + var modulus big.Int + var aIntMod big.Int + modulus.SetInt64(1) + modulus.Lsh(&modulus, (Limbs+1)*64) + aIntMod.Mod(aInt, &modulus) + + words := aIntMod.Bits() + + const steps = 64 / bits.UintSize + for i := 0; i < Limbs*steps; i++ { + assertMatch(t, words, z[i/steps], i) + } + + for i := 0; i < steps; i++ { + assertMatch(t, words, aHi, Limbs*steps+i) + } +} + +func approximateRef(x *Element) uint64 { + + var asInt big.Int + x.ToBigInt(&asInt) + n := x.BitLen() + + if n <= 64 { + return asInt.Uint64() + } + + modulus := big.NewInt(1 << 31) + var lo big.Int + lo.Mod(&asInt, modulus) + + modulus.Lsh(modulus, uint(n-64)) + var hi big.Int + hi.Div(&asInt, modulus) + hi.Lsh(&hi, 31) + + hi.Add(&hi, &lo) + return hi.Uint64() +} diff --git a/ecc/bls24-315/fr/element.go b/ecc/bls24-315/fr/element.go index b91ead8114..0f5ff73699 100644 --- a/ecc/bls24-315/fr/element.go +++ b/ecc/bls24-315/fr/element.go @@ -63,13 +63,21 @@ func Modulus() *big.Int { } // q (modulus) +const qElementWord0 uint64 = 1860204336533995521 +const qElementWord1 uint64 = 14466829657984787300 +const qElementWord2 uint64 = 2737202078770428568 +const qElementWord3 uint64 = 1832378743606059307 + var qElement = Element{ - 1860204336533995521, - 14466829657984787300, - 2737202078770428568, - 1832378743606059307, + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, } +// Used for Montgomery reduction. (qInvNeg) q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +const qInvNegLsw uint64 = 2184305180030271487 + // rSquare var rSquare = Element{ 6242551132904523857, @@ -187,7 +195,7 @@ func (z *Element) IsZero() bool { return (z[3] | z[2] | z[1] | z[0]) == 0 } -// IsUint64 returns true if z[0] >= 0 and all other words are 0 +// IsUint64 returns true if z[0] ⩾ 0 and all other words are 0 func (z *Element) IsUint64() bool { return (z[3] | z[2] | z[1]) == 0 } @@ -257,7 +265,7 @@ func (z *Element) SetRandom() (*Element, error) { z[3] = binary.BigEndian.Uint64(bytes[24:32]) z[3] %= 1832378743606059307 - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 1832378743606059307 || (z[3] == 1832378743606059307 && (z[2] < 2737202078770428568 || (z[2] == 2737202078770428568 && (z[1] < 14466829657984787300 || (z[1] == 14466829657984787300 && (z[0] < 1860204336533995521))))))) { var b uint64 @@ -405,7 +413,58 @@ func _mulGeneric(z, x, y *Element) { z[3], z[2] = madd3(m, 1832378743606059307, c[0], c[2], c[1]) } - // if z > q --> z -= q + // if z > q → z -= q + // note: this is NOT constant time + if !(z[3] < 1832378743606059307 || (z[3] == 1832378743606059307 && (z[2] < 2737202078770428568 || (z[2] == 2737202078770428568 && (z[1] < 14466829657984787300 || (z[1] == 14466829657984787300 && (z[0] < 1860204336533995521))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 1860204336533995521, 0) + z[1], b = bits.Sub64(z[1], 14466829657984787300, b) + z[2], b = bits.Sub64(z[2], 2737202078770428568, b) + z[3], _ = bits.Sub64(z[3], 1832378743606059307, b) + } +} + +func _mulWGeneric(z, x *Element, y uint64) { + + var t [4]uint64 + { + // round 0 + c1, c0 := bits.Mul64(y, x[0]) + m := c0 * 2184305180030271487 + c2 := madd0(m, 1860204336533995521, c0) + c1, c0 = madd1(y, x[1], c1) + c2, t[0] = madd2(m, 14466829657984787300, c2, c0) + c1, c0 = madd1(y, x[2], c1) + c2, t[1] = madd2(m, 2737202078770428568, c2, c0) + c1, c0 = madd1(y, x[3], c1) + t[3], t[2] = madd3(m, 1832378743606059307, c0, c2, c1) + } + { + // round 1 + m := t[0] * 2184305180030271487 + c2 := madd0(m, 1860204336533995521, t[0]) + c2, t[0] = madd2(m, 14466829657984787300, c2, t[1]) + c2, t[1] = madd2(m, 2737202078770428568, c2, t[2]) + t[3], t[2] = madd2(m, 1832378743606059307, t[3], c2) + } + { + // round 2 + m := t[0] * 2184305180030271487 + c2 := madd0(m, 1860204336533995521, t[0]) + c2, t[0] = madd2(m, 14466829657984787300, c2, t[1]) + c2, t[1] = madd2(m, 2737202078770428568, c2, t[2]) + t[3], t[2] = madd2(m, 1832378743606059307, t[3], c2) + } + { + // round 3 + m := t[0] * 2184305180030271487 + c2 := madd0(m, 1860204336533995521, t[0]) + c2, z[0] = madd2(m, 14466829657984787300, c2, t[1]) + c2, z[1] = madd2(m, 2737202078770428568, c2, t[2]) + z[3], z[2] = madd2(m, 1832378743606059307, t[3], c2) + } + + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 1832378743606059307 || (z[3] == 1832378743606059307 && (z[2] < 2737202078770428568 || (z[2] == 2737202078770428568 && (z[1] < 14466829657984787300 || (z[1] == 14466829657984787300 && (z[0] < 1860204336533995521))))))) { var b uint64 @@ -456,7 +515,7 @@ func _fromMontGeneric(z *Element) { z[3] = C } - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 1832378743606059307 || (z[3] == 1832378743606059307 && (z[2] < 2737202078770428568 || (z[2] == 2737202078770428568 && (z[1] < 14466829657984787300 || (z[1] == 14466829657984787300 && (z[0] < 1860204336533995521))))))) { var b uint64 @@ -475,7 +534,7 @@ func _addGeneric(z, x, y *Element) { z[2], carry = bits.Add64(x[2], y[2], carry) z[3], _ = bits.Add64(x[3], y[3], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 1832378743606059307 || (z[3] == 1832378743606059307 && (z[2] < 2737202078770428568 || (z[2] == 2737202078770428568 && (z[1] < 14466829657984787300 || (z[1] == 14466829657984787300 && (z[0] < 1860204336533995521))))))) { var b uint64 @@ -494,7 +553,7 @@ func _doubleGeneric(z, x *Element) { z[2], carry = bits.Add64(x[2], x[2], carry) z[3], _ = bits.Add64(x[3], x[3], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 1832378743606059307 || (z[3] == 1832378743606059307 && (z[2] < 2737202078770428568 || (z[2] == 2737202078770428568 && (z[1] < 14466829657984787300 || (z[1] == 14466829657984787300 && (z[0] < 1860204336533995521))))))) { var b uint64 @@ -534,7 +593,7 @@ func _negGeneric(z, x *Element) { func _reduceGeneric(z *Element) { - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 1832378743606059307 || (z[3] == 1832378743606059307 && (z[2] < 2737202078770428568 || (z[2] == 2737202078770428568 && (z[1] < 14466829657984787300 || (z[1] == 14466829657984787300 && (z[0] < 1860204336533995521))))))) { var b uint64 @@ -642,7 +701,7 @@ func (z *Element) Exp(x Element, exponent *big.Int) *Element { } // ToMont converts z to Montgomery form -// sets and returns z = z * r^2 +// sets and returns z = z * r² func (z *Element) ToMont() *Element { return z.Mul(z, &rSquare) } @@ -772,7 +831,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { return z } -// setBigInt assumes 0 <= v < q +// setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() @@ -958,153 +1017,418 @@ func (z *Element) Sqrt(x *Element) *Element { } } -// Inverse z = x^-1 mod q -// Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" -// if x == 0, sets and returns z = x +func max(a int, b int) int { + if a > b { + return a + } + return b +} + +func min(a int, b int) int { + if a < b { + return a + } + return b +} + +const updateFactorsConversionBias int64 = 0x7fffffff7fffffff // (2³¹ - 1)(2³² + 1) +const updateFactorIdentityMatrixRow0 = 1 +const updateFactorIdentityMatrixRow1 = 1 << 32 + +func updateFactorsDecompose(c int64) (int64, int64) { + c += updateFactorsConversionBias + const low32BitsFilter int64 = 0xFFFFFFFF + f := c&low32BitsFilter - 0x7FFFFFFF + g := c>>32&low32BitsFilter - 0x7FFFFFFF + return f, g +} + +const k = 32 // word size / 2 +const signBitSelector = uint64(1) << 63 +const approxLowBitsN = k - 1 +const approxHighBitsN = k + 1 +const inversionCorrectionFactorWord0 = 11227523567874677897 +const inversionCorrectionFactorWord1 = 17969463060078018966 +const inversionCorrectionFactorWord2 = 11691022387976210672 +const inversionCorrectionFactorWord3 = 904271400317574999 + +const invIterationsN = 18 + +// Inverse z = x⁻¹ mod q +// Implements "Optimized Binary GCD for Modular Inversion" +// https://github.com/pornin/bingcd/blob/main/doc/bingcd.pdf func (z *Element) Inverse(x *Element) *Element { if x.IsZero() { z.SetZero() return z } - // initialize u = q - var u = Element{ - 1860204336533995521, - 14466829657984787300, - 2737202078770428568, - 1832378743606059307, + a := *x + b := Element{ + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, + } // b := q + + u := Element{1} + + // Update factors: we get [u; v]:= [f0 g0; f1 g1] [u; v] + // c_i = f_i + 2³¹ - 1 + 2³² * (g_i + 2³¹ - 1) + var c0, c1 int64 + + // Saved update factors to reduce the number of field multiplications + var pf0, pf1, pg0, pg1 int64 + + var i uint + + var v, s Element + + // Since u,v are updated every other iteration, we must make sure we terminate after evenly many iterations + // This also lets us get away with half as many updates to u,v + // To make this constant-time-ish, replace the condition with i < invIterationsN + for i = 0; i&1 == 1 || !a.IsZero(); i++ { + n := max(a.BitLen(), b.BitLen()) + aApprox, bApprox := approximate(&a, n), approximate(&b, n) + + // After 0 iterations, we have f₀ ≤ 2⁰ and f₁ < 2⁰ + // f0, g0, f1, g1 = 1, 0, 0, 1 + c0, c1 = updateFactorIdentityMatrixRow0, updateFactorIdentityMatrixRow1 + + for j := 0; j < approxLowBitsN; j++ { + + if aApprox&1 == 0 { + aApprox /= 2 + } else { + s, borrow := bits.Sub64(aApprox, bApprox, 0) + if borrow == 1 { + s = bApprox - aApprox + bApprox = aApprox + c0, c1 = c1, c0 + } + + aApprox = s / 2 + c0 = c0 - c1 + + // Now |f₀| < 2ʲ + 2ʲ = 2ʲ⁺¹ + // |f₁| ≤ 2ʲ still + } + + c1 *= 2 + // |f₁| ≤ 2ʲ⁺¹ + } + + s = a + + var g0 int64 + // from this point on c0 aliases for f0 + c0, g0 = updateFactorsDecompose(c0) + aHi := a.linearCombNonModular(&s, c0, &b, g0) + if aHi&signBitSelector != 0 { + // if aHi < 0 + c0, g0 = -c0, -g0 + aHi = a.neg(&a, aHi) + } + // right-shift a by k-1 bits + a[0] = (a[0] >> approxLowBitsN) | ((a[1]) << approxHighBitsN) + a[1] = (a[1] >> approxLowBitsN) | ((a[2]) << approxHighBitsN) + a[2] = (a[2] >> approxLowBitsN) | ((a[3]) << approxHighBitsN) + a[3] = (a[3] >> approxLowBitsN) | (aHi << approxHighBitsN) + + var f1 int64 + // from this point on c1 aliases for g0 + f1, c1 = updateFactorsDecompose(c1) + bHi := b.linearCombNonModular(&s, f1, &b, c1) + if bHi&signBitSelector != 0 { + // if bHi < 0 + f1, c1 = -f1, -c1 + bHi = b.neg(&b, bHi) + } + // right-shift b by k-1 bits + b[0] = (b[0] >> approxLowBitsN) | ((b[1]) << approxHighBitsN) + b[1] = (b[1] >> approxLowBitsN) | ((b[2]) << approxHighBitsN) + b[2] = (b[2] >> approxLowBitsN) | ((b[3]) << approxHighBitsN) + b[3] = (b[3] >> approxLowBitsN) | (bHi << approxHighBitsN) + + if i&1 == 1 { + // Combine current update factors with previously stored ones + // [f₀, g₀; f₁, g₁] ← [f₀, g₀; f₁, g₀] [pf₀, pg₀; pf₀, pg₀] + // We have |f₀|, |g₀|, |pf₀|, |pf₁| ≤ 2ᵏ⁻¹, and that |pf_i| < 2ᵏ⁻¹ for i ∈ {0, 1} + // Then for the new value we get |f₀| < 2ᵏ⁻¹ × 2ᵏ⁻¹ + 2ᵏ⁻¹ × 2ᵏ⁻¹ = 2²ᵏ⁻¹ + // Which leaves us with an extra bit for the sign + + // c0 aliases f0, c1 aliases g1 + c0, g0, f1, c1 = c0*pf0+g0*pf1, + c0*pg0+g0*pg1, + f1*pf0+c1*pf1, + f1*pg0+c1*pg1 + + s = u + u.linearCombSosSigned(&u, c0, &v, g0) + v.linearCombSosSigned(&s, f1, &v, c1) + + } else { + // Save update factors + pf0, pg0, pf1, pg1 = c0, g0, f1, c1 + } } - // initialize s = r^2 - var s = Element{ - 6242551132904523857, - 16951295617263545407, - 10923821274252739203, - 584663452775307866, + // For every iteration that we miss, v is not being multiplied by 2²ᵏ⁻² + const pSq int64 = 1 << (2 * (k - 1)) + // If the function is constant-time ish, this loop will not run (probably no need to take it out explicitly) + for ; i < invIterationsN; i += 2 { + v.mulWSigned(&v, pSq) } - // r = 0 - r := Element{} + z.Mul(&v, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + }) + return z +} - v := *x +// approximate a big number x into a single 64 bit word using its uppermost and lowermost bits +// if x fits in a word as is, no approximation necessary +func approximate(x *Element, nBits int) uint64 { - var carry, borrow uint64 - var bigger bool + if nBits <= 64 { + return x[0] + } - for { - for v[0]&1 == 0 { + const mask = (uint64(1) << (k - 1)) - 1 // k-1 ones + lo := mask & x[0] - // v = v >> 1 + hiWordIndex := (nBits - 1) / 64 - v[0] = v[0]>>1 | v[1]<<63 - v[1] = v[1]>>1 | v[2]<<63 - v[2] = v[2]>>1 | v[3]<<63 - v[3] >>= 1 + hiWordBitsAvailable := nBits - hiWordIndex*64 + hiWordBitsUsed := min(hiWordBitsAvailable, approxHighBitsN) - if s[0]&1 == 1 { + mask_ := uint64(^((1 << (hiWordBitsAvailable - hiWordBitsUsed)) - 1)) + hi := (x[hiWordIndex] & mask_) << (64 - hiWordBitsAvailable) - // s = s + q - s[0], carry = bits.Add64(s[0], 1860204336533995521, 0) - s[1], carry = bits.Add64(s[1], 14466829657984787300, carry) - s[2], carry = bits.Add64(s[2], 2737202078770428568, carry) - s[3], _ = bits.Add64(s[3], 1832378743606059307, carry) + mask_ = ^(1<<(approxLowBitsN+hiWordBitsUsed) - 1) + mid := (mask_ & x[hiWordIndex-1]) >> hiWordBitsUsed - } + return lo | mid | hi +} - // s = s >> 1 +func (z *Element) linearCombSosSigned(x *Element, xC int64, y *Element, yC int64) { + hi := z.linearCombNonModular(x, xC, y, yC) + z.montReduceSigned(z, hi) +} - s[0] = s[0]>>1 | s[1]<<63 - s[1] = s[1]>>1 | s[2]<<63 - s[2] = s[2]>>1 | s[3]<<63 - s[3] >>= 1 +// montReduceSigned SOS algorithm; xHi must be at most 63 bits long. Last bit of xHi may be used as a sign bit +func (z *Element) montReduceSigned(x *Element, xHi uint64) { - } - for u[0]&1 == 0 { + const signBitRemover = ^signBitSelector + neg := xHi&signBitSelector != 0 + // the SOS implementation requires that most significant bit is 0 + // Let X be xHi*r + x + // note that if X is negative we would have initially stored it as 2⁶⁴ r + X + xHi &= signBitRemover + // with this a negative X is now represented as 2⁶³ r + X - // u = u >> 1 + var t [2*Limbs - 1]uint64 + var C uint64 - u[0] = u[0]>>1 | u[1]<<63 - u[1] = u[1]>>1 | u[2]<<63 - u[2] = u[2]>>1 | u[3]<<63 - u[3] >>= 1 + m := x[0] * qInvNegLsw - if r[0]&1 == 1 { + C = madd0(m, qElementWord0, x[0]) + C, t[1] = madd2(m, qElementWord1, x[1], C) + C, t[2] = madd2(m, qElementWord2, x[2], C) + C, t[3] = madd2(m, qElementWord3, x[3], C) - // r = r + q - r[0], carry = bits.Add64(r[0], 1860204336533995521, 0) - r[1], carry = bits.Add64(r[1], 14466829657984787300, carry) - r[2], carry = bits.Add64(r[2], 2737202078770428568, carry) - r[3], _ = bits.Add64(r[3], 1832378743606059307, carry) + // the high word of m * qElement[3] is at most 62 bits + // x[3] + C is at most 65 bits (high word at most 1 bit) + // Thus the resulting C will be at most 63 bits + t[4] = xHi + C + // xHi and C are 63 bits, therefore no overflow - } + { + const i = 1 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + + t[i+Limbs] += C + } + { + const i = 2 + m = t[i] * qInvNegLsw - // r = r >> 1 + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) - r[0] = r[0]>>1 | r[1]<<63 - r[1] = r[1]>>1 | r[2]<<63 - r[2] = r[2]>>1 | r[3]<<63 - r[3] >>= 1 + t[i+Limbs] += C + } + { + const i = 3 + m := t[i] * qInvNegLsw + C = madd0(m, qElementWord0, t[i+0]) + C, z[0] = madd2(m, qElementWord1, t[i+1], C) + C, z[1] = madd2(m, qElementWord2, t[i+2], C) + z[3], z[2] = madd2(m, qElementWord3, t[i+3], C) + } + + // if z > q → z -= q + // note: this is NOT constant time + if !(z[3] < 1832378743606059307 || (z[3] == 1832378743606059307 && (z[2] < 2737202078770428568 || (z[2] == 2737202078770428568 && (z[1] < 14466829657984787300 || (z[1] == 14466829657984787300 && (z[0] < 1860204336533995521))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 1860204336533995521, 0) + z[1], b = bits.Sub64(z[1], 14466829657984787300, b) + z[2], b = bits.Sub64(z[2], 2737202078770428568, b) + z[3], _ = bits.Sub64(z[3], 1832378743606059307, b) + } + if neg { + // We have computed ( 2⁶³ r + X ) r⁻¹ = 2⁶³ + X r⁻¹ instead + var b uint64 + z[0], b = bits.Sub64(z[0], signBitSelector, 0) + z[1], b = bits.Sub64(z[1], 0, b) + z[2], b = bits.Sub64(z[2], 0, b) + z[3], b = bits.Sub64(z[3], 0, b) + + // Occurs iff x == 0 && xHi < 0, i.e. X = rX' for -2⁶³ ≤ X' < 0 + if b != 0 { + // z[3] = -1 + // negative: add q + const neg1 = 0xFFFFFFFFFFFFFFFF + + b = 0 + z[0], b = bits.Add64(z[0], qElementWord0, b) + z[1], b = bits.Add64(z[1], qElementWord1, b) + z[2], b = bits.Add64(z[2], qElementWord2, b) + z[3], _ = bits.Add64(neg1, qElementWord3, b) } + } +} + +// mulWSigned mul word signed (w/ montgomery reduction) +func (z *Element) mulWSigned(x *Element, y int64) { + m := y >> 63 + _mulWGeneric(z, x, uint64((y^m)-m)) + // multiply by abs(y) + if y < 0 { + z.Neg(z) + } +} + +func (z *Element) neg(x *Element, xHi uint64) uint64 { + var b uint64 - // v >= u - bigger = !(v[3] < u[3] || (v[3] == u[3] && (v[2] < u[2] || (v[2] == u[2] && (v[1] < u[1] || (v[1] == u[1] && (v[0] < u[0]))))))) + z[0], b = bits.Sub64(0, x[0], 0) + z[1], b = bits.Sub64(0, x[1], b) + z[2], b = bits.Sub64(0, x[2], b) + z[3], b = bits.Sub64(0, x[3], b) + xHi, _ = bits.Sub64(0, xHi, b) - if bigger { + return xHi +} - // v = v - u - v[0], borrow = bits.Sub64(v[0], u[0], 0) - v[1], borrow = bits.Sub64(v[1], u[1], borrow) - v[2], borrow = bits.Sub64(v[2], u[2], borrow) - v[3], _ = bits.Sub64(v[3], u[3], borrow) +// regular multiplication by one word regular (non montgomery) +// Fewer additions than the branch-free for positive y. Could be faster on some architectures +func (z *Element) mulWRegular(x *Element, y int64) uint64 { - // s = s - r - s[0], borrow = bits.Sub64(s[0], r[0], 0) - s[1], borrow = bits.Sub64(s[1], r[1], borrow) - s[2], borrow = bits.Sub64(s[2], r[2], borrow) - s[3], borrow = bits.Sub64(s[3], r[3], borrow) + // w := abs(y) + m := y >> 63 + w := uint64((y ^ m) - m) - if borrow == 1 { + var c uint64 + c, z[0] = bits.Mul64(x[0], w) + c, z[1] = madd1(x[1], w, c) + c, z[2] = madd1(x[2], w, c) + c, z[3] = madd1(x[3], w, c) - // s = s + q - s[0], carry = bits.Add64(s[0], 1860204336533995521, 0) - s[1], carry = bits.Add64(s[1], 14466829657984787300, carry) - s[2], carry = bits.Add64(s[2], 2737202078770428568, carry) - s[3], _ = bits.Add64(s[3], 1832378743606059307, carry) + if y < 0 { + c = z.neg(z, c) + } - } - } else { + return c +} - // u = u - v - u[0], borrow = bits.Sub64(u[0], v[0], 0) - u[1], borrow = bits.Sub64(u[1], v[1], borrow) - u[2], borrow = bits.Sub64(u[2], v[2], borrow) - u[3], _ = bits.Sub64(u[3], v[3], borrow) +/* +Removed: seems slower +// mulWRegular branch-free regular multiplication by one word (non montgomery) +func (z *Element) mulWRegularBf(x *Element, y int64) uint64 { - // r = r - s - r[0], borrow = bits.Sub64(r[0], s[0], 0) - r[1], borrow = bits.Sub64(r[1], s[1], borrow) - r[2], borrow = bits.Sub64(r[2], s[2], borrow) - r[3], borrow = bits.Sub64(r[3], s[3], borrow) + w := uint64(y) + allNeg := uint64(y >> 63) // -1 if y < 0, 0 o.w - if borrow == 1 { + // s[0], s[1] so results are not stored immediately in z. + // x[i] will be needed in the i+1 th iteration. We don't want to overwrite it in case x = z + var s [2]uint64 + var h [2]uint64 - // r = r + q - r[0], carry = bits.Add64(r[0], 1860204336533995521, 0) - r[1], carry = bits.Add64(r[1], 14466829657984787300, carry) - r[2], carry = bits.Add64(r[2], 2737202078770428568, carry) - r[3], _ = bits.Add64(r[3], 1832378743606059307, carry) + h[0], s[0] = bits.Mul64(x[0], w) - } + c := uint64(0) + b := uint64(0) + + { + const curI = 1 % 2 + const prevI = 1 - curI + const iMinusOne = 1 - 1 + + h[curI], s[curI] = bits.Mul64(x[1], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (u[0] == 1) && (u[3]|u[2]|u[1]) == 0 { - z.Set(&r) - return z + + { + const curI = 2 % 2 + const prevI = 1 - curI + const iMinusOne = 2 - 1 + + h[curI], s[curI] = bits.Mul64(x[2], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (v[0] == 1) && (v[3]|v[2]|v[1]) == 0 { - z.Set(&s) - return z + + { + const curI = 3 % 2 + const prevI = 1 - curI + const iMinusOne = 3 - 1 + + h[curI], s[curI] = bits.Mul64(x[3], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } + { + const curI = 4 % 2 + const prevI = 1 - curI + const iMinusOne = 3 + + s[curI], _ = bits.Sub64(h[prevI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + + return s[curI] + c } +}*/ + +// Requires NoCarry +func (z *Element) linearCombNonModular(x *Element, xC int64, y *Element, yC int64) uint64 { + var yTimes Element + + yHi := yTimes.mulWRegular(y, yC) + xHi := z.mulWRegular(x, xC) + + carry := uint64(0) + z[0], carry = bits.Add64(z[0], yTimes[0], carry) + z[1], carry = bits.Add64(z[1], yTimes[1], carry) + z[2], carry = bits.Add64(z[2], yTimes[2], carry) + z[3], carry = bits.Add64(z[3], yTimes[3], carry) + + yHi, _ = bits.Add64(xHi, yHi, carry) + return yHi } diff --git a/ecc/bls24-315/fr/element_test.go b/ecc/bls24-315/fr/element_test.go index ff1cebf4d4..ba8f8c3790 100644 --- a/ecc/bls24-315/fr/element_test.go +++ b/ecc/bls24-315/fr/element_test.go @@ -22,6 +22,7 @@ import ( "fmt" "math/big" "math/bits" + mrand "math/rand" "testing" "github.com/leanovate/gopter" @@ -271,7 +272,7 @@ var staticTestValues []Element func init() { staticTestValues = append(staticTestValues, Element{}) // zero staticTestValues = append(staticTestValues, One()) // one - staticTestValues = append(staticTestValues, rSquare) // r^2 + staticTestValues = append(staticTestValues, rSquare) // r² var e, one Element one.SetOne() e.Sub(&qElement, &one) @@ -1962,3 +1963,500 @@ func genFull() gopter.Gen { return genResult } } + +func TestElementInversionApproximation(t *testing.T) { + var x Element + for i := 0; i < 1000; i++ { + x.SetRandom() + + // Normally small elements are unlikely. Here we give them a higher chance + xZeros := mrand.Int() % Limbs + for j := 1; j < xZeros; j++ { + x[Limbs-j] = 0 + } + + a := approximate(&x, x.BitLen()) + aRef := approximateRef(&x) + + if a != aRef { + t.Error("Approximation mismatch") + } + } +} + +func TestElementInversionCorrectionFactorFormula(t *testing.T) { + const kLimbs = k * Limbs + const power = kLimbs*6 + invIterationsN*(kLimbs-k+1) + factorInt := big.NewInt(1) + factorInt.Lsh(factorInt, power) + factorInt.Mod(factorInt, Modulus()) + + var refFactorInt big.Int + inversionCorrectionFactor := Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + } + inversionCorrectionFactor.ToBigInt(&refFactorInt) + + if refFactorInt.Cmp(factorInt) != 0 { + t.Error("mismatch") + } +} + +func TestElementLinearComb(t *testing.T) { + var x Element + var y Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + y.SetRandom() + testLinearComb(t, &x, mrand.Int63(), &y, mrand.Int63()) + } +} + +// Probably unnecessary post-dev. In case the output of inv is wrong, this checks whether it's only off by a constant factor. +func TestElementInversionCorrectionFactor(t *testing.T) { + + // (1/x)/inv(x) = (1/1)/inv(1) ⇔ inv(1) = x inv(x) + + var one Element + var oneInv Element + one.SetOne() + oneInv.Inverse(&one) + + for i := 0; i < 100; i++ { + var x Element + var xInv Element + x.SetRandom() + xInv.Inverse(&x) + + x.Mul(&x, &xInv) + if !x.Equal(&oneInv) { + t.Error("Correction factor is inconsistent") + } + } + + if !oneInv.Equal(&one) { + var i big.Int + oneInv.ToBigIntRegular(&i) // no montgomery + i.ModInverse(&i, Modulus()) + var fac Element + fac.setBigInt(&i) // back to montgomery + + var facTimesFac Element + facTimesFac.Mul(&fac, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + }) + + t.Error("Correction factor is consistently off by", fac, "Should be", facTimesFac) + } +} + +func TestElementBigNumNeg(t *testing.T) { + var a Element + aHi := a.neg(&a, 0) + if !a.IsZero() || aHi != 0 { + t.Error("-0 != 0") + } +} + +func TestElementBigNumWMul(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + w := mrand.Int63() + testBigNumWMul(t, &x, w) + } +} + +func TestElementVeryBigIntConversion(t *testing.T) { + xHi := mrand.Uint64() + var x Element + x.SetRandom() + var xInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + x.assertMatchVeryBigInt(t, xHi, &xInt) +} + +func TestElementMontReducePos(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64() & ^signBitSelector) + } +} + +func TestElementMontReduceNeg(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64()|signBitSelector) + } +} + +func TestElementMontNegMultipleOfR(t *testing.T) { + var zero Element + + for i := 0; i < 1000; i++ { + testMontReduceSigned(t, &zero, mrand.Uint64()|signBitSelector) + } +} + +//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +func TestUpdateFactorSubtraction(t *testing.T) { + for i := 0; i < 1000; i++ { + + f0, g0 := randomizeUpdateFactors() + f1, g1 := randomizeUpdateFactors() + + for f0-f1 > 1<<31 || f0-f1 <= -1<<31 { + f1 /= 2 + } + + for g0-g1 > 1<<31 || g0-g1 <= -1<<31 { + g1 /= 2 + } + + c0 := updateFactorsCompose(f0, g0) + c1 := updateFactorsCompose(f1, g1) + + cRes := c0 - c1 + fRes, gRes := updateFactorsDecompose(cRes) + + if fRes != f0-f1 || gRes != g0-g1 { + t.Error(i) + } + } +} + +func TestUpdateFactorsDouble(t *testing.T) { + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f > 1<<30 || f < (-1<<31+1)/2 { + f /= 2 + if g <= 1<<29 && g >= (-1<<31+1)/4 { + g *= 2 //g was kept small on f's account. Now that we're halving f, we can double g + } + } + + if g > 1<<30 || g < (-1<<31+1)/2 { + g /= 2 + + if f <= 1<<29 && f >= (-1<<31+1)/4 { + f *= 2 //f was kept small on g's account. Now that we're halving g, we can double f + } + } + + c := updateFactorsCompose(f, g) + cD := c * 2 + fD, gD := updateFactorsDecompose(cD) + + if fD != 2*f || gD != 2*g { + t.Error(i) + } + } +} + +func TestUpdateFactorsNeg(t *testing.T) { + var fMistake bool + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f == 0x80000000 || g == 0x80000000 { + // Update factors this large can only have been obtained after 31 iterations and will therefore never be negated + // We don't have capacity to store -2³¹ + // Repeat this iteration + i-- + continue + } + + c := updateFactorsCompose(f, g) + nc := -c + nf, ng := updateFactorsDecompose(nc) + fMistake = fMistake || nf != -f + if nf != -f || ng != -g { + t.Errorf("Mismatch iteration #%d:\n%d, %d ->\n %d -> %d ->\n %d, %d\n Inputs in hex: %X, %X", + i, f, g, c, nc, nf, ng, f, g) + } + } + if fMistake { + t.Error("Mistake with f detected") + } else { + t.Log("All good with f") + } +} + +func TestUpdateFactorsNeg0(t *testing.T) { + c := updateFactorsCompose(0, 0) + t.Logf("c(0,0) = %X", c) + cn := -c + + if c != cn { + t.Error("Negation of zero update factors should yield the same result.") + } +} + +func TestUpdateFactorDecomposition(t *testing.T) { + var negSeen bool + + for i := 0; i < 1000; i++ { + + f, g := randomizeUpdateFactors() + + if f <= -(1<<31) || f > 1<<31 { + t.Fatal("f out of range") + } + + negSeen = negSeen || f < 0 + + c := updateFactorsCompose(f, g) + + fBack, gBack := updateFactorsDecompose(c) + + if f != fBack || g != gBack { + t.Errorf("(%d, %d) -> %d -> (%d, %d)\n", f, g, c, fBack, gBack) + } + } + + if !negSeen { + t.Fatal("No negative f factors") + } +} + +func TestUpdateFactorInitialValues(t *testing.T) { + + f0, g0 := updateFactorsDecompose(updateFactorIdentityMatrixRow0) + f1, g1 := updateFactorsDecompose(updateFactorIdentityMatrixRow1) + + if f0 != 1 || g0 != 0 || f1 != 0 || g1 != 1 { + t.Error("Update factor initial value constants are incorrect") + } +} + +func TestUpdateFactorsRandomization(t *testing.T) { + var maxLen int + + //t.Log("|f| + |g| is not to exceed", 1 << 31) + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + lf, lg := abs64T32(f), abs64T32(g) + absSum := lf + lg + if absSum >= 1<<31 { + + if absSum == 1<<31 { + maxLen++ + } else { + t.Error(i, "Sum of absolute values too large, f =", f, ",g =", g, ",|f| + |g| =", absSum) + } + } + } + + if maxLen == 0 { + t.Error("max len not observed") + } else { + t.Log(maxLen, "maxLens observed") + } +} + +func randomizeUpdateFactor(absLimit uint32) int64 { + const maxSizeLikelihood = 10 + maxSize := mrand.Intn(maxSizeLikelihood) + + absLimit64 := int64(absLimit) + var f int64 + switch maxSize { + case 0: + f = absLimit64 + case 1: + f = -absLimit64 + default: + f = int64(mrand.Uint64()%(2*uint64(absLimit64)+1)) - absLimit64 + } + + if f > 1<<31 { + return 1 << 31 + } else if f < -1<<31+1 { + return -1<<31 + 1 + } + + return f +} + +func abs64T32(f int64) uint32 { + if f >= 1<<32 || f < -1<<32 { + panic("f out of range") + } + + if f < 0 { + return uint32(-f) + } + return uint32(f) +} + +func randomizeUpdateFactors() (int64, int64) { + var f [2]int64 + b := mrand.Int() % 2 + + f[b] = randomizeUpdateFactor(1 << 31) + + //As per the paper, |f| + |g| \le 2³¹. + f[1-b] = randomizeUpdateFactor(1<<31 - abs64T32(f[b])) + + //Patching another edge case + if f[0]+f[1] == -1<<31 { + b = mrand.Int() % 2 + f[b]++ + } + + return f[0], f[1] +} + +func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { + + var p1 big.Int + x.ToBigInt(&p1) + p1.Mul(&p1, big.NewInt(xC)) + + var p2 big.Int + y.ToBigInt(&p2) + p2.Mul(&p2, big.NewInt(yC)) + + p1.Add(&p1, &p2) + p1.Mod(&p1, Modulus()) + montReduce(&p1, &p1) + + var z Element + z.linearCombSosSigned(x, xC, y, yC) + z.assertMatchVeryBigInt(t, 0, &p1) +} + +func testBigNumWMul(t *testing.T, a *Element, c int64) { + var aHi uint64 + var aTimes Element + aHi = aTimes.mulWRegular(a, c) + + assertMulProduct(t, a, c, &aTimes, aHi) +} + +func testMontReduceSigned(t *testing.T, x *Element, xHi uint64) { + var res Element + var xInt big.Int + var resInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + res.montReduceSigned(x, xHi) + montReduce(&resInt, &xInt) + res.assertMatchVeryBigInt(t, 0, &resInt) +} + +func updateFactorsCompose(f int64, g int64) int64 { + return f + g<<32 +} + +var rInv big.Int + +func montReduce(res *big.Int, x *big.Int) { + if rInv.BitLen() == 0 { // initialization + rInv.SetUint64(1) + rInv.Lsh(&rInv, Limbs*64) + rInv.ModInverse(&rInv, Modulus()) + } + res.Mul(x, &rInv) + res.Mod(res, Modulus()) +} + +func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { + z.ToBigInt(i) + var upperWord big.Int + upperWord.SetUint64(xHi) + upperWord.Lsh(&upperWord, Limbs*64) + i.Add(&upperWord, i) +} + +func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { + z.toVeryBigIntUnsigned(i, xHi) + if signBitSelector&xHi != 0 { + twosCompModulus := big.NewInt(1) + twosCompModulus.Lsh(twosCompModulus, (Limbs+1)*64) + i.Sub(i, twosCompModulus) + } +} + +func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { + var xInt big.Int + x.ToBigInt(&xInt) + + xInt.Mul(&xInt, big.NewInt(c)) + + result.assertMatchVeryBigInt(t, resultHi, &xInt) + return xInt +} + +func assertMatch(t *testing.T, w []big.Word, a uint64, index int) { + + var wI big.Word + + if index < len(w) { + wI = w[index] + } + + const filter uint64 = 0xFFFFFFFFFFFFFFFF >> (64 - bits.UintSize) + + a = a >> ((index * bits.UintSize) % 64) + a &= filter + + if uint64(wI) != a { + t.Error("Bignum mismatch: disagreement on word", index) + } +} + +func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { + + var modulus big.Int + var aIntMod big.Int + modulus.SetInt64(1) + modulus.Lsh(&modulus, (Limbs+1)*64) + aIntMod.Mod(aInt, &modulus) + + words := aIntMod.Bits() + + const steps = 64 / bits.UintSize + for i := 0; i < Limbs*steps; i++ { + assertMatch(t, words, z[i/steps], i) + } + + for i := 0; i < steps; i++ { + assertMatch(t, words, aHi, Limbs*steps+i) + } +} + +func approximateRef(x *Element) uint64 { + + var asInt big.Int + x.ToBigInt(&asInt) + n := x.BitLen() + + if n <= 64 { + return asInt.Uint64() + } + + modulus := big.NewInt(1 << 31) + var lo big.Int + lo.Mod(&asInt, modulus) + + modulus.Lsh(modulus, uint(n-64)) + var hi big.Int + hi.Div(&asInt, modulus) + hi.Lsh(&hi, 31) + + hi.Add(&hi, &lo) + return hi.Uint64() +} diff --git a/ecc/bn254/fp/element.go b/ecc/bn254/fp/element.go index f1b78434ee..9d0de180c9 100644 --- a/ecc/bn254/fp/element.go +++ b/ecc/bn254/fp/element.go @@ -63,13 +63,21 @@ func Modulus() *big.Int { } // q (modulus) +const qElementWord0 uint64 = 4332616871279656263 +const qElementWord1 uint64 = 10917124144477883021 +const qElementWord2 uint64 = 13281191951274694749 +const qElementWord3 uint64 = 3486998266802970665 + var qElement = Element{ - 4332616871279656263, - 10917124144477883021, - 13281191951274694749, - 3486998266802970665, + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, } +// Used for Montgomery reduction. (qInvNeg) q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +const qInvNegLsw uint64 = 9786893198990664585 + // rSquare var rSquare = Element{ 17522657719365597833, @@ -187,7 +195,7 @@ func (z *Element) IsZero() bool { return (z[3] | z[2] | z[1] | z[0]) == 0 } -// IsUint64 returns true if z[0] >= 0 and all other words are 0 +// IsUint64 returns true if z[0] ⩾ 0 and all other words are 0 func (z *Element) IsUint64() bool { return (z[3] | z[2] | z[1]) == 0 } @@ -257,7 +265,7 @@ func (z *Element) SetRandom() (*Element, error) { z[3] = binary.BigEndian.Uint64(bytes[24:32]) z[3] %= 3486998266802970665 - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 10917124144477883021 || (z[1] == 10917124144477883021 && (z[0] < 4332616871279656263))))))) { var b uint64 @@ -405,7 +413,58 @@ func _mulGeneric(z, x, y *Element) { z[3], z[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) } - // if z > q --> z -= q + // if z > q → z -= q + // note: this is NOT constant time + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 10917124144477883021 || (z[1] == 10917124144477883021 && (z[0] < 4332616871279656263))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4332616871279656263, 0) + z[1], b = bits.Sub64(z[1], 10917124144477883021, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } +} + +func _mulWGeneric(z, x *Element, y uint64) { + + var t [4]uint64 + { + // round 0 + c1, c0 := bits.Mul64(y, x[0]) + m := c0 * 9786893198990664585 + c2 := madd0(m, 4332616871279656263, c0) + c1, c0 = madd1(y, x[1], c1) + c2, t[0] = madd2(m, 10917124144477883021, c2, c0) + c1, c0 = madd1(y, x[2], c1) + c2, t[1] = madd2(m, 13281191951274694749, c2, c0) + c1, c0 = madd1(y, x[3], c1) + t[3], t[2] = madd3(m, 3486998266802970665, c0, c2, c1) + } + { + // round 1 + m := t[0] * 9786893198990664585 + c2 := madd0(m, 4332616871279656263, t[0]) + c2, t[0] = madd2(m, 10917124144477883021, c2, t[1]) + c2, t[1] = madd2(m, 13281191951274694749, c2, t[2]) + t[3], t[2] = madd2(m, 3486998266802970665, t[3], c2) + } + { + // round 2 + m := t[0] * 9786893198990664585 + c2 := madd0(m, 4332616871279656263, t[0]) + c2, t[0] = madd2(m, 10917124144477883021, c2, t[1]) + c2, t[1] = madd2(m, 13281191951274694749, c2, t[2]) + t[3], t[2] = madd2(m, 3486998266802970665, t[3], c2) + } + { + // round 3 + m := t[0] * 9786893198990664585 + c2 := madd0(m, 4332616871279656263, t[0]) + c2, z[0] = madd2(m, 10917124144477883021, c2, t[1]) + c2, z[1] = madd2(m, 13281191951274694749, c2, t[2]) + z[3], z[2] = madd2(m, 3486998266802970665, t[3], c2) + } + + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 10917124144477883021 || (z[1] == 10917124144477883021 && (z[0] < 4332616871279656263))))))) { var b uint64 @@ -456,7 +515,7 @@ func _fromMontGeneric(z *Element) { z[3] = C } - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 10917124144477883021 || (z[1] == 10917124144477883021 && (z[0] < 4332616871279656263))))))) { var b uint64 @@ -475,7 +534,7 @@ func _addGeneric(z, x, y *Element) { z[2], carry = bits.Add64(x[2], y[2], carry) z[3], _ = bits.Add64(x[3], y[3], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 10917124144477883021 || (z[1] == 10917124144477883021 && (z[0] < 4332616871279656263))))))) { var b uint64 @@ -494,7 +553,7 @@ func _doubleGeneric(z, x *Element) { z[2], carry = bits.Add64(x[2], x[2], carry) z[3], _ = bits.Add64(x[3], x[3], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 10917124144477883021 || (z[1] == 10917124144477883021 && (z[0] < 4332616871279656263))))))) { var b uint64 @@ -534,7 +593,7 @@ func _negGeneric(z, x *Element) { func _reduceGeneric(z *Element) { - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 10917124144477883021 || (z[1] == 10917124144477883021 && (z[0] < 4332616871279656263))))))) { var b uint64 @@ -642,7 +701,7 @@ func (z *Element) Exp(x Element, exponent *big.Int) *Element { } // ToMont converts z to Montgomery form -// sets and returns z = z * r^2 +// sets and returns z = z * r² func (z *Element) ToMont() *Element { return z.Mul(z, &rSquare) } @@ -772,7 +831,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { return z } -// setBigInt assumes 0 <= v < q +// setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() @@ -906,153 +965,418 @@ func (z *Element) Sqrt(x *Element) *Element { return nil } -// Inverse z = x^-1 mod q -// Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" -// if x == 0, sets and returns z = x +func max(a int, b int) int { + if a > b { + return a + } + return b +} + +func min(a int, b int) int { + if a < b { + return a + } + return b +} + +const updateFactorsConversionBias int64 = 0x7fffffff7fffffff // (2³¹ - 1)(2³² + 1) +const updateFactorIdentityMatrixRow0 = 1 +const updateFactorIdentityMatrixRow1 = 1 << 32 + +func updateFactorsDecompose(c int64) (int64, int64) { + c += updateFactorsConversionBias + const low32BitsFilter int64 = 0xFFFFFFFF + f := c&low32BitsFilter - 0x7FFFFFFF + g := c>>32&low32BitsFilter - 0x7FFFFFFF + return f, g +} + +const k = 32 // word size / 2 +const signBitSelector = uint64(1) << 63 +const approxLowBitsN = k - 1 +const approxHighBitsN = k + 1 +const inversionCorrectionFactorWord0 = 11111708840330028223 +const inversionCorrectionFactorWord1 = 3098618286181893933 +const inversionCorrectionFactorWord2 = 756602578711705709 +const inversionCorrectionFactorWord3 = 1041752015607019851 + +const invIterationsN = 18 + +// Inverse z = x⁻¹ mod q +// Implements "Optimized Binary GCD for Modular Inversion" +// https://github.com/pornin/bingcd/blob/main/doc/bingcd.pdf func (z *Element) Inverse(x *Element) *Element { if x.IsZero() { z.SetZero() return z } - // initialize u = q - var u = Element{ - 4332616871279656263, - 10917124144477883021, - 13281191951274694749, - 3486998266802970665, + a := *x + b := Element{ + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, + } // b := q + + u := Element{1} + + // Update factors: we get [u; v]:= [f0 g0; f1 g1] [u; v] + // c_i = f_i + 2³¹ - 1 + 2³² * (g_i + 2³¹ - 1) + var c0, c1 int64 + + // Saved update factors to reduce the number of field multiplications + var pf0, pf1, pg0, pg1 int64 + + var i uint + + var v, s Element + + // Since u,v are updated every other iteration, we must make sure we terminate after evenly many iterations + // This also lets us get away with half as many updates to u,v + // To make this constant-time-ish, replace the condition with i < invIterationsN + for i = 0; i&1 == 1 || !a.IsZero(); i++ { + n := max(a.BitLen(), b.BitLen()) + aApprox, bApprox := approximate(&a, n), approximate(&b, n) + + // After 0 iterations, we have f₀ ≤ 2⁰ and f₁ < 2⁰ + // f0, g0, f1, g1 = 1, 0, 0, 1 + c0, c1 = updateFactorIdentityMatrixRow0, updateFactorIdentityMatrixRow1 + + for j := 0; j < approxLowBitsN; j++ { + + if aApprox&1 == 0 { + aApprox /= 2 + } else { + s, borrow := bits.Sub64(aApprox, bApprox, 0) + if borrow == 1 { + s = bApprox - aApprox + bApprox = aApprox + c0, c1 = c1, c0 + } + + aApprox = s / 2 + c0 = c0 - c1 + + // Now |f₀| < 2ʲ + 2ʲ = 2ʲ⁺¹ + // |f₁| ≤ 2ʲ still + } + + c1 *= 2 + // |f₁| ≤ 2ʲ⁺¹ + } + + s = a + + var g0 int64 + // from this point on c0 aliases for f0 + c0, g0 = updateFactorsDecompose(c0) + aHi := a.linearCombNonModular(&s, c0, &b, g0) + if aHi&signBitSelector != 0 { + // if aHi < 0 + c0, g0 = -c0, -g0 + aHi = a.neg(&a, aHi) + } + // right-shift a by k-1 bits + a[0] = (a[0] >> approxLowBitsN) | ((a[1]) << approxHighBitsN) + a[1] = (a[1] >> approxLowBitsN) | ((a[2]) << approxHighBitsN) + a[2] = (a[2] >> approxLowBitsN) | ((a[3]) << approxHighBitsN) + a[3] = (a[3] >> approxLowBitsN) | (aHi << approxHighBitsN) + + var f1 int64 + // from this point on c1 aliases for g0 + f1, c1 = updateFactorsDecompose(c1) + bHi := b.linearCombNonModular(&s, f1, &b, c1) + if bHi&signBitSelector != 0 { + // if bHi < 0 + f1, c1 = -f1, -c1 + bHi = b.neg(&b, bHi) + } + // right-shift b by k-1 bits + b[0] = (b[0] >> approxLowBitsN) | ((b[1]) << approxHighBitsN) + b[1] = (b[1] >> approxLowBitsN) | ((b[2]) << approxHighBitsN) + b[2] = (b[2] >> approxLowBitsN) | ((b[3]) << approxHighBitsN) + b[3] = (b[3] >> approxLowBitsN) | (bHi << approxHighBitsN) + + if i&1 == 1 { + // Combine current update factors with previously stored ones + // [f₀, g₀; f₁, g₁] ← [f₀, g₀; f₁, g₀] [pf₀, pg₀; pf₀, pg₀] + // We have |f₀|, |g₀|, |pf₀|, |pf₁| ≤ 2ᵏ⁻¹, and that |pf_i| < 2ᵏ⁻¹ for i ∈ {0, 1} + // Then for the new value we get |f₀| < 2ᵏ⁻¹ × 2ᵏ⁻¹ + 2ᵏ⁻¹ × 2ᵏ⁻¹ = 2²ᵏ⁻¹ + // Which leaves us with an extra bit for the sign + + // c0 aliases f0, c1 aliases g1 + c0, g0, f1, c1 = c0*pf0+g0*pf1, + c0*pg0+g0*pg1, + f1*pf0+c1*pf1, + f1*pg0+c1*pg1 + + s = u + u.linearCombSosSigned(&u, c0, &v, g0) + v.linearCombSosSigned(&s, f1, &v, c1) + + } else { + // Save update factors + pf0, pg0, pf1, pg1 = c0, g0, f1, c1 + } + } + + // For every iteration that we miss, v is not being multiplied by 2²ᵏ⁻² + const pSq int64 = 1 << (2 * (k - 1)) + // If the function is constant-time ish, this loop will not run (probably no need to take it out explicitly) + for ; i < invIterationsN; i += 2 { + v.mulWSigned(&v, pSq) } - // initialize s = r^2 - var s = Element{ - 17522657719365597833, - 13107472804851548667, - 5164255478447964150, - 493319470278259999, + z.Mul(&v, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + }) + return z +} + +// approximate a big number x into a single 64 bit word using its uppermost and lowermost bits +// if x fits in a word as is, no approximation necessary +func approximate(x *Element, nBits int) uint64 { + + if nBits <= 64 { + return x[0] } - // r = 0 - r := Element{} + const mask = (uint64(1) << (k - 1)) - 1 // k-1 ones + lo := mask & x[0] - v := *x + hiWordIndex := (nBits - 1) / 64 - var carry, borrow uint64 - var bigger bool + hiWordBitsAvailable := nBits - hiWordIndex*64 + hiWordBitsUsed := min(hiWordBitsAvailable, approxHighBitsN) - for { - for v[0]&1 == 0 { + mask_ := uint64(^((1 << (hiWordBitsAvailable - hiWordBitsUsed)) - 1)) + hi := (x[hiWordIndex] & mask_) << (64 - hiWordBitsAvailable) - // v = v >> 1 + mask_ = ^(1<<(approxLowBitsN+hiWordBitsUsed) - 1) + mid := (mask_ & x[hiWordIndex-1]) >> hiWordBitsUsed - v[0] = v[0]>>1 | v[1]<<63 - v[1] = v[1]>>1 | v[2]<<63 - v[2] = v[2]>>1 | v[3]<<63 - v[3] >>= 1 + return lo | mid | hi +} - if s[0]&1 == 1 { +func (z *Element) linearCombSosSigned(x *Element, xC int64, y *Element, yC int64) { + hi := z.linearCombNonModular(x, xC, y, yC) + z.montReduceSigned(z, hi) +} - // s = s + q - s[0], carry = bits.Add64(s[0], 4332616871279656263, 0) - s[1], carry = bits.Add64(s[1], 10917124144477883021, carry) - s[2], carry = bits.Add64(s[2], 13281191951274694749, carry) - s[3], _ = bits.Add64(s[3], 3486998266802970665, carry) +// montReduceSigned SOS algorithm; xHi must be at most 63 bits long. Last bit of xHi may be used as a sign bit +func (z *Element) montReduceSigned(x *Element, xHi uint64) { - } + const signBitRemover = ^signBitSelector + neg := xHi&signBitSelector != 0 + // the SOS implementation requires that most significant bit is 0 + // Let X be xHi*r + x + // note that if X is negative we would have initially stored it as 2⁶⁴ r + X + xHi &= signBitRemover + // with this a negative X is now represented as 2⁶³ r + X - // s = s >> 1 + var t [2*Limbs - 1]uint64 + var C uint64 - s[0] = s[0]>>1 | s[1]<<63 - s[1] = s[1]>>1 | s[2]<<63 - s[2] = s[2]>>1 | s[3]<<63 - s[3] >>= 1 + m := x[0] * qInvNegLsw - } - for u[0]&1 == 0 { + C = madd0(m, qElementWord0, x[0]) + C, t[1] = madd2(m, qElementWord1, x[1], C) + C, t[2] = madd2(m, qElementWord2, x[2], C) + C, t[3] = madd2(m, qElementWord3, x[3], C) - // u = u >> 1 + // the high word of m * qElement[3] is at most 62 bits + // x[3] + C is at most 65 bits (high word at most 1 bit) + // Thus the resulting C will be at most 63 bits + t[4] = xHi + C + // xHi and C are 63 bits, therefore no overflow - u[0] = u[0]>>1 | u[1]<<63 - u[1] = u[1]>>1 | u[2]<<63 - u[2] = u[2]>>1 | u[3]<<63 - u[3] >>= 1 + { + const i = 1 + m = t[i] * qInvNegLsw - if r[0]&1 == 1 { + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) - // r = r + q - r[0], carry = bits.Add64(r[0], 4332616871279656263, 0) - r[1], carry = bits.Add64(r[1], 10917124144477883021, carry) - r[2], carry = bits.Add64(r[2], 13281191951274694749, carry) - r[3], _ = bits.Add64(r[3], 3486998266802970665, carry) + t[i+Limbs] += C + } + { + const i = 2 + m = t[i] * qInvNegLsw - } + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) - // r = r >> 1 + t[i+Limbs] += C + } + { + const i = 3 + m := t[i] * qInvNegLsw - r[0] = r[0]>>1 | r[1]<<63 - r[1] = r[1]>>1 | r[2]<<63 - r[2] = r[2]>>1 | r[3]<<63 - r[3] >>= 1 + C = madd0(m, qElementWord0, t[i+0]) + C, z[0] = madd2(m, qElementWord1, t[i+1], C) + C, z[1] = madd2(m, qElementWord2, t[i+2], C) + z[3], z[2] = madd2(m, qElementWord3, t[i+3], C) + } + // if z > q → z -= q + // note: this is NOT constant time + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 10917124144477883021 || (z[1] == 10917124144477883021 && (z[0] < 4332616871279656263))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4332616871279656263, 0) + z[1], b = bits.Sub64(z[1], 10917124144477883021, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } + if neg { + // We have computed ( 2⁶³ r + X ) r⁻¹ = 2⁶³ + X r⁻¹ instead + var b uint64 + z[0], b = bits.Sub64(z[0], signBitSelector, 0) + z[1], b = bits.Sub64(z[1], 0, b) + z[2], b = bits.Sub64(z[2], 0, b) + z[3], b = bits.Sub64(z[3], 0, b) + + // Occurs iff x == 0 && xHi < 0, i.e. X = rX' for -2⁶³ ≤ X' < 0 + if b != 0 { + // z[3] = -1 + // negative: add q + const neg1 = 0xFFFFFFFFFFFFFFFF + + b = 0 + z[0], b = bits.Add64(z[0], qElementWord0, b) + z[1], b = bits.Add64(z[1], qElementWord1, b) + z[2], b = bits.Add64(z[2], qElementWord2, b) + z[3], _ = bits.Add64(neg1, qElementWord3, b) } + } +} + +// mulWSigned mul word signed (w/ montgomery reduction) +func (z *Element) mulWSigned(x *Element, y int64) { + m := y >> 63 + _mulWGeneric(z, x, uint64((y^m)-m)) + // multiply by abs(y) + if y < 0 { + z.Neg(z) + } +} + +func (z *Element) neg(x *Element, xHi uint64) uint64 { + var b uint64 - // v >= u - bigger = !(v[3] < u[3] || (v[3] == u[3] && (v[2] < u[2] || (v[2] == u[2] && (v[1] < u[1] || (v[1] == u[1] && (v[0] < u[0]))))))) + z[0], b = bits.Sub64(0, x[0], 0) + z[1], b = bits.Sub64(0, x[1], b) + z[2], b = bits.Sub64(0, x[2], b) + z[3], b = bits.Sub64(0, x[3], b) + xHi, _ = bits.Sub64(0, xHi, b) - if bigger { + return xHi +} - // v = v - u - v[0], borrow = bits.Sub64(v[0], u[0], 0) - v[1], borrow = bits.Sub64(v[1], u[1], borrow) - v[2], borrow = bits.Sub64(v[2], u[2], borrow) - v[3], _ = bits.Sub64(v[3], u[3], borrow) +// regular multiplication by one word regular (non montgomery) +// Fewer additions than the branch-free for positive y. Could be faster on some architectures +func (z *Element) mulWRegular(x *Element, y int64) uint64 { - // s = s - r - s[0], borrow = bits.Sub64(s[0], r[0], 0) - s[1], borrow = bits.Sub64(s[1], r[1], borrow) - s[2], borrow = bits.Sub64(s[2], r[2], borrow) - s[3], borrow = bits.Sub64(s[3], r[3], borrow) + // w := abs(y) + m := y >> 63 + w := uint64((y ^ m) - m) - if borrow == 1 { + var c uint64 + c, z[0] = bits.Mul64(x[0], w) + c, z[1] = madd1(x[1], w, c) + c, z[2] = madd1(x[2], w, c) + c, z[3] = madd1(x[3], w, c) - // s = s + q - s[0], carry = bits.Add64(s[0], 4332616871279656263, 0) - s[1], carry = bits.Add64(s[1], 10917124144477883021, carry) - s[2], carry = bits.Add64(s[2], 13281191951274694749, carry) - s[3], _ = bits.Add64(s[3], 3486998266802970665, carry) + if y < 0 { + c = z.neg(z, c) + } - } - } else { + return c +} - // u = u - v - u[0], borrow = bits.Sub64(u[0], v[0], 0) - u[1], borrow = bits.Sub64(u[1], v[1], borrow) - u[2], borrow = bits.Sub64(u[2], v[2], borrow) - u[3], _ = bits.Sub64(u[3], v[3], borrow) +/* +Removed: seems slower +// mulWRegular branch-free regular multiplication by one word (non montgomery) +func (z *Element) mulWRegularBf(x *Element, y int64) uint64 { - // r = r - s - r[0], borrow = bits.Sub64(r[0], s[0], 0) - r[1], borrow = bits.Sub64(r[1], s[1], borrow) - r[2], borrow = bits.Sub64(r[2], s[2], borrow) - r[3], borrow = bits.Sub64(r[3], s[3], borrow) + w := uint64(y) + allNeg := uint64(y >> 63) // -1 if y < 0, 0 o.w - if borrow == 1 { + // s[0], s[1] so results are not stored immediately in z. + // x[i] will be needed in the i+1 th iteration. We don't want to overwrite it in case x = z + var s [2]uint64 + var h [2]uint64 - // r = r + q - r[0], carry = bits.Add64(r[0], 4332616871279656263, 0) - r[1], carry = bits.Add64(r[1], 10917124144477883021, carry) - r[2], carry = bits.Add64(r[2], 13281191951274694749, carry) - r[3], _ = bits.Add64(r[3], 3486998266802970665, carry) + h[0], s[0] = bits.Mul64(x[0], w) - } + c := uint64(0) + b := uint64(0) + + { + const curI = 1 % 2 + const prevI = 1 - curI + const iMinusOne = 1 - 1 + + h[curI], s[curI] = bits.Mul64(x[1], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (u[0] == 1) && (u[3]|u[2]|u[1]) == 0 { - z.Set(&r) - return z + + { + const curI = 2 % 2 + const prevI = 1 - curI + const iMinusOne = 2 - 1 + + h[curI], s[curI] = bits.Mul64(x[2], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (v[0] == 1) && (v[3]|v[2]|v[1]) == 0 { - z.Set(&s) - return z + + { + const curI = 3 % 2 + const prevI = 1 - curI + const iMinusOne = 3 - 1 + + h[curI], s[curI] = bits.Mul64(x[3], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } + { + const curI = 4 % 2 + const prevI = 1 - curI + const iMinusOne = 3 + + s[curI], _ = bits.Sub64(h[prevI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + + return s[curI] + c } +}*/ + +// Requires NoCarry +func (z *Element) linearCombNonModular(x *Element, xC int64, y *Element, yC int64) uint64 { + var yTimes Element + + yHi := yTimes.mulWRegular(y, yC) + xHi := z.mulWRegular(x, xC) + + carry := uint64(0) + z[0], carry = bits.Add64(z[0], yTimes[0], carry) + z[1], carry = bits.Add64(z[1], yTimes[1], carry) + z[2], carry = bits.Add64(z[2], yTimes[2], carry) + z[3], carry = bits.Add64(z[3], yTimes[3], carry) + + yHi, _ = bits.Add64(xHi, yHi, carry) + return yHi } diff --git a/ecc/bn254/fp/element_test.go b/ecc/bn254/fp/element_test.go index 50cc3abf5f..6d35fa6d9c 100644 --- a/ecc/bn254/fp/element_test.go +++ b/ecc/bn254/fp/element_test.go @@ -22,6 +22,7 @@ import ( "fmt" "math/big" "math/bits" + mrand "math/rand" "testing" "github.com/leanovate/gopter" @@ -271,7 +272,7 @@ var staticTestValues []Element func init() { staticTestValues = append(staticTestValues, Element{}) // zero staticTestValues = append(staticTestValues, One()) // one - staticTestValues = append(staticTestValues, rSquare) // r^2 + staticTestValues = append(staticTestValues, rSquare) // r² var e, one Element one.SetOne() e.Sub(&qElement, &one) @@ -1962,3 +1963,500 @@ func genFull() gopter.Gen { return genResult } } + +func TestElementInversionApproximation(t *testing.T) { + var x Element + for i := 0; i < 1000; i++ { + x.SetRandom() + + // Normally small elements are unlikely. Here we give them a higher chance + xZeros := mrand.Int() % Limbs + for j := 1; j < xZeros; j++ { + x[Limbs-j] = 0 + } + + a := approximate(&x, x.BitLen()) + aRef := approximateRef(&x) + + if a != aRef { + t.Error("Approximation mismatch") + } + } +} + +func TestElementInversionCorrectionFactorFormula(t *testing.T) { + const kLimbs = k * Limbs + const power = kLimbs*6 + invIterationsN*(kLimbs-k+1) + factorInt := big.NewInt(1) + factorInt.Lsh(factorInt, power) + factorInt.Mod(factorInt, Modulus()) + + var refFactorInt big.Int + inversionCorrectionFactor := Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + } + inversionCorrectionFactor.ToBigInt(&refFactorInt) + + if refFactorInt.Cmp(factorInt) != 0 { + t.Error("mismatch") + } +} + +func TestElementLinearComb(t *testing.T) { + var x Element + var y Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + y.SetRandom() + testLinearComb(t, &x, mrand.Int63(), &y, mrand.Int63()) + } +} + +// Probably unnecessary post-dev. In case the output of inv is wrong, this checks whether it's only off by a constant factor. +func TestElementInversionCorrectionFactor(t *testing.T) { + + // (1/x)/inv(x) = (1/1)/inv(1) ⇔ inv(1) = x inv(x) + + var one Element + var oneInv Element + one.SetOne() + oneInv.Inverse(&one) + + for i := 0; i < 100; i++ { + var x Element + var xInv Element + x.SetRandom() + xInv.Inverse(&x) + + x.Mul(&x, &xInv) + if !x.Equal(&oneInv) { + t.Error("Correction factor is inconsistent") + } + } + + if !oneInv.Equal(&one) { + var i big.Int + oneInv.ToBigIntRegular(&i) // no montgomery + i.ModInverse(&i, Modulus()) + var fac Element + fac.setBigInt(&i) // back to montgomery + + var facTimesFac Element + facTimesFac.Mul(&fac, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + }) + + t.Error("Correction factor is consistently off by", fac, "Should be", facTimesFac) + } +} + +func TestElementBigNumNeg(t *testing.T) { + var a Element + aHi := a.neg(&a, 0) + if !a.IsZero() || aHi != 0 { + t.Error("-0 != 0") + } +} + +func TestElementBigNumWMul(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + w := mrand.Int63() + testBigNumWMul(t, &x, w) + } +} + +func TestElementVeryBigIntConversion(t *testing.T) { + xHi := mrand.Uint64() + var x Element + x.SetRandom() + var xInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + x.assertMatchVeryBigInt(t, xHi, &xInt) +} + +func TestElementMontReducePos(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64() & ^signBitSelector) + } +} + +func TestElementMontReduceNeg(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64()|signBitSelector) + } +} + +func TestElementMontNegMultipleOfR(t *testing.T) { + var zero Element + + for i := 0; i < 1000; i++ { + testMontReduceSigned(t, &zero, mrand.Uint64()|signBitSelector) + } +} + +//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +func TestUpdateFactorSubtraction(t *testing.T) { + for i := 0; i < 1000; i++ { + + f0, g0 := randomizeUpdateFactors() + f1, g1 := randomizeUpdateFactors() + + for f0-f1 > 1<<31 || f0-f1 <= -1<<31 { + f1 /= 2 + } + + for g0-g1 > 1<<31 || g0-g1 <= -1<<31 { + g1 /= 2 + } + + c0 := updateFactorsCompose(f0, g0) + c1 := updateFactorsCompose(f1, g1) + + cRes := c0 - c1 + fRes, gRes := updateFactorsDecompose(cRes) + + if fRes != f0-f1 || gRes != g0-g1 { + t.Error(i) + } + } +} + +func TestUpdateFactorsDouble(t *testing.T) { + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f > 1<<30 || f < (-1<<31+1)/2 { + f /= 2 + if g <= 1<<29 && g >= (-1<<31+1)/4 { + g *= 2 //g was kept small on f's account. Now that we're halving f, we can double g + } + } + + if g > 1<<30 || g < (-1<<31+1)/2 { + g /= 2 + + if f <= 1<<29 && f >= (-1<<31+1)/4 { + f *= 2 //f was kept small on g's account. Now that we're halving g, we can double f + } + } + + c := updateFactorsCompose(f, g) + cD := c * 2 + fD, gD := updateFactorsDecompose(cD) + + if fD != 2*f || gD != 2*g { + t.Error(i) + } + } +} + +func TestUpdateFactorsNeg(t *testing.T) { + var fMistake bool + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f == 0x80000000 || g == 0x80000000 { + // Update factors this large can only have been obtained after 31 iterations and will therefore never be negated + // We don't have capacity to store -2³¹ + // Repeat this iteration + i-- + continue + } + + c := updateFactorsCompose(f, g) + nc := -c + nf, ng := updateFactorsDecompose(nc) + fMistake = fMistake || nf != -f + if nf != -f || ng != -g { + t.Errorf("Mismatch iteration #%d:\n%d, %d ->\n %d -> %d ->\n %d, %d\n Inputs in hex: %X, %X", + i, f, g, c, nc, nf, ng, f, g) + } + } + if fMistake { + t.Error("Mistake with f detected") + } else { + t.Log("All good with f") + } +} + +func TestUpdateFactorsNeg0(t *testing.T) { + c := updateFactorsCompose(0, 0) + t.Logf("c(0,0) = %X", c) + cn := -c + + if c != cn { + t.Error("Negation of zero update factors should yield the same result.") + } +} + +func TestUpdateFactorDecomposition(t *testing.T) { + var negSeen bool + + for i := 0; i < 1000; i++ { + + f, g := randomizeUpdateFactors() + + if f <= -(1<<31) || f > 1<<31 { + t.Fatal("f out of range") + } + + negSeen = negSeen || f < 0 + + c := updateFactorsCompose(f, g) + + fBack, gBack := updateFactorsDecompose(c) + + if f != fBack || g != gBack { + t.Errorf("(%d, %d) -> %d -> (%d, %d)\n", f, g, c, fBack, gBack) + } + } + + if !negSeen { + t.Fatal("No negative f factors") + } +} + +func TestUpdateFactorInitialValues(t *testing.T) { + + f0, g0 := updateFactorsDecompose(updateFactorIdentityMatrixRow0) + f1, g1 := updateFactorsDecompose(updateFactorIdentityMatrixRow1) + + if f0 != 1 || g0 != 0 || f1 != 0 || g1 != 1 { + t.Error("Update factor initial value constants are incorrect") + } +} + +func TestUpdateFactorsRandomization(t *testing.T) { + var maxLen int + + //t.Log("|f| + |g| is not to exceed", 1 << 31) + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + lf, lg := abs64T32(f), abs64T32(g) + absSum := lf + lg + if absSum >= 1<<31 { + + if absSum == 1<<31 { + maxLen++ + } else { + t.Error(i, "Sum of absolute values too large, f =", f, ",g =", g, ",|f| + |g| =", absSum) + } + } + } + + if maxLen == 0 { + t.Error("max len not observed") + } else { + t.Log(maxLen, "maxLens observed") + } +} + +func randomizeUpdateFactor(absLimit uint32) int64 { + const maxSizeLikelihood = 10 + maxSize := mrand.Intn(maxSizeLikelihood) + + absLimit64 := int64(absLimit) + var f int64 + switch maxSize { + case 0: + f = absLimit64 + case 1: + f = -absLimit64 + default: + f = int64(mrand.Uint64()%(2*uint64(absLimit64)+1)) - absLimit64 + } + + if f > 1<<31 { + return 1 << 31 + } else if f < -1<<31+1 { + return -1<<31 + 1 + } + + return f +} + +func abs64T32(f int64) uint32 { + if f >= 1<<32 || f < -1<<32 { + panic("f out of range") + } + + if f < 0 { + return uint32(-f) + } + return uint32(f) +} + +func randomizeUpdateFactors() (int64, int64) { + var f [2]int64 + b := mrand.Int() % 2 + + f[b] = randomizeUpdateFactor(1 << 31) + + //As per the paper, |f| + |g| \le 2³¹. + f[1-b] = randomizeUpdateFactor(1<<31 - abs64T32(f[b])) + + //Patching another edge case + if f[0]+f[1] == -1<<31 { + b = mrand.Int() % 2 + f[b]++ + } + + return f[0], f[1] +} + +func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { + + var p1 big.Int + x.ToBigInt(&p1) + p1.Mul(&p1, big.NewInt(xC)) + + var p2 big.Int + y.ToBigInt(&p2) + p2.Mul(&p2, big.NewInt(yC)) + + p1.Add(&p1, &p2) + p1.Mod(&p1, Modulus()) + montReduce(&p1, &p1) + + var z Element + z.linearCombSosSigned(x, xC, y, yC) + z.assertMatchVeryBigInt(t, 0, &p1) +} + +func testBigNumWMul(t *testing.T, a *Element, c int64) { + var aHi uint64 + var aTimes Element + aHi = aTimes.mulWRegular(a, c) + + assertMulProduct(t, a, c, &aTimes, aHi) +} + +func testMontReduceSigned(t *testing.T, x *Element, xHi uint64) { + var res Element + var xInt big.Int + var resInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + res.montReduceSigned(x, xHi) + montReduce(&resInt, &xInt) + res.assertMatchVeryBigInt(t, 0, &resInt) +} + +func updateFactorsCompose(f int64, g int64) int64 { + return f + g<<32 +} + +var rInv big.Int + +func montReduce(res *big.Int, x *big.Int) { + if rInv.BitLen() == 0 { // initialization + rInv.SetUint64(1) + rInv.Lsh(&rInv, Limbs*64) + rInv.ModInverse(&rInv, Modulus()) + } + res.Mul(x, &rInv) + res.Mod(res, Modulus()) +} + +func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { + z.ToBigInt(i) + var upperWord big.Int + upperWord.SetUint64(xHi) + upperWord.Lsh(&upperWord, Limbs*64) + i.Add(&upperWord, i) +} + +func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { + z.toVeryBigIntUnsigned(i, xHi) + if signBitSelector&xHi != 0 { + twosCompModulus := big.NewInt(1) + twosCompModulus.Lsh(twosCompModulus, (Limbs+1)*64) + i.Sub(i, twosCompModulus) + } +} + +func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { + var xInt big.Int + x.ToBigInt(&xInt) + + xInt.Mul(&xInt, big.NewInt(c)) + + result.assertMatchVeryBigInt(t, resultHi, &xInt) + return xInt +} + +func assertMatch(t *testing.T, w []big.Word, a uint64, index int) { + + var wI big.Word + + if index < len(w) { + wI = w[index] + } + + const filter uint64 = 0xFFFFFFFFFFFFFFFF >> (64 - bits.UintSize) + + a = a >> ((index * bits.UintSize) % 64) + a &= filter + + if uint64(wI) != a { + t.Error("Bignum mismatch: disagreement on word", index) + } +} + +func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { + + var modulus big.Int + var aIntMod big.Int + modulus.SetInt64(1) + modulus.Lsh(&modulus, (Limbs+1)*64) + aIntMod.Mod(aInt, &modulus) + + words := aIntMod.Bits() + + const steps = 64 / bits.UintSize + for i := 0; i < Limbs*steps; i++ { + assertMatch(t, words, z[i/steps], i) + } + + for i := 0; i < steps; i++ { + assertMatch(t, words, aHi, Limbs*steps+i) + } +} + +func approximateRef(x *Element) uint64 { + + var asInt big.Int + x.ToBigInt(&asInt) + n := x.BitLen() + + if n <= 64 { + return asInt.Uint64() + } + + modulus := big.NewInt(1 << 31) + var lo big.Int + lo.Mod(&asInt, modulus) + + modulus.Lsh(modulus, uint(n-64)) + var hi big.Int + hi.Div(&asInt, modulus) + hi.Lsh(&hi, 31) + + hi.Add(&hi, &lo) + return hi.Uint64() +} diff --git a/ecc/bn254/fr/element.go b/ecc/bn254/fr/element.go index dea2364d39..6bca676657 100644 --- a/ecc/bn254/fr/element.go +++ b/ecc/bn254/fr/element.go @@ -63,13 +63,21 @@ func Modulus() *big.Int { } // q (modulus) +const qElementWord0 uint64 = 4891460686036598785 +const qElementWord1 uint64 = 2896914383306846353 +const qElementWord2 uint64 = 13281191951274694749 +const qElementWord3 uint64 = 3486998266802970665 + var qElement = Element{ - 4891460686036598785, - 2896914383306846353, - 13281191951274694749, - 3486998266802970665, + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, } +// Used for Montgomery reduction. (qInvNeg) q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +const qInvNegLsw uint64 = 14042775128853446655 + // rSquare var rSquare = Element{ 1997599621687373223, @@ -187,7 +195,7 @@ func (z *Element) IsZero() bool { return (z[3] | z[2] | z[1] | z[0]) == 0 } -// IsUint64 returns true if z[0] >= 0 and all other words are 0 +// IsUint64 returns true if z[0] ⩾ 0 and all other words are 0 func (z *Element) IsUint64() bool { return (z[3] | z[2] | z[1]) == 0 } @@ -257,7 +265,7 @@ func (z *Element) SetRandom() (*Element, error) { z[3] = binary.BigEndian.Uint64(bytes[24:32]) z[3] %= 3486998266802970665 - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { var b uint64 @@ -405,7 +413,58 @@ func _mulGeneric(z, x, y *Element) { z[3], z[2] = madd3(m, 3486998266802970665, c[0], c[2], c[1]) } - // if z > q --> z -= q + // if z > q → z -= q + // note: this is NOT constant time + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) + z[1], b = bits.Sub64(z[1], 2896914383306846353, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } +} + +func _mulWGeneric(z, x *Element, y uint64) { + + var t [4]uint64 + { + // round 0 + c1, c0 := bits.Mul64(y, x[0]) + m := c0 * 14042775128853446655 + c2 := madd0(m, 4891460686036598785, c0) + c1, c0 = madd1(y, x[1], c1) + c2, t[0] = madd2(m, 2896914383306846353, c2, c0) + c1, c0 = madd1(y, x[2], c1) + c2, t[1] = madd2(m, 13281191951274694749, c2, c0) + c1, c0 = madd1(y, x[3], c1) + t[3], t[2] = madd3(m, 3486998266802970665, c0, c2, c1) + } + { + // round 1 + m := t[0] * 14042775128853446655 + c2 := madd0(m, 4891460686036598785, t[0]) + c2, t[0] = madd2(m, 2896914383306846353, c2, t[1]) + c2, t[1] = madd2(m, 13281191951274694749, c2, t[2]) + t[3], t[2] = madd2(m, 3486998266802970665, t[3], c2) + } + { + // round 2 + m := t[0] * 14042775128853446655 + c2 := madd0(m, 4891460686036598785, t[0]) + c2, t[0] = madd2(m, 2896914383306846353, c2, t[1]) + c2, t[1] = madd2(m, 13281191951274694749, c2, t[2]) + t[3], t[2] = madd2(m, 3486998266802970665, t[3], c2) + } + { + // round 3 + m := t[0] * 14042775128853446655 + c2 := madd0(m, 4891460686036598785, t[0]) + c2, z[0] = madd2(m, 2896914383306846353, c2, t[1]) + c2, z[1] = madd2(m, 13281191951274694749, c2, t[2]) + z[3], z[2] = madd2(m, 3486998266802970665, t[3], c2) + } + + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { var b uint64 @@ -456,7 +515,7 @@ func _fromMontGeneric(z *Element) { z[3] = C } - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { var b uint64 @@ -475,7 +534,7 @@ func _addGeneric(z, x, y *Element) { z[2], carry = bits.Add64(x[2], y[2], carry) z[3], _ = bits.Add64(x[3], y[3], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { var b uint64 @@ -494,7 +553,7 @@ func _doubleGeneric(z, x *Element) { z[2], carry = bits.Add64(x[2], x[2], carry) z[3], _ = bits.Add64(x[3], x[3], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { var b uint64 @@ -534,7 +593,7 @@ func _negGeneric(z, x *Element) { func _reduceGeneric(z *Element) { - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { var b uint64 @@ -642,7 +701,7 @@ func (z *Element) Exp(x Element, exponent *big.Int) *Element { } // ToMont converts z to Montgomery form -// sets and returns z = z * r^2 +// sets and returns z = z * r² func (z *Element) ToMont() *Element { return z.Mul(z, &rSquare) } @@ -772,7 +831,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { return z } -// setBigInt assumes 0 <= v < q +// setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() @@ -958,153 +1017,418 @@ func (z *Element) Sqrt(x *Element) *Element { } } -// Inverse z = x^-1 mod q -// Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" -// if x == 0, sets and returns z = x +func max(a int, b int) int { + if a > b { + return a + } + return b +} + +func min(a int, b int) int { + if a < b { + return a + } + return b +} + +const updateFactorsConversionBias int64 = 0x7fffffff7fffffff // (2³¹ - 1)(2³² + 1) +const updateFactorIdentityMatrixRow0 = 1 +const updateFactorIdentityMatrixRow1 = 1 << 32 + +func updateFactorsDecompose(c int64) (int64, int64) { + c += updateFactorsConversionBias + const low32BitsFilter int64 = 0xFFFFFFFF + f := c&low32BitsFilter - 0x7FFFFFFF + g := c>>32&low32BitsFilter - 0x7FFFFFFF + return f, g +} + +const k = 32 // word size / 2 +const signBitSelector = uint64(1) << 63 +const approxLowBitsN = k - 1 +const approxHighBitsN = k + 1 +const inversionCorrectionFactorWord0 = 13488105295233737379 +const inversionCorrectionFactorWord1 = 17373395488625725466 +const inversionCorrectionFactorWord2 = 6831692495576925776 +const inversionCorrectionFactorWord3 = 3282329835997625403 + +const invIterationsN = 18 + +// Inverse z = x⁻¹ mod q +// Implements "Optimized Binary GCD for Modular Inversion" +// https://github.com/pornin/bingcd/blob/main/doc/bingcd.pdf func (z *Element) Inverse(x *Element) *Element { if x.IsZero() { z.SetZero() return z } - // initialize u = q - var u = Element{ - 4891460686036598785, - 2896914383306846353, - 13281191951274694749, - 3486998266802970665, + a := *x + b := Element{ + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, + } // b := q + + u := Element{1} + + // Update factors: we get [u; v]:= [f0 g0; f1 g1] [u; v] + // c_i = f_i + 2³¹ - 1 + 2³² * (g_i + 2³¹ - 1) + var c0, c1 int64 + + // Saved update factors to reduce the number of field multiplications + var pf0, pf1, pg0, pg1 int64 + + var i uint + + var v, s Element + + // Since u,v are updated every other iteration, we must make sure we terminate after evenly many iterations + // This also lets us get away with half as many updates to u,v + // To make this constant-time-ish, replace the condition with i < invIterationsN + for i = 0; i&1 == 1 || !a.IsZero(); i++ { + n := max(a.BitLen(), b.BitLen()) + aApprox, bApprox := approximate(&a, n), approximate(&b, n) + + // After 0 iterations, we have f₀ ≤ 2⁰ and f₁ < 2⁰ + // f0, g0, f1, g1 = 1, 0, 0, 1 + c0, c1 = updateFactorIdentityMatrixRow0, updateFactorIdentityMatrixRow1 + + for j := 0; j < approxLowBitsN; j++ { + + if aApprox&1 == 0 { + aApprox /= 2 + } else { + s, borrow := bits.Sub64(aApprox, bApprox, 0) + if borrow == 1 { + s = bApprox - aApprox + bApprox = aApprox + c0, c1 = c1, c0 + } + + aApprox = s / 2 + c0 = c0 - c1 + + // Now |f₀| < 2ʲ + 2ʲ = 2ʲ⁺¹ + // |f₁| ≤ 2ʲ still + } + + c1 *= 2 + // |f₁| ≤ 2ʲ⁺¹ + } + + s = a + + var g0 int64 + // from this point on c0 aliases for f0 + c0, g0 = updateFactorsDecompose(c0) + aHi := a.linearCombNonModular(&s, c0, &b, g0) + if aHi&signBitSelector != 0 { + // if aHi < 0 + c0, g0 = -c0, -g0 + aHi = a.neg(&a, aHi) + } + // right-shift a by k-1 bits + a[0] = (a[0] >> approxLowBitsN) | ((a[1]) << approxHighBitsN) + a[1] = (a[1] >> approxLowBitsN) | ((a[2]) << approxHighBitsN) + a[2] = (a[2] >> approxLowBitsN) | ((a[3]) << approxHighBitsN) + a[3] = (a[3] >> approxLowBitsN) | (aHi << approxHighBitsN) + + var f1 int64 + // from this point on c1 aliases for g0 + f1, c1 = updateFactorsDecompose(c1) + bHi := b.linearCombNonModular(&s, f1, &b, c1) + if bHi&signBitSelector != 0 { + // if bHi < 0 + f1, c1 = -f1, -c1 + bHi = b.neg(&b, bHi) + } + // right-shift b by k-1 bits + b[0] = (b[0] >> approxLowBitsN) | ((b[1]) << approxHighBitsN) + b[1] = (b[1] >> approxLowBitsN) | ((b[2]) << approxHighBitsN) + b[2] = (b[2] >> approxLowBitsN) | ((b[3]) << approxHighBitsN) + b[3] = (b[3] >> approxLowBitsN) | (bHi << approxHighBitsN) + + if i&1 == 1 { + // Combine current update factors with previously stored ones + // [f₀, g₀; f₁, g₁] ← [f₀, g₀; f₁, g₀] [pf₀, pg₀; pf₀, pg₀] + // We have |f₀|, |g₀|, |pf₀|, |pf₁| ≤ 2ᵏ⁻¹, and that |pf_i| < 2ᵏ⁻¹ for i ∈ {0, 1} + // Then for the new value we get |f₀| < 2ᵏ⁻¹ × 2ᵏ⁻¹ + 2ᵏ⁻¹ × 2ᵏ⁻¹ = 2²ᵏ⁻¹ + // Which leaves us with an extra bit for the sign + + // c0 aliases f0, c1 aliases g1 + c0, g0, f1, c1 = c0*pf0+g0*pf1, + c0*pg0+g0*pg1, + f1*pf0+c1*pf1, + f1*pg0+c1*pg1 + + s = u + u.linearCombSosSigned(&u, c0, &v, g0) + v.linearCombSosSigned(&s, f1, &v, c1) + + } else { + // Save update factors + pf0, pg0, pf1, pg1 = c0, g0, f1, c1 + } } - // initialize s = r^2 - var s = Element{ - 1997599621687373223, - 6052339484930628067, - 10108755138030829701, - 150537098327114917, + // For every iteration that we miss, v is not being multiplied by 2²ᵏ⁻² + const pSq int64 = 1 << (2 * (k - 1)) + // If the function is constant-time ish, this loop will not run (probably no need to take it out explicitly) + for ; i < invIterationsN; i += 2 { + v.mulWSigned(&v, pSq) } - // r = 0 - r := Element{} + z.Mul(&v, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + }) + return z +} - v := *x +// approximate a big number x into a single 64 bit word using its uppermost and lowermost bits +// if x fits in a word as is, no approximation necessary +func approximate(x *Element, nBits int) uint64 { - var carry, borrow uint64 - var bigger bool + if nBits <= 64 { + return x[0] + } - for { - for v[0]&1 == 0 { + const mask = (uint64(1) << (k - 1)) - 1 // k-1 ones + lo := mask & x[0] - // v = v >> 1 + hiWordIndex := (nBits - 1) / 64 - v[0] = v[0]>>1 | v[1]<<63 - v[1] = v[1]>>1 | v[2]<<63 - v[2] = v[2]>>1 | v[3]<<63 - v[3] >>= 1 + hiWordBitsAvailable := nBits - hiWordIndex*64 + hiWordBitsUsed := min(hiWordBitsAvailable, approxHighBitsN) - if s[0]&1 == 1 { + mask_ := uint64(^((1 << (hiWordBitsAvailable - hiWordBitsUsed)) - 1)) + hi := (x[hiWordIndex] & mask_) << (64 - hiWordBitsAvailable) - // s = s + q - s[0], carry = bits.Add64(s[0], 4891460686036598785, 0) - s[1], carry = bits.Add64(s[1], 2896914383306846353, carry) - s[2], carry = bits.Add64(s[2], 13281191951274694749, carry) - s[3], _ = bits.Add64(s[3], 3486998266802970665, carry) + mask_ = ^(1<<(approxLowBitsN+hiWordBitsUsed) - 1) + mid := (mask_ & x[hiWordIndex-1]) >> hiWordBitsUsed - } + return lo | mid | hi +} - // s = s >> 1 +func (z *Element) linearCombSosSigned(x *Element, xC int64, y *Element, yC int64) { + hi := z.linearCombNonModular(x, xC, y, yC) + z.montReduceSigned(z, hi) +} - s[0] = s[0]>>1 | s[1]<<63 - s[1] = s[1]>>1 | s[2]<<63 - s[2] = s[2]>>1 | s[3]<<63 - s[3] >>= 1 +// montReduceSigned SOS algorithm; xHi must be at most 63 bits long. Last bit of xHi may be used as a sign bit +func (z *Element) montReduceSigned(x *Element, xHi uint64) { - } - for u[0]&1 == 0 { + const signBitRemover = ^signBitSelector + neg := xHi&signBitSelector != 0 + // the SOS implementation requires that most significant bit is 0 + // Let X be xHi*r + x + // note that if X is negative we would have initially stored it as 2⁶⁴ r + X + xHi &= signBitRemover + // with this a negative X is now represented as 2⁶³ r + X - // u = u >> 1 + var t [2*Limbs - 1]uint64 + var C uint64 - u[0] = u[0]>>1 | u[1]<<63 - u[1] = u[1]>>1 | u[2]<<63 - u[2] = u[2]>>1 | u[3]<<63 - u[3] >>= 1 + m := x[0] * qInvNegLsw - if r[0]&1 == 1 { + C = madd0(m, qElementWord0, x[0]) + C, t[1] = madd2(m, qElementWord1, x[1], C) + C, t[2] = madd2(m, qElementWord2, x[2], C) + C, t[3] = madd2(m, qElementWord3, x[3], C) - // r = r + q - r[0], carry = bits.Add64(r[0], 4891460686036598785, 0) - r[1], carry = bits.Add64(r[1], 2896914383306846353, carry) - r[2], carry = bits.Add64(r[2], 13281191951274694749, carry) - r[3], _ = bits.Add64(r[3], 3486998266802970665, carry) + // the high word of m * qElement[3] is at most 62 bits + // x[3] + C is at most 65 bits (high word at most 1 bit) + // Thus the resulting C will be at most 63 bits + t[4] = xHi + C + // xHi and C are 63 bits, therefore no overflow - } + { + const i = 1 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + + t[i+Limbs] += C + } + { + const i = 2 + m = t[i] * qInvNegLsw - // r = r >> 1 + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) - r[0] = r[0]>>1 | r[1]<<63 - r[1] = r[1]>>1 | r[2]<<63 - r[2] = r[2]>>1 | r[3]<<63 - r[3] >>= 1 + t[i+Limbs] += C + } + { + const i = 3 + m := t[i] * qInvNegLsw + C = madd0(m, qElementWord0, t[i+0]) + C, z[0] = madd2(m, qElementWord1, t[i+1], C) + C, z[1] = madd2(m, qElementWord2, t[i+2], C) + z[3], z[2] = madd2(m, qElementWord3, t[i+3], C) + } + + // if z > q → z -= q + // note: this is NOT constant time + if !(z[3] < 3486998266802970665 || (z[3] == 3486998266802970665 && (z[2] < 13281191951274694749 || (z[2] == 13281191951274694749 && (z[1] < 2896914383306846353 || (z[1] == 2896914383306846353 && (z[0] < 4891460686036598785))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 4891460686036598785, 0) + z[1], b = bits.Sub64(z[1], 2896914383306846353, b) + z[2], b = bits.Sub64(z[2], 13281191951274694749, b) + z[3], _ = bits.Sub64(z[3], 3486998266802970665, b) + } + if neg { + // We have computed ( 2⁶³ r + X ) r⁻¹ = 2⁶³ + X r⁻¹ instead + var b uint64 + z[0], b = bits.Sub64(z[0], signBitSelector, 0) + z[1], b = bits.Sub64(z[1], 0, b) + z[2], b = bits.Sub64(z[2], 0, b) + z[3], b = bits.Sub64(z[3], 0, b) + + // Occurs iff x == 0 && xHi < 0, i.e. X = rX' for -2⁶³ ≤ X' < 0 + if b != 0 { + // z[3] = -1 + // negative: add q + const neg1 = 0xFFFFFFFFFFFFFFFF + + b = 0 + z[0], b = bits.Add64(z[0], qElementWord0, b) + z[1], b = bits.Add64(z[1], qElementWord1, b) + z[2], b = bits.Add64(z[2], qElementWord2, b) + z[3], _ = bits.Add64(neg1, qElementWord3, b) } + } +} + +// mulWSigned mul word signed (w/ montgomery reduction) +func (z *Element) mulWSigned(x *Element, y int64) { + m := y >> 63 + _mulWGeneric(z, x, uint64((y^m)-m)) + // multiply by abs(y) + if y < 0 { + z.Neg(z) + } +} + +func (z *Element) neg(x *Element, xHi uint64) uint64 { + var b uint64 - // v >= u - bigger = !(v[3] < u[3] || (v[3] == u[3] && (v[2] < u[2] || (v[2] == u[2] && (v[1] < u[1] || (v[1] == u[1] && (v[0] < u[0]))))))) + z[0], b = bits.Sub64(0, x[0], 0) + z[1], b = bits.Sub64(0, x[1], b) + z[2], b = bits.Sub64(0, x[2], b) + z[3], b = bits.Sub64(0, x[3], b) + xHi, _ = bits.Sub64(0, xHi, b) - if bigger { + return xHi +} - // v = v - u - v[0], borrow = bits.Sub64(v[0], u[0], 0) - v[1], borrow = bits.Sub64(v[1], u[1], borrow) - v[2], borrow = bits.Sub64(v[2], u[2], borrow) - v[3], _ = bits.Sub64(v[3], u[3], borrow) +// regular multiplication by one word regular (non montgomery) +// Fewer additions than the branch-free for positive y. Could be faster on some architectures +func (z *Element) mulWRegular(x *Element, y int64) uint64 { - // s = s - r - s[0], borrow = bits.Sub64(s[0], r[0], 0) - s[1], borrow = bits.Sub64(s[1], r[1], borrow) - s[2], borrow = bits.Sub64(s[2], r[2], borrow) - s[3], borrow = bits.Sub64(s[3], r[3], borrow) + // w := abs(y) + m := y >> 63 + w := uint64((y ^ m) - m) - if borrow == 1 { + var c uint64 + c, z[0] = bits.Mul64(x[0], w) + c, z[1] = madd1(x[1], w, c) + c, z[2] = madd1(x[2], w, c) + c, z[3] = madd1(x[3], w, c) - // s = s + q - s[0], carry = bits.Add64(s[0], 4891460686036598785, 0) - s[1], carry = bits.Add64(s[1], 2896914383306846353, carry) - s[2], carry = bits.Add64(s[2], 13281191951274694749, carry) - s[3], _ = bits.Add64(s[3], 3486998266802970665, carry) + if y < 0 { + c = z.neg(z, c) + } - } - } else { + return c +} - // u = u - v - u[0], borrow = bits.Sub64(u[0], v[0], 0) - u[1], borrow = bits.Sub64(u[1], v[1], borrow) - u[2], borrow = bits.Sub64(u[2], v[2], borrow) - u[3], _ = bits.Sub64(u[3], v[3], borrow) +/* +Removed: seems slower +// mulWRegular branch-free regular multiplication by one word (non montgomery) +func (z *Element) mulWRegularBf(x *Element, y int64) uint64 { - // r = r - s - r[0], borrow = bits.Sub64(r[0], s[0], 0) - r[1], borrow = bits.Sub64(r[1], s[1], borrow) - r[2], borrow = bits.Sub64(r[2], s[2], borrow) - r[3], borrow = bits.Sub64(r[3], s[3], borrow) + w := uint64(y) + allNeg := uint64(y >> 63) // -1 if y < 0, 0 o.w - if borrow == 1 { + // s[0], s[1] so results are not stored immediately in z. + // x[i] will be needed in the i+1 th iteration. We don't want to overwrite it in case x = z + var s [2]uint64 + var h [2]uint64 - // r = r + q - r[0], carry = bits.Add64(r[0], 4891460686036598785, 0) - r[1], carry = bits.Add64(r[1], 2896914383306846353, carry) - r[2], carry = bits.Add64(r[2], 13281191951274694749, carry) - r[3], _ = bits.Add64(r[3], 3486998266802970665, carry) + h[0], s[0] = bits.Mul64(x[0], w) - } + c := uint64(0) + b := uint64(0) + + { + const curI = 1 % 2 + const prevI = 1 - curI + const iMinusOne = 1 - 1 + + h[curI], s[curI] = bits.Mul64(x[1], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (u[0] == 1) && (u[3]|u[2]|u[1]) == 0 { - z.Set(&r) - return z + + { + const curI = 2 % 2 + const prevI = 1 - curI + const iMinusOne = 2 - 1 + + h[curI], s[curI] = bits.Mul64(x[2], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (v[0] == 1) && (v[3]|v[2]|v[1]) == 0 { - z.Set(&s) - return z + + { + const curI = 3 % 2 + const prevI = 1 - curI + const iMinusOne = 3 - 1 + + h[curI], s[curI] = bits.Mul64(x[3], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } + { + const curI = 4 % 2 + const prevI = 1 - curI + const iMinusOne = 3 + + s[curI], _ = bits.Sub64(h[prevI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + + return s[curI] + c } +}*/ + +// Requires NoCarry +func (z *Element) linearCombNonModular(x *Element, xC int64, y *Element, yC int64) uint64 { + var yTimes Element + + yHi := yTimes.mulWRegular(y, yC) + xHi := z.mulWRegular(x, xC) + + carry := uint64(0) + z[0], carry = bits.Add64(z[0], yTimes[0], carry) + z[1], carry = bits.Add64(z[1], yTimes[1], carry) + z[2], carry = bits.Add64(z[2], yTimes[2], carry) + z[3], carry = bits.Add64(z[3], yTimes[3], carry) + + yHi, _ = bits.Add64(xHi, yHi, carry) + return yHi } diff --git a/ecc/bn254/fr/element_test.go b/ecc/bn254/fr/element_test.go index 7eee01fd3d..cf35006002 100644 --- a/ecc/bn254/fr/element_test.go +++ b/ecc/bn254/fr/element_test.go @@ -22,6 +22,7 @@ import ( "fmt" "math/big" "math/bits" + mrand "math/rand" "testing" "github.com/leanovate/gopter" @@ -271,7 +272,7 @@ var staticTestValues []Element func init() { staticTestValues = append(staticTestValues, Element{}) // zero staticTestValues = append(staticTestValues, One()) // one - staticTestValues = append(staticTestValues, rSquare) // r^2 + staticTestValues = append(staticTestValues, rSquare) // r² var e, one Element one.SetOne() e.Sub(&qElement, &one) @@ -1962,3 +1963,500 @@ func genFull() gopter.Gen { return genResult } } + +func TestElementInversionApproximation(t *testing.T) { + var x Element + for i := 0; i < 1000; i++ { + x.SetRandom() + + // Normally small elements are unlikely. Here we give them a higher chance + xZeros := mrand.Int() % Limbs + for j := 1; j < xZeros; j++ { + x[Limbs-j] = 0 + } + + a := approximate(&x, x.BitLen()) + aRef := approximateRef(&x) + + if a != aRef { + t.Error("Approximation mismatch") + } + } +} + +func TestElementInversionCorrectionFactorFormula(t *testing.T) { + const kLimbs = k * Limbs + const power = kLimbs*6 + invIterationsN*(kLimbs-k+1) + factorInt := big.NewInt(1) + factorInt.Lsh(factorInt, power) + factorInt.Mod(factorInt, Modulus()) + + var refFactorInt big.Int + inversionCorrectionFactor := Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + } + inversionCorrectionFactor.ToBigInt(&refFactorInt) + + if refFactorInt.Cmp(factorInt) != 0 { + t.Error("mismatch") + } +} + +func TestElementLinearComb(t *testing.T) { + var x Element + var y Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + y.SetRandom() + testLinearComb(t, &x, mrand.Int63(), &y, mrand.Int63()) + } +} + +// Probably unnecessary post-dev. In case the output of inv is wrong, this checks whether it's only off by a constant factor. +func TestElementInversionCorrectionFactor(t *testing.T) { + + // (1/x)/inv(x) = (1/1)/inv(1) ⇔ inv(1) = x inv(x) + + var one Element + var oneInv Element + one.SetOne() + oneInv.Inverse(&one) + + for i := 0; i < 100; i++ { + var x Element + var xInv Element + x.SetRandom() + xInv.Inverse(&x) + + x.Mul(&x, &xInv) + if !x.Equal(&oneInv) { + t.Error("Correction factor is inconsistent") + } + } + + if !oneInv.Equal(&one) { + var i big.Int + oneInv.ToBigIntRegular(&i) // no montgomery + i.ModInverse(&i, Modulus()) + var fac Element + fac.setBigInt(&i) // back to montgomery + + var facTimesFac Element + facTimesFac.Mul(&fac, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + }) + + t.Error("Correction factor is consistently off by", fac, "Should be", facTimesFac) + } +} + +func TestElementBigNumNeg(t *testing.T) { + var a Element + aHi := a.neg(&a, 0) + if !a.IsZero() || aHi != 0 { + t.Error("-0 != 0") + } +} + +func TestElementBigNumWMul(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + w := mrand.Int63() + testBigNumWMul(t, &x, w) + } +} + +func TestElementVeryBigIntConversion(t *testing.T) { + xHi := mrand.Uint64() + var x Element + x.SetRandom() + var xInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + x.assertMatchVeryBigInt(t, xHi, &xInt) +} + +func TestElementMontReducePos(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64() & ^signBitSelector) + } +} + +func TestElementMontReduceNeg(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64()|signBitSelector) + } +} + +func TestElementMontNegMultipleOfR(t *testing.T) { + var zero Element + + for i := 0; i < 1000; i++ { + testMontReduceSigned(t, &zero, mrand.Uint64()|signBitSelector) + } +} + +//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +func TestUpdateFactorSubtraction(t *testing.T) { + for i := 0; i < 1000; i++ { + + f0, g0 := randomizeUpdateFactors() + f1, g1 := randomizeUpdateFactors() + + for f0-f1 > 1<<31 || f0-f1 <= -1<<31 { + f1 /= 2 + } + + for g0-g1 > 1<<31 || g0-g1 <= -1<<31 { + g1 /= 2 + } + + c0 := updateFactorsCompose(f0, g0) + c1 := updateFactorsCompose(f1, g1) + + cRes := c0 - c1 + fRes, gRes := updateFactorsDecompose(cRes) + + if fRes != f0-f1 || gRes != g0-g1 { + t.Error(i) + } + } +} + +func TestUpdateFactorsDouble(t *testing.T) { + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f > 1<<30 || f < (-1<<31+1)/2 { + f /= 2 + if g <= 1<<29 && g >= (-1<<31+1)/4 { + g *= 2 //g was kept small on f's account. Now that we're halving f, we can double g + } + } + + if g > 1<<30 || g < (-1<<31+1)/2 { + g /= 2 + + if f <= 1<<29 && f >= (-1<<31+1)/4 { + f *= 2 //f was kept small on g's account. Now that we're halving g, we can double f + } + } + + c := updateFactorsCompose(f, g) + cD := c * 2 + fD, gD := updateFactorsDecompose(cD) + + if fD != 2*f || gD != 2*g { + t.Error(i) + } + } +} + +func TestUpdateFactorsNeg(t *testing.T) { + var fMistake bool + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f == 0x80000000 || g == 0x80000000 { + // Update factors this large can only have been obtained after 31 iterations and will therefore never be negated + // We don't have capacity to store -2³¹ + // Repeat this iteration + i-- + continue + } + + c := updateFactorsCompose(f, g) + nc := -c + nf, ng := updateFactorsDecompose(nc) + fMistake = fMistake || nf != -f + if nf != -f || ng != -g { + t.Errorf("Mismatch iteration #%d:\n%d, %d ->\n %d -> %d ->\n %d, %d\n Inputs in hex: %X, %X", + i, f, g, c, nc, nf, ng, f, g) + } + } + if fMistake { + t.Error("Mistake with f detected") + } else { + t.Log("All good with f") + } +} + +func TestUpdateFactorsNeg0(t *testing.T) { + c := updateFactorsCompose(0, 0) + t.Logf("c(0,0) = %X", c) + cn := -c + + if c != cn { + t.Error("Negation of zero update factors should yield the same result.") + } +} + +func TestUpdateFactorDecomposition(t *testing.T) { + var negSeen bool + + for i := 0; i < 1000; i++ { + + f, g := randomizeUpdateFactors() + + if f <= -(1<<31) || f > 1<<31 { + t.Fatal("f out of range") + } + + negSeen = negSeen || f < 0 + + c := updateFactorsCompose(f, g) + + fBack, gBack := updateFactorsDecompose(c) + + if f != fBack || g != gBack { + t.Errorf("(%d, %d) -> %d -> (%d, %d)\n", f, g, c, fBack, gBack) + } + } + + if !negSeen { + t.Fatal("No negative f factors") + } +} + +func TestUpdateFactorInitialValues(t *testing.T) { + + f0, g0 := updateFactorsDecompose(updateFactorIdentityMatrixRow0) + f1, g1 := updateFactorsDecompose(updateFactorIdentityMatrixRow1) + + if f0 != 1 || g0 != 0 || f1 != 0 || g1 != 1 { + t.Error("Update factor initial value constants are incorrect") + } +} + +func TestUpdateFactorsRandomization(t *testing.T) { + var maxLen int + + //t.Log("|f| + |g| is not to exceed", 1 << 31) + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + lf, lg := abs64T32(f), abs64T32(g) + absSum := lf + lg + if absSum >= 1<<31 { + + if absSum == 1<<31 { + maxLen++ + } else { + t.Error(i, "Sum of absolute values too large, f =", f, ",g =", g, ",|f| + |g| =", absSum) + } + } + } + + if maxLen == 0 { + t.Error("max len not observed") + } else { + t.Log(maxLen, "maxLens observed") + } +} + +func randomizeUpdateFactor(absLimit uint32) int64 { + const maxSizeLikelihood = 10 + maxSize := mrand.Intn(maxSizeLikelihood) + + absLimit64 := int64(absLimit) + var f int64 + switch maxSize { + case 0: + f = absLimit64 + case 1: + f = -absLimit64 + default: + f = int64(mrand.Uint64()%(2*uint64(absLimit64)+1)) - absLimit64 + } + + if f > 1<<31 { + return 1 << 31 + } else if f < -1<<31+1 { + return -1<<31 + 1 + } + + return f +} + +func abs64T32(f int64) uint32 { + if f >= 1<<32 || f < -1<<32 { + panic("f out of range") + } + + if f < 0 { + return uint32(-f) + } + return uint32(f) +} + +func randomizeUpdateFactors() (int64, int64) { + var f [2]int64 + b := mrand.Int() % 2 + + f[b] = randomizeUpdateFactor(1 << 31) + + //As per the paper, |f| + |g| \le 2³¹. + f[1-b] = randomizeUpdateFactor(1<<31 - abs64T32(f[b])) + + //Patching another edge case + if f[0]+f[1] == -1<<31 { + b = mrand.Int() % 2 + f[b]++ + } + + return f[0], f[1] +} + +func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { + + var p1 big.Int + x.ToBigInt(&p1) + p1.Mul(&p1, big.NewInt(xC)) + + var p2 big.Int + y.ToBigInt(&p2) + p2.Mul(&p2, big.NewInt(yC)) + + p1.Add(&p1, &p2) + p1.Mod(&p1, Modulus()) + montReduce(&p1, &p1) + + var z Element + z.linearCombSosSigned(x, xC, y, yC) + z.assertMatchVeryBigInt(t, 0, &p1) +} + +func testBigNumWMul(t *testing.T, a *Element, c int64) { + var aHi uint64 + var aTimes Element + aHi = aTimes.mulWRegular(a, c) + + assertMulProduct(t, a, c, &aTimes, aHi) +} + +func testMontReduceSigned(t *testing.T, x *Element, xHi uint64) { + var res Element + var xInt big.Int + var resInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + res.montReduceSigned(x, xHi) + montReduce(&resInt, &xInt) + res.assertMatchVeryBigInt(t, 0, &resInt) +} + +func updateFactorsCompose(f int64, g int64) int64 { + return f + g<<32 +} + +var rInv big.Int + +func montReduce(res *big.Int, x *big.Int) { + if rInv.BitLen() == 0 { // initialization + rInv.SetUint64(1) + rInv.Lsh(&rInv, Limbs*64) + rInv.ModInverse(&rInv, Modulus()) + } + res.Mul(x, &rInv) + res.Mod(res, Modulus()) +} + +func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { + z.ToBigInt(i) + var upperWord big.Int + upperWord.SetUint64(xHi) + upperWord.Lsh(&upperWord, Limbs*64) + i.Add(&upperWord, i) +} + +func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { + z.toVeryBigIntUnsigned(i, xHi) + if signBitSelector&xHi != 0 { + twosCompModulus := big.NewInt(1) + twosCompModulus.Lsh(twosCompModulus, (Limbs+1)*64) + i.Sub(i, twosCompModulus) + } +} + +func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { + var xInt big.Int + x.ToBigInt(&xInt) + + xInt.Mul(&xInt, big.NewInt(c)) + + result.assertMatchVeryBigInt(t, resultHi, &xInt) + return xInt +} + +func assertMatch(t *testing.T, w []big.Word, a uint64, index int) { + + var wI big.Word + + if index < len(w) { + wI = w[index] + } + + const filter uint64 = 0xFFFFFFFFFFFFFFFF >> (64 - bits.UintSize) + + a = a >> ((index * bits.UintSize) % 64) + a &= filter + + if uint64(wI) != a { + t.Error("Bignum mismatch: disagreement on word", index) + } +} + +func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { + + var modulus big.Int + var aIntMod big.Int + modulus.SetInt64(1) + modulus.Lsh(&modulus, (Limbs+1)*64) + aIntMod.Mod(aInt, &modulus) + + words := aIntMod.Bits() + + const steps = 64 / bits.UintSize + for i := 0; i < Limbs*steps; i++ { + assertMatch(t, words, z[i/steps], i) + } + + for i := 0; i < steps; i++ { + assertMatch(t, words, aHi, Limbs*steps+i) + } +} + +func approximateRef(x *Element) uint64 { + + var asInt big.Int + x.ToBigInt(&asInt) + n := x.BitLen() + + if n <= 64 { + return asInt.Uint64() + } + + modulus := big.NewInt(1 << 31) + var lo big.Int + lo.Mod(&asInt, modulus) + + modulus.Lsh(modulus, uint(n-64)) + var hi big.Int + hi.Div(&asInt, modulus) + hi.Lsh(&hi, 31) + + hi.Add(&hi, &lo) + return hi.Uint64() +} diff --git a/ecc/bw6-633/fp/element.go b/ecc/bw6-633/fp/element.go index 065a47a485..bdcb4d1702 100644 --- a/ecc/bw6-633/fp/element.go +++ b/ecc/bw6-633/fp/element.go @@ -63,19 +63,33 @@ func Modulus() *big.Int { } // q (modulus) +const qElementWord0 uint64 = 15512955586897510413 +const qElementWord1 uint64 = 4410884215886313276 +const qElementWord2 uint64 = 15543556715411259941 +const qElementWord3 uint64 = 9083347379620258823 +const qElementWord4 uint64 = 13320134076191308873 +const qElementWord5 uint64 = 9318693926755804304 +const qElementWord6 uint64 = 5645674015335635503 +const qElementWord7 uint64 = 12176845843281334983 +const qElementWord8 uint64 = 18165857675053050549 +const qElementWord9 uint64 = 82862755739295587 + var qElement = Element{ - 15512955586897510413, - 4410884215886313276, - 15543556715411259941, - 9083347379620258823, - 13320134076191308873, - 9318693926755804304, - 5645674015335635503, - 12176845843281334983, - 18165857675053050549, - 82862755739295587, + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, + qElementWord4, + qElementWord5, + qElementWord6, + qElementWord7, + qElementWord8, + qElementWord9, } +// Used for Montgomery reduction. (qInvNeg) q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +const qInvNegLsw uint64 = 13046692460116554043 + // rSquare var rSquare = Element{ 7358459907925294924, @@ -217,7 +231,7 @@ func (z *Element) IsZero() bool { return (z[9] | z[8] | z[7] | z[6] | z[5] | z[4] | z[3] | z[2] | z[1] | z[0]) == 0 } -// IsUint64 returns true if z[0] >= 0 and all other words are 0 +// IsUint64 returns true if z[0] ⩾ 0 and all other words are 0 func (z *Element) IsUint64() bool { return (z[9] | z[8] | z[7] | z[6] | z[5] | z[4] | z[3] | z[2] | z[1]) == 0 } @@ -329,7 +343,7 @@ func (z *Element) SetRandom() (*Element, error) { z[9] = binary.BigEndian.Uint64(bytes[72:80]) z[9] %= 82862755739295587 - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[9] < 82862755739295587 || (z[9] == 82862755739295587 && (z[8] < 18165857675053050549 || (z[8] == 18165857675053050549 && (z[7] < 12176845843281334983 || (z[7] == 12176845843281334983 && (z[6] < 5645674015335635503 || (z[6] == 5645674015335635503 && (z[5] < 9318693926755804304 || (z[5] == 9318693926755804304 && (z[4] < 13320134076191308873 || (z[4] == 13320134076191308873 && (z[3] < 9083347379620258823 || (z[3] == 9083347379620258823 && (z[2] < 15543556715411259941 || (z[2] == 15543556715411259941 && (z[1] < 4410884215886313276 || (z[1] == 4410884215886313276 && (z[0] < 15512955586897510413))))))))))))))))))) { var b uint64 @@ -693,7 +707,178 @@ func _mulGeneric(z, x, y *Element) { z[9], z[8] = madd3(m, 82862755739295587, c[0], c[2], c[1]) } - // if z > q --> z -= q + // if z > q → z -= q + // note: this is NOT constant time + if !(z[9] < 82862755739295587 || (z[9] == 82862755739295587 && (z[8] < 18165857675053050549 || (z[8] == 18165857675053050549 && (z[7] < 12176845843281334983 || (z[7] == 12176845843281334983 && (z[6] < 5645674015335635503 || (z[6] == 5645674015335635503 && (z[5] < 9318693926755804304 || (z[5] == 9318693926755804304 && (z[4] < 13320134076191308873 || (z[4] == 13320134076191308873 && (z[3] < 9083347379620258823 || (z[3] == 9083347379620258823 && (z[2] < 15543556715411259941 || (z[2] == 15543556715411259941 && (z[1] < 4410884215886313276 || (z[1] == 4410884215886313276 && (z[0] < 15512955586897510413))))))))))))))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 15512955586897510413, 0) + z[1], b = bits.Sub64(z[1], 4410884215886313276, b) + z[2], b = bits.Sub64(z[2], 15543556715411259941, b) + z[3], b = bits.Sub64(z[3], 9083347379620258823, b) + z[4], b = bits.Sub64(z[4], 13320134076191308873, b) + z[5], b = bits.Sub64(z[5], 9318693926755804304, b) + z[6], b = bits.Sub64(z[6], 5645674015335635503, b) + z[7], b = bits.Sub64(z[7], 12176845843281334983, b) + z[8], b = bits.Sub64(z[8], 18165857675053050549, b) + z[9], _ = bits.Sub64(z[9], 82862755739295587, b) + } +} + +func _mulWGeneric(z, x *Element, y uint64) { + + var t [10]uint64 + { + // round 0 + c1, c0 := bits.Mul64(y, x[0]) + m := c0 * 13046692460116554043 + c2 := madd0(m, 15512955586897510413, c0) + c1, c0 = madd1(y, x[1], c1) + c2, t[0] = madd2(m, 4410884215886313276, c2, c0) + c1, c0 = madd1(y, x[2], c1) + c2, t[1] = madd2(m, 15543556715411259941, c2, c0) + c1, c0 = madd1(y, x[3], c1) + c2, t[2] = madd2(m, 9083347379620258823, c2, c0) + c1, c0 = madd1(y, x[4], c1) + c2, t[3] = madd2(m, 13320134076191308873, c2, c0) + c1, c0 = madd1(y, x[5], c1) + c2, t[4] = madd2(m, 9318693926755804304, c2, c0) + c1, c0 = madd1(y, x[6], c1) + c2, t[5] = madd2(m, 5645674015335635503, c2, c0) + c1, c0 = madd1(y, x[7], c1) + c2, t[6] = madd2(m, 12176845843281334983, c2, c0) + c1, c0 = madd1(y, x[8], c1) + c2, t[7] = madd2(m, 18165857675053050549, c2, c0) + c1, c0 = madd1(y, x[9], c1) + t[9], t[8] = madd3(m, 82862755739295587, c0, c2, c1) + } + { + // round 1 + m := t[0] * 13046692460116554043 + c2 := madd0(m, 15512955586897510413, t[0]) + c2, t[0] = madd2(m, 4410884215886313276, c2, t[1]) + c2, t[1] = madd2(m, 15543556715411259941, c2, t[2]) + c2, t[2] = madd2(m, 9083347379620258823, c2, t[3]) + c2, t[3] = madd2(m, 13320134076191308873, c2, t[4]) + c2, t[4] = madd2(m, 9318693926755804304, c2, t[5]) + c2, t[5] = madd2(m, 5645674015335635503, c2, t[6]) + c2, t[6] = madd2(m, 12176845843281334983, c2, t[7]) + c2, t[7] = madd2(m, 18165857675053050549, c2, t[8]) + t[9], t[8] = madd2(m, 82862755739295587, t[9], c2) + } + { + // round 2 + m := t[0] * 13046692460116554043 + c2 := madd0(m, 15512955586897510413, t[0]) + c2, t[0] = madd2(m, 4410884215886313276, c2, t[1]) + c2, t[1] = madd2(m, 15543556715411259941, c2, t[2]) + c2, t[2] = madd2(m, 9083347379620258823, c2, t[3]) + c2, t[3] = madd2(m, 13320134076191308873, c2, t[4]) + c2, t[4] = madd2(m, 9318693926755804304, c2, t[5]) + c2, t[5] = madd2(m, 5645674015335635503, c2, t[6]) + c2, t[6] = madd2(m, 12176845843281334983, c2, t[7]) + c2, t[7] = madd2(m, 18165857675053050549, c2, t[8]) + t[9], t[8] = madd2(m, 82862755739295587, t[9], c2) + } + { + // round 3 + m := t[0] * 13046692460116554043 + c2 := madd0(m, 15512955586897510413, t[0]) + c2, t[0] = madd2(m, 4410884215886313276, c2, t[1]) + c2, t[1] = madd2(m, 15543556715411259941, c2, t[2]) + c2, t[2] = madd2(m, 9083347379620258823, c2, t[3]) + c2, t[3] = madd2(m, 13320134076191308873, c2, t[4]) + c2, t[4] = madd2(m, 9318693926755804304, c2, t[5]) + c2, t[5] = madd2(m, 5645674015335635503, c2, t[6]) + c2, t[6] = madd2(m, 12176845843281334983, c2, t[7]) + c2, t[7] = madd2(m, 18165857675053050549, c2, t[8]) + t[9], t[8] = madd2(m, 82862755739295587, t[9], c2) + } + { + // round 4 + m := t[0] * 13046692460116554043 + c2 := madd0(m, 15512955586897510413, t[0]) + c2, t[0] = madd2(m, 4410884215886313276, c2, t[1]) + c2, t[1] = madd2(m, 15543556715411259941, c2, t[2]) + c2, t[2] = madd2(m, 9083347379620258823, c2, t[3]) + c2, t[3] = madd2(m, 13320134076191308873, c2, t[4]) + c2, t[4] = madd2(m, 9318693926755804304, c2, t[5]) + c2, t[5] = madd2(m, 5645674015335635503, c2, t[6]) + c2, t[6] = madd2(m, 12176845843281334983, c2, t[7]) + c2, t[7] = madd2(m, 18165857675053050549, c2, t[8]) + t[9], t[8] = madd2(m, 82862755739295587, t[9], c2) + } + { + // round 5 + m := t[0] * 13046692460116554043 + c2 := madd0(m, 15512955586897510413, t[0]) + c2, t[0] = madd2(m, 4410884215886313276, c2, t[1]) + c2, t[1] = madd2(m, 15543556715411259941, c2, t[2]) + c2, t[2] = madd2(m, 9083347379620258823, c2, t[3]) + c2, t[3] = madd2(m, 13320134076191308873, c2, t[4]) + c2, t[4] = madd2(m, 9318693926755804304, c2, t[5]) + c2, t[5] = madd2(m, 5645674015335635503, c2, t[6]) + c2, t[6] = madd2(m, 12176845843281334983, c2, t[7]) + c2, t[7] = madd2(m, 18165857675053050549, c2, t[8]) + t[9], t[8] = madd2(m, 82862755739295587, t[9], c2) + } + { + // round 6 + m := t[0] * 13046692460116554043 + c2 := madd0(m, 15512955586897510413, t[0]) + c2, t[0] = madd2(m, 4410884215886313276, c2, t[1]) + c2, t[1] = madd2(m, 15543556715411259941, c2, t[2]) + c2, t[2] = madd2(m, 9083347379620258823, c2, t[3]) + c2, t[3] = madd2(m, 13320134076191308873, c2, t[4]) + c2, t[4] = madd2(m, 9318693926755804304, c2, t[5]) + c2, t[5] = madd2(m, 5645674015335635503, c2, t[6]) + c2, t[6] = madd2(m, 12176845843281334983, c2, t[7]) + c2, t[7] = madd2(m, 18165857675053050549, c2, t[8]) + t[9], t[8] = madd2(m, 82862755739295587, t[9], c2) + } + { + // round 7 + m := t[0] * 13046692460116554043 + c2 := madd0(m, 15512955586897510413, t[0]) + c2, t[0] = madd2(m, 4410884215886313276, c2, t[1]) + c2, t[1] = madd2(m, 15543556715411259941, c2, t[2]) + c2, t[2] = madd2(m, 9083347379620258823, c2, t[3]) + c2, t[3] = madd2(m, 13320134076191308873, c2, t[4]) + c2, t[4] = madd2(m, 9318693926755804304, c2, t[5]) + c2, t[5] = madd2(m, 5645674015335635503, c2, t[6]) + c2, t[6] = madd2(m, 12176845843281334983, c2, t[7]) + c2, t[7] = madd2(m, 18165857675053050549, c2, t[8]) + t[9], t[8] = madd2(m, 82862755739295587, t[9], c2) + } + { + // round 8 + m := t[0] * 13046692460116554043 + c2 := madd0(m, 15512955586897510413, t[0]) + c2, t[0] = madd2(m, 4410884215886313276, c2, t[1]) + c2, t[1] = madd2(m, 15543556715411259941, c2, t[2]) + c2, t[2] = madd2(m, 9083347379620258823, c2, t[3]) + c2, t[3] = madd2(m, 13320134076191308873, c2, t[4]) + c2, t[4] = madd2(m, 9318693926755804304, c2, t[5]) + c2, t[5] = madd2(m, 5645674015335635503, c2, t[6]) + c2, t[6] = madd2(m, 12176845843281334983, c2, t[7]) + c2, t[7] = madd2(m, 18165857675053050549, c2, t[8]) + t[9], t[8] = madd2(m, 82862755739295587, t[9], c2) + } + { + // round 9 + m := t[0] * 13046692460116554043 + c2 := madd0(m, 15512955586897510413, t[0]) + c2, z[0] = madd2(m, 4410884215886313276, c2, t[1]) + c2, z[1] = madd2(m, 15543556715411259941, c2, t[2]) + c2, z[2] = madd2(m, 9083347379620258823, c2, t[3]) + c2, z[3] = madd2(m, 13320134076191308873, c2, t[4]) + c2, z[4] = madd2(m, 9318693926755804304, c2, t[5]) + c2, z[5] = madd2(m, 5645674015335635503, c2, t[6]) + c2, z[6] = madd2(m, 12176845843281334983, c2, t[7]) + c2, z[7] = madd2(m, 18165857675053050549, c2, t[8]) + z[9], z[8] = madd2(m, 82862755739295587, t[9], c2) + } + + // if z > q → z -= q // note: this is NOT constant time if !(z[9] < 82862755739295587 || (z[9] == 82862755739295587 && (z[8] < 18165857675053050549 || (z[8] == 18165857675053050549 && (z[7] < 12176845843281334983 || (z[7] == 12176845843281334983 && (z[6] < 5645674015335635503 || (z[6] == 5645674015335635503 && (z[5] < 9318693926755804304 || (z[5] == 9318693926755804304 && (z[4] < 13320134076191308873 || (z[4] == 13320134076191308873 && (z[3] < 9083347379620258823 || (z[3] == 9083347379620258823 && (z[2] < 15543556715411259941 || (z[2] == 15543556715411259941 && (z[1] < 4410884215886313276 || (z[1] == 4410884215886313276 && (z[0] < 15512955586897510413))))))))))))))))))) { var b uint64 @@ -864,7 +1049,7 @@ func _fromMontGeneric(z *Element) { z[9] = C } - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[9] < 82862755739295587 || (z[9] == 82862755739295587 && (z[8] < 18165857675053050549 || (z[8] == 18165857675053050549 && (z[7] < 12176845843281334983 || (z[7] == 12176845843281334983 && (z[6] < 5645674015335635503 || (z[6] == 5645674015335635503 && (z[5] < 9318693926755804304 || (z[5] == 9318693926755804304 && (z[4] < 13320134076191308873 || (z[4] == 13320134076191308873 && (z[3] < 9083347379620258823 || (z[3] == 9083347379620258823 && (z[2] < 15543556715411259941 || (z[2] == 15543556715411259941 && (z[1] < 4410884215886313276 || (z[1] == 4410884215886313276 && (z[0] < 15512955586897510413))))))))))))))))))) { var b uint64 @@ -895,7 +1080,7 @@ func _addGeneric(z, x, y *Element) { z[8], carry = bits.Add64(x[8], y[8], carry) z[9], _ = bits.Add64(x[9], y[9], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[9] < 82862755739295587 || (z[9] == 82862755739295587 && (z[8] < 18165857675053050549 || (z[8] == 18165857675053050549 && (z[7] < 12176845843281334983 || (z[7] == 12176845843281334983 && (z[6] < 5645674015335635503 || (z[6] == 5645674015335635503 && (z[5] < 9318693926755804304 || (z[5] == 9318693926755804304 && (z[4] < 13320134076191308873 || (z[4] == 13320134076191308873 && (z[3] < 9083347379620258823 || (z[3] == 9083347379620258823 && (z[2] < 15543556715411259941 || (z[2] == 15543556715411259941 && (z[1] < 4410884215886313276 || (z[1] == 4410884215886313276 && (z[0] < 15512955586897510413))))))))))))))))))) { var b uint64 @@ -926,7 +1111,7 @@ func _doubleGeneric(z, x *Element) { z[8], carry = bits.Add64(x[8], x[8], carry) z[9], _ = bits.Add64(x[9], x[9], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[9] < 82862755739295587 || (z[9] == 82862755739295587 && (z[8] < 18165857675053050549 || (z[8] == 18165857675053050549 && (z[7] < 12176845843281334983 || (z[7] == 12176845843281334983 && (z[6] < 5645674015335635503 || (z[6] == 5645674015335635503 && (z[5] < 9318693926755804304 || (z[5] == 9318693926755804304 && (z[4] < 13320134076191308873 || (z[4] == 13320134076191308873 && (z[3] < 9083347379620258823 || (z[3] == 9083347379620258823 && (z[2] < 15543556715411259941 || (z[2] == 15543556715411259941 && (z[1] < 4410884215886313276 || (z[1] == 4410884215886313276 && (z[0] < 15512955586897510413))))))))))))))))))) { var b uint64 @@ -990,7 +1175,7 @@ func _negGeneric(z, x *Element) { func _reduceGeneric(z *Element) { - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[9] < 82862755739295587 || (z[9] == 82862755739295587 && (z[8] < 18165857675053050549 || (z[8] == 18165857675053050549 && (z[7] < 12176845843281334983 || (z[7] == 12176845843281334983 && (z[6] < 5645674015335635503 || (z[6] == 5645674015335635503 && (z[5] < 9318693926755804304 || (z[5] == 9318693926755804304 && (z[4] < 13320134076191308873 || (z[4] == 13320134076191308873 && (z[3] < 9083347379620258823 || (z[3] == 9083347379620258823 && (z[2] < 15543556715411259941 || (z[2] == 15543556715411259941 && (z[1] < 4410884215886313276 || (z[1] == 4410884215886313276 && (z[0] < 15512955586897510413))))))))))))))))))) { var b uint64 @@ -1122,7 +1307,7 @@ func (z *Element) Exp(x Element, exponent *big.Int) *Element { } // ToMont converts z to Montgomery form -// sets and returns z = z * r^2 +// sets and returns z = z * r² func (z *Element) ToMont() *Element { return z.Mul(z, &rSquare) } @@ -1264,7 +1449,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { return z } -// setBigInt assumes 0 <= v < q +// setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() @@ -1407,237 +1592,676 @@ func (z *Element) Sqrt(x *Element) *Element { return nil } -// Inverse z = x^-1 mod q -// Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" -// if x == 0, sets and returns z = x +func max(a int, b int) int { + if a > b { + return a + } + return b +} + +func min(a int, b int) int { + if a < b { + return a + } + return b +} + +const updateFactorsConversionBias int64 = 0x7fffffff7fffffff // (2³¹ - 1)(2³² + 1) +const updateFactorIdentityMatrixRow0 = 1 +const updateFactorIdentityMatrixRow1 = 1 << 32 + +func updateFactorsDecompose(c int64) (int64, int64) { + c += updateFactorsConversionBias + const low32BitsFilter int64 = 0xFFFFFFFF + f := c&low32BitsFilter - 0x7FFFFFFF + g := c>>32&low32BitsFilter - 0x7FFFFFFF + return f, g +} + +const k = 32 // word size / 2 +const signBitSelector = uint64(1) << 63 +const approxLowBitsN = k - 1 +const approxHighBitsN = k + 1 +const inversionCorrectionFactorWord0 = 17335095338408674528 +const inversionCorrectionFactorWord1 = 1935156146725576072 +const inversionCorrectionFactorWord2 = 12310223143035529855 +const inversionCorrectionFactorWord3 = 14776388015283991997 +const inversionCorrectionFactorWord4 = 13807356859349388480 +const inversionCorrectionFactorWord5 = 10412247811534140886 +const inversionCorrectionFactorWord6 = 1537112855455741892 +const inversionCorrectionFactorWord7 = 5281081904757642912 +const inversionCorrectionFactorWord8 = 14734303888675989218 +const inversionCorrectionFactorWord9 = 64202171737444348 + +const invIterationsN = 42 + +// Inverse z = x⁻¹ mod q +// Implements "Optimized Binary GCD for Modular Inversion" +// https://github.com/pornin/bingcd/blob/main/doc/bingcd.pdf func (z *Element) Inverse(x *Element) *Element { if x.IsZero() { z.SetZero() return z } - // initialize u = q - var u = Element{ - 15512955586897510413, - 4410884215886313276, - 15543556715411259941, - 9083347379620258823, - 13320134076191308873, - 9318693926755804304, - 5645674015335635503, - 12176845843281334983, - 18165857675053050549, - 82862755739295587, - } - - // initialize s = r^2 - var s = Element{ - 7358459907925294924, - 14414180951914241931, - 16619482658146888203, - 760736596725344926, - 12753071240931896792, - 13425190760400245818, - 12591714441439252728, - 15325516497554583360, - 5301152003049442834, - 35368377961363834, - } - - // r = 0 - r := Element{} - - v := *x - - var carry, borrow uint64 - var bigger bool - - for { - for v[0]&1 == 0 { - - // v = v >> 1 - - v[0] = v[0]>>1 | v[1]<<63 - v[1] = v[1]>>1 | v[2]<<63 - v[2] = v[2]>>1 | v[3]<<63 - v[3] = v[3]>>1 | v[4]<<63 - v[4] = v[4]>>1 | v[5]<<63 - v[5] = v[5]>>1 | v[6]<<63 - v[6] = v[6]>>1 | v[7]<<63 - v[7] = v[7]>>1 | v[8]<<63 - v[8] = v[8]>>1 | v[9]<<63 - v[9] >>= 1 - - if s[0]&1 == 1 { - - // s = s + q - s[0], carry = bits.Add64(s[0], 15512955586897510413, 0) - s[1], carry = bits.Add64(s[1], 4410884215886313276, carry) - s[2], carry = bits.Add64(s[2], 15543556715411259941, carry) - s[3], carry = bits.Add64(s[3], 9083347379620258823, carry) - s[4], carry = bits.Add64(s[4], 13320134076191308873, carry) - s[5], carry = bits.Add64(s[5], 9318693926755804304, carry) - s[6], carry = bits.Add64(s[6], 5645674015335635503, carry) - s[7], carry = bits.Add64(s[7], 12176845843281334983, carry) - s[8], carry = bits.Add64(s[8], 18165857675053050549, carry) - s[9], _ = bits.Add64(s[9], 82862755739295587, carry) + a := *x + b := Element{ + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, + qElementWord4, + qElementWord5, + qElementWord6, + qElementWord7, + qElementWord8, + qElementWord9, + } // b := q + + u := Element{1} + + // Update factors: we get [u; v]:= [f0 g0; f1 g1] [u; v] + // c_i = f_i + 2³¹ - 1 + 2³² * (g_i + 2³¹ - 1) + var c0, c1 int64 + + // Saved update factors to reduce the number of field multiplications + var pf0, pf1, pg0, pg1 int64 + + var i uint + + var v, s Element + + // Since u,v are updated every other iteration, we must make sure we terminate after evenly many iterations + // This also lets us get away with half as many updates to u,v + // To make this constant-time-ish, replace the condition with i < invIterationsN + for i = 0; i&1 == 1 || !a.IsZero(); i++ { + n := max(a.BitLen(), b.BitLen()) + aApprox, bApprox := approximate(&a, n), approximate(&b, n) + // After 0 iterations, we have f₀ ≤ 2⁰ and f₁ < 2⁰ + // f0, g0, f1, g1 = 1, 0, 0, 1 + c0, c1 = updateFactorIdentityMatrixRow0, updateFactorIdentityMatrixRow1 + + for j := 0; j < approxLowBitsN; j++ { + + if aApprox&1 == 0 { + aApprox /= 2 + } else { + s, borrow := bits.Sub64(aApprox, bApprox, 0) + if borrow == 1 { + s = bApprox - aApprox + bApprox = aApprox + c0, c1 = c1, c0 + } + + aApprox = s / 2 + c0 = c0 - c1 + + // Now |f₀| < 2ʲ + 2ʲ = 2ʲ⁺¹ + // |f₁| ≤ 2ʲ still } - // s = s >> 1 + c1 *= 2 + // |f₁| ≤ 2ʲ⁺¹ + } - s[0] = s[0]>>1 | s[1]<<63 - s[1] = s[1]>>1 | s[2]<<63 - s[2] = s[2]>>1 | s[3]<<63 - s[3] = s[3]>>1 | s[4]<<63 - s[4] = s[4]>>1 | s[5]<<63 - s[5] = s[5]>>1 | s[6]<<63 - s[6] = s[6]>>1 | s[7]<<63 - s[7] = s[7]>>1 | s[8]<<63 - s[8] = s[8]>>1 | s[9]<<63 - s[9] >>= 1 + s = a + var g0 int64 + // from this point on c0 aliases for f0 + c0, g0 = updateFactorsDecompose(c0) + aHi := a.linearCombNonModular(&s, c0, &b, g0) + if aHi&signBitSelector != 0 { + // if aHi < 0 + c0, g0 = -c0, -g0 + aHi = a.neg(&a, aHi) + } + // right-shift a by k-1 bits + a[0] = (a[0] >> approxLowBitsN) | ((a[1]) << approxHighBitsN) + a[1] = (a[1] >> approxLowBitsN) | ((a[2]) << approxHighBitsN) + a[2] = (a[2] >> approxLowBitsN) | ((a[3]) << approxHighBitsN) + a[3] = (a[3] >> approxLowBitsN) | ((a[4]) << approxHighBitsN) + a[4] = (a[4] >> approxLowBitsN) | ((a[5]) << approxHighBitsN) + a[5] = (a[5] >> approxLowBitsN) | ((a[6]) << approxHighBitsN) + a[6] = (a[6] >> approxLowBitsN) | ((a[7]) << approxHighBitsN) + a[7] = (a[7] >> approxLowBitsN) | ((a[8]) << approxHighBitsN) + a[8] = (a[8] >> approxLowBitsN) | ((a[9]) << approxHighBitsN) + a[9] = (a[9] >> approxLowBitsN) | (aHi << approxHighBitsN) + + var f1 int64 + // from this point on c1 aliases for g0 + f1, c1 = updateFactorsDecompose(c1) + bHi := b.linearCombNonModular(&s, f1, &b, c1) + if bHi&signBitSelector != 0 { + // if bHi < 0 + f1, c1 = -f1, -c1 + bHi = b.neg(&b, bHi) } - for u[0]&1 == 0 { - - // u = u >> 1 - - u[0] = u[0]>>1 | u[1]<<63 - u[1] = u[1]>>1 | u[2]<<63 - u[2] = u[2]>>1 | u[3]<<63 - u[3] = u[3]>>1 | u[4]<<63 - u[4] = u[4]>>1 | u[5]<<63 - u[5] = u[5]>>1 | u[6]<<63 - u[6] = u[6]>>1 | u[7]<<63 - u[7] = u[7]>>1 | u[8]<<63 - u[8] = u[8]>>1 | u[9]<<63 - u[9] >>= 1 - - if r[0]&1 == 1 { - - // r = r + q - r[0], carry = bits.Add64(r[0], 15512955586897510413, 0) - r[1], carry = bits.Add64(r[1], 4410884215886313276, carry) - r[2], carry = bits.Add64(r[2], 15543556715411259941, carry) - r[3], carry = bits.Add64(r[3], 9083347379620258823, carry) - r[4], carry = bits.Add64(r[4], 13320134076191308873, carry) - r[5], carry = bits.Add64(r[5], 9318693926755804304, carry) - r[6], carry = bits.Add64(r[6], 5645674015335635503, carry) - r[7], carry = bits.Add64(r[7], 12176845843281334983, carry) - r[8], carry = bits.Add64(r[8], 18165857675053050549, carry) - r[9], _ = bits.Add64(r[9], 82862755739295587, carry) + // right-shift b by k-1 bits + b[0] = (b[0] >> approxLowBitsN) | ((b[1]) << approxHighBitsN) + b[1] = (b[1] >> approxLowBitsN) | ((b[2]) << approxHighBitsN) + b[2] = (b[2] >> approxLowBitsN) | ((b[3]) << approxHighBitsN) + b[3] = (b[3] >> approxLowBitsN) | ((b[4]) << approxHighBitsN) + b[4] = (b[4] >> approxLowBitsN) | ((b[5]) << approxHighBitsN) + b[5] = (b[5] >> approxLowBitsN) | ((b[6]) << approxHighBitsN) + b[6] = (b[6] >> approxLowBitsN) | ((b[7]) << approxHighBitsN) + b[7] = (b[7] >> approxLowBitsN) | ((b[8]) << approxHighBitsN) + b[8] = (b[8] >> approxLowBitsN) | ((b[9]) << approxHighBitsN) + b[9] = (b[9] >> approxLowBitsN) | (bHi << approxHighBitsN) + + if i&1 == 1 { + // Combine current update factors with previously stored ones + // [f₀, g₀; f₁, g₁] ← [f₀, g₀; f₁, g₀] [pf₀, pg₀; pf₀, pg₀] + // We have |f₀|, |g₀|, |pf₀|, |pf₁| ≤ 2ᵏ⁻¹, and that |pf_i| < 2ᵏ⁻¹ for i ∈ {0, 1} + // Then for the new value we get |f₀| < 2ᵏ⁻¹ × 2ᵏ⁻¹ + 2ᵏ⁻¹ × 2ᵏ⁻¹ = 2²ᵏ⁻¹ + // Which leaves us with an extra bit for the sign + + // c0 aliases f0, c1 aliases g1 + c0, g0, f1, c1 = c0*pf0+g0*pf1, + c0*pg0+g0*pg1, + f1*pf0+c1*pf1, + f1*pg0+c1*pg1 + + s = u + u.linearCombSosSigned(&u, c0, &v, g0) + v.linearCombSosSigned(&s, f1, &v, c1) - } + } else { + // Save update factors + pf0, pg0, pf1, pg1 = c0, g0, f1, c1 + } + } + + // For every iteration that we miss, v is not being multiplied by 2²ᵏ⁻² + const pSq int64 = 1 << (2 * (k - 1)) + // If the function is constant-time ish, this loop will not run (probably no need to take it out explicitly) + for ; i < invIterationsN; i += 2 { + v.mulWSigned(&v, pSq) + } + + z.Mul(&v, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + inversionCorrectionFactorWord5, + inversionCorrectionFactorWord6, + inversionCorrectionFactorWord7, + inversionCorrectionFactorWord8, + inversionCorrectionFactorWord9, + }) + return z +} + +// approximate a big number x into a single 64 bit word using its uppermost and lowermost bits +// if x fits in a word as is, no approximation necessary +func approximate(x *Element, nBits int) uint64 { + + if nBits <= 64 { + return x[0] + } + + const mask = (uint64(1) << (k - 1)) - 1 // k-1 ones + lo := mask & x[0] + + hiWordIndex := (nBits - 1) / 64 + + hiWordBitsAvailable := nBits - hiWordIndex*64 + hiWordBitsUsed := min(hiWordBitsAvailable, approxHighBitsN) + + mask_ := uint64(^((1 << (hiWordBitsAvailable - hiWordBitsUsed)) - 1)) + hi := (x[hiWordIndex] & mask_) << (64 - hiWordBitsAvailable) + + mask_ = ^(1<<(approxLowBitsN+hiWordBitsUsed) - 1) + mid := (mask_ & x[hiWordIndex-1]) >> hiWordBitsUsed + + return lo | mid | hi +} - // r = r >> 1 +func (z *Element) linearCombSosSigned(x *Element, xC int64, y *Element, yC int64) { + hi := z.linearCombNonModular(x, xC, y, yC) + z.montReduceSigned(z, hi) +} - r[0] = r[0]>>1 | r[1]<<63 - r[1] = r[1]>>1 | r[2]<<63 - r[2] = r[2]>>1 | r[3]<<63 - r[3] = r[3]>>1 | r[4]<<63 - r[4] = r[4]>>1 | r[5]<<63 - r[5] = r[5]>>1 | r[6]<<63 - r[6] = r[6]>>1 | r[7]<<63 - r[7] = r[7]>>1 | r[8]<<63 - r[8] = r[8]>>1 | r[9]<<63 - r[9] >>= 1 +// montReduceSigned SOS algorithm; xHi must be at most 63 bits long. Last bit of xHi may be used as a sign bit +func (z *Element) montReduceSigned(x *Element, xHi uint64) { + + const signBitRemover = ^signBitSelector + neg := xHi&signBitSelector != 0 + // the SOS implementation requires that most significant bit is 0 + // Let X be xHi*r + x + // note that if X is negative we would have initially stored it as 2⁶⁴ r + X + xHi &= signBitRemover + // with this a negative X is now represented as 2⁶³ r + X + + var t [2*Limbs - 1]uint64 + var C uint64 + + m := x[0] * qInvNegLsw + + C = madd0(m, qElementWord0, x[0]) + C, t[1] = madd2(m, qElementWord1, x[1], C) + C, t[2] = madd2(m, qElementWord2, x[2], C) + C, t[3] = madd2(m, qElementWord3, x[3], C) + C, t[4] = madd2(m, qElementWord4, x[4], C) + C, t[5] = madd2(m, qElementWord5, x[5], C) + C, t[6] = madd2(m, qElementWord6, x[6], C) + C, t[7] = madd2(m, qElementWord7, x[7], C) + C, t[8] = madd2(m, qElementWord8, x[8], C) + C, t[9] = madd2(m, qElementWord9, x[9], C) + + // the high word of m * qElement[9] is at most 62 bits + // x[9] + C is at most 65 bits (high word at most 1 bit) + // Thus the resulting C will be at most 63 bits + t[10] = xHi + C + // xHi and C are 63 bits, therefore no overflow + { + const i = 1 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + C, t[i+6] = madd2(m, qElementWord6, t[i+6], C) + C, t[i+7] = madd2(m, qElementWord7, t[i+7], C) + C, t[i+8] = madd2(m, qElementWord8, t[i+8], C) + C, t[i+9] = madd2(m, qElementWord9, t[i+9], C) + + t[i+Limbs] += C + } + { + const i = 2 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + C, t[i+6] = madd2(m, qElementWord6, t[i+6], C) + C, t[i+7] = madd2(m, qElementWord7, t[i+7], C) + C, t[i+8] = madd2(m, qElementWord8, t[i+8], C) + C, t[i+9] = madd2(m, qElementWord9, t[i+9], C) + + t[i+Limbs] += C + } + { + const i = 3 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + C, t[i+6] = madd2(m, qElementWord6, t[i+6], C) + C, t[i+7] = madd2(m, qElementWord7, t[i+7], C) + C, t[i+8] = madd2(m, qElementWord8, t[i+8], C) + C, t[i+9] = madd2(m, qElementWord9, t[i+9], C) + + t[i+Limbs] += C + } + { + const i = 4 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + C, t[i+6] = madd2(m, qElementWord6, t[i+6], C) + C, t[i+7] = madd2(m, qElementWord7, t[i+7], C) + C, t[i+8] = madd2(m, qElementWord8, t[i+8], C) + C, t[i+9] = madd2(m, qElementWord9, t[i+9], C) + + t[i+Limbs] += C + } + { + const i = 5 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + C, t[i+6] = madd2(m, qElementWord6, t[i+6], C) + C, t[i+7] = madd2(m, qElementWord7, t[i+7], C) + C, t[i+8] = madd2(m, qElementWord8, t[i+8], C) + C, t[i+9] = madd2(m, qElementWord9, t[i+9], C) + + t[i+Limbs] += C + } + { + const i = 6 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + C, t[i+6] = madd2(m, qElementWord6, t[i+6], C) + C, t[i+7] = madd2(m, qElementWord7, t[i+7], C) + C, t[i+8] = madd2(m, qElementWord8, t[i+8], C) + C, t[i+9] = madd2(m, qElementWord9, t[i+9], C) + + t[i+Limbs] += C + } + { + const i = 7 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + C, t[i+6] = madd2(m, qElementWord6, t[i+6], C) + C, t[i+7] = madd2(m, qElementWord7, t[i+7], C) + C, t[i+8] = madd2(m, qElementWord8, t[i+8], C) + C, t[i+9] = madd2(m, qElementWord9, t[i+9], C) + + t[i+Limbs] += C + } + { + const i = 8 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + C, t[i+6] = madd2(m, qElementWord6, t[i+6], C) + C, t[i+7] = madd2(m, qElementWord7, t[i+7], C) + C, t[i+8] = madd2(m, qElementWord8, t[i+8], C) + C, t[i+9] = madd2(m, qElementWord9, t[i+9], C) + + t[i+Limbs] += C + } + { + const i = 9 + m := t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, z[0] = madd2(m, qElementWord1, t[i+1], C) + C, z[1] = madd2(m, qElementWord2, t[i+2], C) + C, z[2] = madd2(m, qElementWord3, t[i+3], C) + C, z[3] = madd2(m, qElementWord4, t[i+4], C) + C, z[4] = madd2(m, qElementWord5, t[i+5], C) + C, z[5] = madd2(m, qElementWord6, t[i+6], C) + C, z[6] = madd2(m, qElementWord7, t[i+7], C) + C, z[7] = madd2(m, qElementWord8, t[i+8], C) + z[9], z[8] = madd2(m, qElementWord9, t[i+9], C) + } + + // if z > q → z -= q + // note: this is NOT constant time + if !(z[9] < 82862755739295587 || (z[9] == 82862755739295587 && (z[8] < 18165857675053050549 || (z[8] == 18165857675053050549 && (z[7] < 12176845843281334983 || (z[7] == 12176845843281334983 && (z[6] < 5645674015335635503 || (z[6] == 5645674015335635503 && (z[5] < 9318693926755804304 || (z[5] == 9318693926755804304 && (z[4] < 13320134076191308873 || (z[4] == 13320134076191308873 && (z[3] < 9083347379620258823 || (z[3] == 9083347379620258823 && (z[2] < 15543556715411259941 || (z[2] == 15543556715411259941 && (z[1] < 4410884215886313276 || (z[1] == 4410884215886313276 && (z[0] < 15512955586897510413))))))))))))))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 15512955586897510413, 0) + z[1], b = bits.Sub64(z[1], 4410884215886313276, b) + z[2], b = bits.Sub64(z[2], 15543556715411259941, b) + z[3], b = bits.Sub64(z[3], 9083347379620258823, b) + z[4], b = bits.Sub64(z[4], 13320134076191308873, b) + z[5], b = bits.Sub64(z[5], 9318693926755804304, b) + z[6], b = bits.Sub64(z[6], 5645674015335635503, b) + z[7], b = bits.Sub64(z[7], 12176845843281334983, b) + z[8], b = bits.Sub64(z[8], 18165857675053050549, b) + z[9], _ = bits.Sub64(z[9], 82862755739295587, b) + } + if neg { + // We have computed ( 2⁶³ r + X ) r⁻¹ = 2⁶³ + X r⁻¹ instead + var b uint64 + z[0], b = bits.Sub64(z[0], signBitSelector, 0) + z[1], b = bits.Sub64(z[1], 0, b) + z[2], b = bits.Sub64(z[2], 0, b) + z[3], b = bits.Sub64(z[3], 0, b) + z[4], b = bits.Sub64(z[4], 0, b) + z[5], b = bits.Sub64(z[5], 0, b) + z[6], b = bits.Sub64(z[6], 0, b) + z[7], b = bits.Sub64(z[7], 0, b) + z[8], b = bits.Sub64(z[8], 0, b) + z[9], b = bits.Sub64(z[9], 0, b) + + // Occurs iff x == 0 && xHi < 0, i.e. X = rX' for -2⁶³ ≤ X' < 0 + if b != 0 { + // z[9] = -1 + // negative: add q + const neg1 = 0xFFFFFFFFFFFFFFFF + + b = 0 + z[0], b = bits.Add64(z[0], qElementWord0, b) + z[1], b = bits.Add64(z[1], qElementWord1, b) + z[2], b = bits.Add64(z[2], qElementWord2, b) + z[3], b = bits.Add64(z[3], qElementWord3, b) + z[4], b = bits.Add64(z[4], qElementWord4, b) + z[5], b = bits.Add64(z[5], qElementWord5, b) + z[6], b = bits.Add64(z[6], qElementWord6, b) + z[7], b = bits.Add64(z[7], qElementWord7, b) + z[8], b = bits.Add64(z[8], qElementWord8, b) + z[9], _ = bits.Add64(neg1, qElementWord9, b) } + } +} - // v >= u - bigger = !(v[9] < u[9] || (v[9] == u[9] && (v[8] < u[8] || (v[8] == u[8] && (v[7] < u[7] || (v[7] == u[7] && (v[6] < u[6] || (v[6] == u[6] && (v[5] < u[5] || (v[5] == u[5] && (v[4] < u[4] || (v[4] == u[4] && (v[3] < u[3] || (v[3] == u[3] && (v[2] < u[2] || (v[2] == u[2] && (v[1] < u[1] || (v[1] == u[1] && (v[0] < u[0]))))))))))))))))))) - - if bigger { - - // v = v - u - v[0], borrow = bits.Sub64(v[0], u[0], 0) - v[1], borrow = bits.Sub64(v[1], u[1], borrow) - v[2], borrow = bits.Sub64(v[2], u[2], borrow) - v[3], borrow = bits.Sub64(v[3], u[3], borrow) - v[4], borrow = bits.Sub64(v[4], u[4], borrow) - v[5], borrow = bits.Sub64(v[5], u[5], borrow) - v[6], borrow = bits.Sub64(v[6], u[6], borrow) - v[7], borrow = bits.Sub64(v[7], u[7], borrow) - v[8], borrow = bits.Sub64(v[8], u[8], borrow) - v[9], _ = bits.Sub64(v[9], u[9], borrow) - - // s = s - r - s[0], borrow = bits.Sub64(s[0], r[0], 0) - s[1], borrow = bits.Sub64(s[1], r[1], borrow) - s[2], borrow = bits.Sub64(s[2], r[2], borrow) - s[3], borrow = bits.Sub64(s[3], r[3], borrow) - s[4], borrow = bits.Sub64(s[4], r[4], borrow) - s[5], borrow = bits.Sub64(s[5], r[5], borrow) - s[6], borrow = bits.Sub64(s[6], r[6], borrow) - s[7], borrow = bits.Sub64(s[7], r[7], borrow) - s[8], borrow = bits.Sub64(s[8], r[8], borrow) - s[9], borrow = bits.Sub64(s[9], r[9], borrow) - - if borrow == 1 { - - // s = s + q - s[0], carry = bits.Add64(s[0], 15512955586897510413, 0) - s[1], carry = bits.Add64(s[1], 4410884215886313276, carry) - s[2], carry = bits.Add64(s[2], 15543556715411259941, carry) - s[3], carry = bits.Add64(s[3], 9083347379620258823, carry) - s[4], carry = bits.Add64(s[4], 13320134076191308873, carry) - s[5], carry = bits.Add64(s[5], 9318693926755804304, carry) - s[6], carry = bits.Add64(s[6], 5645674015335635503, carry) - s[7], carry = bits.Add64(s[7], 12176845843281334983, carry) - s[8], carry = bits.Add64(s[8], 18165857675053050549, carry) - s[9], _ = bits.Add64(s[9], 82862755739295587, carry) +// mulWSigned mul word signed (w/ montgomery reduction) +func (z *Element) mulWSigned(x *Element, y int64) { + m := y >> 63 + _mulWGeneric(z, x, uint64((y^m)-m)) + // multiply by abs(y) + if y < 0 { + z.Neg(z) + } +} - } - } else { +func (z *Element) neg(x *Element, xHi uint64) uint64 { + var b uint64 - // u = u - v - u[0], borrow = bits.Sub64(u[0], v[0], 0) - u[1], borrow = bits.Sub64(u[1], v[1], borrow) - u[2], borrow = bits.Sub64(u[2], v[2], borrow) - u[3], borrow = bits.Sub64(u[3], v[3], borrow) - u[4], borrow = bits.Sub64(u[4], v[4], borrow) - u[5], borrow = bits.Sub64(u[5], v[5], borrow) - u[6], borrow = bits.Sub64(u[6], v[6], borrow) - u[7], borrow = bits.Sub64(u[7], v[7], borrow) - u[8], borrow = bits.Sub64(u[8], v[8], borrow) - u[9], _ = bits.Sub64(u[9], v[9], borrow) - - // r = r - s - r[0], borrow = bits.Sub64(r[0], s[0], 0) - r[1], borrow = bits.Sub64(r[1], s[1], borrow) - r[2], borrow = bits.Sub64(r[2], s[2], borrow) - r[3], borrow = bits.Sub64(r[3], s[3], borrow) - r[4], borrow = bits.Sub64(r[4], s[4], borrow) - r[5], borrow = bits.Sub64(r[5], s[5], borrow) - r[6], borrow = bits.Sub64(r[6], s[6], borrow) - r[7], borrow = bits.Sub64(r[7], s[7], borrow) - r[8], borrow = bits.Sub64(r[8], s[8], borrow) - r[9], borrow = bits.Sub64(r[9], s[9], borrow) - - if borrow == 1 { - - // r = r + q - r[0], carry = bits.Add64(r[0], 15512955586897510413, 0) - r[1], carry = bits.Add64(r[1], 4410884215886313276, carry) - r[2], carry = bits.Add64(r[2], 15543556715411259941, carry) - r[3], carry = bits.Add64(r[3], 9083347379620258823, carry) - r[4], carry = bits.Add64(r[4], 13320134076191308873, carry) - r[5], carry = bits.Add64(r[5], 9318693926755804304, carry) - r[6], carry = bits.Add64(r[6], 5645674015335635503, carry) - r[7], carry = bits.Add64(r[7], 12176845843281334983, carry) - r[8], carry = bits.Add64(r[8], 18165857675053050549, carry) - r[9], _ = bits.Add64(r[9], 82862755739295587, carry) + z[0], b = bits.Sub64(0, x[0], 0) + z[1], b = bits.Sub64(0, x[1], b) + z[2], b = bits.Sub64(0, x[2], b) + z[3], b = bits.Sub64(0, x[3], b) + z[4], b = bits.Sub64(0, x[4], b) + z[5], b = bits.Sub64(0, x[5], b) + z[6], b = bits.Sub64(0, x[6], b) + z[7], b = bits.Sub64(0, x[7], b) + z[8], b = bits.Sub64(0, x[8], b) + z[9], b = bits.Sub64(0, x[9], b) + xHi, _ = bits.Sub64(0, xHi, b) + + return xHi +} - } +// regular multiplication by one word regular (non montgomery) +// Fewer additions than the branch-free for positive y. Could be faster on some architectures +func (z *Element) mulWRegular(x *Element, y int64) uint64 { + + // w := abs(y) + m := y >> 63 + w := uint64((y ^ m) - m) + + var c uint64 + c, z[0] = bits.Mul64(x[0], w) + c, z[1] = madd1(x[1], w, c) + c, z[2] = madd1(x[2], w, c) + c, z[3] = madd1(x[3], w, c) + c, z[4] = madd1(x[4], w, c) + c, z[5] = madd1(x[5], w, c) + c, z[6] = madd1(x[6], w, c) + c, z[7] = madd1(x[7], w, c) + c, z[8] = madd1(x[8], w, c) + c, z[9] = madd1(x[9], w, c) + + if y < 0 { + c = z.neg(z, c) + } + + return c +} + +/* +Removed: seems slower +// mulWRegular branch-free regular multiplication by one word (non montgomery) +func (z *Element) mulWRegularBf(x *Element, y int64) uint64 { + + w := uint64(y) + allNeg := uint64(y >> 63) // -1 if y < 0, 0 o.w + + // s[0], s[1] so results are not stored immediately in z. + // x[i] will be needed in the i+1 th iteration. We don't want to overwrite it in case x = z + var s [2]uint64 + var h [2]uint64 + + h[0], s[0] = bits.Mul64(x[0], w) + + c := uint64(0) + b := uint64(0) + + { + const curI = 1 % 2 + const prevI = 1 - curI + const iMinusOne = 1 - 1 + + h[curI], s[curI] = bits.Mul64(x[1], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } + + { + const curI = 2 % 2 + const prevI = 1 - curI + const iMinusOne = 2 - 1 + + h[curI], s[curI] = bits.Mul64(x[2], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } + + { + const curI = 3 % 2 + const prevI = 1 - curI + const iMinusOne = 3 - 1 + + h[curI], s[curI] = bits.Mul64(x[3], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } + + { + const curI = 4 % 2 + const prevI = 1 - curI + const iMinusOne = 4 - 1 + + h[curI], s[curI] = bits.Mul64(x[4], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } + + { + const curI = 5 % 2 + const prevI = 1 - curI + const iMinusOne = 5 - 1 + + h[curI], s[curI] = bits.Mul64(x[5], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (u[0] == 1) && (u[9]|u[8]|u[7]|u[6]|u[5]|u[4]|u[3]|u[2]|u[1]) == 0 { - z.Set(&r) - return z + + { + const curI = 6 % 2 + const prevI = 1 - curI + const iMinusOne = 6 - 1 + + h[curI], s[curI] = bits.Mul64(x[6], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (v[0] == 1) && (v[9]|v[8]|v[7]|v[6]|v[5]|v[4]|v[3]|v[2]|v[1]) == 0 { - z.Set(&s) - return z + + { + const curI = 7 % 2 + const prevI = 1 - curI + const iMinusOne = 7 - 1 + + h[curI], s[curI] = bits.Mul64(x[7], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } + + { + const curI = 8 % 2 + const prevI = 1 - curI + const iMinusOne = 8 - 1 + + h[curI], s[curI] = bits.Mul64(x[8], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } + + { + const curI = 9 % 2 + const prevI = 1 - curI + const iMinusOne = 9 - 1 + + h[curI], s[curI] = bits.Mul64(x[9], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } + { + const curI = 10 % 2 + const prevI = 1 - curI + const iMinusOne = 9 + + s[curI], _ = bits.Sub64(h[prevI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + + return s[curI] + c } +}*/ + +// Requires NoCarry +func (z *Element) linearCombNonModular(x *Element, xC int64, y *Element, yC int64) uint64 { + var yTimes Element + + yHi := yTimes.mulWRegular(y, yC) + xHi := z.mulWRegular(x, xC) + + carry := uint64(0) + z[0], carry = bits.Add64(z[0], yTimes[0], carry) + z[1], carry = bits.Add64(z[1], yTimes[1], carry) + z[2], carry = bits.Add64(z[2], yTimes[2], carry) + z[3], carry = bits.Add64(z[3], yTimes[3], carry) + z[4], carry = bits.Add64(z[4], yTimes[4], carry) + z[5], carry = bits.Add64(z[5], yTimes[5], carry) + z[6], carry = bits.Add64(z[6], yTimes[6], carry) + z[7], carry = bits.Add64(z[7], yTimes[7], carry) + z[8], carry = bits.Add64(z[8], yTimes[8], carry) + z[9], carry = bits.Add64(z[9], yTimes[9], carry) + + yHi, _ = bits.Add64(xHi, yHi, carry) + return yHi } diff --git a/ecc/bw6-633/fp/element_test.go b/ecc/bw6-633/fp/element_test.go index b71c5d76e0..0f314bdaa3 100644 --- a/ecc/bw6-633/fp/element_test.go +++ b/ecc/bw6-633/fp/element_test.go @@ -22,6 +22,7 @@ import ( "fmt" "math/big" "math/bits" + mrand "math/rand" "testing" "github.com/leanovate/gopter" @@ -283,7 +284,7 @@ var staticTestValues []Element func init() { staticTestValues = append(staticTestValues, Element{}) // zero staticTestValues = append(staticTestValues, One()) // one - staticTestValues = append(staticTestValues, rSquare) // r^2 + staticTestValues = append(staticTestValues, rSquare) // r² var e, one Element one.SetOne() e.Sub(&qElement, &one) @@ -2046,3 +2047,512 @@ func genFull() gopter.Gen { return genResult } } + +func TestElementInversionApproximation(t *testing.T) { + var x Element + for i := 0; i < 1000; i++ { + x.SetRandom() + + // Normally small elements are unlikely. Here we give them a higher chance + xZeros := mrand.Int() % Limbs + for j := 1; j < xZeros; j++ { + x[Limbs-j] = 0 + } + + a := approximate(&x, x.BitLen()) + aRef := approximateRef(&x) + + if a != aRef { + t.Error("Approximation mismatch") + } + } +} + +func TestElementInversionCorrectionFactorFormula(t *testing.T) { + const kLimbs = k * Limbs + const power = kLimbs*6 + invIterationsN*(kLimbs-k+1) + factorInt := big.NewInt(1) + factorInt.Lsh(factorInt, power) + factorInt.Mod(factorInt, Modulus()) + + var refFactorInt big.Int + inversionCorrectionFactor := Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + inversionCorrectionFactorWord5, + inversionCorrectionFactorWord6, + inversionCorrectionFactorWord7, + inversionCorrectionFactorWord8, + inversionCorrectionFactorWord9, + } + inversionCorrectionFactor.ToBigInt(&refFactorInt) + + if refFactorInt.Cmp(factorInt) != 0 { + t.Error("mismatch") + } +} + +func TestElementLinearComb(t *testing.T) { + var x Element + var y Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + y.SetRandom() + testLinearComb(t, &x, mrand.Int63(), &y, mrand.Int63()) + } +} + +// Probably unnecessary post-dev. In case the output of inv is wrong, this checks whether it's only off by a constant factor. +func TestElementInversionCorrectionFactor(t *testing.T) { + + // (1/x)/inv(x) = (1/1)/inv(1) ⇔ inv(1) = x inv(x) + + var one Element + var oneInv Element + one.SetOne() + oneInv.Inverse(&one) + + for i := 0; i < 100; i++ { + var x Element + var xInv Element + x.SetRandom() + xInv.Inverse(&x) + + x.Mul(&x, &xInv) + if !x.Equal(&oneInv) { + t.Error("Correction factor is inconsistent") + } + } + + if !oneInv.Equal(&one) { + var i big.Int + oneInv.ToBigIntRegular(&i) // no montgomery + i.ModInverse(&i, Modulus()) + var fac Element + fac.setBigInt(&i) // back to montgomery + + var facTimesFac Element + facTimesFac.Mul(&fac, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + inversionCorrectionFactorWord5, + inversionCorrectionFactorWord6, + inversionCorrectionFactorWord7, + inversionCorrectionFactorWord8, + inversionCorrectionFactorWord9, + }) + + t.Error("Correction factor is consistently off by", fac, "Should be", facTimesFac) + } +} + +func TestElementBigNumNeg(t *testing.T) { + var a Element + aHi := a.neg(&a, 0) + if !a.IsZero() || aHi != 0 { + t.Error("-0 != 0") + } +} + +func TestElementBigNumWMul(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + w := mrand.Int63() + testBigNumWMul(t, &x, w) + } +} + +func TestElementVeryBigIntConversion(t *testing.T) { + xHi := mrand.Uint64() + var x Element + x.SetRandom() + var xInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + x.assertMatchVeryBigInt(t, xHi, &xInt) +} + +func TestElementMontReducePos(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64() & ^signBitSelector) + } +} + +func TestElementMontReduceNeg(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64()|signBitSelector) + } +} + +func TestElementMontNegMultipleOfR(t *testing.T) { + var zero Element + + for i := 0; i < 1000; i++ { + testMontReduceSigned(t, &zero, mrand.Uint64()|signBitSelector) + } +} + +//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +func TestUpdateFactorSubtraction(t *testing.T) { + for i := 0; i < 1000; i++ { + + f0, g0 := randomizeUpdateFactors() + f1, g1 := randomizeUpdateFactors() + + for f0-f1 > 1<<31 || f0-f1 <= -1<<31 { + f1 /= 2 + } + + for g0-g1 > 1<<31 || g0-g1 <= -1<<31 { + g1 /= 2 + } + + c0 := updateFactorsCompose(f0, g0) + c1 := updateFactorsCompose(f1, g1) + + cRes := c0 - c1 + fRes, gRes := updateFactorsDecompose(cRes) + + if fRes != f0-f1 || gRes != g0-g1 { + t.Error(i) + } + } +} + +func TestUpdateFactorsDouble(t *testing.T) { + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f > 1<<30 || f < (-1<<31+1)/2 { + f /= 2 + if g <= 1<<29 && g >= (-1<<31+1)/4 { + g *= 2 //g was kept small on f's account. Now that we're halving f, we can double g + } + } + + if g > 1<<30 || g < (-1<<31+1)/2 { + g /= 2 + + if f <= 1<<29 && f >= (-1<<31+1)/4 { + f *= 2 //f was kept small on g's account. Now that we're halving g, we can double f + } + } + + c := updateFactorsCompose(f, g) + cD := c * 2 + fD, gD := updateFactorsDecompose(cD) + + if fD != 2*f || gD != 2*g { + t.Error(i) + } + } +} + +func TestUpdateFactorsNeg(t *testing.T) { + var fMistake bool + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f == 0x80000000 || g == 0x80000000 { + // Update factors this large can only have been obtained after 31 iterations and will therefore never be negated + // We don't have capacity to store -2³¹ + // Repeat this iteration + i-- + continue + } + + c := updateFactorsCompose(f, g) + nc := -c + nf, ng := updateFactorsDecompose(nc) + fMistake = fMistake || nf != -f + if nf != -f || ng != -g { + t.Errorf("Mismatch iteration #%d:\n%d, %d ->\n %d -> %d ->\n %d, %d\n Inputs in hex: %X, %X", + i, f, g, c, nc, nf, ng, f, g) + } + } + if fMistake { + t.Error("Mistake with f detected") + } else { + t.Log("All good with f") + } +} + +func TestUpdateFactorsNeg0(t *testing.T) { + c := updateFactorsCompose(0, 0) + t.Logf("c(0,0) = %X", c) + cn := -c + + if c != cn { + t.Error("Negation of zero update factors should yield the same result.") + } +} + +func TestUpdateFactorDecomposition(t *testing.T) { + var negSeen bool + + for i := 0; i < 1000; i++ { + + f, g := randomizeUpdateFactors() + + if f <= -(1<<31) || f > 1<<31 { + t.Fatal("f out of range") + } + + negSeen = negSeen || f < 0 + + c := updateFactorsCompose(f, g) + + fBack, gBack := updateFactorsDecompose(c) + + if f != fBack || g != gBack { + t.Errorf("(%d, %d) -> %d -> (%d, %d)\n", f, g, c, fBack, gBack) + } + } + + if !negSeen { + t.Fatal("No negative f factors") + } +} + +func TestUpdateFactorInitialValues(t *testing.T) { + + f0, g0 := updateFactorsDecompose(updateFactorIdentityMatrixRow0) + f1, g1 := updateFactorsDecompose(updateFactorIdentityMatrixRow1) + + if f0 != 1 || g0 != 0 || f1 != 0 || g1 != 1 { + t.Error("Update factor initial value constants are incorrect") + } +} + +func TestUpdateFactorsRandomization(t *testing.T) { + var maxLen int + + //t.Log("|f| + |g| is not to exceed", 1 << 31) + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + lf, lg := abs64T32(f), abs64T32(g) + absSum := lf + lg + if absSum >= 1<<31 { + + if absSum == 1<<31 { + maxLen++ + } else { + t.Error(i, "Sum of absolute values too large, f =", f, ",g =", g, ",|f| + |g| =", absSum) + } + } + } + + if maxLen == 0 { + t.Error("max len not observed") + } else { + t.Log(maxLen, "maxLens observed") + } +} + +func randomizeUpdateFactor(absLimit uint32) int64 { + const maxSizeLikelihood = 10 + maxSize := mrand.Intn(maxSizeLikelihood) + + absLimit64 := int64(absLimit) + var f int64 + switch maxSize { + case 0: + f = absLimit64 + case 1: + f = -absLimit64 + default: + f = int64(mrand.Uint64()%(2*uint64(absLimit64)+1)) - absLimit64 + } + + if f > 1<<31 { + return 1 << 31 + } else if f < -1<<31+1 { + return -1<<31 + 1 + } + + return f +} + +func abs64T32(f int64) uint32 { + if f >= 1<<32 || f < -1<<32 { + panic("f out of range") + } + + if f < 0 { + return uint32(-f) + } + return uint32(f) +} + +func randomizeUpdateFactors() (int64, int64) { + var f [2]int64 + b := mrand.Int() % 2 + + f[b] = randomizeUpdateFactor(1 << 31) + + //As per the paper, |f| + |g| \le 2³¹. + f[1-b] = randomizeUpdateFactor(1<<31 - abs64T32(f[b])) + + //Patching another edge case + if f[0]+f[1] == -1<<31 { + b = mrand.Int() % 2 + f[b]++ + } + + return f[0], f[1] +} + +func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { + + var p1 big.Int + x.ToBigInt(&p1) + p1.Mul(&p1, big.NewInt(xC)) + + var p2 big.Int + y.ToBigInt(&p2) + p2.Mul(&p2, big.NewInt(yC)) + + p1.Add(&p1, &p2) + p1.Mod(&p1, Modulus()) + montReduce(&p1, &p1) + + var z Element + z.linearCombSosSigned(x, xC, y, yC) + z.assertMatchVeryBigInt(t, 0, &p1) +} + +func testBigNumWMul(t *testing.T, a *Element, c int64) { + var aHi uint64 + var aTimes Element + aHi = aTimes.mulWRegular(a, c) + + assertMulProduct(t, a, c, &aTimes, aHi) +} + +func testMontReduceSigned(t *testing.T, x *Element, xHi uint64) { + var res Element + var xInt big.Int + var resInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + res.montReduceSigned(x, xHi) + montReduce(&resInt, &xInt) + res.assertMatchVeryBigInt(t, 0, &resInt) +} + +func updateFactorsCompose(f int64, g int64) int64 { + return f + g<<32 +} + +var rInv big.Int + +func montReduce(res *big.Int, x *big.Int) { + if rInv.BitLen() == 0 { // initialization + rInv.SetUint64(1) + rInv.Lsh(&rInv, Limbs*64) + rInv.ModInverse(&rInv, Modulus()) + } + res.Mul(x, &rInv) + res.Mod(res, Modulus()) +} + +func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { + z.ToBigInt(i) + var upperWord big.Int + upperWord.SetUint64(xHi) + upperWord.Lsh(&upperWord, Limbs*64) + i.Add(&upperWord, i) +} + +func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { + z.toVeryBigIntUnsigned(i, xHi) + if signBitSelector&xHi != 0 { + twosCompModulus := big.NewInt(1) + twosCompModulus.Lsh(twosCompModulus, (Limbs+1)*64) + i.Sub(i, twosCompModulus) + } +} + +func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { + var xInt big.Int + x.ToBigInt(&xInt) + + xInt.Mul(&xInt, big.NewInt(c)) + + result.assertMatchVeryBigInt(t, resultHi, &xInt) + return xInt +} + +func assertMatch(t *testing.T, w []big.Word, a uint64, index int) { + + var wI big.Word + + if index < len(w) { + wI = w[index] + } + + const filter uint64 = 0xFFFFFFFFFFFFFFFF >> (64 - bits.UintSize) + + a = a >> ((index * bits.UintSize) % 64) + a &= filter + + if uint64(wI) != a { + t.Error("Bignum mismatch: disagreement on word", index) + } +} + +func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { + + var modulus big.Int + var aIntMod big.Int + modulus.SetInt64(1) + modulus.Lsh(&modulus, (Limbs+1)*64) + aIntMod.Mod(aInt, &modulus) + + words := aIntMod.Bits() + + const steps = 64 / bits.UintSize + for i := 0; i < Limbs*steps; i++ { + assertMatch(t, words, z[i/steps], i) + } + + for i := 0; i < steps; i++ { + assertMatch(t, words, aHi, Limbs*steps+i) + } +} + +func approximateRef(x *Element) uint64 { + + var asInt big.Int + x.ToBigInt(&asInt) + n := x.BitLen() + + if n <= 64 { + return asInt.Uint64() + } + + modulus := big.NewInt(1 << 31) + var lo big.Int + lo.Mod(&asInt, modulus) + + modulus.Lsh(modulus, uint(n-64)) + var hi big.Int + hi.Div(&asInt, modulus) + hi.Lsh(&hi, 31) + + hi.Add(&hi, &lo) + return hi.Uint64() +} diff --git a/ecc/bw6-633/fr/element.go b/ecc/bw6-633/fr/element.go index 17df8c82fd..65def8d48c 100644 --- a/ecc/bw6-633/fr/element.go +++ b/ecc/bw6-633/fr/element.go @@ -63,14 +63,23 @@ func Modulus() *big.Int { } // q (modulus) +const qElementWord0 uint64 = 8063698428123676673 +const qElementWord1 uint64 = 4764498181658371330 +const qElementWord2 uint64 = 16051339359738796768 +const qElementWord3 uint64 = 15273757526516850351 +const qElementWord4 uint64 = 342900304943437392 + var qElement = Element{ - 8063698428123676673, - 4764498181658371330, - 16051339359738796768, - 15273757526516850351, - 342900304943437392, + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, + qElementWord4, } +// Used for Montgomery reduction. (qInvNeg) q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +const qInvNegLsw uint64 = 8083954730842193919 + // rSquare var rSquare = Element{ 7746605402484284438, @@ -192,7 +201,7 @@ func (z *Element) IsZero() bool { return (z[4] | z[3] | z[2] | z[1] | z[0]) == 0 } -// IsUint64 returns true if z[0] >= 0 and all other words are 0 +// IsUint64 returns true if z[0] ⩾ 0 and all other words are 0 func (z *Element) IsUint64() bool { return (z[4] | z[3] | z[2] | z[1]) == 0 } @@ -269,7 +278,7 @@ func (z *Element) SetRandom() (*Element, error) { z[4] = binary.BigEndian.Uint64(bytes[32:40]) z[4] %= 342900304943437392 - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[4] < 342900304943437392 || (z[4] == 342900304943437392 && (z[3] < 15273757526516850351 || (z[3] == 15273757526516850351 && (z[2] < 16051339359738796768 || (z[2] == 16051339359738796768 && (z[1] < 4764498181658371330 || (z[1] == 4764498181658371330 && (z[0] < 8063698428123676673))))))))) { var b uint64 @@ -443,7 +452,73 @@ func _mulGeneric(z, x, y *Element) { z[4], z[3] = madd3(m, 342900304943437392, c[0], c[2], c[1]) } - // if z > q --> z -= q + // if z > q → z -= q + // note: this is NOT constant time + if !(z[4] < 342900304943437392 || (z[4] == 342900304943437392 && (z[3] < 15273757526516850351 || (z[3] == 15273757526516850351 && (z[2] < 16051339359738796768 || (z[2] == 16051339359738796768 && (z[1] < 4764498181658371330 || (z[1] == 4764498181658371330 && (z[0] < 8063698428123676673))))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 8063698428123676673, 0) + z[1], b = bits.Sub64(z[1], 4764498181658371330, b) + z[2], b = bits.Sub64(z[2], 16051339359738796768, b) + z[3], b = bits.Sub64(z[3], 15273757526516850351, b) + z[4], _ = bits.Sub64(z[4], 342900304943437392, b) + } +} + +func _mulWGeneric(z, x *Element, y uint64) { + + var t [5]uint64 + { + // round 0 + c1, c0 := bits.Mul64(y, x[0]) + m := c0 * 8083954730842193919 + c2 := madd0(m, 8063698428123676673, c0) + c1, c0 = madd1(y, x[1], c1) + c2, t[0] = madd2(m, 4764498181658371330, c2, c0) + c1, c0 = madd1(y, x[2], c1) + c2, t[1] = madd2(m, 16051339359738796768, c2, c0) + c1, c0 = madd1(y, x[3], c1) + c2, t[2] = madd2(m, 15273757526516850351, c2, c0) + c1, c0 = madd1(y, x[4], c1) + t[4], t[3] = madd3(m, 342900304943437392, c0, c2, c1) + } + { + // round 1 + m := t[0] * 8083954730842193919 + c2 := madd0(m, 8063698428123676673, t[0]) + c2, t[0] = madd2(m, 4764498181658371330, c2, t[1]) + c2, t[1] = madd2(m, 16051339359738796768, c2, t[2]) + c2, t[2] = madd2(m, 15273757526516850351, c2, t[3]) + t[4], t[3] = madd2(m, 342900304943437392, t[4], c2) + } + { + // round 2 + m := t[0] * 8083954730842193919 + c2 := madd0(m, 8063698428123676673, t[0]) + c2, t[0] = madd2(m, 4764498181658371330, c2, t[1]) + c2, t[1] = madd2(m, 16051339359738796768, c2, t[2]) + c2, t[2] = madd2(m, 15273757526516850351, c2, t[3]) + t[4], t[3] = madd2(m, 342900304943437392, t[4], c2) + } + { + // round 3 + m := t[0] * 8083954730842193919 + c2 := madd0(m, 8063698428123676673, t[0]) + c2, t[0] = madd2(m, 4764498181658371330, c2, t[1]) + c2, t[1] = madd2(m, 16051339359738796768, c2, t[2]) + c2, t[2] = madd2(m, 15273757526516850351, c2, t[3]) + t[4], t[3] = madd2(m, 342900304943437392, t[4], c2) + } + { + // round 4 + m := t[0] * 8083954730842193919 + c2 := madd0(m, 8063698428123676673, t[0]) + c2, z[0] = madd2(m, 4764498181658371330, c2, t[1]) + c2, z[1] = madd2(m, 16051339359738796768, c2, t[2]) + c2, z[2] = madd2(m, 15273757526516850351, c2, t[3]) + z[4], z[3] = madd2(m, 342900304943437392, t[4], c2) + } + + // if z > q → z -= q // note: this is NOT constant time if !(z[4] < 342900304943437392 || (z[4] == 342900304943437392 && (z[3] < 15273757526516850351 || (z[3] == 15273757526516850351 && (z[2] < 16051339359738796768 || (z[2] == 16051339359738796768 && (z[1] < 4764498181658371330 || (z[1] == 4764498181658371330 && (z[0] < 8063698428123676673))))))))) { var b uint64 @@ -509,7 +584,7 @@ func _fromMontGeneric(z *Element) { z[4] = C } - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[4] < 342900304943437392 || (z[4] == 342900304943437392 && (z[3] < 15273757526516850351 || (z[3] == 15273757526516850351 && (z[2] < 16051339359738796768 || (z[2] == 16051339359738796768 && (z[1] < 4764498181658371330 || (z[1] == 4764498181658371330 && (z[0] < 8063698428123676673))))))))) { var b uint64 @@ -530,7 +605,7 @@ func _addGeneric(z, x, y *Element) { z[3], carry = bits.Add64(x[3], y[3], carry) z[4], _ = bits.Add64(x[4], y[4], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[4] < 342900304943437392 || (z[4] == 342900304943437392 && (z[3] < 15273757526516850351 || (z[3] == 15273757526516850351 && (z[2] < 16051339359738796768 || (z[2] == 16051339359738796768 && (z[1] < 4764498181658371330 || (z[1] == 4764498181658371330 && (z[0] < 8063698428123676673))))))))) { var b uint64 @@ -551,7 +626,7 @@ func _doubleGeneric(z, x *Element) { z[3], carry = bits.Add64(x[3], x[3], carry) z[4], _ = bits.Add64(x[4], x[4], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[4] < 342900304943437392 || (z[4] == 342900304943437392 && (z[3] < 15273757526516850351 || (z[3] == 15273757526516850351 && (z[2] < 16051339359738796768 || (z[2] == 16051339359738796768 && (z[1] < 4764498181658371330 || (z[1] == 4764498181658371330 && (z[0] < 8063698428123676673))))))))) { var b uint64 @@ -595,7 +670,7 @@ func _negGeneric(z, x *Element) { func _reduceGeneric(z *Element) { - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[4] < 342900304943437392 || (z[4] == 342900304943437392 && (z[3] < 15273757526516850351 || (z[3] == 15273757526516850351 && (z[2] < 16051339359738796768 || (z[2] == 16051339359738796768 && (z[1] < 4764498181658371330 || (z[1] == 4764498181658371330 && (z[0] < 8063698428123676673))))))))) { var b uint64 @@ -707,7 +782,7 @@ func (z *Element) Exp(x Element, exponent *big.Int) *Element { } // ToMont converts z to Montgomery form -// sets and returns z = z * r^2 +// sets and returns z = z * r² func (z *Element) ToMont() *Element { return z.Mul(z, &rSquare) } @@ -839,7 +914,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { return z } -// setBigInt assumes 0 <= v < q +// setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() @@ -1026,167 +1101,456 @@ func (z *Element) Sqrt(x *Element) *Element { } } -// Inverse z = x^-1 mod q -// Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" -// if x == 0, sets and returns z = x +func max(a int, b int) int { + if a > b { + return a + } + return b +} + +func min(a int, b int) int { + if a < b { + return a + } + return b +} + +const updateFactorsConversionBias int64 = 0x7fffffff7fffffff // (2³¹ - 1)(2³² + 1) +const updateFactorIdentityMatrixRow0 = 1 +const updateFactorIdentityMatrixRow1 = 1 << 32 + +func updateFactorsDecompose(c int64) (int64, int64) { + c += updateFactorsConversionBias + const low32BitsFilter int64 = 0xFFFFFFFF + f := c&low32BitsFilter - 0x7FFFFFFF + g := c>>32&low32BitsFilter - 0x7FFFFFFF + return f, g +} + +const k = 32 // word size / 2 +const signBitSelector = uint64(1) << 63 +const approxLowBitsN = k - 1 +const approxHighBitsN = k + 1 +const inversionCorrectionFactorWord0 = 13359241550610159594 +const inversionCorrectionFactorWord1 = 7624632887220174691 +const inversionCorrectionFactorWord2 = 6412344873752403825 +const inversionCorrectionFactorWord3 = 11214014560053792263 +const inversionCorrectionFactorWord4 = 75428258669939399 + +const invIterationsN = 22 + +// Inverse z = x⁻¹ mod q +// Implements "Optimized Binary GCD for Modular Inversion" +// https://github.com/pornin/bingcd/blob/main/doc/bingcd.pdf func (z *Element) Inverse(x *Element) *Element { if x.IsZero() { z.SetZero() return z } - // initialize u = q - var u = Element{ - 8063698428123676673, - 4764498181658371330, - 16051339359738796768, - 15273757526516850351, - 342900304943437392, + a := *x + b := Element{ + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, + qElementWord4, + } // b := q + + u := Element{1} + + // Update factors: we get [u; v]:= [f0 g0; f1 g1] [u; v] + // c_i = f_i + 2³¹ - 1 + 2³² * (g_i + 2³¹ - 1) + var c0, c1 int64 + + // Saved update factors to reduce the number of field multiplications + var pf0, pf1, pg0, pg1 int64 + + var i uint + + var v, s Element + + // Since u,v are updated every other iteration, we must make sure we terminate after evenly many iterations + // This also lets us get away with half as many updates to u,v + // To make this constant-time-ish, replace the condition with i < invIterationsN + for i = 0; i&1 == 1 || !a.IsZero(); i++ { + n := max(a.BitLen(), b.BitLen()) + aApprox, bApprox := approximate(&a, n), approximate(&b, n) + + // After 0 iterations, we have f₀ ≤ 2⁰ and f₁ < 2⁰ + // f0, g0, f1, g1 = 1, 0, 0, 1 + c0, c1 = updateFactorIdentityMatrixRow0, updateFactorIdentityMatrixRow1 + + for j := 0; j < approxLowBitsN; j++ { + + if aApprox&1 == 0 { + aApprox /= 2 + } else { + s, borrow := bits.Sub64(aApprox, bApprox, 0) + if borrow == 1 { + s = bApprox - aApprox + bApprox = aApprox + c0, c1 = c1, c0 + } + + aApprox = s / 2 + c0 = c0 - c1 + + // Now |f₀| < 2ʲ + 2ʲ = 2ʲ⁺¹ + // |f₁| ≤ 2ʲ still + } + + c1 *= 2 + // |f₁| ≤ 2ʲ⁺¹ + } + + s = a + + var g0 int64 + // from this point on c0 aliases for f0 + c0, g0 = updateFactorsDecompose(c0) + aHi := a.linearCombNonModular(&s, c0, &b, g0) + if aHi&signBitSelector != 0 { + // if aHi < 0 + c0, g0 = -c0, -g0 + aHi = a.neg(&a, aHi) + } + // right-shift a by k-1 bits + a[0] = (a[0] >> approxLowBitsN) | ((a[1]) << approxHighBitsN) + a[1] = (a[1] >> approxLowBitsN) | ((a[2]) << approxHighBitsN) + a[2] = (a[2] >> approxLowBitsN) | ((a[3]) << approxHighBitsN) + a[3] = (a[3] >> approxLowBitsN) | ((a[4]) << approxHighBitsN) + a[4] = (a[4] >> approxLowBitsN) | (aHi << approxHighBitsN) + + var f1 int64 + // from this point on c1 aliases for g0 + f1, c1 = updateFactorsDecompose(c1) + bHi := b.linearCombNonModular(&s, f1, &b, c1) + if bHi&signBitSelector != 0 { + // if bHi < 0 + f1, c1 = -f1, -c1 + bHi = b.neg(&b, bHi) + } + // right-shift b by k-1 bits + b[0] = (b[0] >> approxLowBitsN) | ((b[1]) << approxHighBitsN) + b[1] = (b[1] >> approxLowBitsN) | ((b[2]) << approxHighBitsN) + b[2] = (b[2] >> approxLowBitsN) | ((b[3]) << approxHighBitsN) + b[3] = (b[3] >> approxLowBitsN) | ((b[4]) << approxHighBitsN) + b[4] = (b[4] >> approxLowBitsN) | (bHi << approxHighBitsN) + + if i&1 == 1 { + // Combine current update factors with previously stored ones + // [f₀, g₀; f₁, g₁] ← [f₀, g₀; f₁, g₀] [pf₀, pg₀; pf₀, pg₀] + // We have |f₀|, |g₀|, |pf₀|, |pf₁| ≤ 2ᵏ⁻¹, and that |pf_i| < 2ᵏ⁻¹ for i ∈ {0, 1} + // Then for the new value we get |f₀| < 2ᵏ⁻¹ × 2ᵏ⁻¹ + 2ᵏ⁻¹ × 2ᵏ⁻¹ = 2²ᵏ⁻¹ + // Which leaves us with an extra bit for the sign + + // c0 aliases f0, c1 aliases g1 + c0, g0, f1, c1 = c0*pf0+g0*pf1, + c0*pg0+g0*pg1, + f1*pf0+c1*pf1, + f1*pg0+c1*pg1 + + s = u + u.linearCombSosSigned(&u, c0, &v, g0) + v.linearCombSosSigned(&s, f1, &v, c1) + + } else { + // Save update factors + pf0, pg0, pf1, pg1 = c0, g0, f1, c1 + } } - // initialize s = r^2 - var s = Element{ - 7746605402484284438, - 6457291528853138485, - 14067144135019420374, - 14705958577488011058, - 150264569250089173, + // For every iteration that we miss, v is not being multiplied by 2²ᵏ⁻² + const pSq int64 = 1 << (2 * (k - 1)) + // If the function is constant-time ish, this loop will not run (probably no need to take it out explicitly) + for ; i < invIterationsN; i += 2 { + v.mulWSigned(&v, pSq) } - // r = 0 - r := Element{} + z.Mul(&v, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + }) + return z +} - v := *x +// approximate a big number x into a single 64 bit word using its uppermost and lowermost bits +// if x fits in a word as is, no approximation necessary +func approximate(x *Element, nBits int) uint64 { - var carry, borrow uint64 - var bigger bool + if nBits <= 64 { + return x[0] + } - for { - for v[0]&1 == 0 { + const mask = (uint64(1) << (k - 1)) - 1 // k-1 ones + lo := mask & x[0] - // v = v >> 1 + hiWordIndex := (nBits - 1) / 64 - v[0] = v[0]>>1 | v[1]<<63 - v[1] = v[1]>>1 | v[2]<<63 - v[2] = v[2]>>1 | v[3]<<63 - v[3] = v[3]>>1 | v[4]<<63 - v[4] >>= 1 + hiWordBitsAvailable := nBits - hiWordIndex*64 + hiWordBitsUsed := min(hiWordBitsAvailable, approxHighBitsN) - if s[0]&1 == 1 { + mask_ := uint64(^((1 << (hiWordBitsAvailable - hiWordBitsUsed)) - 1)) + hi := (x[hiWordIndex] & mask_) << (64 - hiWordBitsAvailable) - // s = s + q - s[0], carry = bits.Add64(s[0], 8063698428123676673, 0) - s[1], carry = bits.Add64(s[1], 4764498181658371330, carry) - s[2], carry = bits.Add64(s[2], 16051339359738796768, carry) - s[3], carry = bits.Add64(s[3], 15273757526516850351, carry) - s[4], _ = bits.Add64(s[4], 342900304943437392, carry) + mask_ = ^(1<<(approxLowBitsN+hiWordBitsUsed) - 1) + mid := (mask_ & x[hiWordIndex-1]) >> hiWordBitsUsed - } + return lo | mid | hi +} - // s = s >> 1 +func (z *Element) linearCombSosSigned(x *Element, xC int64, y *Element, yC int64) { + hi := z.linearCombNonModular(x, xC, y, yC) + z.montReduceSigned(z, hi) +} - s[0] = s[0]>>1 | s[1]<<63 - s[1] = s[1]>>1 | s[2]<<63 - s[2] = s[2]>>1 | s[3]<<63 - s[3] = s[3]>>1 | s[4]<<63 - s[4] >>= 1 +// montReduceSigned SOS algorithm; xHi must be at most 63 bits long. Last bit of xHi may be used as a sign bit +func (z *Element) montReduceSigned(x *Element, xHi uint64) { - } - for u[0]&1 == 0 { + const signBitRemover = ^signBitSelector + neg := xHi&signBitSelector != 0 + // the SOS implementation requires that most significant bit is 0 + // Let X be xHi*r + x + // note that if X is negative we would have initially stored it as 2⁶⁴ r + X + xHi &= signBitRemover + // with this a negative X is now represented as 2⁶³ r + X - // u = u >> 1 + var t [2*Limbs - 1]uint64 + var C uint64 - u[0] = u[0]>>1 | u[1]<<63 - u[1] = u[1]>>1 | u[2]<<63 - u[2] = u[2]>>1 | u[3]<<63 - u[3] = u[3]>>1 | u[4]<<63 - u[4] >>= 1 + m := x[0] * qInvNegLsw - if r[0]&1 == 1 { + C = madd0(m, qElementWord0, x[0]) + C, t[1] = madd2(m, qElementWord1, x[1], C) + C, t[2] = madd2(m, qElementWord2, x[2], C) + C, t[3] = madd2(m, qElementWord3, x[3], C) + C, t[4] = madd2(m, qElementWord4, x[4], C) - // r = r + q - r[0], carry = bits.Add64(r[0], 8063698428123676673, 0) - r[1], carry = bits.Add64(r[1], 4764498181658371330, carry) - r[2], carry = bits.Add64(r[2], 16051339359738796768, carry) - r[3], carry = bits.Add64(r[3], 15273757526516850351, carry) - r[4], _ = bits.Add64(r[4], 342900304943437392, carry) + // the high word of m * qElement[4] is at most 62 bits + // x[4] + C is at most 65 bits (high word at most 1 bit) + // Thus the resulting C will be at most 63 bits + t[5] = xHi + C + // xHi and C are 63 bits, therefore no overflow - } + { + const i = 1 + m = t[i] * qInvNegLsw - // r = r >> 1 + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) - r[0] = r[0]>>1 | r[1]<<63 - r[1] = r[1]>>1 | r[2]<<63 - r[2] = r[2]>>1 | r[3]<<63 - r[3] = r[3]>>1 | r[4]<<63 - r[4] >>= 1 + t[i+Limbs] += C + } + { + const i = 2 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + + t[i+Limbs] += C + } + { + const i = 3 + m = t[i] * qInvNegLsw + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + + t[i+Limbs] += C + } + { + const i = 4 + m := t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, z[0] = madd2(m, qElementWord1, t[i+1], C) + C, z[1] = madd2(m, qElementWord2, t[i+2], C) + C, z[2] = madd2(m, qElementWord3, t[i+3], C) + z[4], z[3] = madd2(m, qElementWord4, t[i+4], C) + } + + // if z > q → z -= q + // note: this is NOT constant time + if !(z[4] < 342900304943437392 || (z[4] == 342900304943437392 && (z[3] < 15273757526516850351 || (z[3] == 15273757526516850351 && (z[2] < 16051339359738796768 || (z[2] == 16051339359738796768 && (z[1] < 4764498181658371330 || (z[1] == 4764498181658371330 && (z[0] < 8063698428123676673))))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 8063698428123676673, 0) + z[1], b = bits.Sub64(z[1], 4764498181658371330, b) + z[2], b = bits.Sub64(z[2], 16051339359738796768, b) + z[3], b = bits.Sub64(z[3], 15273757526516850351, b) + z[4], _ = bits.Sub64(z[4], 342900304943437392, b) + } + if neg { + // We have computed ( 2⁶³ r + X ) r⁻¹ = 2⁶³ + X r⁻¹ instead + var b uint64 + z[0], b = bits.Sub64(z[0], signBitSelector, 0) + z[1], b = bits.Sub64(z[1], 0, b) + z[2], b = bits.Sub64(z[2], 0, b) + z[3], b = bits.Sub64(z[3], 0, b) + z[4], b = bits.Sub64(z[4], 0, b) + + // Occurs iff x == 0 && xHi < 0, i.e. X = rX' for -2⁶³ ≤ X' < 0 + if b != 0 { + // z[4] = -1 + // negative: add q + const neg1 = 0xFFFFFFFFFFFFFFFF + + b = 0 + z[0], b = bits.Add64(z[0], qElementWord0, b) + z[1], b = bits.Add64(z[1], qElementWord1, b) + z[2], b = bits.Add64(z[2], qElementWord2, b) + z[3], b = bits.Add64(z[3], qElementWord3, b) + z[4], _ = bits.Add64(neg1, qElementWord4, b) } + } +} - // v >= u - bigger = !(v[4] < u[4] || (v[4] == u[4] && (v[3] < u[3] || (v[3] == u[3] && (v[2] < u[2] || (v[2] == u[2] && (v[1] < u[1] || (v[1] == u[1] && (v[0] < u[0]))))))))) +// mulWSigned mul word signed (w/ montgomery reduction) +func (z *Element) mulWSigned(x *Element, y int64) { + m := y >> 63 + _mulWGeneric(z, x, uint64((y^m)-m)) + // multiply by abs(y) + if y < 0 { + z.Neg(z) + } +} - if bigger { +func (z *Element) neg(x *Element, xHi uint64) uint64 { + var b uint64 - // v = v - u - v[0], borrow = bits.Sub64(v[0], u[0], 0) - v[1], borrow = bits.Sub64(v[1], u[1], borrow) - v[2], borrow = bits.Sub64(v[2], u[2], borrow) - v[3], borrow = bits.Sub64(v[3], u[3], borrow) - v[4], _ = bits.Sub64(v[4], u[4], borrow) + z[0], b = bits.Sub64(0, x[0], 0) + z[1], b = bits.Sub64(0, x[1], b) + z[2], b = bits.Sub64(0, x[2], b) + z[3], b = bits.Sub64(0, x[3], b) + z[4], b = bits.Sub64(0, x[4], b) + xHi, _ = bits.Sub64(0, xHi, b) - // s = s - r - s[0], borrow = bits.Sub64(s[0], r[0], 0) - s[1], borrow = bits.Sub64(s[1], r[1], borrow) - s[2], borrow = bits.Sub64(s[2], r[2], borrow) - s[3], borrow = bits.Sub64(s[3], r[3], borrow) - s[4], borrow = bits.Sub64(s[4], r[4], borrow) + return xHi +} - if borrow == 1 { +// regular multiplication by one word regular (non montgomery) +// Fewer additions than the branch-free for positive y. Could be faster on some architectures +func (z *Element) mulWRegular(x *Element, y int64) uint64 { - // s = s + q - s[0], carry = bits.Add64(s[0], 8063698428123676673, 0) - s[1], carry = bits.Add64(s[1], 4764498181658371330, carry) - s[2], carry = bits.Add64(s[2], 16051339359738796768, carry) - s[3], carry = bits.Add64(s[3], 15273757526516850351, carry) - s[4], _ = bits.Add64(s[4], 342900304943437392, carry) + // w := abs(y) + m := y >> 63 + w := uint64((y ^ m) - m) - } - } else { + var c uint64 + c, z[0] = bits.Mul64(x[0], w) + c, z[1] = madd1(x[1], w, c) + c, z[2] = madd1(x[2], w, c) + c, z[3] = madd1(x[3], w, c) + c, z[4] = madd1(x[4], w, c) - // u = u - v - u[0], borrow = bits.Sub64(u[0], v[0], 0) - u[1], borrow = bits.Sub64(u[1], v[1], borrow) - u[2], borrow = bits.Sub64(u[2], v[2], borrow) - u[3], borrow = bits.Sub64(u[3], v[3], borrow) - u[4], _ = bits.Sub64(u[4], v[4], borrow) - - // r = r - s - r[0], borrow = bits.Sub64(r[0], s[0], 0) - r[1], borrow = bits.Sub64(r[1], s[1], borrow) - r[2], borrow = bits.Sub64(r[2], s[2], borrow) - r[3], borrow = bits.Sub64(r[3], s[3], borrow) - r[4], borrow = bits.Sub64(r[4], s[4], borrow) - - if borrow == 1 { - - // r = r + q - r[0], carry = bits.Add64(r[0], 8063698428123676673, 0) - r[1], carry = bits.Add64(r[1], 4764498181658371330, carry) - r[2], carry = bits.Add64(r[2], 16051339359738796768, carry) - r[3], carry = bits.Add64(r[3], 15273757526516850351, carry) - r[4], _ = bits.Add64(r[4], 342900304943437392, carry) + if y < 0 { + c = z.neg(z, c) + } - } + return c +} + +/* +Removed: seems slower +// mulWRegular branch-free regular multiplication by one word (non montgomery) +func (z *Element) mulWRegularBf(x *Element, y int64) uint64 { + + w := uint64(y) + allNeg := uint64(y >> 63) // -1 if y < 0, 0 o.w + + // s[0], s[1] so results are not stored immediately in z. + // x[i] will be needed in the i+1 th iteration. We don't want to overwrite it in case x = z + var s [2]uint64 + var h [2]uint64 + + h[0], s[0] = bits.Mul64(x[0], w) + + c := uint64(0) + b := uint64(0) + + { + const curI = 1 % 2 + const prevI = 1 - curI + const iMinusOne = 1 - 1 + + h[curI], s[curI] = bits.Mul64(x[1], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } + + { + const curI = 2 % 2 + const prevI = 1 - curI + const iMinusOne = 2 - 1 + + h[curI], s[curI] = bits.Mul64(x[2], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (u[0] == 1) && (u[4]|u[3]|u[2]|u[1]) == 0 { - z.Set(&r) - return z + + { + const curI = 3 % 2 + const prevI = 1 - curI + const iMinusOne = 3 - 1 + + h[curI], s[curI] = bits.Mul64(x[3], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (v[0] == 1) && (v[4]|v[3]|v[2]|v[1]) == 0 { - z.Set(&s) - return z + + { + const curI = 4 % 2 + const prevI = 1 - curI + const iMinusOne = 4 - 1 + + h[curI], s[curI] = bits.Mul64(x[4], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } + { + const curI = 5 % 2 + const prevI = 1 - curI + const iMinusOne = 4 + + s[curI], _ = bits.Sub64(h[prevI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + + return s[curI] + c } +}*/ + +// Requires NoCarry +func (z *Element) linearCombNonModular(x *Element, xC int64, y *Element, yC int64) uint64 { + var yTimes Element + + yHi := yTimes.mulWRegular(y, yC) + xHi := z.mulWRegular(x, xC) + + carry := uint64(0) + z[0], carry = bits.Add64(z[0], yTimes[0], carry) + z[1], carry = bits.Add64(z[1], yTimes[1], carry) + z[2], carry = bits.Add64(z[2], yTimes[2], carry) + z[3], carry = bits.Add64(z[3], yTimes[3], carry) + z[4], carry = bits.Add64(z[4], yTimes[4], carry) + + yHi, _ = bits.Add64(xHi, yHi, carry) + return yHi } diff --git a/ecc/bw6-633/fr/element_test.go b/ecc/bw6-633/fr/element_test.go index ab427a1cbc..8df0fe87db 100644 --- a/ecc/bw6-633/fr/element_test.go +++ b/ecc/bw6-633/fr/element_test.go @@ -22,6 +22,7 @@ import ( "fmt" "math/big" "math/bits" + mrand "math/rand" "testing" "github.com/leanovate/gopter" @@ -273,7 +274,7 @@ var staticTestValues []Element func init() { staticTestValues = append(staticTestValues, Element{}) // zero staticTestValues = append(staticTestValues, One()) // one - staticTestValues = append(staticTestValues, rSquare) // r^2 + staticTestValues = append(staticTestValues, rSquare) // r² var e, one Element one.SetOne() e.Sub(&qElement, &one) @@ -1976,3 +1977,502 @@ func genFull() gopter.Gen { return genResult } } + +func TestElementInversionApproximation(t *testing.T) { + var x Element + for i := 0; i < 1000; i++ { + x.SetRandom() + + // Normally small elements are unlikely. Here we give them a higher chance + xZeros := mrand.Int() % Limbs + for j := 1; j < xZeros; j++ { + x[Limbs-j] = 0 + } + + a := approximate(&x, x.BitLen()) + aRef := approximateRef(&x) + + if a != aRef { + t.Error("Approximation mismatch") + } + } +} + +func TestElementInversionCorrectionFactorFormula(t *testing.T) { + const kLimbs = k * Limbs + const power = kLimbs*6 + invIterationsN*(kLimbs-k+1) + factorInt := big.NewInt(1) + factorInt.Lsh(factorInt, power) + factorInt.Mod(factorInt, Modulus()) + + var refFactorInt big.Int + inversionCorrectionFactor := Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + } + inversionCorrectionFactor.ToBigInt(&refFactorInt) + + if refFactorInt.Cmp(factorInt) != 0 { + t.Error("mismatch") + } +} + +func TestElementLinearComb(t *testing.T) { + var x Element + var y Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + y.SetRandom() + testLinearComb(t, &x, mrand.Int63(), &y, mrand.Int63()) + } +} + +// Probably unnecessary post-dev. In case the output of inv is wrong, this checks whether it's only off by a constant factor. +func TestElementInversionCorrectionFactor(t *testing.T) { + + // (1/x)/inv(x) = (1/1)/inv(1) ⇔ inv(1) = x inv(x) + + var one Element + var oneInv Element + one.SetOne() + oneInv.Inverse(&one) + + for i := 0; i < 100; i++ { + var x Element + var xInv Element + x.SetRandom() + xInv.Inverse(&x) + + x.Mul(&x, &xInv) + if !x.Equal(&oneInv) { + t.Error("Correction factor is inconsistent") + } + } + + if !oneInv.Equal(&one) { + var i big.Int + oneInv.ToBigIntRegular(&i) // no montgomery + i.ModInverse(&i, Modulus()) + var fac Element + fac.setBigInt(&i) // back to montgomery + + var facTimesFac Element + facTimesFac.Mul(&fac, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + }) + + t.Error("Correction factor is consistently off by", fac, "Should be", facTimesFac) + } +} + +func TestElementBigNumNeg(t *testing.T) { + var a Element + aHi := a.neg(&a, 0) + if !a.IsZero() || aHi != 0 { + t.Error("-0 != 0") + } +} + +func TestElementBigNumWMul(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + w := mrand.Int63() + testBigNumWMul(t, &x, w) + } +} + +func TestElementVeryBigIntConversion(t *testing.T) { + xHi := mrand.Uint64() + var x Element + x.SetRandom() + var xInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + x.assertMatchVeryBigInt(t, xHi, &xInt) +} + +func TestElementMontReducePos(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64() & ^signBitSelector) + } +} + +func TestElementMontReduceNeg(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64()|signBitSelector) + } +} + +func TestElementMontNegMultipleOfR(t *testing.T) { + var zero Element + + for i := 0; i < 1000; i++ { + testMontReduceSigned(t, &zero, mrand.Uint64()|signBitSelector) + } +} + +//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +func TestUpdateFactorSubtraction(t *testing.T) { + for i := 0; i < 1000; i++ { + + f0, g0 := randomizeUpdateFactors() + f1, g1 := randomizeUpdateFactors() + + for f0-f1 > 1<<31 || f0-f1 <= -1<<31 { + f1 /= 2 + } + + for g0-g1 > 1<<31 || g0-g1 <= -1<<31 { + g1 /= 2 + } + + c0 := updateFactorsCompose(f0, g0) + c1 := updateFactorsCompose(f1, g1) + + cRes := c0 - c1 + fRes, gRes := updateFactorsDecompose(cRes) + + if fRes != f0-f1 || gRes != g0-g1 { + t.Error(i) + } + } +} + +func TestUpdateFactorsDouble(t *testing.T) { + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f > 1<<30 || f < (-1<<31+1)/2 { + f /= 2 + if g <= 1<<29 && g >= (-1<<31+1)/4 { + g *= 2 //g was kept small on f's account. Now that we're halving f, we can double g + } + } + + if g > 1<<30 || g < (-1<<31+1)/2 { + g /= 2 + + if f <= 1<<29 && f >= (-1<<31+1)/4 { + f *= 2 //f was kept small on g's account. Now that we're halving g, we can double f + } + } + + c := updateFactorsCompose(f, g) + cD := c * 2 + fD, gD := updateFactorsDecompose(cD) + + if fD != 2*f || gD != 2*g { + t.Error(i) + } + } +} + +func TestUpdateFactorsNeg(t *testing.T) { + var fMistake bool + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f == 0x80000000 || g == 0x80000000 { + // Update factors this large can only have been obtained after 31 iterations and will therefore never be negated + // We don't have capacity to store -2³¹ + // Repeat this iteration + i-- + continue + } + + c := updateFactorsCompose(f, g) + nc := -c + nf, ng := updateFactorsDecompose(nc) + fMistake = fMistake || nf != -f + if nf != -f || ng != -g { + t.Errorf("Mismatch iteration #%d:\n%d, %d ->\n %d -> %d ->\n %d, %d\n Inputs in hex: %X, %X", + i, f, g, c, nc, nf, ng, f, g) + } + } + if fMistake { + t.Error("Mistake with f detected") + } else { + t.Log("All good with f") + } +} + +func TestUpdateFactorsNeg0(t *testing.T) { + c := updateFactorsCompose(0, 0) + t.Logf("c(0,0) = %X", c) + cn := -c + + if c != cn { + t.Error("Negation of zero update factors should yield the same result.") + } +} + +func TestUpdateFactorDecomposition(t *testing.T) { + var negSeen bool + + for i := 0; i < 1000; i++ { + + f, g := randomizeUpdateFactors() + + if f <= -(1<<31) || f > 1<<31 { + t.Fatal("f out of range") + } + + negSeen = negSeen || f < 0 + + c := updateFactorsCompose(f, g) + + fBack, gBack := updateFactorsDecompose(c) + + if f != fBack || g != gBack { + t.Errorf("(%d, %d) -> %d -> (%d, %d)\n", f, g, c, fBack, gBack) + } + } + + if !negSeen { + t.Fatal("No negative f factors") + } +} + +func TestUpdateFactorInitialValues(t *testing.T) { + + f0, g0 := updateFactorsDecompose(updateFactorIdentityMatrixRow0) + f1, g1 := updateFactorsDecompose(updateFactorIdentityMatrixRow1) + + if f0 != 1 || g0 != 0 || f1 != 0 || g1 != 1 { + t.Error("Update factor initial value constants are incorrect") + } +} + +func TestUpdateFactorsRandomization(t *testing.T) { + var maxLen int + + //t.Log("|f| + |g| is not to exceed", 1 << 31) + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + lf, lg := abs64T32(f), abs64T32(g) + absSum := lf + lg + if absSum >= 1<<31 { + + if absSum == 1<<31 { + maxLen++ + } else { + t.Error(i, "Sum of absolute values too large, f =", f, ",g =", g, ",|f| + |g| =", absSum) + } + } + } + + if maxLen == 0 { + t.Error("max len not observed") + } else { + t.Log(maxLen, "maxLens observed") + } +} + +func randomizeUpdateFactor(absLimit uint32) int64 { + const maxSizeLikelihood = 10 + maxSize := mrand.Intn(maxSizeLikelihood) + + absLimit64 := int64(absLimit) + var f int64 + switch maxSize { + case 0: + f = absLimit64 + case 1: + f = -absLimit64 + default: + f = int64(mrand.Uint64()%(2*uint64(absLimit64)+1)) - absLimit64 + } + + if f > 1<<31 { + return 1 << 31 + } else if f < -1<<31+1 { + return -1<<31 + 1 + } + + return f +} + +func abs64T32(f int64) uint32 { + if f >= 1<<32 || f < -1<<32 { + panic("f out of range") + } + + if f < 0 { + return uint32(-f) + } + return uint32(f) +} + +func randomizeUpdateFactors() (int64, int64) { + var f [2]int64 + b := mrand.Int() % 2 + + f[b] = randomizeUpdateFactor(1 << 31) + + //As per the paper, |f| + |g| \le 2³¹. + f[1-b] = randomizeUpdateFactor(1<<31 - abs64T32(f[b])) + + //Patching another edge case + if f[0]+f[1] == -1<<31 { + b = mrand.Int() % 2 + f[b]++ + } + + return f[0], f[1] +} + +func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { + + var p1 big.Int + x.ToBigInt(&p1) + p1.Mul(&p1, big.NewInt(xC)) + + var p2 big.Int + y.ToBigInt(&p2) + p2.Mul(&p2, big.NewInt(yC)) + + p1.Add(&p1, &p2) + p1.Mod(&p1, Modulus()) + montReduce(&p1, &p1) + + var z Element + z.linearCombSosSigned(x, xC, y, yC) + z.assertMatchVeryBigInt(t, 0, &p1) +} + +func testBigNumWMul(t *testing.T, a *Element, c int64) { + var aHi uint64 + var aTimes Element + aHi = aTimes.mulWRegular(a, c) + + assertMulProduct(t, a, c, &aTimes, aHi) +} + +func testMontReduceSigned(t *testing.T, x *Element, xHi uint64) { + var res Element + var xInt big.Int + var resInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + res.montReduceSigned(x, xHi) + montReduce(&resInt, &xInt) + res.assertMatchVeryBigInt(t, 0, &resInt) +} + +func updateFactorsCompose(f int64, g int64) int64 { + return f + g<<32 +} + +var rInv big.Int + +func montReduce(res *big.Int, x *big.Int) { + if rInv.BitLen() == 0 { // initialization + rInv.SetUint64(1) + rInv.Lsh(&rInv, Limbs*64) + rInv.ModInverse(&rInv, Modulus()) + } + res.Mul(x, &rInv) + res.Mod(res, Modulus()) +} + +func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { + z.ToBigInt(i) + var upperWord big.Int + upperWord.SetUint64(xHi) + upperWord.Lsh(&upperWord, Limbs*64) + i.Add(&upperWord, i) +} + +func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { + z.toVeryBigIntUnsigned(i, xHi) + if signBitSelector&xHi != 0 { + twosCompModulus := big.NewInt(1) + twosCompModulus.Lsh(twosCompModulus, (Limbs+1)*64) + i.Sub(i, twosCompModulus) + } +} + +func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { + var xInt big.Int + x.ToBigInt(&xInt) + + xInt.Mul(&xInt, big.NewInt(c)) + + result.assertMatchVeryBigInt(t, resultHi, &xInt) + return xInt +} + +func assertMatch(t *testing.T, w []big.Word, a uint64, index int) { + + var wI big.Word + + if index < len(w) { + wI = w[index] + } + + const filter uint64 = 0xFFFFFFFFFFFFFFFF >> (64 - bits.UintSize) + + a = a >> ((index * bits.UintSize) % 64) + a &= filter + + if uint64(wI) != a { + t.Error("Bignum mismatch: disagreement on word", index) + } +} + +func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { + + var modulus big.Int + var aIntMod big.Int + modulus.SetInt64(1) + modulus.Lsh(&modulus, (Limbs+1)*64) + aIntMod.Mod(aInt, &modulus) + + words := aIntMod.Bits() + + const steps = 64 / bits.UintSize + for i := 0; i < Limbs*steps; i++ { + assertMatch(t, words, z[i/steps], i) + } + + for i := 0; i < steps; i++ { + assertMatch(t, words, aHi, Limbs*steps+i) + } +} + +func approximateRef(x *Element) uint64 { + + var asInt big.Int + x.ToBigInt(&asInt) + n := x.BitLen() + + if n <= 64 { + return asInt.Uint64() + } + + modulus := big.NewInt(1 << 31) + var lo big.Int + lo.Mod(&asInt, modulus) + + modulus.Lsh(modulus, uint(n-64)) + var hi big.Int + hi.Div(&asInt, modulus) + hi.Lsh(&hi, 31) + + hi.Add(&hi, &lo) + return hi.Uint64() +} diff --git a/ecc/bw6-761/fp/element.go b/ecc/bw6-761/fp/element.go index 953507c023..85b67fe769 100644 --- a/ecc/bw6-761/fp/element.go +++ b/ecc/bw6-761/fp/element.go @@ -63,21 +63,37 @@ func Modulus() *big.Int { } // q (modulus) +const qElementWord0 uint64 = 17626244516597989515 +const qElementWord1 uint64 = 16614129118623039618 +const qElementWord2 uint64 = 1588918198704579639 +const qElementWord3 uint64 = 10998096788944562424 +const qElementWord4 uint64 = 8204665564953313070 +const qElementWord5 uint64 = 9694500593442880912 +const qElementWord6 uint64 = 274362232328168196 +const qElementWord7 uint64 = 8105254717682411801 +const qElementWord8 uint64 = 5945444129596489281 +const qElementWord9 uint64 = 13341377791855249032 +const qElementWord10 uint64 = 15098257552581525310 +const qElementWord11 uint64 = 81882988782276106 + var qElement = Element{ - 17626244516597989515, - 16614129118623039618, - 1588918198704579639, - 10998096788944562424, - 8204665564953313070, - 9694500593442880912, - 274362232328168196, - 8105254717682411801, - 5945444129596489281, - 13341377791855249032, - 15098257552581525310, - 81882988782276106, + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, + qElementWord4, + qElementWord5, + qElementWord6, + qElementWord7, + qElementWord8, + qElementWord9, + qElementWord10, + qElementWord11, } +// Used for Montgomery reduction. (qInvNeg) q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +const qInvNegLsw uint64 = 744663313386281181 + // rSquare var rSquare = Element{ 14305184132582319705, @@ -227,7 +243,7 @@ func (z *Element) IsZero() bool { return (z[11] | z[10] | z[9] | z[8] | z[7] | z[6] | z[5] | z[4] | z[3] | z[2] | z[1] | z[0]) == 0 } -// IsUint64 returns true if z[0] >= 0 and all other words are 0 +// IsUint64 returns true if z[0] ⩾ 0 and all other words are 0 func (z *Element) IsUint64() bool { return (z[11] | z[10] | z[9] | z[8] | z[7] | z[6] | z[5] | z[4] | z[3] | z[2] | z[1]) == 0 } @@ -353,7 +369,7 @@ func (z *Element) SetRandom() (*Element, error) { z[11] = binary.BigEndian.Uint64(bytes[88:96]) z[11] %= 81882988782276106 - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[11] < 81882988782276106 || (z[11] == 81882988782276106 && (z[10] < 15098257552581525310 || (z[10] == 15098257552581525310 && (z[9] < 13341377791855249032 || (z[9] == 13341377791855249032 && (z[8] < 5945444129596489281 || (z[8] == 5945444129596489281 && (z[7] < 8105254717682411801 || (z[7] == 8105254717682411801 && (z[6] < 274362232328168196 || (z[6] == 274362232328168196 && (z[5] < 9694500593442880912 || (z[5] == 9694500593442880912 && (z[4] < 8204665564953313070 || (z[4] == 8204665564953313070 && (z[3] < 10998096788944562424 || (z[3] == 10998096788944562424 && (z[2] < 1588918198704579639 || (z[2] == 1588918198704579639 && (z[1] < 16614129118623039618 || (z[1] == 16614129118623039618 && (z[0] < 17626244516597989515))))))))))))))))))))))) { var b uint64 @@ -821,7 +837,234 @@ func _mulGeneric(z, x, y *Element) { z[11], z[10] = madd3(m, 81882988782276106, c[0], c[2], c[1]) } - // if z > q --> z -= q + // if z > q → z -= q + // note: this is NOT constant time + if !(z[11] < 81882988782276106 || (z[11] == 81882988782276106 && (z[10] < 15098257552581525310 || (z[10] == 15098257552581525310 && (z[9] < 13341377791855249032 || (z[9] == 13341377791855249032 && (z[8] < 5945444129596489281 || (z[8] == 5945444129596489281 && (z[7] < 8105254717682411801 || (z[7] == 8105254717682411801 && (z[6] < 274362232328168196 || (z[6] == 274362232328168196 && (z[5] < 9694500593442880912 || (z[5] == 9694500593442880912 && (z[4] < 8204665564953313070 || (z[4] == 8204665564953313070 && (z[3] < 10998096788944562424 || (z[3] == 10998096788944562424 && (z[2] < 1588918198704579639 || (z[2] == 1588918198704579639 && (z[1] < 16614129118623039618 || (z[1] == 16614129118623039618 && (z[0] < 17626244516597989515))))))))))))))))))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 17626244516597989515, 0) + z[1], b = bits.Sub64(z[1], 16614129118623039618, b) + z[2], b = bits.Sub64(z[2], 1588918198704579639, b) + z[3], b = bits.Sub64(z[3], 10998096788944562424, b) + z[4], b = bits.Sub64(z[4], 8204665564953313070, b) + z[5], b = bits.Sub64(z[5], 9694500593442880912, b) + z[6], b = bits.Sub64(z[6], 274362232328168196, b) + z[7], b = bits.Sub64(z[7], 8105254717682411801, b) + z[8], b = bits.Sub64(z[8], 5945444129596489281, b) + z[9], b = bits.Sub64(z[9], 13341377791855249032, b) + z[10], b = bits.Sub64(z[10], 15098257552581525310, b) + z[11], _ = bits.Sub64(z[11], 81882988782276106, b) + } +} + +func _mulWGeneric(z, x *Element, y uint64) { + + var t [12]uint64 + { + // round 0 + c1, c0 := bits.Mul64(y, x[0]) + m := c0 * 744663313386281181 + c2 := madd0(m, 17626244516597989515, c0) + c1, c0 = madd1(y, x[1], c1) + c2, t[0] = madd2(m, 16614129118623039618, c2, c0) + c1, c0 = madd1(y, x[2], c1) + c2, t[1] = madd2(m, 1588918198704579639, c2, c0) + c1, c0 = madd1(y, x[3], c1) + c2, t[2] = madd2(m, 10998096788944562424, c2, c0) + c1, c0 = madd1(y, x[4], c1) + c2, t[3] = madd2(m, 8204665564953313070, c2, c0) + c1, c0 = madd1(y, x[5], c1) + c2, t[4] = madd2(m, 9694500593442880912, c2, c0) + c1, c0 = madd1(y, x[6], c1) + c2, t[5] = madd2(m, 274362232328168196, c2, c0) + c1, c0 = madd1(y, x[7], c1) + c2, t[6] = madd2(m, 8105254717682411801, c2, c0) + c1, c0 = madd1(y, x[8], c1) + c2, t[7] = madd2(m, 5945444129596489281, c2, c0) + c1, c0 = madd1(y, x[9], c1) + c2, t[8] = madd2(m, 13341377791855249032, c2, c0) + c1, c0 = madd1(y, x[10], c1) + c2, t[9] = madd2(m, 15098257552581525310, c2, c0) + c1, c0 = madd1(y, x[11], c1) + t[11], t[10] = madd3(m, 81882988782276106, c0, c2, c1) + } + { + // round 1 + m := t[0] * 744663313386281181 + c2 := madd0(m, 17626244516597989515, t[0]) + c2, t[0] = madd2(m, 16614129118623039618, c2, t[1]) + c2, t[1] = madd2(m, 1588918198704579639, c2, t[2]) + c2, t[2] = madd2(m, 10998096788944562424, c2, t[3]) + c2, t[3] = madd2(m, 8204665564953313070, c2, t[4]) + c2, t[4] = madd2(m, 9694500593442880912, c2, t[5]) + c2, t[5] = madd2(m, 274362232328168196, c2, t[6]) + c2, t[6] = madd2(m, 8105254717682411801, c2, t[7]) + c2, t[7] = madd2(m, 5945444129596489281, c2, t[8]) + c2, t[8] = madd2(m, 13341377791855249032, c2, t[9]) + c2, t[9] = madd2(m, 15098257552581525310, c2, t[10]) + t[11], t[10] = madd2(m, 81882988782276106, t[11], c2) + } + { + // round 2 + m := t[0] * 744663313386281181 + c2 := madd0(m, 17626244516597989515, t[0]) + c2, t[0] = madd2(m, 16614129118623039618, c2, t[1]) + c2, t[1] = madd2(m, 1588918198704579639, c2, t[2]) + c2, t[2] = madd2(m, 10998096788944562424, c2, t[3]) + c2, t[3] = madd2(m, 8204665564953313070, c2, t[4]) + c2, t[4] = madd2(m, 9694500593442880912, c2, t[5]) + c2, t[5] = madd2(m, 274362232328168196, c2, t[6]) + c2, t[6] = madd2(m, 8105254717682411801, c2, t[7]) + c2, t[7] = madd2(m, 5945444129596489281, c2, t[8]) + c2, t[8] = madd2(m, 13341377791855249032, c2, t[9]) + c2, t[9] = madd2(m, 15098257552581525310, c2, t[10]) + t[11], t[10] = madd2(m, 81882988782276106, t[11], c2) + } + { + // round 3 + m := t[0] * 744663313386281181 + c2 := madd0(m, 17626244516597989515, t[0]) + c2, t[0] = madd2(m, 16614129118623039618, c2, t[1]) + c2, t[1] = madd2(m, 1588918198704579639, c2, t[2]) + c2, t[2] = madd2(m, 10998096788944562424, c2, t[3]) + c2, t[3] = madd2(m, 8204665564953313070, c2, t[4]) + c2, t[4] = madd2(m, 9694500593442880912, c2, t[5]) + c2, t[5] = madd2(m, 274362232328168196, c2, t[6]) + c2, t[6] = madd2(m, 8105254717682411801, c2, t[7]) + c2, t[7] = madd2(m, 5945444129596489281, c2, t[8]) + c2, t[8] = madd2(m, 13341377791855249032, c2, t[9]) + c2, t[9] = madd2(m, 15098257552581525310, c2, t[10]) + t[11], t[10] = madd2(m, 81882988782276106, t[11], c2) + } + { + // round 4 + m := t[0] * 744663313386281181 + c2 := madd0(m, 17626244516597989515, t[0]) + c2, t[0] = madd2(m, 16614129118623039618, c2, t[1]) + c2, t[1] = madd2(m, 1588918198704579639, c2, t[2]) + c2, t[2] = madd2(m, 10998096788944562424, c2, t[3]) + c2, t[3] = madd2(m, 8204665564953313070, c2, t[4]) + c2, t[4] = madd2(m, 9694500593442880912, c2, t[5]) + c2, t[5] = madd2(m, 274362232328168196, c2, t[6]) + c2, t[6] = madd2(m, 8105254717682411801, c2, t[7]) + c2, t[7] = madd2(m, 5945444129596489281, c2, t[8]) + c2, t[8] = madd2(m, 13341377791855249032, c2, t[9]) + c2, t[9] = madd2(m, 15098257552581525310, c2, t[10]) + t[11], t[10] = madd2(m, 81882988782276106, t[11], c2) + } + { + // round 5 + m := t[0] * 744663313386281181 + c2 := madd0(m, 17626244516597989515, t[0]) + c2, t[0] = madd2(m, 16614129118623039618, c2, t[1]) + c2, t[1] = madd2(m, 1588918198704579639, c2, t[2]) + c2, t[2] = madd2(m, 10998096788944562424, c2, t[3]) + c2, t[3] = madd2(m, 8204665564953313070, c2, t[4]) + c2, t[4] = madd2(m, 9694500593442880912, c2, t[5]) + c2, t[5] = madd2(m, 274362232328168196, c2, t[6]) + c2, t[6] = madd2(m, 8105254717682411801, c2, t[7]) + c2, t[7] = madd2(m, 5945444129596489281, c2, t[8]) + c2, t[8] = madd2(m, 13341377791855249032, c2, t[9]) + c2, t[9] = madd2(m, 15098257552581525310, c2, t[10]) + t[11], t[10] = madd2(m, 81882988782276106, t[11], c2) + } + { + // round 6 + m := t[0] * 744663313386281181 + c2 := madd0(m, 17626244516597989515, t[0]) + c2, t[0] = madd2(m, 16614129118623039618, c2, t[1]) + c2, t[1] = madd2(m, 1588918198704579639, c2, t[2]) + c2, t[2] = madd2(m, 10998096788944562424, c2, t[3]) + c2, t[3] = madd2(m, 8204665564953313070, c2, t[4]) + c2, t[4] = madd2(m, 9694500593442880912, c2, t[5]) + c2, t[5] = madd2(m, 274362232328168196, c2, t[6]) + c2, t[6] = madd2(m, 8105254717682411801, c2, t[7]) + c2, t[7] = madd2(m, 5945444129596489281, c2, t[8]) + c2, t[8] = madd2(m, 13341377791855249032, c2, t[9]) + c2, t[9] = madd2(m, 15098257552581525310, c2, t[10]) + t[11], t[10] = madd2(m, 81882988782276106, t[11], c2) + } + { + // round 7 + m := t[0] * 744663313386281181 + c2 := madd0(m, 17626244516597989515, t[0]) + c2, t[0] = madd2(m, 16614129118623039618, c2, t[1]) + c2, t[1] = madd2(m, 1588918198704579639, c2, t[2]) + c2, t[2] = madd2(m, 10998096788944562424, c2, t[3]) + c2, t[3] = madd2(m, 8204665564953313070, c2, t[4]) + c2, t[4] = madd2(m, 9694500593442880912, c2, t[5]) + c2, t[5] = madd2(m, 274362232328168196, c2, t[6]) + c2, t[6] = madd2(m, 8105254717682411801, c2, t[7]) + c2, t[7] = madd2(m, 5945444129596489281, c2, t[8]) + c2, t[8] = madd2(m, 13341377791855249032, c2, t[9]) + c2, t[9] = madd2(m, 15098257552581525310, c2, t[10]) + t[11], t[10] = madd2(m, 81882988782276106, t[11], c2) + } + { + // round 8 + m := t[0] * 744663313386281181 + c2 := madd0(m, 17626244516597989515, t[0]) + c2, t[0] = madd2(m, 16614129118623039618, c2, t[1]) + c2, t[1] = madd2(m, 1588918198704579639, c2, t[2]) + c2, t[2] = madd2(m, 10998096788944562424, c2, t[3]) + c2, t[3] = madd2(m, 8204665564953313070, c2, t[4]) + c2, t[4] = madd2(m, 9694500593442880912, c2, t[5]) + c2, t[5] = madd2(m, 274362232328168196, c2, t[6]) + c2, t[6] = madd2(m, 8105254717682411801, c2, t[7]) + c2, t[7] = madd2(m, 5945444129596489281, c2, t[8]) + c2, t[8] = madd2(m, 13341377791855249032, c2, t[9]) + c2, t[9] = madd2(m, 15098257552581525310, c2, t[10]) + t[11], t[10] = madd2(m, 81882988782276106, t[11], c2) + } + { + // round 9 + m := t[0] * 744663313386281181 + c2 := madd0(m, 17626244516597989515, t[0]) + c2, t[0] = madd2(m, 16614129118623039618, c2, t[1]) + c2, t[1] = madd2(m, 1588918198704579639, c2, t[2]) + c2, t[2] = madd2(m, 10998096788944562424, c2, t[3]) + c2, t[3] = madd2(m, 8204665564953313070, c2, t[4]) + c2, t[4] = madd2(m, 9694500593442880912, c2, t[5]) + c2, t[5] = madd2(m, 274362232328168196, c2, t[6]) + c2, t[6] = madd2(m, 8105254717682411801, c2, t[7]) + c2, t[7] = madd2(m, 5945444129596489281, c2, t[8]) + c2, t[8] = madd2(m, 13341377791855249032, c2, t[9]) + c2, t[9] = madd2(m, 15098257552581525310, c2, t[10]) + t[11], t[10] = madd2(m, 81882988782276106, t[11], c2) + } + { + // round 10 + m := t[0] * 744663313386281181 + c2 := madd0(m, 17626244516597989515, t[0]) + c2, t[0] = madd2(m, 16614129118623039618, c2, t[1]) + c2, t[1] = madd2(m, 1588918198704579639, c2, t[2]) + c2, t[2] = madd2(m, 10998096788944562424, c2, t[3]) + c2, t[3] = madd2(m, 8204665564953313070, c2, t[4]) + c2, t[4] = madd2(m, 9694500593442880912, c2, t[5]) + c2, t[5] = madd2(m, 274362232328168196, c2, t[6]) + c2, t[6] = madd2(m, 8105254717682411801, c2, t[7]) + c2, t[7] = madd2(m, 5945444129596489281, c2, t[8]) + c2, t[8] = madd2(m, 13341377791855249032, c2, t[9]) + c2, t[9] = madd2(m, 15098257552581525310, c2, t[10]) + t[11], t[10] = madd2(m, 81882988782276106, t[11], c2) + } + { + // round 11 + m := t[0] * 744663313386281181 + c2 := madd0(m, 17626244516597989515, t[0]) + c2, z[0] = madd2(m, 16614129118623039618, c2, t[1]) + c2, z[1] = madd2(m, 1588918198704579639, c2, t[2]) + c2, z[2] = madd2(m, 10998096788944562424, c2, t[3]) + c2, z[3] = madd2(m, 8204665564953313070, c2, t[4]) + c2, z[4] = madd2(m, 9694500593442880912, c2, t[5]) + c2, z[5] = madd2(m, 274362232328168196, c2, t[6]) + c2, z[6] = madd2(m, 8105254717682411801, c2, t[7]) + c2, z[7] = madd2(m, 5945444129596489281, c2, t[8]) + c2, z[8] = madd2(m, 13341377791855249032, c2, t[9]) + c2, z[9] = madd2(m, 15098257552581525310, c2, t[10]) + z[11], z[10] = madd2(m, 81882988782276106, t[11], c2) + } + + // if z > q → z -= q // note: this is NOT constant time if !(z[11] < 81882988782276106 || (z[11] == 81882988782276106 && (z[10] < 15098257552581525310 || (z[10] == 15098257552581525310 && (z[9] < 13341377791855249032 || (z[9] == 13341377791855249032 && (z[8] < 5945444129596489281 || (z[8] == 5945444129596489281 && (z[7] < 8105254717682411801 || (z[7] == 8105254717682411801 && (z[6] < 274362232328168196 || (z[6] == 274362232328168196 && (z[5] < 9694500593442880912 || (z[5] == 9694500593442880912 && (z[4] < 8204665564953313070 || (z[4] == 8204665564953313070 && (z[3] < 10998096788944562424 || (z[3] == 10998096788944562424 && (z[2] < 1588918198704579639 || (z[2] == 1588918198704579639 && (z[1] < 16614129118623039618 || (z[1] == 16614129118623039618 && (z[0] < 17626244516597989515))))))))))))))))))))))) { var b uint64 @@ -1048,7 +1291,7 @@ func _fromMontGeneric(z *Element) { z[11] = C } - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[11] < 81882988782276106 || (z[11] == 81882988782276106 && (z[10] < 15098257552581525310 || (z[10] == 15098257552581525310 && (z[9] < 13341377791855249032 || (z[9] == 13341377791855249032 && (z[8] < 5945444129596489281 || (z[8] == 5945444129596489281 && (z[7] < 8105254717682411801 || (z[7] == 8105254717682411801 && (z[6] < 274362232328168196 || (z[6] == 274362232328168196 && (z[5] < 9694500593442880912 || (z[5] == 9694500593442880912 && (z[4] < 8204665564953313070 || (z[4] == 8204665564953313070 && (z[3] < 10998096788944562424 || (z[3] == 10998096788944562424 && (z[2] < 1588918198704579639 || (z[2] == 1588918198704579639 && (z[1] < 16614129118623039618 || (z[1] == 16614129118623039618 && (z[0] < 17626244516597989515))))))))))))))))))))))) { var b uint64 @@ -1083,7 +1326,7 @@ func _addGeneric(z, x, y *Element) { z[10], carry = bits.Add64(x[10], y[10], carry) z[11], _ = bits.Add64(x[11], y[11], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[11] < 81882988782276106 || (z[11] == 81882988782276106 && (z[10] < 15098257552581525310 || (z[10] == 15098257552581525310 && (z[9] < 13341377791855249032 || (z[9] == 13341377791855249032 && (z[8] < 5945444129596489281 || (z[8] == 5945444129596489281 && (z[7] < 8105254717682411801 || (z[7] == 8105254717682411801 && (z[6] < 274362232328168196 || (z[6] == 274362232328168196 && (z[5] < 9694500593442880912 || (z[5] == 9694500593442880912 && (z[4] < 8204665564953313070 || (z[4] == 8204665564953313070 && (z[3] < 10998096788944562424 || (z[3] == 10998096788944562424 && (z[2] < 1588918198704579639 || (z[2] == 1588918198704579639 && (z[1] < 16614129118623039618 || (z[1] == 16614129118623039618 && (z[0] < 17626244516597989515))))))))))))))))))))))) { var b uint64 @@ -1118,7 +1361,7 @@ func _doubleGeneric(z, x *Element) { z[10], carry = bits.Add64(x[10], x[10], carry) z[11], _ = bits.Add64(x[11], x[11], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[11] < 81882988782276106 || (z[11] == 81882988782276106 && (z[10] < 15098257552581525310 || (z[10] == 15098257552581525310 && (z[9] < 13341377791855249032 || (z[9] == 13341377791855249032 && (z[8] < 5945444129596489281 || (z[8] == 5945444129596489281 && (z[7] < 8105254717682411801 || (z[7] == 8105254717682411801 && (z[6] < 274362232328168196 || (z[6] == 274362232328168196 && (z[5] < 9694500593442880912 || (z[5] == 9694500593442880912 && (z[4] < 8204665564953313070 || (z[4] == 8204665564953313070 && (z[3] < 10998096788944562424 || (z[3] == 10998096788944562424 && (z[2] < 1588918198704579639 || (z[2] == 1588918198704579639 && (z[1] < 16614129118623039618 || (z[1] == 16614129118623039618 && (z[0] < 17626244516597989515))))))))))))))))))))))) { var b uint64 @@ -1190,7 +1433,7 @@ func _negGeneric(z, x *Element) { func _reduceGeneric(z *Element) { - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[11] < 81882988782276106 || (z[11] == 81882988782276106 && (z[10] < 15098257552581525310 || (z[10] == 15098257552581525310 && (z[9] < 13341377791855249032 || (z[9] == 13341377791855249032 && (z[8] < 5945444129596489281 || (z[8] == 5945444129596489281 && (z[7] < 8105254717682411801 || (z[7] == 8105254717682411801 && (z[6] < 274362232328168196 || (z[6] == 274362232328168196 && (z[5] < 9694500593442880912 || (z[5] == 9694500593442880912 && (z[4] < 8204665564953313070 || (z[4] == 8204665564953313070 && (z[3] < 10998096788944562424 || (z[3] == 10998096788944562424 && (z[2] < 1588918198704579639 || (z[2] == 1588918198704579639 && (z[1] < 16614129118623039618 || (z[1] == 16614129118623039618 && (z[0] < 17626244516597989515))))))))))))))))))))))) { var b uint64 @@ -1330,7 +1573,7 @@ func (z *Element) Exp(x Element, exponent *big.Int) *Element { } // ToMont converts z to Montgomery form -// sets and returns z = z * r^2 +// sets and returns z = z * r² func (z *Element) ToMont() *Element { return z.Mul(z, &rSquare) } @@ -1476,7 +1719,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { return z } -// setBigInt assumes 0 <= v < q +// setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() @@ -1610,265 +1853,778 @@ func (z *Element) Sqrt(x *Element) *Element { return nil } -// Inverse z = x^-1 mod q -// Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" -// if x == 0, sets and returns z = x +func max(a int, b int) int { + if a > b { + return a + } + return b +} + +func min(a int, b int) int { + if a < b { + return a + } + return b +} + +const updateFactorsConversionBias int64 = 0x7fffffff7fffffff // (2³¹ - 1)(2³² + 1) +const updateFactorIdentityMatrixRow0 = 1 +const updateFactorIdentityMatrixRow1 = 1 << 32 + +func updateFactorsDecompose(c int64) (int64, int64) { + c += updateFactorsConversionBias + const low32BitsFilter int64 = 0xFFFFFFFF + f := c&low32BitsFilter - 0x7FFFFFFF + g := c>>32&low32BitsFilter - 0x7FFFFFFF + return f, g +} + +const k = 32 // word size / 2 +const signBitSelector = uint64(1) << 63 +const approxLowBitsN = k - 1 +const approxHighBitsN = k + 1 +const inversionCorrectionFactorWord0 = 10031381836020524396 +const inversionCorrectionFactorWord1 = 13512348327667036555 +const inversionCorrectionFactorWord2 = 4712458382768910368 +const inversionCorrectionFactorWord3 = 10608536430169864474 +const inversionCorrectionFactorWord4 = 6011516409926961524 +const inversionCorrectionFactorWord5 = 453532925360796333 +const inversionCorrectionFactorWord6 = 12814326068023107562 +const inversionCorrectionFactorWord7 = 4764844688547002673 +const inversionCorrectionFactorWord8 = 297975318568649638 +const inversionCorrectionFactorWord9 = 3076225984588847531 +const inversionCorrectionFactorWord10 = 327844855039329024 +const inversionCorrectionFactorWord11 = 65221078716978344 + +const invIterationsN = 50 + +// Inverse z = x⁻¹ mod q +// Implements "Optimized Binary GCD for Modular Inversion" +// https://github.com/pornin/bingcd/blob/main/doc/bingcd.pdf func (z *Element) Inverse(x *Element) *Element { if x.IsZero() { z.SetZero() return z } - // initialize u = q - var u = Element{ - 17626244516597989515, - 16614129118623039618, - 1588918198704579639, - 10998096788944562424, - 8204665564953313070, - 9694500593442880912, - 274362232328168196, - 8105254717682411801, - 5945444129596489281, - 13341377791855249032, - 15098257552581525310, - 81882988782276106, - } - - // initialize s = r^2 - var s = Element{ - 14305184132582319705, - 8868935336694416555, - 9196887162930508889, - 15486798265448570248, - 5402985275949444416, - 10893197322525159598, - 3204916688966998390, - 12417238192559061753, - 12426306557607898622, - 1305582522441154384, - 10311846026977660324, - 48736111365249031, - } - - // r = 0 - r := Element{} - - v := *x - - var carry, borrow uint64 - var bigger bool - - for { - for v[0]&1 == 0 { - - // v = v >> 1 - - v[0] = v[0]>>1 | v[1]<<63 - v[1] = v[1]>>1 | v[2]<<63 - v[2] = v[2]>>1 | v[3]<<63 - v[3] = v[3]>>1 | v[4]<<63 - v[4] = v[4]>>1 | v[5]<<63 - v[5] = v[5]>>1 | v[6]<<63 - v[6] = v[6]>>1 | v[7]<<63 - v[7] = v[7]>>1 | v[8]<<63 - v[8] = v[8]>>1 | v[9]<<63 - v[9] = v[9]>>1 | v[10]<<63 - v[10] = v[10]>>1 | v[11]<<63 - v[11] >>= 1 - - if s[0]&1 == 1 { - - // s = s + q - s[0], carry = bits.Add64(s[0], 17626244516597989515, 0) - s[1], carry = bits.Add64(s[1], 16614129118623039618, carry) - s[2], carry = bits.Add64(s[2], 1588918198704579639, carry) - s[3], carry = bits.Add64(s[3], 10998096788944562424, carry) - s[4], carry = bits.Add64(s[4], 8204665564953313070, carry) - s[5], carry = bits.Add64(s[5], 9694500593442880912, carry) - s[6], carry = bits.Add64(s[6], 274362232328168196, carry) - s[7], carry = bits.Add64(s[7], 8105254717682411801, carry) - s[8], carry = bits.Add64(s[8], 5945444129596489281, carry) - s[9], carry = bits.Add64(s[9], 13341377791855249032, carry) - s[10], carry = bits.Add64(s[10], 15098257552581525310, carry) - s[11], _ = bits.Add64(s[11], 81882988782276106, carry) + a := *x + b := Element{ + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, + qElementWord4, + qElementWord5, + qElementWord6, + qElementWord7, + qElementWord8, + qElementWord9, + qElementWord10, + qElementWord11, + } // b := q + + u := Element{1} + + // Update factors: we get [u; v]:= [f0 g0; f1 g1] [u; v] + // c_i = f_i + 2³¹ - 1 + 2³² * (g_i + 2³¹ - 1) + var c0, c1 int64 + + // Saved update factors to reduce the number of field multiplications + var pf0, pf1, pg0, pg1 int64 + + var i uint + + var v, s Element + // Since u,v are updated every other iteration, we must make sure we terminate after evenly many iterations + // This also lets us get away with half as many updates to u,v + // To make this constant-time-ish, replace the condition with i < invIterationsN + for i = 0; i&1 == 1 || !a.IsZero(); i++ { + n := max(a.BitLen(), b.BitLen()) + aApprox, bApprox := approximate(&a, n), approximate(&b, n) + + // After 0 iterations, we have f₀ ≤ 2⁰ and f₁ < 2⁰ + // f0, g0, f1, g1 = 1, 0, 0, 1 + c0, c1 = updateFactorIdentityMatrixRow0, updateFactorIdentityMatrixRow1 + + for j := 0; j < approxLowBitsN; j++ { + + if aApprox&1 == 0 { + aApprox /= 2 + } else { + s, borrow := bits.Sub64(aApprox, bApprox, 0) + if borrow == 1 { + s = bApprox - aApprox + bApprox = aApprox + c0, c1 = c1, c0 + } + + aApprox = s / 2 + c0 = c0 - c1 + + // Now |f₀| < 2ʲ + 2ʲ = 2ʲ⁺¹ + // |f₁| ≤ 2ʲ still } - // s = s >> 1 - - s[0] = s[0]>>1 | s[1]<<63 - s[1] = s[1]>>1 | s[2]<<63 - s[2] = s[2]>>1 | s[3]<<63 - s[3] = s[3]>>1 | s[4]<<63 - s[4] = s[4]>>1 | s[5]<<63 - s[5] = s[5]>>1 | s[6]<<63 - s[6] = s[6]>>1 | s[7]<<63 - s[7] = s[7]>>1 | s[8]<<63 - s[8] = s[8]>>1 | s[9]<<63 - s[9] = s[9]>>1 | s[10]<<63 - s[10] = s[10]>>1 | s[11]<<63 - s[11] >>= 1 + c1 *= 2 + // |f₁| ≤ 2ʲ⁺¹ + } + + s = a + var g0 int64 + // from this point on c0 aliases for f0 + c0, g0 = updateFactorsDecompose(c0) + aHi := a.linearCombNonModular(&s, c0, &b, g0) + if aHi&signBitSelector != 0 { + // if aHi < 0 + c0, g0 = -c0, -g0 + aHi = a.neg(&a, aHi) } - for u[0]&1 == 0 { - - // u = u >> 1 - - u[0] = u[0]>>1 | u[1]<<63 - u[1] = u[1]>>1 | u[2]<<63 - u[2] = u[2]>>1 | u[3]<<63 - u[3] = u[3]>>1 | u[4]<<63 - u[4] = u[4]>>1 | u[5]<<63 - u[5] = u[5]>>1 | u[6]<<63 - u[6] = u[6]>>1 | u[7]<<63 - u[7] = u[7]>>1 | u[8]<<63 - u[8] = u[8]>>1 | u[9]<<63 - u[9] = u[9]>>1 | u[10]<<63 - u[10] = u[10]>>1 | u[11]<<63 - u[11] >>= 1 - - if r[0]&1 == 1 { - - // r = r + q - r[0], carry = bits.Add64(r[0], 17626244516597989515, 0) - r[1], carry = bits.Add64(r[1], 16614129118623039618, carry) - r[2], carry = bits.Add64(r[2], 1588918198704579639, carry) - r[3], carry = bits.Add64(r[3], 10998096788944562424, carry) - r[4], carry = bits.Add64(r[4], 8204665564953313070, carry) - r[5], carry = bits.Add64(r[5], 9694500593442880912, carry) - r[6], carry = bits.Add64(r[6], 274362232328168196, carry) - r[7], carry = bits.Add64(r[7], 8105254717682411801, carry) - r[8], carry = bits.Add64(r[8], 5945444129596489281, carry) - r[9], carry = bits.Add64(r[9], 13341377791855249032, carry) - r[10], carry = bits.Add64(r[10], 15098257552581525310, carry) - r[11], _ = bits.Add64(r[11], 81882988782276106, carry) + // right-shift a by k-1 bits + a[0] = (a[0] >> approxLowBitsN) | ((a[1]) << approxHighBitsN) + a[1] = (a[1] >> approxLowBitsN) | ((a[2]) << approxHighBitsN) + a[2] = (a[2] >> approxLowBitsN) | ((a[3]) << approxHighBitsN) + a[3] = (a[3] >> approxLowBitsN) | ((a[4]) << approxHighBitsN) + a[4] = (a[4] >> approxLowBitsN) | ((a[5]) << approxHighBitsN) + a[5] = (a[5] >> approxLowBitsN) | ((a[6]) << approxHighBitsN) + a[6] = (a[6] >> approxLowBitsN) | ((a[7]) << approxHighBitsN) + a[7] = (a[7] >> approxLowBitsN) | ((a[8]) << approxHighBitsN) + a[8] = (a[8] >> approxLowBitsN) | ((a[9]) << approxHighBitsN) + a[9] = (a[9] >> approxLowBitsN) | ((a[10]) << approxHighBitsN) + a[10] = (a[10] >> approxLowBitsN) | ((a[11]) << approxHighBitsN) + a[11] = (a[11] >> approxLowBitsN) | (aHi << approxHighBitsN) + + var f1 int64 + // from this point on c1 aliases for g0 + f1, c1 = updateFactorsDecompose(c1) + bHi := b.linearCombNonModular(&s, f1, &b, c1) + if bHi&signBitSelector != 0 { + // if bHi < 0 + f1, c1 = -f1, -c1 + bHi = b.neg(&b, bHi) + } + // right-shift b by k-1 bits + b[0] = (b[0] >> approxLowBitsN) | ((b[1]) << approxHighBitsN) + b[1] = (b[1] >> approxLowBitsN) | ((b[2]) << approxHighBitsN) + b[2] = (b[2] >> approxLowBitsN) | ((b[3]) << approxHighBitsN) + b[3] = (b[3] >> approxLowBitsN) | ((b[4]) << approxHighBitsN) + b[4] = (b[4] >> approxLowBitsN) | ((b[5]) << approxHighBitsN) + b[5] = (b[5] >> approxLowBitsN) | ((b[6]) << approxHighBitsN) + b[6] = (b[6] >> approxLowBitsN) | ((b[7]) << approxHighBitsN) + b[7] = (b[7] >> approxLowBitsN) | ((b[8]) << approxHighBitsN) + b[8] = (b[8] >> approxLowBitsN) | ((b[9]) << approxHighBitsN) + b[9] = (b[9] >> approxLowBitsN) | ((b[10]) << approxHighBitsN) + b[10] = (b[10] >> approxLowBitsN) | ((b[11]) << approxHighBitsN) + b[11] = (b[11] >> approxLowBitsN) | (bHi << approxHighBitsN) + + if i&1 == 1 { + // Combine current update factors with previously stored ones + // [f₀, g₀; f₁, g₁] ← [f₀, g₀; f₁, g₀] [pf₀, pg₀; pf₀, pg₀] + // We have |f₀|, |g₀|, |pf₀|, |pf₁| ≤ 2ᵏ⁻¹, and that |pf_i| < 2ᵏ⁻¹ for i ∈ {0, 1} + // Then for the new value we get |f₀| < 2ᵏ⁻¹ × 2ᵏ⁻¹ + 2ᵏ⁻¹ × 2ᵏ⁻¹ = 2²ᵏ⁻¹ + // Which leaves us with an extra bit for the sign + + // c0 aliases f0, c1 aliases g1 + c0, g0, f1, c1 = c0*pf0+g0*pf1, + c0*pg0+g0*pg1, + f1*pf0+c1*pf1, + f1*pg0+c1*pg1 + + s = u + u.linearCombSosSigned(&u, c0, &v, g0) + v.linearCombSosSigned(&s, f1, &v, c1) - } + } else { + // Save update factors + pf0, pg0, pf1, pg1 = c0, g0, f1, c1 + } + } - // r = r >> 1 - - r[0] = r[0]>>1 | r[1]<<63 - r[1] = r[1]>>1 | r[2]<<63 - r[2] = r[2]>>1 | r[3]<<63 - r[3] = r[3]>>1 | r[4]<<63 - r[4] = r[4]>>1 | r[5]<<63 - r[5] = r[5]>>1 | r[6]<<63 - r[6] = r[6]>>1 | r[7]<<63 - r[7] = r[7]>>1 | r[8]<<63 - r[8] = r[8]>>1 | r[9]<<63 - r[9] = r[9]>>1 | r[10]<<63 - r[10] = r[10]>>1 | r[11]<<63 - r[11] >>= 1 + // For every iteration that we miss, v is not being multiplied by 2²ᵏ⁻² + const pSq int64 = 1 << (2 * (k - 1)) + // If the function is constant-time ish, this loop will not run (probably no need to take it out explicitly) + for ; i < invIterationsN; i += 2 { + v.mulWSigned(&v, pSq) + } + + z.Mul(&v, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + inversionCorrectionFactorWord5, + inversionCorrectionFactorWord6, + inversionCorrectionFactorWord7, + inversionCorrectionFactorWord8, + inversionCorrectionFactorWord9, + inversionCorrectionFactorWord10, + inversionCorrectionFactorWord11, + }) + return z +} + +// approximate a big number x into a single 64 bit word using its uppermost and lowermost bits +// if x fits in a word as is, no approximation necessary +func approximate(x *Element, nBits int) uint64 { + + if nBits <= 64 { + return x[0] + } + + const mask = (uint64(1) << (k - 1)) - 1 // k-1 ones + lo := mask & x[0] + + hiWordIndex := (nBits - 1) / 64 + + hiWordBitsAvailable := nBits - hiWordIndex*64 + hiWordBitsUsed := min(hiWordBitsAvailable, approxHighBitsN) + + mask_ := uint64(^((1 << (hiWordBitsAvailable - hiWordBitsUsed)) - 1)) + hi := (x[hiWordIndex] & mask_) << (64 - hiWordBitsAvailable) + + mask_ = ^(1<<(approxLowBitsN+hiWordBitsUsed) - 1) + mid := (mask_ & x[hiWordIndex-1]) >> hiWordBitsUsed + + return lo | mid | hi +} +func (z *Element) linearCombSosSigned(x *Element, xC int64, y *Element, yC int64) { + hi := z.linearCombNonModular(x, xC, y, yC) + z.montReduceSigned(z, hi) +} + +// montReduceSigned SOS algorithm; xHi must be at most 63 bits long. Last bit of xHi may be used as a sign bit +func (z *Element) montReduceSigned(x *Element, xHi uint64) { + + const signBitRemover = ^signBitSelector + neg := xHi&signBitSelector != 0 + // the SOS implementation requires that most significant bit is 0 + // Let X be xHi*r + x + // note that if X is negative we would have initially stored it as 2⁶⁴ r + X + xHi &= signBitRemover + // with this a negative X is now represented as 2⁶³ r + X + + var t [2*Limbs - 1]uint64 + var C uint64 + + m := x[0] * qInvNegLsw + + C = madd0(m, qElementWord0, x[0]) + C, t[1] = madd2(m, qElementWord1, x[1], C) + C, t[2] = madd2(m, qElementWord2, x[2], C) + C, t[3] = madd2(m, qElementWord3, x[3], C) + C, t[4] = madd2(m, qElementWord4, x[4], C) + C, t[5] = madd2(m, qElementWord5, x[5], C) + C, t[6] = madd2(m, qElementWord6, x[6], C) + C, t[7] = madd2(m, qElementWord7, x[7], C) + C, t[8] = madd2(m, qElementWord8, x[8], C) + C, t[9] = madd2(m, qElementWord9, x[9], C) + C, t[10] = madd2(m, qElementWord10, x[10], C) + C, t[11] = madd2(m, qElementWord11, x[11], C) + + // the high word of m * qElement[11] is at most 62 bits + // x[11] + C is at most 65 bits (high word at most 1 bit) + // Thus the resulting C will be at most 63 bits + t[12] = xHi + C + // xHi and C are 63 bits, therefore no overflow + + { + const i = 1 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + C, t[i+6] = madd2(m, qElementWord6, t[i+6], C) + C, t[i+7] = madd2(m, qElementWord7, t[i+7], C) + C, t[i+8] = madd2(m, qElementWord8, t[i+8], C) + C, t[i+9] = madd2(m, qElementWord9, t[i+9], C) + C, t[i+10] = madd2(m, qElementWord10, t[i+10], C) + C, t[i+11] = madd2(m, qElementWord11, t[i+11], C) + + t[i+Limbs] += C + } + { + const i = 2 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + C, t[i+6] = madd2(m, qElementWord6, t[i+6], C) + C, t[i+7] = madd2(m, qElementWord7, t[i+7], C) + C, t[i+8] = madd2(m, qElementWord8, t[i+8], C) + C, t[i+9] = madd2(m, qElementWord9, t[i+9], C) + C, t[i+10] = madd2(m, qElementWord10, t[i+10], C) + C, t[i+11] = madd2(m, qElementWord11, t[i+11], C) + + t[i+Limbs] += C + } + { + const i = 3 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + C, t[i+6] = madd2(m, qElementWord6, t[i+6], C) + C, t[i+7] = madd2(m, qElementWord7, t[i+7], C) + C, t[i+8] = madd2(m, qElementWord8, t[i+8], C) + C, t[i+9] = madd2(m, qElementWord9, t[i+9], C) + C, t[i+10] = madd2(m, qElementWord10, t[i+10], C) + C, t[i+11] = madd2(m, qElementWord11, t[i+11], C) + + t[i+Limbs] += C + } + { + const i = 4 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + C, t[i+6] = madd2(m, qElementWord6, t[i+6], C) + C, t[i+7] = madd2(m, qElementWord7, t[i+7], C) + C, t[i+8] = madd2(m, qElementWord8, t[i+8], C) + C, t[i+9] = madd2(m, qElementWord9, t[i+9], C) + C, t[i+10] = madd2(m, qElementWord10, t[i+10], C) + C, t[i+11] = madd2(m, qElementWord11, t[i+11], C) + + t[i+Limbs] += C + } + { + const i = 5 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + C, t[i+6] = madd2(m, qElementWord6, t[i+6], C) + C, t[i+7] = madd2(m, qElementWord7, t[i+7], C) + C, t[i+8] = madd2(m, qElementWord8, t[i+8], C) + C, t[i+9] = madd2(m, qElementWord9, t[i+9], C) + C, t[i+10] = madd2(m, qElementWord10, t[i+10], C) + C, t[i+11] = madd2(m, qElementWord11, t[i+11], C) + + t[i+Limbs] += C + } + { + const i = 6 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + C, t[i+6] = madd2(m, qElementWord6, t[i+6], C) + C, t[i+7] = madd2(m, qElementWord7, t[i+7], C) + C, t[i+8] = madd2(m, qElementWord8, t[i+8], C) + C, t[i+9] = madd2(m, qElementWord9, t[i+9], C) + C, t[i+10] = madd2(m, qElementWord10, t[i+10], C) + C, t[i+11] = madd2(m, qElementWord11, t[i+11], C) + + t[i+Limbs] += C + } + { + const i = 7 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + C, t[i+6] = madd2(m, qElementWord6, t[i+6], C) + C, t[i+7] = madd2(m, qElementWord7, t[i+7], C) + C, t[i+8] = madd2(m, qElementWord8, t[i+8], C) + C, t[i+9] = madd2(m, qElementWord9, t[i+9], C) + C, t[i+10] = madd2(m, qElementWord10, t[i+10], C) + C, t[i+11] = madd2(m, qElementWord11, t[i+11], C) + + t[i+Limbs] += C + } + { + const i = 8 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + C, t[i+6] = madd2(m, qElementWord6, t[i+6], C) + C, t[i+7] = madd2(m, qElementWord7, t[i+7], C) + C, t[i+8] = madd2(m, qElementWord8, t[i+8], C) + C, t[i+9] = madd2(m, qElementWord9, t[i+9], C) + C, t[i+10] = madd2(m, qElementWord10, t[i+10], C) + C, t[i+11] = madd2(m, qElementWord11, t[i+11], C) + + t[i+Limbs] += C + } + { + const i = 9 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + C, t[i+6] = madd2(m, qElementWord6, t[i+6], C) + C, t[i+7] = madd2(m, qElementWord7, t[i+7], C) + C, t[i+8] = madd2(m, qElementWord8, t[i+8], C) + C, t[i+9] = madd2(m, qElementWord9, t[i+9], C) + C, t[i+10] = madd2(m, qElementWord10, t[i+10], C) + C, t[i+11] = madd2(m, qElementWord11, t[i+11], C) + + t[i+Limbs] += C + } + { + const i = 10 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + C, t[i+6] = madd2(m, qElementWord6, t[i+6], C) + C, t[i+7] = madd2(m, qElementWord7, t[i+7], C) + C, t[i+8] = madd2(m, qElementWord8, t[i+8], C) + C, t[i+9] = madd2(m, qElementWord9, t[i+9], C) + C, t[i+10] = madd2(m, qElementWord10, t[i+10], C) + C, t[i+11] = madd2(m, qElementWord11, t[i+11], C) + + t[i+Limbs] += C + } + { + const i = 11 + m := t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, z[0] = madd2(m, qElementWord1, t[i+1], C) + C, z[1] = madd2(m, qElementWord2, t[i+2], C) + C, z[2] = madd2(m, qElementWord3, t[i+3], C) + C, z[3] = madd2(m, qElementWord4, t[i+4], C) + C, z[4] = madd2(m, qElementWord5, t[i+5], C) + C, z[5] = madd2(m, qElementWord6, t[i+6], C) + C, z[6] = madd2(m, qElementWord7, t[i+7], C) + C, z[7] = madd2(m, qElementWord8, t[i+8], C) + C, z[8] = madd2(m, qElementWord9, t[i+9], C) + C, z[9] = madd2(m, qElementWord10, t[i+10], C) + z[11], z[10] = madd2(m, qElementWord11, t[i+11], C) + } + + // if z > q → z -= q + // note: this is NOT constant time + if !(z[11] < 81882988782276106 || (z[11] == 81882988782276106 && (z[10] < 15098257552581525310 || (z[10] == 15098257552581525310 && (z[9] < 13341377791855249032 || (z[9] == 13341377791855249032 && (z[8] < 5945444129596489281 || (z[8] == 5945444129596489281 && (z[7] < 8105254717682411801 || (z[7] == 8105254717682411801 && (z[6] < 274362232328168196 || (z[6] == 274362232328168196 && (z[5] < 9694500593442880912 || (z[5] == 9694500593442880912 && (z[4] < 8204665564953313070 || (z[4] == 8204665564953313070 && (z[3] < 10998096788944562424 || (z[3] == 10998096788944562424 && (z[2] < 1588918198704579639 || (z[2] == 1588918198704579639 && (z[1] < 16614129118623039618 || (z[1] == 16614129118623039618 && (z[0] < 17626244516597989515))))))))))))))))))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 17626244516597989515, 0) + z[1], b = bits.Sub64(z[1], 16614129118623039618, b) + z[2], b = bits.Sub64(z[2], 1588918198704579639, b) + z[3], b = bits.Sub64(z[3], 10998096788944562424, b) + z[4], b = bits.Sub64(z[4], 8204665564953313070, b) + z[5], b = bits.Sub64(z[5], 9694500593442880912, b) + z[6], b = bits.Sub64(z[6], 274362232328168196, b) + z[7], b = bits.Sub64(z[7], 8105254717682411801, b) + z[8], b = bits.Sub64(z[8], 5945444129596489281, b) + z[9], b = bits.Sub64(z[9], 13341377791855249032, b) + z[10], b = bits.Sub64(z[10], 15098257552581525310, b) + z[11], _ = bits.Sub64(z[11], 81882988782276106, b) + } + if neg { + // We have computed ( 2⁶³ r + X ) r⁻¹ = 2⁶³ + X r⁻¹ instead + var b uint64 + z[0], b = bits.Sub64(z[0], signBitSelector, 0) + z[1], b = bits.Sub64(z[1], 0, b) + z[2], b = bits.Sub64(z[2], 0, b) + z[3], b = bits.Sub64(z[3], 0, b) + z[4], b = bits.Sub64(z[4], 0, b) + z[5], b = bits.Sub64(z[5], 0, b) + z[6], b = bits.Sub64(z[6], 0, b) + z[7], b = bits.Sub64(z[7], 0, b) + z[8], b = bits.Sub64(z[8], 0, b) + z[9], b = bits.Sub64(z[9], 0, b) + z[10], b = bits.Sub64(z[10], 0, b) + z[11], b = bits.Sub64(z[11], 0, b) + + // Occurs iff x == 0 && xHi < 0, i.e. X = rX' for -2⁶³ ≤ X' < 0 + if b != 0 { + // z[11] = -1 + // negative: add q + const neg1 = 0xFFFFFFFFFFFFFFFF + + b = 0 + z[0], b = bits.Add64(z[0], qElementWord0, b) + z[1], b = bits.Add64(z[1], qElementWord1, b) + z[2], b = bits.Add64(z[2], qElementWord2, b) + z[3], b = bits.Add64(z[3], qElementWord3, b) + z[4], b = bits.Add64(z[4], qElementWord4, b) + z[5], b = bits.Add64(z[5], qElementWord5, b) + z[6], b = bits.Add64(z[6], qElementWord6, b) + z[7], b = bits.Add64(z[7], qElementWord7, b) + z[8], b = bits.Add64(z[8], qElementWord8, b) + z[9], b = bits.Add64(z[9], qElementWord9, b) + z[10], b = bits.Add64(z[10], qElementWord10, b) + z[11], _ = bits.Add64(neg1, qElementWord11, b) } + } +} - // v >= u - bigger = !(v[11] < u[11] || (v[11] == u[11] && (v[10] < u[10] || (v[10] == u[10] && (v[9] < u[9] || (v[9] == u[9] && (v[8] < u[8] || (v[8] == u[8] && (v[7] < u[7] || (v[7] == u[7] && (v[6] < u[6] || (v[6] == u[6] && (v[5] < u[5] || (v[5] == u[5] && (v[4] < u[4] || (v[4] == u[4] && (v[3] < u[3] || (v[3] == u[3] && (v[2] < u[2] || (v[2] == u[2] && (v[1] < u[1] || (v[1] == u[1] && (v[0] < u[0]))))))))))))))))))))))) - - if bigger { - - // v = v - u - v[0], borrow = bits.Sub64(v[0], u[0], 0) - v[1], borrow = bits.Sub64(v[1], u[1], borrow) - v[2], borrow = bits.Sub64(v[2], u[2], borrow) - v[3], borrow = bits.Sub64(v[3], u[3], borrow) - v[4], borrow = bits.Sub64(v[4], u[4], borrow) - v[5], borrow = bits.Sub64(v[5], u[5], borrow) - v[6], borrow = bits.Sub64(v[6], u[6], borrow) - v[7], borrow = bits.Sub64(v[7], u[7], borrow) - v[8], borrow = bits.Sub64(v[8], u[8], borrow) - v[9], borrow = bits.Sub64(v[9], u[9], borrow) - v[10], borrow = bits.Sub64(v[10], u[10], borrow) - v[11], _ = bits.Sub64(v[11], u[11], borrow) - - // s = s - r - s[0], borrow = bits.Sub64(s[0], r[0], 0) - s[1], borrow = bits.Sub64(s[1], r[1], borrow) - s[2], borrow = bits.Sub64(s[2], r[2], borrow) - s[3], borrow = bits.Sub64(s[3], r[3], borrow) - s[4], borrow = bits.Sub64(s[4], r[4], borrow) - s[5], borrow = bits.Sub64(s[5], r[5], borrow) - s[6], borrow = bits.Sub64(s[6], r[6], borrow) - s[7], borrow = bits.Sub64(s[7], r[7], borrow) - s[8], borrow = bits.Sub64(s[8], r[8], borrow) - s[9], borrow = bits.Sub64(s[9], r[9], borrow) - s[10], borrow = bits.Sub64(s[10], r[10], borrow) - s[11], borrow = bits.Sub64(s[11], r[11], borrow) - - if borrow == 1 { - - // s = s + q - s[0], carry = bits.Add64(s[0], 17626244516597989515, 0) - s[1], carry = bits.Add64(s[1], 16614129118623039618, carry) - s[2], carry = bits.Add64(s[2], 1588918198704579639, carry) - s[3], carry = bits.Add64(s[3], 10998096788944562424, carry) - s[4], carry = bits.Add64(s[4], 8204665564953313070, carry) - s[5], carry = bits.Add64(s[5], 9694500593442880912, carry) - s[6], carry = bits.Add64(s[6], 274362232328168196, carry) - s[7], carry = bits.Add64(s[7], 8105254717682411801, carry) - s[8], carry = bits.Add64(s[8], 5945444129596489281, carry) - s[9], carry = bits.Add64(s[9], 13341377791855249032, carry) - s[10], carry = bits.Add64(s[10], 15098257552581525310, carry) - s[11], _ = bits.Add64(s[11], 81882988782276106, carry) +// mulWSigned mul word signed (w/ montgomery reduction) +func (z *Element) mulWSigned(x *Element, y int64) { + m := y >> 63 + _mulWGeneric(z, x, uint64((y^m)-m)) + // multiply by abs(y) + if y < 0 { + z.Neg(z) + } +} - } - } else { +func (z *Element) neg(x *Element, xHi uint64) uint64 { + var b uint64 - // u = u - v - u[0], borrow = bits.Sub64(u[0], v[0], 0) - u[1], borrow = bits.Sub64(u[1], v[1], borrow) - u[2], borrow = bits.Sub64(u[2], v[2], borrow) - u[3], borrow = bits.Sub64(u[3], v[3], borrow) - u[4], borrow = bits.Sub64(u[4], v[4], borrow) - u[5], borrow = bits.Sub64(u[5], v[5], borrow) - u[6], borrow = bits.Sub64(u[6], v[6], borrow) - u[7], borrow = bits.Sub64(u[7], v[7], borrow) - u[8], borrow = bits.Sub64(u[8], v[8], borrow) - u[9], borrow = bits.Sub64(u[9], v[9], borrow) - u[10], borrow = bits.Sub64(u[10], v[10], borrow) - u[11], _ = bits.Sub64(u[11], v[11], borrow) - - // r = r - s - r[0], borrow = bits.Sub64(r[0], s[0], 0) - r[1], borrow = bits.Sub64(r[1], s[1], borrow) - r[2], borrow = bits.Sub64(r[2], s[2], borrow) - r[3], borrow = bits.Sub64(r[3], s[3], borrow) - r[4], borrow = bits.Sub64(r[4], s[4], borrow) - r[5], borrow = bits.Sub64(r[5], s[5], borrow) - r[6], borrow = bits.Sub64(r[6], s[6], borrow) - r[7], borrow = bits.Sub64(r[7], s[7], borrow) - r[8], borrow = bits.Sub64(r[8], s[8], borrow) - r[9], borrow = bits.Sub64(r[9], s[9], borrow) - r[10], borrow = bits.Sub64(r[10], s[10], borrow) - r[11], borrow = bits.Sub64(r[11], s[11], borrow) - - if borrow == 1 { - - // r = r + q - r[0], carry = bits.Add64(r[0], 17626244516597989515, 0) - r[1], carry = bits.Add64(r[1], 16614129118623039618, carry) - r[2], carry = bits.Add64(r[2], 1588918198704579639, carry) - r[3], carry = bits.Add64(r[3], 10998096788944562424, carry) - r[4], carry = bits.Add64(r[4], 8204665564953313070, carry) - r[5], carry = bits.Add64(r[5], 9694500593442880912, carry) - r[6], carry = bits.Add64(r[6], 274362232328168196, carry) - r[7], carry = bits.Add64(r[7], 8105254717682411801, carry) - r[8], carry = bits.Add64(r[8], 5945444129596489281, carry) - r[9], carry = bits.Add64(r[9], 13341377791855249032, carry) - r[10], carry = bits.Add64(r[10], 15098257552581525310, carry) - r[11], _ = bits.Add64(r[11], 81882988782276106, carry) + z[0], b = bits.Sub64(0, x[0], 0) + z[1], b = bits.Sub64(0, x[1], b) + z[2], b = bits.Sub64(0, x[2], b) + z[3], b = bits.Sub64(0, x[3], b) + z[4], b = bits.Sub64(0, x[4], b) + z[5], b = bits.Sub64(0, x[5], b) + z[6], b = bits.Sub64(0, x[6], b) + z[7], b = bits.Sub64(0, x[7], b) + z[8], b = bits.Sub64(0, x[8], b) + z[9], b = bits.Sub64(0, x[9], b) + z[10], b = bits.Sub64(0, x[10], b) + z[11], b = bits.Sub64(0, x[11], b) + xHi, _ = bits.Sub64(0, xHi, b) + + return xHi +} - } +// regular multiplication by one word regular (non montgomery) +// Fewer additions than the branch-free for positive y. Could be faster on some architectures +func (z *Element) mulWRegular(x *Element, y int64) uint64 { + + // w := abs(y) + m := y >> 63 + w := uint64((y ^ m) - m) + + var c uint64 + c, z[0] = bits.Mul64(x[0], w) + c, z[1] = madd1(x[1], w, c) + c, z[2] = madd1(x[2], w, c) + c, z[3] = madd1(x[3], w, c) + c, z[4] = madd1(x[4], w, c) + c, z[5] = madd1(x[5], w, c) + c, z[6] = madd1(x[6], w, c) + c, z[7] = madd1(x[7], w, c) + c, z[8] = madd1(x[8], w, c) + c, z[9] = madd1(x[9], w, c) + c, z[10] = madd1(x[10], w, c) + c, z[11] = madd1(x[11], w, c) + + if y < 0 { + c = z.neg(z, c) + } + + return c +} + +/* +Removed: seems slower +// mulWRegular branch-free regular multiplication by one word (non montgomery) +func (z *Element) mulWRegularBf(x *Element, y int64) uint64 { + + w := uint64(y) + allNeg := uint64(y >> 63) // -1 if y < 0, 0 o.w + + // s[0], s[1] so results are not stored immediately in z. + // x[i] will be needed in the i+1 th iteration. We don't want to overwrite it in case x = z + var s [2]uint64 + var h [2]uint64 + + h[0], s[0] = bits.Mul64(x[0], w) + + c := uint64(0) + b := uint64(0) + + { + const curI = 1 % 2 + const prevI = 1 - curI + const iMinusOne = 1 - 1 + + h[curI], s[curI] = bits.Mul64(x[1], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } + + { + const curI = 2 % 2 + const prevI = 1 - curI + const iMinusOne = 2 - 1 + + h[curI], s[curI] = bits.Mul64(x[2], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } + + { + const curI = 3 % 2 + const prevI = 1 - curI + const iMinusOne = 3 - 1 + + h[curI], s[curI] = bits.Mul64(x[3], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } + + { + const curI = 4 % 2 + const prevI = 1 - curI + const iMinusOne = 4 - 1 + + h[curI], s[curI] = bits.Mul64(x[4], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } + + { + const curI = 5 % 2 + const prevI = 1 - curI + const iMinusOne = 5 - 1 + + h[curI], s[curI] = bits.Mul64(x[5], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (u[0] == 1) && (u[11]|u[10]|u[9]|u[8]|u[7]|u[6]|u[5]|u[4]|u[3]|u[2]|u[1]) == 0 { - z.Set(&r) - return z + + { + const curI = 6 % 2 + const prevI = 1 - curI + const iMinusOne = 6 - 1 + + h[curI], s[curI] = bits.Mul64(x[6], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } + + { + const curI = 7 % 2 + const prevI = 1 - curI + const iMinusOne = 7 - 1 + + h[curI], s[curI] = bits.Mul64(x[7], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } + + { + const curI = 8 % 2 + const prevI = 1 - curI + const iMinusOne = 8 - 1 + + h[curI], s[curI] = bits.Mul64(x[8], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (v[0] == 1) && (v[11]|v[10]|v[9]|v[8]|v[7]|v[6]|v[5]|v[4]|v[3]|v[2]|v[1]) == 0 { - z.Set(&s) - return z + + { + const curI = 9 % 2 + const prevI = 1 - curI + const iMinusOne = 9 - 1 + + h[curI], s[curI] = bits.Mul64(x[9], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } + + { + const curI = 10 % 2 + const prevI = 1 - curI + const iMinusOne = 10 - 1 + + h[curI], s[curI] = bits.Mul64(x[10], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } + + { + const curI = 11 % 2 + const prevI = 1 - curI + const iMinusOne = 11 - 1 + + h[curI], s[curI] = bits.Mul64(x[11], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } + { + const curI = 12 % 2 + const prevI = 1 - curI + const iMinusOne = 11 + + s[curI], _ = bits.Sub64(h[prevI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + + return s[curI] + c } +}*/ + +// Requires NoCarry +func (z *Element) linearCombNonModular(x *Element, xC int64, y *Element, yC int64) uint64 { + var yTimes Element + + yHi := yTimes.mulWRegular(y, yC) + xHi := z.mulWRegular(x, xC) + + carry := uint64(0) + z[0], carry = bits.Add64(z[0], yTimes[0], carry) + z[1], carry = bits.Add64(z[1], yTimes[1], carry) + z[2], carry = bits.Add64(z[2], yTimes[2], carry) + z[3], carry = bits.Add64(z[3], yTimes[3], carry) + z[4], carry = bits.Add64(z[4], yTimes[4], carry) + z[5], carry = bits.Add64(z[5], yTimes[5], carry) + z[6], carry = bits.Add64(z[6], yTimes[6], carry) + z[7], carry = bits.Add64(z[7], yTimes[7], carry) + z[8], carry = bits.Add64(z[8], yTimes[8], carry) + z[9], carry = bits.Add64(z[9], yTimes[9], carry) + z[10], carry = bits.Add64(z[10], yTimes[10], carry) + z[11], carry = bits.Add64(z[11], yTimes[11], carry) + + yHi, _ = bits.Add64(xHi, yHi, carry) + return yHi } diff --git a/ecc/bw6-761/fp/element_test.go b/ecc/bw6-761/fp/element_test.go index 55503d2a8f..bbd68ce9d7 100644 --- a/ecc/bw6-761/fp/element_test.go +++ b/ecc/bw6-761/fp/element_test.go @@ -22,6 +22,7 @@ import ( "fmt" "math/big" "math/bits" + mrand "math/rand" "testing" "github.com/leanovate/gopter" @@ -287,7 +288,7 @@ var staticTestValues []Element func init() { staticTestValues = append(staticTestValues, Element{}) // zero staticTestValues = append(staticTestValues, One()) // one - staticTestValues = append(staticTestValues, rSquare) // r^2 + staticTestValues = append(staticTestValues, rSquare) // r² var e, one Element one.SetOne() e.Sub(&qElement, &one) @@ -2074,3 +2075,516 @@ func genFull() gopter.Gen { return genResult } } + +func TestElementInversionApproximation(t *testing.T) { + var x Element + for i := 0; i < 1000; i++ { + x.SetRandom() + + // Normally small elements are unlikely. Here we give them a higher chance + xZeros := mrand.Int() % Limbs + for j := 1; j < xZeros; j++ { + x[Limbs-j] = 0 + } + + a := approximate(&x, x.BitLen()) + aRef := approximateRef(&x) + + if a != aRef { + t.Error("Approximation mismatch") + } + } +} + +func TestElementInversionCorrectionFactorFormula(t *testing.T) { + const kLimbs = k * Limbs + const power = kLimbs*6 + invIterationsN*(kLimbs-k+1) + factorInt := big.NewInt(1) + factorInt.Lsh(factorInt, power) + factorInt.Mod(factorInt, Modulus()) + + var refFactorInt big.Int + inversionCorrectionFactor := Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + inversionCorrectionFactorWord5, + inversionCorrectionFactorWord6, + inversionCorrectionFactorWord7, + inversionCorrectionFactorWord8, + inversionCorrectionFactorWord9, + inversionCorrectionFactorWord10, + inversionCorrectionFactorWord11, + } + inversionCorrectionFactor.ToBigInt(&refFactorInt) + + if refFactorInt.Cmp(factorInt) != 0 { + t.Error("mismatch") + } +} + +func TestElementLinearComb(t *testing.T) { + var x Element + var y Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + y.SetRandom() + testLinearComb(t, &x, mrand.Int63(), &y, mrand.Int63()) + } +} + +// Probably unnecessary post-dev. In case the output of inv is wrong, this checks whether it's only off by a constant factor. +func TestElementInversionCorrectionFactor(t *testing.T) { + + // (1/x)/inv(x) = (1/1)/inv(1) ⇔ inv(1) = x inv(x) + + var one Element + var oneInv Element + one.SetOne() + oneInv.Inverse(&one) + + for i := 0; i < 100; i++ { + var x Element + var xInv Element + x.SetRandom() + xInv.Inverse(&x) + + x.Mul(&x, &xInv) + if !x.Equal(&oneInv) { + t.Error("Correction factor is inconsistent") + } + } + + if !oneInv.Equal(&one) { + var i big.Int + oneInv.ToBigIntRegular(&i) // no montgomery + i.ModInverse(&i, Modulus()) + var fac Element + fac.setBigInt(&i) // back to montgomery + + var facTimesFac Element + facTimesFac.Mul(&fac, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + inversionCorrectionFactorWord5, + inversionCorrectionFactorWord6, + inversionCorrectionFactorWord7, + inversionCorrectionFactorWord8, + inversionCorrectionFactorWord9, + inversionCorrectionFactorWord10, + inversionCorrectionFactorWord11, + }) + + t.Error("Correction factor is consistently off by", fac, "Should be", facTimesFac) + } +} + +func TestElementBigNumNeg(t *testing.T) { + var a Element + aHi := a.neg(&a, 0) + if !a.IsZero() || aHi != 0 { + t.Error("-0 != 0") + } +} + +func TestElementBigNumWMul(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + w := mrand.Int63() + testBigNumWMul(t, &x, w) + } +} + +func TestElementVeryBigIntConversion(t *testing.T) { + xHi := mrand.Uint64() + var x Element + x.SetRandom() + var xInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + x.assertMatchVeryBigInt(t, xHi, &xInt) +} + +func TestElementMontReducePos(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64() & ^signBitSelector) + } +} + +func TestElementMontReduceNeg(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64()|signBitSelector) + } +} + +func TestElementMontNegMultipleOfR(t *testing.T) { + var zero Element + + for i := 0; i < 1000; i++ { + testMontReduceSigned(t, &zero, mrand.Uint64()|signBitSelector) + } +} + +//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +func TestUpdateFactorSubtraction(t *testing.T) { + for i := 0; i < 1000; i++ { + + f0, g0 := randomizeUpdateFactors() + f1, g1 := randomizeUpdateFactors() + + for f0-f1 > 1<<31 || f0-f1 <= -1<<31 { + f1 /= 2 + } + + for g0-g1 > 1<<31 || g0-g1 <= -1<<31 { + g1 /= 2 + } + + c0 := updateFactorsCompose(f0, g0) + c1 := updateFactorsCompose(f1, g1) + + cRes := c0 - c1 + fRes, gRes := updateFactorsDecompose(cRes) + + if fRes != f0-f1 || gRes != g0-g1 { + t.Error(i) + } + } +} + +func TestUpdateFactorsDouble(t *testing.T) { + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f > 1<<30 || f < (-1<<31+1)/2 { + f /= 2 + if g <= 1<<29 && g >= (-1<<31+1)/4 { + g *= 2 //g was kept small on f's account. Now that we're halving f, we can double g + } + } + + if g > 1<<30 || g < (-1<<31+1)/2 { + g /= 2 + + if f <= 1<<29 && f >= (-1<<31+1)/4 { + f *= 2 //f was kept small on g's account. Now that we're halving g, we can double f + } + } + + c := updateFactorsCompose(f, g) + cD := c * 2 + fD, gD := updateFactorsDecompose(cD) + + if fD != 2*f || gD != 2*g { + t.Error(i) + } + } +} + +func TestUpdateFactorsNeg(t *testing.T) { + var fMistake bool + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f == 0x80000000 || g == 0x80000000 { + // Update factors this large can only have been obtained after 31 iterations and will therefore never be negated + // We don't have capacity to store -2³¹ + // Repeat this iteration + i-- + continue + } + + c := updateFactorsCompose(f, g) + nc := -c + nf, ng := updateFactorsDecompose(nc) + fMistake = fMistake || nf != -f + if nf != -f || ng != -g { + t.Errorf("Mismatch iteration #%d:\n%d, %d ->\n %d -> %d ->\n %d, %d\n Inputs in hex: %X, %X", + i, f, g, c, nc, nf, ng, f, g) + } + } + if fMistake { + t.Error("Mistake with f detected") + } else { + t.Log("All good with f") + } +} + +func TestUpdateFactorsNeg0(t *testing.T) { + c := updateFactorsCompose(0, 0) + t.Logf("c(0,0) = %X", c) + cn := -c + + if c != cn { + t.Error("Negation of zero update factors should yield the same result.") + } +} + +func TestUpdateFactorDecomposition(t *testing.T) { + var negSeen bool + + for i := 0; i < 1000; i++ { + + f, g := randomizeUpdateFactors() + + if f <= -(1<<31) || f > 1<<31 { + t.Fatal("f out of range") + } + + negSeen = negSeen || f < 0 + + c := updateFactorsCompose(f, g) + + fBack, gBack := updateFactorsDecompose(c) + + if f != fBack || g != gBack { + t.Errorf("(%d, %d) -> %d -> (%d, %d)\n", f, g, c, fBack, gBack) + } + } + + if !negSeen { + t.Fatal("No negative f factors") + } +} + +func TestUpdateFactorInitialValues(t *testing.T) { + + f0, g0 := updateFactorsDecompose(updateFactorIdentityMatrixRow0) + f1, g1 := updateFactorsDecompose(updateFactorIdentityMatrixRow1) + + if f0 != 1 || g0 != 0 || f1 != 0 || g1 != 1 { + t.Error("Update factor initial value constants are incorrect") + } +} + +func TestUpdateFactorsRandomization(t *testing.T) { + var maxLen int + + //t.Log("|f| + |g| is not to exceed", 1 << 31) + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + lf, lg := abs64T32(f), abs64T32(g) + absSum := lf + lg + if absSum >= 1<<31 { + + if absSum == 1<<31 { + maxLen++ + } else { + t.Error(i, "Sum of absolute values too large, f =", f, ",g =", g, ",|f| + |g| =", absSum) + } + } + } + + if maxLen == 0 { + t.Error("max len not observed") + } else { + t.Log(maxLen, "maxLens observed") + } +} + +func randomizeUpdateFactor(absLimit uint32) int64 { + const maxSizeLikelihood = 10 + maxSize := mrand.Intn(maxSizeLikelihood) + + absLimit64 := int64(absLimit) + var f int64 + switch maxSize { + case 0: + f = absLimit64 + case 1: + f = -absLimit64 + default: + f = int64(mrand.Uint64()%(2*uint64(absLimit64)+1)) - absLimit64 + } + + if f > 1<<31 { + return 1 << 31 + } else if f < -1<<31+1 { + return -1<<31 + 1 + } + + return f +} + +func abs64T32(f int64) uint32 { + if f >= 1<<32 || f < -1<<32 { + panic("f out of range") + } + + if f < 0 { + return uint32(-f) + } + return uint32(f) +} + +func randomizeUpdateFactors() (int64, int64) { + var f [2]int64 + b := mrand.Int() % 2 + + f[b] = randomizeUpdateFactor(1 << 31) + + //As per the paper, |f| + |g| \le 2³¹. + f[1-b] = randomizeUpdateFactor(1<<31 - abs64T32(f[b])) + + //Patching another edge case + if f[0]+f[1] == -1<<31 { + b = mrand.Int() % 2 + f[b]++ + } + + return f[0], f[1] +} + +func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { + + var p1 big.Int + x.ToBigInt(&p1) + p1.Mul(&p1, big.NewInt(xC)) + + var p2 big.Int + y.ToBigInt(&p2) + p2.Mul(&p2, big.NewInt(yC)) + + p1.Add(&p1, &p2) + p1.Mod(&p1, Modulus()) + montReduce(&p1, &p1) + + var z Element + z.linearCombSosSigned(x, xC, y, yC) + z.assertMatchVeryBigInt(t, 0, &p1) +} + +func testBigNumWMul(t *testing.T, a *Element, c int64) { + var aHi uint64 + var aTimes Element + aHi = aTimes.mulWRegular(a, c) + + assertMulProduct(t, a, c, &aTimes, aHi) +} + +func testMontReduceSigned(t *testing.T, x *Element, xHi uint64) { + var res Element + var xInt big.Int + var resInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + res.montReduceSigned(x, xHi) + montReduce(&resInt, &xInt) + res.assertMatchVeryBigInt(t, 0, &resInt) +} + +func updateFactorsCompose(f int64, g int64) int64 { + return f + g<<32 +} + +var rInv big.Int + +func montReduce(res *big.Int, x *big.Int) { + if rInv.BitLen() == 0 { // initialization + rInv.SetUint64(1) + rInv.Lsh(&rInv, Limbs*64) + rInv.ModInverse(&rInv, Modulus()) + } + res.Mul(x, &rInv) + res.Mod(res, Modulus()) +} + +func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { + z.ToBigInt(i) + var upperWord big.Int + upperWord.SetUint64(xHi) + upperWord.Lsh(&upperWord, Limbs*64) + i.Add(&upperWord, i) +} + +func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { + z.toVeryBigIntUnsigned(i, xHi) + if signBitSelector&xHi != 0 { + twosCompModulus := big.NewInt(1) + twosCompModulus.Lsh(twosCompModulus, (Limbs+1)*64) + i.Sub(i, twosCompModulus) + } +} + +func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { + var xInt big.Int + x.ToBigInt(&xInt) + + xInt.Mul(&xInt, big.NewInt(c)) + + result.assertMatchVeryBigInt(t, resultHi, &xInt) + return xInt +} + +func assertMatch(t *testing.T, w []big.Word, a uint64, index int) { + + var wI big.Word + + if index < len(w) { + wI = w[index] + } + + const filter uint64 = 0xFFFFFFFFFFFFFFFF >> (64 - bits.UintSize) + + a = a >> ((index * bits.UintSize) % 64) + a &= filter + + if uint64(wI) != a { + t.Error("Bignum mismatch: disagreement on word", index) + } +} + +func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { + + var modulus big.Int + var aIntMod big.Int + modulus.SetInt64(1) + modulus.Lsh(&modulus, (Limbs+1)*64) + aIntMod.Mod(aInt, &modulus) + + words := aIntMod.Bits() + + const steps = 64 / bits.UintSize + for i := 0; i < Limbs*steps; i++ { + assertMatch(t, words, z[i/steps], i) + } + + for i := 0; i < steps; i++ { + assertMatch(t, words, aHi, Limbs*steps+i) + } +} + +func approximateRef(x *Element) uint64 { + + var asInt big.Int + x.ToBigInt(&asInt) + n := x.BitLen() + + if n <= 64 { + return asInt.Uint64() + } + + modulus := big.NewInt(1 << 31) + var lo big.Int + lo.Mod(&asInt, modulus) + + modulus.Lsh(modulus, uint(n-64)) + var hi big.Int + hi.Div(&asInt, modulus) + hi.Lsh(&hi, 31) + + hi.Add(&hi, &lo) + return hi.Uint64() +} diff --git a/ecc/bw6-761/fr/element.go b/ecc/bw6-761/fr/element.go index 00fd61eab8..ac2c7cec21 100644 --- a/ecc/bw6-761/fr/element.go +++ b/ecc/bw6-761/fr/element.go @@ -63,15 +63,25 @@ func Modulus() *big.Int { } // q (modulus) +const qElementWord0 uint64 = 9586122913090633729 +const qElementWord1 uint64 = 1660523435060625408 +const qElementWord2 uint64 = 2230234197602682880 +const qElementWord3 uint64 = 1883307231910630287 +const qElementWord4 uint64 = 14284016967150029115 +const qElementWord5 uint64 = 121098312706494698 + var qElement = Element{ - 9586122913090633729, - 1660523435060625408, - 2230234197602682880, - 1883307231910630287, - 14284016967150029115, - 121098312706494698, + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, + qElementWord4, + qElementWord5, } +// Used for Montgomery reduction. (qInvNeg) q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +const qInvNegLsw uint64 = 9586122913090633727 + // rSquare var rSquare = Element{ 13224372171368877346, @@ -197,7 +207,7 @@ func (z *Element) IsZero() bool { return (z[5] | z[4] | z[3] | z[2] | z[1] | z[0]) == 0 } -// IsUint64 returns true if z[0] >= 0 and all other words are 0 +// IsUint64 returns true if z[0] ⩾ 0 and all other words are 0 func (z *Element) IsUint64() bool { return (z[5] | z[4] | z[3] | z[2] | z[1]) == 0 } @@ -281,7 +291,7 @@ func (z *Element) SetRandom() (*Element, error) { z[5] = binary.BigEndian.Uint64(bytes[40:48]) z[5] %= 121098312706494698 - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[5] < 121098312706494698 || (z[5] == 121098312706494698 && (z[4] < 14284016967150029115 || (z[4] == 14284016967150029115 && (z[3] < 1883307231910630287 || (z[3] == 1883307231910630287 && (z[2] < 2230234197602682880 || (z[2] == 2230234197602682880 && (z[1] < 1660523435060625408 || (z[1] == 1660523435060625408 && (z[0] < 9586122913090633729))))))))))) { var b uint64 @@ -485,7 +495,90 @@ func _mulGeneric(z, x, y *Element) { z[5], z[4] = madd3(m, 121098312706494698, c[0], c[2], c[1]) } - // if z > q --> z -= q + // if z > q → z -= q + // note: this is NOT constant time + if !(z[5] < 121098312706494698 || (z[5] == 121098312706494698 && (z[4] < 14284016967150029115 || (z[4] == 14284016967150029115 && (z[3] < 1883307231910630287 || (z[3] == 1883307231910630287 && (z[2] < 2230234197602682880 || (z[2] == 2230234197602682880 && (z[1] < 1660523435060625408 || (z[1] == 1660523435060625408 && (z[0] < 9586122913090633729))))))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 9586122913090633729, 0) + z[1], b = bits.Sub64(z[1], 1660523435060625408, b) + z[2], b = bits.Sub64(z[2], 2230234197602682880, b) + z[3], b = bits.Sub64(z[3], 1883307231910630287, b) + z[4], b = bits.Sub64(z[4], 14284016967150029115, b) + z[5], _ = bits.Sub64(z[5], 121098312706494698, b) + } +} + +func _mulWGeneric(z, x *Element, y uint64) { + + var t [6]uint64 + { + // round 0 + c1, c0 := bits.Mul64(y, x[0]) + m := c0 * 9586122913090633727 + c2 := madd0(m, 9586122913090633729, c0) + c1, c0 = madd1(y, x[1], c1) + c2, t[0] = madd2(m, 1660523435060625408, c2, c0) + c1, c0 = madd1(y, x[2], c1) + c2, t[1] = madd2(m, 2230234197602682880, c2, c0) + c1, c0 = madd1(y, x[3], c1) + c2, t[2] = madd2(m, 1883307231910630287, c2, c0) + c1, c0 = madd1(y, x[4], c1) + c2, t[3] = madd2(m, 14284016967150029115, c2, c0) + c1, c0 = madd1(y, x[5], c1) + t[5], t[4] = madd3(m, 121098312706494698, c0, c2, c1) + } + { + // round 1 + m := t[0] * 9586122913090633727 + c2 := madd0(m, 9586122913090633729, t[0]) + c2, t[0] = madd2(m, 1660523435060625408, c2, t[1]) + c2, t[1] = madd2(m, 2230234197602682880, c2, t[2]) + c2, t[2] = madd2(m, 1883307231910630287, c2, t[3]) + c2, t[3] = madd2(m, 14284016967150029115, c2, t[4]) + t[5], t[4] = madd2(m, 121098312706494698, t[5], c2) + } + { + // round 2 + m := t[0] * 9586122913090633727 + c2 := madd0(m, 9586122913090633729, t[0]) + c2, t[0] = madd2(m, 1660523435060625408, c2, t[1]) + c2, t[1] = madd2(m, 2230234197602682880, c2, t[2]) + c2, t[2] = madd2(m, 1883307231910630287, c2, t[3]) + c2, t[3] = madd2(m, 14284016967150029115, c2, t[4]) + t[5], t[4] = madd2(m, 121098312706494698, t[5], c2) + } + { + // round 3 + m := t[0] * 9586122913090633727 + c2 := madd0(m, 9586122913090633729, t[0]) + c2, t[0] = madd2(m, 1660523435060625408, c2, t[1]) + c2, t[1] = madd2(m, 2230234197602682880, c2, t[2]) + c2, t[2] = madd2(m, 1883307231910630287, c2, t[3]) + c2, t[3] = madd2(m, 14284016967150029115, c2, t[4]) + t[5], t[4] = madd2(m, 121098312706494698, t[5], c2) + } + { + // round 4 + m := t[0] * 9586122913090633727 + c2 := madd0(m, 9586122913090633729, t[0]) + c2, t[0] = madd2(m, 1660523435060625408, c2, t[1]) + c2, t[1] = madd2(m, 2230234197602682880, c2, t[2]) + c2, t[2] = madd2(m, 1883307231910630287, c2, t[3]) + c2, t[3] = madd2(m, 14284016967150029115, c2, t[4]) + t[5], t[4] = madd2(m, 121098312706494698, t[5], c2) + } + { + // round 5 + m := t[0] * 9586122913090633727 + c2 := madd0(m, 9586122913090633729, t[0]) + c2, z[0] = madd2(m, 1660523435060625408, c2, t[1]) + c2, z[1] = madd2(m, 2230234197602682880, c2, t[2]) + c2, z[2] = madd2(m, 1883307231910630287, c2, t[3]) + c2, z[3] = madd2(m, 14284016967150029115, c2, t[4]) + z[5], z[4] = madd2(m, 121098312706494698, t[5], c2) + } + + // if z > q → z -= q // note: this is NOT constant time if !(z[5] < 121098312706494698 || (z[5] == 121098312706494698 && (z[4] < 14284016967150029115 || (z[4] == 14284016967150029115 && (z[3] < 1883307231910630287 || (z[3] == 1883307231910630287 && (z[2] < 2230234197602682880 || (z[2] == 2230234197602682880 && (z[1] < 1660523435060625408 || (z[1] == 1660523435060625408 && (z[0] < 9586122913090633729))))))))))) { var b uint64 @@ -568,7 +661,7 @@ func _fromMontGeneric(z *Element) { z[5] = C } - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[5] < 121098312706494698 || (z[5] == 121098312706494698 && (z[4] < 14284016967150029115 || (z[4] == 14284016967150029115 && (z[3] < 1883307231910630287 || (z[3] == 1883307231910630287 && (z[2] < 2230234197602682880 || (z[2] == 2230234197602682880 && (z[1] < 1660523435060625408 || (z[1] == 1660523435060625408 && (z[0] < 9586122913090633729))))))))))) { var b uint64 @@ -591,7 +684,7 @@ func _addGeneric(z, x, y *Element) { z[4], carry = bits.Add64(x[4], y[4], carry) z[5], _ = bits.Add64(x[5], y[5], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[5] < 121098312706494698 || (z[5] == 121098312706494698 && (z[4] < 14284016967150029115 || (z[4] == 14284016967150029115 && (z[3] < 1883307231910630287 || (z[3] == 1883307231910630287 && (z[2] < 2230234197602682880 || (z[2] == 2230234197602682880 && (z[1] < 1660523435060625408 || (z[1] == 1660523435060625408 && (z[0] < 9586122913090633729))))))))))) { var b uint64 @@ -614,7 +707,7 @@ func _doubleGeneric(z, x *Element) { z[4], carry = bits.Add64(x[4], x[4], carry) z[5], _ = bits.Add64(x[5], x[5], carry) - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[5] < 121098312706494698 || (z[5] == 121098312706494698 && (z[4] < 14284016967150029115 || (z[4] == 14284016967150029115 && (z[3] < 1883307231910630287 || (z[3] == 1883307231910630287 && (z[2] < 2230234197602682880 || (z[2] == 2230234197602682880 && (z[1] < 1660523435060625408 || (z[1] == 1660523435060625408 && (z[0] < 9586122913090633729))))))))))) { var b uint64 @@ -662,7 +755,7 @@ func _negGeneric(z, x *Element) { func _reduceGeneric(z *Element) { - // if z > q --> z -= q + // if z > q → z -= q // note: this is NOT constant time if !(z[5] < 121098312706494698 || (z[5] == 121098312706494698 && (z[4] < 14284016967150029115 || (z[4] == 14284016967150029115 && (z[3] < 1883307231910630287 || (z[3] == 1883307231910630287 && (z[2] < 2230234197602682880 || (z[2] == 2230234197602682880 && (z[1] < 1660523435060625408 || (z[1] == 1660523435060625408 && (z[0] < 9586122913090633729))))))))))) { var b uint64 @@ -778,7 +871,7 @@ func (z *Element) Exp(x Element, exponent *big.Int) *Element { } // ToMont converts z to Montgomery form -// sets and returns z = z * r^2 +// sets and returns z = z * r² func (z *Element) ToMont() *Element { return z.Mul(z, &rSquare) } @@ -912,7 +1005,7 @@ func (z *Element) SetBigInt(v *big.Int) *Element { return z } -// setBigInt assumes 0 <= v < q +// setBigInt assumes 0 ⩽ v < q func (z *Element) setBigInt(v *big.Int) *Element { vBits := v.Bits() @@ -1100,181 +1193,496 @@ func (z *Element) Sqrt(x *Element) *Element { } } -// Inverse z = x^-1 mod q -// Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" -// if x == 0, sets and returns z = x +func max(a int, b int) int { + if a > b { + return a + } + return b +} + +func min(a int, b int) int { + if a < b { + return a + } + return b +} + +const updateFactorsConversionBias int64 = 0x7fffffff7fffffff // (2³¹ - 1)(2³² + 1) +const updateFactorIdentityMatrixRow0 = 1 +const updateFactorIdentityMatrixRow1 = 1 << 32 + +func updateFactorsDecompose(c int64) (int64, int64) { + c += updateFactorsConversionBias + const low32BitsFilter int64 = 0xFFFFFFFF + f := c&low32BitsFilter - 0x7FFFFFFF + g := c>>32&low32BitsFilter - 0x7FFFFFFF + return f, g +} + +const k = 32 // word size / 2 +const signBitSelector = uint64(1) << 63 +const approxLowBitsN = k - 1 +const approxHighBitsN = k + 1 +const inversionCorrectionFactorWord0 = 16386826051656692015 +const inversionCorrectionFactorWord1 = 8373462824848618879 +const inversionCorrectionFactorWord2 = 7553521018781888459 +const inversionCorrectionFactorWord3 = 595240760696852504 +const inversionCorrectionFactorWord4 = 16794241053652767540 +const inversionCorrectionFactorWord5 = 43911691917702151 + +const invIterationsN = 26 + +// Inverse z = x⁻¹ mod q +// Implements "Optimized Binary GCD for Modular Inversion" +// https://github.com/pornin/bingcd/blob/main/doc/bingcd.pdf func (z *Element) Inverse(x *Element) *Element { if x.IsZero() { z.SetZero() return z } - // initialize u = q - var u = Element{ - 9586122913090633729, - 1660523435060625408, - 2230234197602682880, - 1883307231910630287, - 14284016967150029115, - 121098312706494698, + a := *x + b := Element{ + qElementWord0, + qElementWord1, + qElementWord2, + qElementWord3, + qElementWord4, + qElementWord5, + } // b := q + + u := Element{1} + + // Update factors: we get [u; v]:= [f0 g0; f1 g1] [u; v] + // c_i = f_i + 2³¹ - 1 + 2³² * (g_i + 2³¹ - 1) + var c0, c1 int64 + + // Saved update factors to reduce the number of field multiplications + var pf0, pf1, pg0, pg1 int64 + + var i uint + + var v, s Element + + // Since u,v are updated every other iteration, we must make sure we terminate after evenly many iterations + // This also lets us get away with half as many updates to u,v + // To make this constant-time-ish, replace the condition with i < invIterationsN + for i = 0; i&1 == 1 || !a.IsZero(); i++ { + n := max(a.BitLen(), b.BitLen()) + aApprox, bApprox := approximate(&a, n), approximate(&b, n) + + // After 0 iterations, we have f₀ ≤ 2⁰ and f₁ < 2⁰ + // f0, g0, f1, g1 = 1, 0, 0, 1 + c0, c1 = updateFactorIdentityMatrixRow0, updateFactorIdentityMatrixRow1 + + for j := 0; j < approxLowBitsN; j++ { + + if aApprox&1 == 0 { + aApprox /= 2 + } else { + s, borrow := bits.Sub64(aApprox, bApprox, 0) + if borrow == 1 { + s = bApprox - aApprox + bApprox = aApprox + c0, c1 = c1, c0 + } + + aApprox = s / 2 + c0 = c0 - c1 + + // Now |f₀| < 2ʲ + 2ʲ = 2ʲ⁺¹ + // |f₁| ≤ 2ʲ still + } + + c1 *= 2 + // |f₁| ≤ 2ʲ⁺¹ + } + + s = a + + var g0 int64 + // from this point on c0 aliases for f0 + c0, g0 = updateFactorsDecompose(c0) + aHi := a.linearCombNonModular(&s, c0, &b, g0) + if aHi&signBitSelector != 0 { + // if aHi < 0 + c0, g0 = -c0, -g0 + aHi = a.neg(&a, aHi) + } + // right-shift a by k-1 bits + a[0] = (a[0] >> approxLowBitsN) | ((a[1]) << approxHighBitsN) + a[1] = (a[1] >> approxLowBitsN) | ((a[2]) << approxHighBitsN) + a[2] = (a[2] >> approxLowBitsN) | ((a[3]) << approxHighBitsN) + a[3] = (a[3] >> approxLowBitsN) | ((a[4]) << approxHighBitsN) + a[4] = (a[4] >> approxLowBitsN) | ((a[5]) << approxHighBitsN) + a[5] = (a[5] >> approxLowBitsN) | (aHi << approxHighBitsN) + + var f1 int64 + // from this point on c1 aliases for g0 + f1, c1 = updateFactorsDecompose(c1) + bHi := b.linearCombNonModular(&s, f1, &b, c1) + if bHi&signBitSelector != 0 { + // if bHi < 0 + f1, c1 = -f1, -c1 + bHi = b.neg(&b, bHi) + } + // right-shift b by k-1 bits + b[0] = (b[0] >> approxLowBitsN) | ((b[1]) << approxHighBitsN) + b[1] = (b[1] >> approxLowBitsN) | ((b[2]) << approxHighBitsN) + b[2] = (b[2] >> approxLowBitsN) | ((b[3]) << approxHighBitsN) + b[3] = (b[3] >> approxLowBitsN) | ((b[4]) << approxHighBitsN) + b[4] = (b[4] >> approxLowBitsN) | ((b[5]) << approxHighBitsN) + b[5] = (b[5] >> approxLowBitsN) | (bHi << approxHighBitsN) + + if i&1 == 1 { + // Combine current update factors with previously stored ones + // [f₀, g₀; f₁, g₁] ← [f₀, g₀; f₁, g₀] [pf₀, pg₀; pf₀, pg₀] + // We have |f₀|, |g₀|, |pf₀|, |pf₁| ≤ 2ᵏ⁻¹, and that |pf_i| < 2ᵏ⁻¹ for i ∈ {0, 1} + // Then for the new value we get |f₀| < 2ᵏ⁻¹ × 2ᵏ⁻¹ + 2ᵏ⁻¹ × 2ᵏ⁻¹ = 2²ᵏ⁻¹ + // Which leaves us with an extra bit for the sign + + // c0 aliases f0, c1 aliases g1 + c0, g0, f1, c1 = c0*pf0+g0*pf1, + c0*pg0+g0*pg1, + f1*pf0+c1*pf1, + f1*pg0+c1*pg1 + + s = u + u.linearCombSosSigned(&u, c0, &v, g0) + v.linearCombSosSigned(&s, f1, &v, c1) + + } else { + // Save update factors + pf0, pg0, pf1, pg1 = c0, g0, f1, c1 + } } - // initialize s = r^2 - var s = Element{ - 13224372171368877346, - 227991066186625457, - 2496666625421784173, - 13825906835078366124, - 9475172226622360569, - 30958721782860680, + // For every iteration that we miss, v is not being multiplied by 2²ᵏ⁻² + const pSq int64 = 1 << (2 * (k - 1)) + // If the function is constant-time ish, this loop will not run (probably no need to take it out explicitly) + for ; i < invIterationsN; i += 2 { + v.mulWSigned(&v, pSq) } - // r = 0 - r := Element{} + z.Mul(&v, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + inversionCorrectionFactorWord5, + }) + return z +} - v := *x +// approximate a big number x into a single 64 bit word using its uppermost and lowermost bits +// if x fits in a word as is, no approximation necessary +func approximate(x *Element, nBits int) uint64 { - var carry, borrow uint64 - var bigger bool + if nBits <= 64 { + return x[0] + } - for { - for v[0]&1 == 0 { + const mask = (uint64(1) << (k - 1)) - 1 // k-1 ones + lo := mask & x[0] - // v = v >> 1 + hiWordIndex := (nBits - 1) / 64 - v[0] = v[0]>>1 | v[1]<<63 - v[1] = v[1]>>1 | v[2]<<63 - v[2] = v[2]>>1 | v[3]<<63 - v[3] = v[3]>>1 | v[4]<<63 - v[4] = v[4]>>1 | v[5]<<63 - v[5] >>= 1 + hiWordBitsAvailable := nBits - hiWordIndex*64 + hiWordBitsUsed := min(hiWordBitsAvailable, approxHighBitsN) - if s[0]&1 == 1 { + mask_ := uint64(^((1 << (hiWordBitsAvailable - hiWordBitsUsed)) - 1)) + hi := (x[hiWordIndex] & mask_) << (64 - hiWordBitsAvailable) - // s = s + q - s[0], carry = bits.Add64(s[0], 9586122913090633729, 0) - s[1], carry = bits.Add64(s[1], 1660523435060625408, carry) - s[2], carry = bits.Add64(s[2], 2230234197602682880, carry) - s[3], carry = bits.Add64(s[3], 1883307231910630287, carry) - s[4], carry = bits.Add64(s[4], 14284016967150029115, carry) - s[5], _ = bits.Add64(s[5], 121098312706494698, carry) + mask_ = ^(1<<(approxLowBitsN+hiWordBitsUsed) - 1) + mid := (mask_ & x[hiWordIndex-1]) >> hiWordBitsUsed - } + return lo | mid | hi +} - // s = s >> 1 +func (z *Element) linearCombSosSigned(x *Element, xC int64, y *Element, yC int64) { + hi := z.linearCombNonModular(x, xC, y, yC) + z.montReduceSigned(z, hi) +} + +// montReduceSigned SOS algorithm; xHi must be at most 63 bits long. Last bit of xHi may be used as a sign bit +func (z *Element) montReduceSigned(x *Element, xHi uint64) { + + const signBitRemover = ^signBitSelector + neg := xHi&signBitSelector != 0 + // the SOS implementation requires that most significant bit is 0 + // Let X be xHi*r + x + // note that if X is negative we would have initially stored it as 2⁶⁴ r + X + xHi &= signBitRemover + // with this a negative X is now represented as 2⁶³ r + X + + var t [2*Limbs - 1]uint64 + var C uint64 + + m := x[0] * qInvNegLsw + + C = madd0(m, qElementWord0, x[0]) + C, t[1] = madd2(m, qElementWord1, x[1], C) + C, t[2] = madd2(m, qElementWord2, x[2], C) + C, t[3] = madd2(m, qElementWord3, x[3], C) + C, t[4] = madd2(m, qElementWord4, x[4], C) + C, t[5] = madd2(m, qElementWord5, x[5], C) + + // the high word of m * qElement[5] is at most 62 bits + // x[5] + C is at most 65 bits (high word at most 1 bit) + // Thus the resulting C will be at most 63 bits + t[6] = xHi + C + // xHi and C are 63 bits, therefore no overflow + + { + const i = 1 + m = t[i] * qInvNegLsw - s[0] = s[0]>>1 | s[1]<<63 - s[1] = s[1]>>1 | s[2]<<63 - s[2] = s[2]>>1 | s[3]<<63 - s[3] = s[3]>>1 | s[4]<<63 - s[4] = s[4]>>1 | s[5]<<63 - s[5] >>= 1 + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + t[i+Limbs] += C + } + { + const i = 2 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + + t[i+Limbs] += C + } + { + const i = 3 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + + t[i+Limbs] += C + } + { + const i = 4 + m = t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, t[i+1] = madd2(m, qElementWord1, t[i+1], C) + C, t[i+2] = madd2(m, qElementWord2, t[i+2], C) + C, t[i+3] = madd2(m, qElementWord3, t[i+3], C) + C, t[i+4] = madd2(m, qElementWord4, t[i+4], C) + C, t[i+5] = madd2(m, qElementWord5, t[i+5], C) + + t[i+Limbs] += C + } + { + const i = 5 + m := t[i] * qInvNegLsw + + C = madd0(m, qElementWord0, t[i+0]) + C, z[0] = madd2(m, qElementWord1, t[i+1], C) + C, z[1] = madd2(m, qElementWord2, t[i+2], C) + C, z[2] = madd2(m, qElementWord3, t[i+3], C) + C, z[3] = madd2(m, qElementWord4, t[i+4], C) + z[5], z[4] = madd2(m, qElementWord5, t[i+5], C) + } + + // if z > q → z -= q + // note: this is NOT constant time + if !(z[5] < 121098312706494698 || (z[5] == 121098312706494698 && (z[4] < 14284016967150029115 || (z[4] == 14284016967150029115 && (z[3] < 1883307231910630287 || (z[3] == 1883307231910630287 && (z[2] < 2230234197602682880 || (z[2] == 2230234197602682880 && (z[1] < 1660523435060625408 || (z[1] == 1660523435060625408 && (z[0] < 9586122913090633729))))))))))) { + var b uint64 + z[0], b = bits.Sub64(z[0], 9586122913090633729, 0) + z[1], b = bits.Sub64(z[1], 1660523435060625408, b) + z[2], b = bits.Sub64(z[2], 2230234197602682880, b) + z[3], b = bits.Sub64(z[3], 1883307231910630287, b) + z[4], b = bits.Sub64(z[4], 14284016967150029115, b) + z[5], _ = bits.Sub64(z[5], 121098312706494698, b) + } + if neg { + // We have computed ( 2⁶³ r + X ) r⁻¹ = 2⁶³ + X r⁻¹ instead + var b uint64 + z[0], b = bits.Sub64(z[0], signBitSelector, 0) + z[1], b = bits.Sub64(z[1], 0, b) + z[2], b = bits.Sub64(z[2], 0, b) + z[3], b = bits.Sub64(z[3], 0, b) + z[4], b = bits.Sub64(z[4], 0, b) + z[5], b = bits.Sub64(z[5], 0, b) + + // Occurs iff x == 0 && xHi < 0, i.e. X = rX' for -2⁶³ ≤ X' < 0 + if b != 0 { + // z[5] = -1 + // negative: add q + const neg1 = 0xFFFFFFFFFFFFFFFF + + b = 0 + z[0], b = bits.Add64(z[0], qElementWord0, b) + z[1], b = bits.Add64(z[1], qElementWord1, b) + z[2], b = bits.Add64(z[2], qElementWord2, b) + z[3], b = bits.Add64(z[3], qElementWord3, b) + z[4], b = bits.Add64(z[4], qElementWord4, b) + z[5], _ = bits.Add64(neg1, qElementWord5, b) } - for u[0]&1 == 0 { + } +} - // u = u >> 1 +// mulWSigned mul word signed (w/ montgomery reduction) +func (z *Element) mulWSigned(x *Element, y int64) { + m := y >> 63 + _mulWGeneric(z, x, uint64((y^m)-m)) + // multiply by abs(y) + if y < 0 { + z.Neg(z) + } +} - u[0] = u[0]>>1 | u[1]<<63 - u[1] = u[1]>>1 | u[2]<<63 - u[2] = u[2]>>1 | u[3]<<63 - u[3] = u[3]>>1 | u[4]<<63 - u[4] = u[4]>>1 | u[5]<<63 - u[5] >>= 1 +func (z *Element) neg(x *Element, xHi uint64) uint64 { + var b uint64 - if r[0]&1 == 1 { + z[0], b = bits.Sub64(0, x[0], 0) + z[1], b = bits.Sub64(0, x[1], b) + z[2], b = bits.Sub64(0, x[2], b) + z[3], b = bits.Sub64(0, x[3], b) + z[4], b = bits.Sub64(0, x[4], b) + z[5], b = bits.Sub64(0, x[5], b) + xHi, _ = bits.Sub64(0, xHi, b) - // r = r + q - r[0], carry = bits.Add64(r[0], 9586122913090633729, 0) - r[1], carry = bits.Add64(r[1], 1660523435060625408, carry) - r[2], carry = bits.Add64(r[2], 2230234197602682880, carry) - r[3], carry = bits.Add64(r[3], 1883307231910630287, carry) - r[4], carry = bits.Add64(r[4], 14284016967150029115, carry) - r[5], _ = bits.Add64(r[5], 121098312706494698, carry) + return xHi +} - } +// regular multiplication by one word regular (non montgomery) +// Fewer additions than the branch-free for positive y. Could be faster on some architectures +func (z *Element) mulWRegular(x *Element, y int64) uint64 { + + // w := abs(y) + m := y >> 63 + w := uint64((y ^ m) - m) + + var c uint64 + c, z[0] = bits.Mul64(x[0], w) + c, z[1] = madd1(x[1], w, c) + c, z[2] = madd1(x[2], w, c) + c, z[3] = madd1(x[3], w, c) + c, z[4] = madd1(x[4], w, c) + c, z[5] = madd1(x[5], w, c) + + if y < 0 { + c = z.neg(z, c) + } + + return c +} + +/* +Removed: seems slower +// mulWRegular branch-free regular multiplication by one word (non montgomery) +func (z *Element) mulWRegularBf(x *Element, y int64) uint64 { + + w := uint64(y) + allNeg := uint64(y >> 63) // -1 if y < 0, 0 o.w - // r = r >> 1 + // s[0], s[1] so results are not stored immediately in z. + // x[i] will be needed in the i+1 th iteration. We don't want to overwrite it in case x = z + var s [2]uint64 + var h [2]uint64 - r[0] = r[0]>>1 | r[1]<<63 - r[1] = r[1]>>1 | r[2]<<63 - r[2] = r[2]>>1 | r[3]<<63 - r[3] = r[3]>>1 | r[4]<<63 - r[4] = r[4]>>1 | r[5]<<63 - r[5] >>= 1 + h[0], s[0] = bits.Mul64(x[0], w) + c := uint64(0) + b := uint64(0) + + { + const curI = 1 % 2 + const prevI = 1 - curI + const iMinusOne = 1 - 1 + + h[curI], s[curI] = bits.Mul64(x[1], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - // v >= u - bigger = !(v[5] < u[5] || (v[5] == u[5] && (v[4] < u[4] || (v[4] == u[4] && (v[3] < u[3] || (v[3] == u[3] && (v[2] < u[2] || (v[2] == u[2] && (v[1] < u[1] || (v[1] == u[1] && (v[0] < u[0]))))))))))) - - if bigger { - - // v = v - u - v[0], borrow = bits.Sub64(v[0], u[0], 0) - v[1], borrow = bits.Sub64(v[1], u[1], borrow) - v[2], borrow = bits.Sub64(v[2], u[2], borrow) - v[3], borrow = bits.Sub64(v[3], u[3], borrow) - v[4], borrow = bits.Sub64(v[4], u[4], borrow) - v[5], _ = bits.Sub64(v[5], u[5], borrow) - - // s = s - r - s[0], borrow = bits.Sub64(s[0], r[0], 0) - s[1], borrow = bits.Sub64(s[1], r[1], borrow) - s[2], borrow = bits.Sub64(s[2], r[2], borrow) - s[3], borrow = bits.Sub64(s[3], r[3], borrow) - s[4], borrow = bits.Sub64(s[4], r[4], borrow) - s[5], borrow = bits.Sub64(s[5], r[5], borrow) - - if borrow == 1 { - - // s = s + q - s[0], carry = bits.Add64(s[0], 9586122913090633729, 0) - s[1], carry = bits.Add64(s[1], 1660523435060625408, carry) - s[2], carry = bits.Add64(s[2], 2230234197602682880, carry) - s[3], carry = bits.Add64(s[3], 1883307231910630287, carry) - s[4], carry = bits.Add64(s[4], 14284016967150029115, carry) - s[5], _ = bits.Add64(s[5], 121098312706494698, carry) + { + const curI = 2 % 2 + const prevI = 1 - curI + const iMinusOne = 2 - 1 - } - } else { + h[curI], s[curI] = bits.Mul64(x[2], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } - // u = u - v - u[0], borrow = bits.Sub64(u[0], v[0], 0) - u[1], borrow = bits.Sub64(u[1], v[1], borrow) - u[2], borrow = bits.Sub64(u[2], v[2], borrow) - u[3], borrow = bits.Sub64(u[3], v[3], borrow) - u[4], borrow = bits.Sub64(u[4], v[4], borrow) - u[5], _ = bits.Sub64(u[5], v[5], borrow) - - // r = r - s - r[0], borrow = bits.Sub64(r[0], s[0], 0) - r[1], borrow = bits.Sub64(r[1], s[1], borrow) - r[2], borrow = bits.Sub64(r[2], s[2], borrow) - r[3], borrow = bits.Sub64(r[3], s[3], borrow) - r[4], borrow = bits.Sub64(r[4], s[4], borrow) - r[5], borrow = bits.Sub64(r[5], s[5], borrow) - - if borrow == 1 { - - // r = r + q - r[0], carry = bits.Add64(r[0], 9586122913090633729, 0) - r[1], carry = bits.Add64(r[1], 1660523435060625408, carry) - r[2], carry = bits.Add64(r[2], 2230234197602682880, carry) - r[3], carry = bits.Add64(r[3], 1883307231910630287, carry) - r[4], carry = bits.Add64(r[4], 14284016967150029115, carry) - r[5], _ = bits.Add64(r[5], 121098312706494698, carry) + { + const curI = 3 % 2 + const prevI = 1 - curI + const iMinusOne = 3 - 1 - } + h[curI], s[curI] = bits.Mul64(x[3], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (u[0] == 1) && (u[5]|u[4]|u[3]|u[2]|u[1]) == 0 { - z.Set(&r) - return z + + { + const curI = 4 % 2 + const prevI = 1 - curI + const iMinusOne = 4 - 1 + + h[curI], s[curI] = bits.Mul64(x[4], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } - if (v[0] == 1) && (v[5]|v[4]|v[3]|v[2]|v[1]) == 0 { - z.Set(&s) - return z + + { + const curI = 5 % 2 + const prevI = 1 - curI + const iMinusOne = 5 - 1 + + h[curI], s[curI] = bits.Mul64(x[5], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] } + { + const curI = 6 % 2 + const prevI = 1 - curI + const iMinusOne = 5 + + s[curI], _ = bits.Sub64(h[prevI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + + return s[curI] + c } +}*/ + +// Requires NoCarry +func (z *Element) linearCombNonModular(x *Element, xC int64, y *Element, yC int64) uint64 { + var yTimes Element + + yHi := yTimes.mulWRegular(y, yC) + xHi := z.mulWRegular(x, xC) + + carry := uint64(0) + z[0], carry = bits.Add64(z[0], yTimes[0], carry) + z[1], carry = bits.Add64(z[1], yTimes[1], carry) + z[2], carry = bits.Add64(z[2], yTimes[2], carry) + z[3], carry = bits.Add64(z[3], yTimes[3], carry) + z[4], carry = bits.Add64(z[4], yTimes[4], carry) + z[5], carry = bits.Add64(z[5], yTimes[5], carry) + + yHi, _ = bits.Add64(xHi, yHi, carry) + return yHi } diff --git a/ecc/bw6-761/fr/element_test.go b/ecc/bw6-761/fr/element_test.go index 5cd2173564..ebadf1fb00 100644 --- a/ecc/bw6-761/fr/element_test.go +++ b/ecc/bw6-761/fr/element_test.go @@ -22,6 +22,7 @@ import ( "fmt" "math/big" "math/bits" + mrand "math/rand" "testing" "github.com/leanovate/gopter" @@ -275,7 +276,7 @@ var staticTestValues []Element func init() { staticTestValues = append(staticTestValues, Element{}) // zero staticTestValues = append(staticTestValues, One()) // one - staticTestValues = append(staticTestValues, rSquare) // r^2 + staticTestValues = append(staticTestValues, rSquare) // r² var e, one Element one.SetOne() e.Sub(&qElement, &one) @@ -1990,3 +1991,504 @@ func genFull() gopter.Gen { return genResult } } + +func TestElementInversionApproximation(t *testing.T) { + var x Element + for i := 0; i < 1000; i++ { + x.SetRandom() + + // Normally small elements are unlikely. Here we give them a higher chance + xZeros := mrand.Int() % Limbs + for j := 1; j < xZeros; j++ { + x[Limbs-j] = 0 + } + + a := approximate(&x, x.BitLen()) + aRef := approximateRef(&x) + + if a != aRef { + t.Error("Approximation mismatch") + } + } +} + +func TestElementInversionCorrectionFactorFormula(t *testing.T) { + const kLimbs = k * Limbs + const power = kLimbs*6 + invIterationsN*(kLimbs-k+1) + factorInt := big.NewInt(1) + factorInt.Lsh(factorInt, power) + factorInt.Mod(factorInt, Modulus()) + + var refFactorInt big.Int + inversionCorrectionFactor := Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + inversionCorrectionFactorWord5, + } + inversionCorrectionFactor.ToBigInt(&refFactorInt) + + if refFactorInt.Cmp(factorInt) != 0 { + t.Error("mismatch") + } +} + +func TestElementLinearComb(t *testing.T) { + var x Element + var y Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + y.SetRandom() + testLinearComb(t, &x, mrand.Int63(), &y, mrand.Int63()) + } +} + +// Probably unnecessary post-dev. In case the output of inv is wrong, this checks whether it's only off by a constant factor. +func TestElementInversionCorrectionFactor(t *testing.T) { + + // (1/x)/inv(x) = (1/1)/inv(1) ⇔ inv(1) = x inv(x) + + var one Element + var oneInv Element + one.SetOne() + oneInv.Inverse(&one) + + for i := 0; i < 100; i++ { + var x Element + var xInv Element + x.SetRandom() + xInv.Inverse(&x) + + x.Mul(&x, &xInv) + if !x.Equal(&oneInv) { + t.Error("Correction factor is inconsistent") + } + } + + if !oneInv.Equal(&one) { + var i big.Int + oneInv.ToBigIntRegular(&i) // no montgomery + i.ModInverse(&i, Modulus()) + var fac Element + fac.setBigInt(&i) // back to montgomery + + var facTimesFac Element + facTimesFac.Mul(&fac, &Element{ + inversionCorrectionFactorWord0, + inversionCorrectionFactorWord1, + inversionCorrectionFactorWord2, + inversionCorrectionFactorWord3, + inversionCorrectionFactorWord4, + inversionCorrectionFactorWord5, + }) + + t.Error("Correction factor is consistently off by", fac, "Should be", facTimesFac) + } +} + +func TestElementBigNumNeg(t *testing.T) { + var a Element + aHi := a.neg(&a, 0) + if !a.IsZero() || aHi != 0 { + t.Error("-0 != 0") + } +} + +func TestElementBigNumWMul(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + w := mrand.Int63() + testBigNumWMul(t, &x, w) + } +} + +func TestElementVeryBigIntConversion(t *testing.T) { + xHi := mrand.Uint64() + var x Element + x.SetRandom() + var xInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + x.assertMatchVeryBigInt(t, xHi, &xInt) +} + +func TestElementMontReducePos(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64() & ^signBitSelector) + } +} + +func TestElementMontReduceNeg(t *testing.T) { + var x Element + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64()|signBitSelector) + } +} + +func TestElementMontNegMultipleOfR(t *testing.T) { + var zero Element + + for i := 0; i < 1000; i++ { + testMontReduceSigned(t, &zero, mrand.Uint64()|signBitSelector) + } +} + +//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +func TestUpdateFactorSubtraction(t *testing.T) { + for i := 0; i < 1000; i++ { + + f0, g0 := randomizeUpdateFactors() + f1, g1 := randomizeUpdateFactors() + + for f0-f1 > 1<<31 || f0-f1 <= -1<<31 { + f1 /= 2 + } + + for g0-g1 > 1<<31 || g0-g1 <= -1<<31 { + g1 /= 2 + } + + c0 := updateFactorsCompose(f0, g0) + c1 := updateFactorsCompose(f1, g1) + + cRes := c0 - c1 + fRes, gRes := updateFactorsDecompose(cRes) + + if fRes != f0-f1 || gRes != g0-g1 { + t.Error(i) + } + } +} + +func TestUpdateFactorsDouble(t *testing.T) { + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f > 1<<30 || f < (-1<<31+1)/2 { + f /= 2 + if g <= 1<<29 && g >= (-1<<31+1)/4 { + g *= 2 //g was kept small on f's account. Now that we're halving f, we can double g + } + } + + if g > 1<<30 || g < (-1<<31+1)/2 { + g /= 2 + + if f <= 1<<29 && f >= (-1<<31+1)/4 { + f *= 2 //f was kept small on g's account. Now that we're halving g, we can double f + } + } + + c := updateFactorsCompose(f, g) + cD := c * 2 + fD, gD := updateFactorsDecompose(cD) + + if fD != 2*f || gD != 2*g { + t.Error(i) + } + } +} + +func TestUpdateFactorsNeg(t *testing.T) { + var fMistake bool + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f == 0x80000000 || g == 0x80000000 { + // Update factors this large can only have been obtained after 31 iterations and will therefore never be negated + // We don't have capacity to store -2³¹ + // Repeat this iteration + i-- + continue + } + + c := updateFactorsCompose(f, g) + nc := -c + nf, ng := updateFactorsDecompose(nc) + fMistake = fMistake || nf != -f + if nf != -f || ng != -g { + t.Errorf("Mismatch iteration #%d:\n%d, %d ->\n %d -> %d ->\n %d, %d\n Inputs in hex: %X, %X", + i, f, g, c, nc, nf, ng, f, g) + } + } + if fMistake { + t.Error("Mistake with f detected") + } else { + t.Log("All good with f") + } +} + +func TestUpdateFactorsNeg0(t *testing.T) { + c := updateFactorsCompose(0, 0) + t.Logf("c(0,0) = %X", c) + cn := -c + + if c != cn { + t.Error("Negation of zero update factors should yield the same result.") + } +} + +func TestUpdateFactorDecomposition(t *testing.T) { + var negSeen bool + + for i := 0; i < 1000; i++ { + + f, g := randomizeUpdateFactors() + + if f <= -(1<<31) || f > 1<<31 { + t.Fatal("f out of range") + } + + negSeen = negSeen || f < 0 + + c := updateFactorsCompose(f, g) + + fBack, gBack := updateFactorsDecompose(c) + + if f != fBack || g != gBack { + t.Errorf("(%d, %d) -> %d -> (%d, %d)\n", f, g, c, fBack, gBack) + } + } + + if !negSeen { + t.Fatal("No negative f factors") + } +} + +func TestUpdateFactorInitialValues(t *testing.T) { + + f0, g0 := updateFactorsDecompose(updateFactorIdentityMatrixRow0) + f1, g1 := updateFactorsDecompose(updateFactorIdentityMatrixRow1) + + if f0 != 1 || g0 != 0 || f1 != 0 || g1 != 1 { + t.Error("Update factor initial value constants are incorrect") + } +} + +func TestUpdateFactorsRandomization(t *testing.T) { + var maxLen int + + //t.Log("|f| + |g| is not to exceed", 1 << 31) + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + lf, lg := abs64T32(f), abs64T32(g) + absSum := lf + lg + if absSum >= 1<<31 { + + if absSum == 1<<31 { + maxLen++ + } else { + t.Error(i, "Sum of absolute values too large, f =", f, ",g =", g, ",|f| + |g| =", absSum) + } + } + } + + if maxLen == 0 { + t.Error("max len not observed") + } else { + t.Log(maxLen, "maxLens observed") + } +} + +func randomizeUpdateFactor(absLimit uint32) int64 { + const maxSizeLikelihood = 10 + maxSize := mrand.Intn(maxSizeLikelihood) + + absLimit64 := int64(absLimit) + var f int64 + switch maxSize { + case 0: + f = absLimit64 + case 1: + f = -absLimit64 + default: + f = int64(mrand.Uint64()%(2*uint64(absLimit64)+1)) - absLimit64 + } + + if f > 1<<31 { + return 1 << 31 + } else if f < -1<<31+1 { + return -1<<31 + 1 + } + + return f +} + +func abs64T32(f int64) uint32 { + if f >= 1<<32 || f < -1<<32 { + panic("f out of range") + } + + if f < 0 { + return uint32(-f) + } + return uint32(f) +} + +func randomizeUpdateFactors() (int64, int64) { + var f [2]int64 + b := mrand.Int() % 2 + + f[b] = randomizeUpdateFactor(1 << 31) + + //As per the paper, |f| + |g| \le 2³¹. + f[1-b] = randomizeUpdateFactor(1<<31 - abs64T32(f[b])) + + //Patching another edge case + if f[0]+f[1] == -1<<31 { + b = mrand.Int() % 2 + f[b]++ + } + + return f[0], f[1] +} + +func testLinearComb(t *testing.T, x *Element, xC int64, y *Element, yC int64) { + + var p1 big.Int + x.ToBigInt(&p1) + p1.Mul(&p1, big.NewInt(xC)) + + var p2 big.Int + y.ToBigInt(&p2) + p2.Mul(&p2, big.NewInt(yC)) + + p1.Add(&p1, &p2) + p1.Mod(&p1, Modulus()) + montReduce(&p1, &p1) + + var z Element + z.linearCombSosSigned(x, xC, y, yC) + z.assertMatchVeryBigInt(t, 0, &p1) +} + +func testBigNumWMul(t *testing.T, a *Element, c int64) { + var aHi uint64 + var aTimes Element + aHi = aTimes.mulWRegular(a, c) + + assertMulProduct(t, a, c, &aTimes, aHi) +} + +func testMontReduceSigned(t *testing.T, x *Element, xHi uint64) { + var res Element + var xInt big.Int + var resInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + res.montReduceSigned(x, xHi) + montReduce(&resInt, &xInt) + res.assertMatchVeryBigInt(t, 0, &resInt) +} + +func updateFactorsCompose(f int64, g int64) int64 { + return f + g<<32 +} + +var rInv big.Int + +func montReduce(res *big.Int, x *big.Int) { + if rInv.BitLen() == 0 { // initialization + rInv.SetUint64(1) + rInv.Lsh(&rInv, Limbs*64) + rInv.ModInverse(&rInv, Modulus()) + } + res.Mul(x, &rInv) + res.Mod(res, Modulus()) +} + +func (z *Element) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { + z.ToBigInt(i) + var upperWord big.Int + upperWord.SetUint64(xHi) + upperWord.Lsh(&upperWord, Limbs*64) + i.Add(&upperWord, i) +} + +func (z *Element) toVeryBigIntSigned(i *big.Int, xHi uint64) { + z.toVeryBigIntUnsigned(i, xHi) + if signBitSelector&xHi != 0 { + twosCompModulus := big.NewInt(1) + twosCompModulus.Lsh(twosCompModulus, (Limbs+1)*64) + i.Sub(i, twosCompModulus) + } +} + +func assertMulProduct(t *testing.T, x *Element, c int64, result *Element, resultHi uint64) big.Int { + var xInt big.Int + x.ToBigInt(&xInt) + + xInt.Mul(&xInt, big.NewInt(c)) + + result.assertMatchVeryBigInt(t, resultHi, &xInt) + return xInt +} + +func assertMatch(t *testing.T, w []big.Word, a uint64, index int) { + + var wI big.Word + + if index < len(w) { + wI = w[index] + } + + const filter uint64 = 0xFFFFFFFFFFFFFFFF >> (64 - bits.UintSize) + + a = a >> ((index * bits.UintSize) % 64) + a &= filter + + if uint64(wI) != a { + t.Error("Bignum mismatch: disagreement on word", index) + } +} + +func (z *Element) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { + + var modulus big.Int + var aIntMod big.Int + modulus.SetInt64(1) + modulus.Lsh(&modulus, (Limbs+1)*64) + aIntMod.Mod(aInt, &modulus) + + words := aIntMod.Bits() + + const steps = 64 / bits.UintSize + for i := 0; i < Limbs*steps; i++ { + assertMatch(t, words, z[i/steps], i) + } + + for i := 0; i < steps; i++ { + assertMatch(t, words, aHi, Limbs*steps+i) + } +} + +func approximateRef(x *Element) uint64 { + + var asInt big.Int + x.ToBigInt(&asInt) + n := x.BitLen() + + if n <= 64 { + return asInt.Uint64() + } + + modulus := big.NewInt(1 << 31) + var lo big.Int + lo.Mod(&asInt, modulus) + + modulus.Lsh(modulus, uint(n-64)) + var hi big.Int + hi.Div(&asInt, modulus) + hi.Lsh(&hi, 31) + + hi.Add(&hi, &lo) + return hi.Uint64() +} diff --git a/field/field.go b/field/field.go index 39c3675f44..7597a663d7 100644 --- a/field/field.go +++ b/field/field.go @@ -29,44 +29,47 @@ var ( // Field precomputed values used in template for code generation of field element APIs type Field struct { - PackageName string - ElementName string - ModulusBig *big.Int - Modulus string - ModulusHex string - NbWords int - NbBits int - NbWordsLastIndex int - NbWordsIndexesNoZero []int - NbWordsIndexesFull []int - Q []uint64 - QInverse []uint64 - QMinusOneHalvedP []uint64 // ((q-1) / 2 ) + 1 - ASM bool - RSquare []uint64 - One []uint64 - LegendreExponent string // big.Int to base16 string - NoCarry bool - NoCarrySquare bool // used if NoCarry is set, but some op may overflow in square optimization - SqrtQ3Mod4 bool - SqrtAtkin bool - SqrtTonelliShanks bool - SqrtE uint64 - SqrtS []uint64 - SqrtAtkinExponent string // big.Int to base16 string - SqrtSMinusOneOver2 string // big.Int to base16 string - SqrtQ3Mod4Exponent string // big.Int to base16 string - SqrtG []uint64 // NonResidue ^ SqrtR (montgomery form) - NonResidue []uint64 // (montgomery form) - - LegendreExponentData *addchain.AddChainData - SqrtAtkinExponentData *addchain.AddChainData - SqrtSMinusOneOver2Data *addchain.AddChainData - SqrtQ3Mod4ExponentData *addchain.AddChainData - UseAddChain bool + PackageName string + ElementName string + ModulusBig *big.Int + Modulus string + ModulusHex string + NbWords int + NbBits int + NbWordsLastIndex int + NbWordsIndexesNoZero []int + NbWordsIndexesFull []int + NbWordsIndexesNoLast []int + NbWordsIndexesNoZeroNoLast []int + P20InversionCorrectiveFac []uint64 + P20InversionNbIterations int + Q []uint64 + QInverse []uint64 + QMinusOneHalvedP []uint64 // ((q-1) / 2 ) + 1 + ASM bool + RSquare []uint64 + One []uint64 + LegendreExponent string // big.Int to base16 string + NoCarry bool + NoCarrySquare bool // used if NoCarry is set, but some op may overflow in square optimization + SqrtQ3Mod4 bool + SqrtAtkin bool + SqrtTonelliShanks bool + SqrtE uint64 + SqrtS []uint64 + SqrtAtkinExponent string // big.Int to base16 string + SqrtSMinusOneOver2 string // big.Int to base16 string + SqrtQ3Mod4Exponent string // big.Int to base16 string + SqrtG []uint64 // NonResidue ^ SqrtR (montgomery form) + NonResidue []uint64 // (montgomery form) + LegendreExponentData *addchain.AddChainData + SqrtAtkinExponentData *addchain.AddChainData + SqrtSMinusOneOver2Data *addchain.AddChainData + SqrtQ3Mod4ExponentData *addchain.AddChainData + UseAddChain bool } -// NewField returns a data structure with needed informations to generate apis for field element +// NewField returns a data structure with needed information to generate apis for field element // // See field/generator package func NewField(packageName, elementName, modulus string, useAddChain bool) (*Field, error) { @@ -110,7 +113,22 @@ func NewField(packageName, elementName, modulus string, useAddChain bool) (*Fiel _qInv.Mod(_qInv, _r) F.QInverse = toUint64Slice(_qInv, F.NbWords) - // rsquare + // Pornin20 inversion correction factors + k := 32 // Optimized for 64 bit machines, still works for 32 + + p20InvInnerLoopNbIterations := 2*F.NbBits - 1 + // if constant time inversion then p20InvInnerLoopNbIterations-- (among other changes) + F.P20InversionNbIterations = (p20InvInnerLoopNbIterations-1)/(k-1) + 1 // ⌈ (2 * field size - 1) / (k-1) ⌉ + F.P20InversionNbIterations += F.P20InversionNbIterations % 2 // "round up" to a multiple of 2 + + kLimbs := k * F.NbWords + p20InversionCorrectiveFacPower := kLimbs*6 + F.P20InversionNbIterations*(kLimbs-k+1) + p20InversionCorrectiveFac := big.NewInt(1) + p20InversionCorrectiveFac.Lsh(p20InversionCorrectiveFac, uint(p20InversionCorrectiveFacPower)) + p20InversionCorrectiveFac.Mod(p20InversionCorrectiveFac, &bModulus) + F.P20InversionCorrectiveFac = toUint64Slice(p20InversionCorrectiveFac, F.NbWords) + + // rsquare _rSquare := big.NewInt(2) exponent := big.NewInt(int64(F.NbWords) * 64 * 2) _rSquare.Exp(_rSquare, exponent, &bModulus) @@ -124,11 +142,19 @@ func NewField(packageName, elementName, modulus string, useAddChain bool) (*Fiel // indexes (template helpers) F.NbWordsIndexesFull = make([]int, F.NbWords) F.NbWordsIndexesNoZero = make([]int, F.NbWords-1) + F.NbWordsIndexesNoLast = make([]int, F.NbWords-1) + F.NbWordsIndexesNoZeroNoLast = make([]int, F.NbWords-2) for i := 0; i < F.NbWords; i++ { F.NbWordsIndexesFull[i] = i if i > 0 { F.NbWordsIndexesNoZero[i-1] = i } + if i != F.NbWords-1 { + F.NbWordsIndexesNoLast[i] = i + if i > 0 { + F.NbWordsIndexesNoZeroNoLast[i-1] = i + } + } } // See https://hackmd.io/@zkteam/modular_multiplication @@ -136,7 +162,7 @@ func NewField(packageName, elementName, modulus string, useAddChain bool) (*Fiel // we can simplify the montgomery multiplication const B = (^uint64(0) >> 1) - 1 F.NoCarry = (F.Q[len(F.Q)-1] <= B) && F.NbWords <= 12 - const BSquare = (^uint64(0) >> 2) + const BSquare = ^uint64(0) >> 2 F.NoCarrySquare = F.Q[len(F.Q)-1] <= BSquare // Legendre exponent (p-1)/2 @@ -184,7 +210,7 @@ func NewField(packageName, elementName, modulus string, useAddChain bool) (*Fiel // use Tonelli-Shanks F.SqrtTonelliShanks = true - // Write q-1 =2^e * s , s odd + // Write q-1 =2ᵉ * s , s odd var s big.Int one.SetUint64(1) s.Sub(&bModulus, &one) diff --git a/field/generator/generator.go b/field/generator/generator.go index 13e7063706..56d2075bcc 100644 --- a/field/generator/generator.go +++ b/field/generator/generator.go @@ -36,6 +36,7 @@ func GenerateFF(F *field.Field, outputDir string) error { element.MulNoCarry, element.Sqrt, element.Inverse, + element.BigNum, } // test file templates @@ -44,8 +45,8 @@ func GenerateFF(F *field.Field, outputDir string) error { element.MulNoCarry, element.Reduce, element.Test, + element.InverseTests, } - // output files eName := strings.ToLower(F.ElementName) diff --git a/field/internal/templates/element/base.go b/field/internal/templates/element/base.go index d8ef75d0cf..f2f8d7b0a6 100644 --- a/field/internal/templates/element/base.go +++ b/field/internal/templates/element/base.go @@ -49,11 +49,18 @@ func Modulus() *big.Int { } // q (modulus) +{{- range $i := $.NbWordsIndexesFull}} +const q{{$.ElementName}}Word{{$i}} uint64 = {{index $.Q $i}} +{{- end}} + var q{{.ElementName}} = {{.ElementName}}{ - {{- range $i := .NbWordsIndexesFull}} - {{index $.Q $i}},{{end}} + {{- range $i := $.NbWordsIndexesFull}} + q{{$.ElementName}}Word{{$i}},{{end}} } +// Used for Montgomery reduction. (qInvNeg) q + r'.r = 1, i.e., qInvNeg = - q⁻¹ mod r +const qInvNegLsw uint64 = {{index .QInverse 0}} + // rSquare var rSquare = {{.ElementName}}{ {{- range $i := .RSquare}} @@ -169,7 +176,7 @@ func (z *{{.ElementName}}) IsZero() bool { return ( {{- range $i := reverse .NbWordsIndexesNoZero}} z[{{$i}}] | {{end}}z[0]) == 0 } -// IsUint64 returns true if z[0] >= 0 and all other words are 0 +// IsUint64 returns true if z[0] ⩾ 0 and all other words are 0 func (z *{{.ElementName}}) IsUint64() bool { return ( {{- range $i := reverse .NbWordsIndexesNoZero}} z[{{$i}}] {{- if ne $i 1}}|{{- end}} {{end}}) == 0 } @@ -319,6 +326,11 @@ func _mulGeneric(z,x,y *{{.ElementName}}) { {{ template "reduce" . }} } +func _mulWGeneric(z,x *{{.ElementName}}, y uint64) { + {{ template "mul_nocarry_v2" dict "all" . "V2" "x"}} + {{ template "reduce" . }} +} + func _fromMontGeneric(z *{{.ElementName}}) { // the following lines implement z = z * 1 @@ -513,7 +525,28 @@ func (z *{{.ElementName}}) BitLen() int { return bits.Len64(z[0]) } - +{{ define "add_q" }} + // {{$.V1}} = {{$.V1}} + q + {{$.V1}}[0], carry = bits.Add64({{$.V1}}[0], {{index $.all.Q 0}}, 0) + {{- range $i := .all.NbWordsIndexesNoZero}} + {{- if eq $i $.all.NbWordsLastIndex}} + {{$.V1}}[{{$i}}], _ = bits.Add64({{$.V1}}[{{$i}}], {{index $.all.Q $i}}, carry) + {{- else}} + {{$.V1}}[{{$i}}], carry = bits.Add64({{$.V1}}[{{$i}}], {{index $.all.Q $i}}, carry) + {{- end}} + {{- end}} +{{ end }} + +{{ define "rsh V nbWords" }} + // {{$.V}} = {{$.V}} >> 1 + {{$lastIndex := sub .nbWords 1}} + {{- range $i := iterate .nbWords}} + {{- if ne $i $lastIndex}} + {{$.V}}[{{$i}}] = {{$.V}}[{{$i}}] >> 1 | {{$.V}}[{{(add $i 1)}}] << 63 + {{- end}} + {{- end}} + {{$.V}}[{{$lastIndex}}] >>= 1 +{{ end }} ` diff --git a/field/internal/templates/element/bignum.go b/field/internal/templates/element/bignum.go new file mode 100644 index 0000000000..e6d02d263a --- /dev/null +++ b/field/internal/templates/element/bignum.go @@ -0,0 +1,103 @@ +package element + +const BigNum = ` + +{{/* Only used for the Pornin Extended GCD Inverse Algorithm*/}} +{{if eq .NoCarry true}} + +func (z *{{.ElementName}}) neg(x *{{.ElementName}}, xHi uint64) uint64 { + var b uint64 + + z[0], b = bits.Sub64(0, x[0], 0) + {{- range $i := .NbWordsIndexesNoZero}} + z[{{$i}}], b = bits.Sub64(0, x[{{$i}}], b) + {{- end}} + xHi, _ = bits.Sub64(0, xHi, b) + + return xHi +} + +// regular multiplication by one word regular (non montgomery) +// Fewer additions than the branch-free for positive y. Could be faster on some architectures +func (z *{{.ElementName}}) mulWRegular(x *{{.ElementName}}, y int64) uint64 { + + // w := abs(y) + m := y >> 63 + w := uint64((y^m)-m) + + var c uint64 + c, z[0] = bits.Mul64(x[0], w) + {{- range $i := .NbWordsIndexesNoZero }} + c, z[{{$i}}] = madd1(x[{{$i}}], w, c) + {{- end}} + + if y < 0 { + c = z.neg(z, c) + } + + return c +} + +/* +Removed: seems slower +// mulWRegular branch-free regular multiplication by one word (non montgomery) +func (z *{{.ElementName}}) mulWRegularBf(x *{{.ElementName}}, y int64) uint64 { + + w := uint64(y) + allNeg := uint64(y >> 63) // -1 if y < 0, 0 o.w + + // s[0], s[1] so results are not stored immediately in z. + // x[i] will be needed in the i+1 th iteration. We don't want to overwrite it in case x = z + var s [2]uint64 + var h [2]uint64 + + h[0], s[0] = bits.Mul64(x[0], w) + + c := uint64(0) + b := uint64(0) + + {{- range $i := .NbWordsIndexesNoZero}} + + { + const curI = {{$i}} % 2 + const prevI = 1 - curI + const iMinusOne = {{$i}} - 1 + + h[curI], s[curI] = bits.Mul64(x[{{$i}}], w) + s[curI], c = bits.Add64(s[curI], h[prevI], c) + s[curI], b = bits.Sub64(s[curI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + } + {{- end}} + { + const curI = {{.NbWords}} % 2 + const prevI = 1 - curI + const iMinusOne = {{.NbWordsLastIndex}} + + s[curI], _ = bits.Sub64(h[prevI], allNeg & x[iMinusOne], b) + z[iMinusOne] = s[prevI] + + return s[curI] + c + } +}*/ + +// Requires NoCarry +func (z *{{.ElementName}}) linearCombNonModular(x *{{.ElementName}}, xC int64, y *{{.ElementName}}, yC int64) uint64 { + var yTimes {{.ElementName}} + + yHi := yTimes.mulWRegular(y, yC) + xHi := z.mulWRegular(x, xC) + + carry := uint64(0) + + {{- range $i := .NbWordsIndexesFull}} + z[{{$i}}], carry = bits.Add64(z[{{$i}}], yTimes[{{$i}}], carry) + {{- end}} + + yHi, _ = bits.Add64(xHi, yHi, carry) + + return yHi +} + +{{- end}} +` diff --git a/field/internal/templates/element/conv.go b/field/internal/templates/element/conv.go index 2c21905236..5c606e3223 100644 --- a/field/internal/templates/element/conv.go +++ b/field/internal/templates/element/conv.go @@ -3,7 +3,7 @@ package element const Conv = ` // ToMont converts z to Montgomery form -// sets and returns z = z * r^2 +// sets and returns z = z * r² func (z *{{.ElementName}}) ToMont() *{{.ElementName}} { return z.Mul(z, &rSquare) } @@ -70,7 +70,6 @@ func (z {{.ElementName}}) ToBigIntRegular(res *big.Int) *big.Int { return z.ToBigInt(res) } - // Bytes returns the regular (non montgomery) value // of z as a big-endian byte array. func (z *{{.ElementName}}) Bytes() (res [Limbs*8]byte) { @@ -141,7 +140,7 @@ func (z *{{.ElementName}}) SetBigInt(v *big.Int) *{{.ElementName}} { return z } -// setBigInt assumes 0 <= v < q +// setBigInt assumes 0 ⩽ v < q func (z *{{.ElementName}}) setBigInt(v *big.Int) *{{.ElementName}} { vBits := v.Bits() diff --git a/field/internal/templates/element/inverse.go b/field/internal/templates/element/inverse.go index 6341d54542..938c110662 100644 --- a/field/internal/templates/element/inverse.go +++ b/field/internal/templates/element/inverse.go @@ -5,7 +5,7 @@ const Inverse = ` {{/* We use big.Int for Inverse for these type of moduli */}} {{if eq .NoCarry false}} -// Inverse z = x^-1 mod q +// Inverse z = x⁻¹ mod q // note: allocates a big.Int (math/big) func (z *{{.ElementName}}) Inverse( x *{{.ElementName}}) *{{.ElementName}} { var _xNonMont big.Int @@ -15,124 +15,301 @@ func (z *{{.ElementName}}) Inverse( x *{{.ElementName}}) *{{.ElementName}} { return z } - {{ else }} -// Inverse z = x^-1 mod q -// Algorithm 16 in "Efficient Software-Implementation of Finite Fields with Applications to Cryptography" -// if x == 0, sets and returns z = x +func max(a int, b int) int { + if a > b { + return a + } + return b +} + +func min(a int, b int) int { + if a < b { + return a + } + return b +} + +const updateFactorsConversionBias int64 = 0x7fffffff7fffffff // (2³¹ - 1)(2³² + 1) +const updateFactorIdentityMatrixRow0 = 1 +const updateFactorIdentityMatrixRow1 = 1 << 32 + +func updateFactorsDecompose(c int64) (int64, int64) { + c += updateFactorsConversionBias + const low32BitsFilter int64 = 0xFFFFFFFF + f := c&low32BitsFilter - 0x7FFFFFFF + g := c>>32&low32BitsFilter - 0x7FFFFFFF + return f, g +} + +const k = 32 // word size / 2 +const signBitSelector = uint64(1) << 63 +const approxLowBitsN = k - 1 +const approxHighBitsN = k + 1 + +{{- range $i := .NbWordsIndexesFull}} +const inversionCorrectionFactorWord{{$i}} = {{index $.P20InversionCorrectiveFac $i}} +{{- end}} + +const invIterationsN = {{.P20InversionNbIterations}} + +// Inverse z = x⁻¹ mod q +// Implements "Optimized Binary GCD for Modular Inversion" +// https://github.com/pornin/bingcd/blob/main/doc/bingcd.pdf func (z *{{.ElementName}}) Inverse(x *{{.ElementName}}) *{{.ElementName}} { if x.IsZero() { z.SetZero() return z } - // initialize u = q - var u = {{.ElementName}}{ + a := *x + b := {{.ElementName}} { {{- range $i := .NbWordsIndexesFull}} - {{index $.Q $i}},{{end}} - } + q{{$.ElementName}}Word{{$i}},{{end}} + } // b := q - // initialize s = r^2 - var s = {{.ElementName}}{ - {{- range $i := .RSquare}} - {{$i}},{{end}} - } + u := {{.ElementName}}{1} - // r = 0 - r := {{.ElementName}}{} + // Update factors: we get [u; v]:= [f0 g0; f1 g1] [u; v] + // c_i = f_i + 2³¹ - 1 + 2³² * (g_i + 2³¹ - 1) + var c0, c1 int64 - v := *x + // Saved update factors to reduce the number of field multiplications + var pf0, pf1, pg0, pg1 int64 - var carry, borrow uint64 - var bigger bool + var i uint - for { - for v[0]&1 == 0 { - {{ rsh "v" .NbWords}} - if s[0]&1 == 1 { - {{ template "add_q" dict "all" . "V1" "s" }} - } - {{ rsh "s" .NbWords}} - } - for u[0]&1 == 0 { - {{ rsh "u" .NbWords}} - if r[0]&1 == 1 { - {{ template "add_q" dict "all" . "V1" "r" }} - } - {{ rsh "r" .NbWords}} - } - {{ template "bigger" dict "all" . "V1" "v" "V2" "u"}} - if bigger { - {{ template "sub_noborrow" dict "all" . "V1" "v" "V2" "u" "OmitLast" "true"}} - {{ template "sub_noborrow" dict "all" . "V1" "s" "V2" "r" "OmitLast" "false"}} - if borrow == 1 { - {{ template "add_q" dict "all" . "V1" "s" }} - } - } else { - {{ template "sub_noborrow" dict "all" . "V1" "u" "V2" "v" "OmitLast" "true"}} - {{ template "sub_noborrow" dict "all" . "V1" "r" "V2" "s" "OmitLast" "false"}} - if borrow == 1 { - {{ template "add_q" dict "all" . "V1" "r" }} + var v, s {{.ElementName}} + + // Since u,v are updated every other iteration, we must make sure we terminate after evenly many iterations + // This also lets us get away with half as many updates to u,v + // To make this constant-time-ish, replace the condition with i < invIterationsN + for i = 0; i&1 == 1 || !a.IsZero(); i++ { + n := max(a.BitLen(), b.BitLen()) + aApprox, bApprox := approximate(&a, n), approximate(&b, n) + + // After 0 iterations, we have f₀ ≤ 2⁰ and f₁ < 2⁰ + // f0, g0, f1, g1 = 1, 0, 0, 1 + c0, c1 = updateFactorIdentityMatrixRow0, updateFactorIdentityMatrixRow1 + + for j := 0; j < approxLowBitsN; j++ { + + if aApprox&1 == 0 { + aApprox /= 2 + } else { + s, borrow := bits.Sub64(aApprox, bApprox, 0) + if borrow == 1 { + s = bApprox - aApprox + bApprox = aApprox + c0, c1 = c1, c0 + } + + aApprox = s / 2 + c0 = c0 - c1 + + // Now |f₀| < 2ʲ + 2ʲ = 2ʲ⁺¹ + // |f₁| ≤ 2ʲ still } + + c1 *= 2 + // |f₁| ≤ 2ʲ⁺¹ } - if (u[0] == 1) && ({{- range $i := reverse .NbWordsIndexesNoZero}}u[{{$i}}] {{if eq $i 1}}{{else}} | {{end}}{{end}} ) == 0 { - z.Set(&r) - return z + + s = a + + var g0 int64 + // from this point on c0 aliases for f0 + c0, g0 = updateFactorsDecompose(c0) + aHi := a.linearCombNonModular(&s, c0, &b, g0) + if aHi & signBitSelector != 0 { + // if aHi < 0 + c0, g0 = -c0, -g0 + aHi = a.neg(&a, aHi) } - if (v[0] == 1) && ({{- range $i := reverse .NbWordsIndexesNoZero}}v[{{$i}}] {{if eq $i 1}}{{else}} | {{end}}{{end}} ) == 0 { - z.Set(&s) - return z + // right-shift a by k-1 bits + + {{- range $i := .NbWordsIndexesFull}} + {{- if eq $i $.NbWordsLastIndex}} + a[{{$i}}] = (a[{{$i}}] >> approxLowBitsN) | (aHi << approxHighBitsN) + {{- else }} + a[{{$i}}] = (a[{{$i}}] >> approxLowBitsN) | ((a[{{add $i 1}}]) << approxHighBitsN) + {{- end}} + {{- end}} + + var f1 int64 + // from this point on c1 aliases for g0 + f1, c1 = updateFactorsDecompose(c1) + bHi := b.linearCombNonModular(&s, f1, &b, c1) + if bHi & signBitSelector != 0 { + // if bHi < 0 + f1, c1 = -f1, -c1 + bHi = b.neg(&b, bHi) } + // right-shift b by k-1 bits + + {{- range $i := .NbWordsIndexesFull}} + {{- if eq $i $.NbWordsLastIndex}} + b[{{$i}}] = (b[{{$i}}] >> approxLowBitsN) | (bHi << approxHighBitsN) + {{- else }} + b[{{$i}}] = (b[{{$i}}] >> approxLowBitsN) | ((b[{{add $i 1}}]) << approxHighBitsN) + {{- end}} + {{- end}} + + if i&1 == 1 { + // Combine current update factors with previously stored ones + // [f₀, g₀; f₁, g₁] ← [f₀, g₀; f₁, g₀] [pf₀, pg₀; pf₀, pg₀] + // We have |f₀|, |g₀|, |pf₀|, |pf₁| ≤ 2ᵏ⁻¹, and that |pf_i| < 2ᵏ⁻¹ for i ∈ {0, 1} + // Then for the new value we get |f₀| < 2ᵏ⁻¹ × 2ᵏ⁻¹ + 2ᵏ⁻¹ × 2ᵏ⁻¹ = 2²ᵏ⁻¹ + // Which leaves us with an extra bit for the sign + + // c0 aliases f0, c1 aliases g1 + c0, g0, f1, c1 = c0*pf0+g0*pf1, + c0*pg0+g0*pg1, + f1*pf0+c1*pf1, + f1*pg0+c1*pg1 + + s = u + u.linearCombSosSigned(&u, c0, &v, g0) + v.linearCombSosSigned(&s, f1, &v, c1) + + } else { + // Save update factors + pf0, pg0, pf1, pg1 = c0, g0, f1, c1 + } + } + + // For every iteration that we miss, v is not being multiplied by 2²ᵏ⁻² + const pSq int64 = 1 << (2 * (k - 1)) + // If the function is constant-time ish, this loop will not run (probably no need to take it out explicitly) + for ; i < invIterationsN; i += 2 { + v.mulWSigned(&v, pSq) } + z.Mul(&v, &{{.ElementName}}{ + {{- range $i := .NbWordsIndexesFull }} + inversionCorrectionFactorWord{{$i}}, + {{- end}} + }) + return z } -{{ end }} +// approximate a big number x into a single 64 bit word using its uppermost and lowermost bits +// if x fits in a word as is, no approximation necessary +func approximate(x *{{.ElementName}}, nBits int) uint64 { + if nBits <= 64 { + return x[0] + } + const mask = (uint64(1) << (k - 1)) - 1 // k-1 ones + lo := mask & x[0] + hiWordIndex := (nBits - 1) / 64 -{{ define "bigger" }} - // {{$.V1}} >= {{$.V2}} - bigger = !({{- range $i := reverse $.all.NbWordsIndexesNoZero}} {{$.V1}}[{{$i}}] < {{$.V2}}[{{$i}}] || ( {{$.V1}}[{{$i}}] == {{$.V2}}[{{$i}}] && ( - {{- end}}{{$.V1}}[0] < {{$.V2}}[0] {{- range $i := $.all.NbWordsIndexesNoZero}} )) {{- end}} ) -{{ end }} + hiWordBitsAvailable := nBits - hiWordIndex * 64 + hiWordBitsUsed := min(hiWordBitsAvailable, approxHighBitsN) -{{ define "add_q" }} - // {{$.V1}} = {{$.V1}} + q - {{$.V1}}[0], carry = bits.Add64({{$.V1}}[0], {{index $.all.Q 0}}, 0) - {{- range $i := .all.NbWordsIndexesNoZero}} - {{- if eq $i $.all.NbWordsLastIndex}} - {{$.V1}}[{{$i}}], _ = bits.Add64({{$.V1}}[{{$i}}], {{index $.all.Q $i}}, carry) - {{- else}} - {{$.V1}}[{{$i}}], carry = bits.Add64({{$.V1}}[{{$i}}], {{index $.all.Q $i}}, carry) - {{- end}} + mask_ := uint64(^((1 << (hiWordBitsAvailable - hiWordBitsUsed)) - 1)) + hi := (x[hiWordIndex] & mask_) << (64 - hiWordBitsAvailable) + + mask_ = ^(1<<(approxLowBitsN + hiWordBitsUsed) - 1) + mid := (mask_ & x[hiWordIndex-1]) >> hiWordBitsUsed + + return lo | mid | hi +} + +func (z *{{.ElementName}}) linearCombSosSigned(x *{{.ElementName}}, xC int64, y *{{.ElementName}}, yC int64) { + hi := z.linearCombNonModular(x, xC, y, yC) + z.montReduceSigned(z, hi) +} + +// montReduceSigned SOS algorithm; xHi must be at most 63 bits long. Last bit of xHi may be used as a sign bit +func (z *{{.ElementName}}) montReduceSigned(x *{{.ElementName}}, xHi uint64) { + + const signBitRemover = ^signBitSelector + neg := xHi & signBitSelector != 0 + // the SOS implementation requires that most significant bit is 0 + // Let X be xHi*r + x + // note that if X is negative we would have initially stored it as 2⁶⁴ r + X + xHi &= signBitRemover + // with this a negative X is now represented as 2⁶³ r + X + + var t [2*Limbs - 1]uint64 + var C uint64 + + m := x[0] * qInvNegLsw + + C = madd0(m, q{{.ElementName}}Word0, x[0]) + {{- range $i := .NbWordsIndexesNoZero}} + C, t[{{$i}}] = madd2(m, q{{$.ElementName}}Word{{$i}}, x[{{$i}}], C) {{- end}} -{{ end }} -{{ define "sub_noborrow" }} - // {{$.V1}} = {{$.V1}} - {{$.V2}} - {{$.V1}}[0], borrow = bits.Sub64({{$.V1}}[0], {{$.V2}}[0], 0) - {{- range $i := .all.NbWordsIndexesNoZero}} - {{- if and (eq $i $.all.NbWordsLastIndex) (eq "true" $.OmitLast)}} - {{$.V1}}[{{$i}}], _ = bits.Sub64({{$.V1}}[{{$i}}], {{$.V2}}[{{$i}}], borrow) - {{- else}} - {{$.V1}}[{{$i}}], borrow = bits.Sub64({{$.V1}}[{{$i}}], {{$.V2}}[{{$i}}], borrow) + // the high word of m * q{{.ElementName}}[{{.NbWordsLastIndex}}] is at most 62 bits + // x[{{.NbWordsLastIndex}}] + C is at most 65 bits (high word at most 1 bit) + // Thus the resulting C will be at most 63 bits + t[{{.NbWords}}] = xHi + C + // xHi and C are 63 bits, therefore no overflow + + {{/* $NbWordsIndexesNoZeroInnerLoop := .NbWordsIndexesNoZero*/}} + {{- range $i := .NbWordsIndexesNoZeroNoLast}} + { + const i = {{$i}} + m = t[i] * qInvNegLsw + + C = madd0(m, q{{$.ElementName}}Word0, t[i+0]) + + {{- range $j := $.NbWordsIndexesNoZero}} + C, t[i + {{$j}}] = madd2(m, q{{$.ElementName}}Word{{$j}}, t[i + {{$j}}], C) {{- end}} + + t[i + Limbs] += C + } {{- end}} -{{ end }} + { + const i = {{.NbWordsLastIndex}} + m := t[i] * qInvNegLsw + C = madd0(m, q{{.ElementName}}Word0, t[i+0]) + {{- range $j := $.NbWordsIndexesNoZeroNoLast}} + C, z[{{sub $j 1}}] = madd2(m, q{{$.ElementName}}Word{{$j}}, t[i+{{$j}}], C) + {{- end}} + z[{{.NbWordsLastIndex}}], z[{{sub .NbWordsLastIndex 1}}] = madd2(m, q{{.ElementName}}Word{{.NbWordsLastIndex}}, t[i+{{.NbWordsLastIndex}}], C) + } -{{ define "rsh V nbWords" }} - // {{$.V}} = {{$.V}} >> 1 - {{$lastIndex := sub .nbWords 1}} - {{- range $i := iterate .nbWords}} - {{- if ne $i $lastIndex}} - {{$.V}}[{{$i}}] = {{$.V}}[{{$i}}] >> 1 | {{$.V}}[{{(add $i 1)}}] << 63 + {{ template "reduce" . }} + if neg { + // We have computed ( 2⁶³ r + X ) r⁻¹ = 2⁶³ + X r⁻¹ instead + var b uint64 + z[0], b = bits.Sub64(z[0], signBitSelector, 0) + + {{- range $i := .NbWordsIndexesNoZero}} + z[{{$i}}], b = bits.Sub64(z[{{$i}}], 0, b) {{- end}} - {{- end}} - {{$.V}}[{{$lastIndex}}] >>= 1 -{{ end }} + // Occurs iff x == 0 && xHi < 0, i.e. X = rX' for -2⁶³ ≤ X' < 0 + if b != 0 { + // z[{{.NbWordsLastIndex}}] = -1 + // negative: add q + const neg1 = 0xFFFFFFFFFFFFFFFF + + b = 0 + {{- range $i := .NbWordsIndexesNoLast}} + z[{{$i}}], b = bits.Add64(z[{{$i}}], q{{$.ElementName}}Word{{$i}}, b) + {{- end}} + z[{{.NbWordsLastIndex}}], _ = bits.Add64(neg1, q{{$.ElementName}}Word{{$.NbWordsLastIndex}}, b) + } + } +} + +// mulWSigned mul word signed (w/ montgomery reduction) +func (z *{{.ElementName}}) mulWSigned(x *{{.ElementName}}, y int64) { + m := y >> 63 + _mulWGeneric(z, x, uint64((y ^ m) - m)) + // multiply by abs(y) + if y < 0 { + z.Neg(z) + } +} +{{ end }} ` diff --git a/field/internal/templates/element/inverse_tests.go b/field/internal/templates/element/inverse_tests.go new file mode 100644 index 0000000000..3bd598f7ea --- /dev/null +++ b/field/internal/templates/element/inverse_tests.go @@ -0,0 +1,502 @@ +package element + +const InverseTests = ` + +{{if eq .NoCarry true}} + +func Test{{.ElementName}}InversionApproximation(t *testing.T) { + var x {{.ElementName}} + for i := 0; i < 1000; i++ { + x.SetRandom() + + // Normally small elements are unlikely. Here we give them a higher chance + xZeros := mrand.Int() % Limbs + for j := 1; j < xZeros; j++ { + x[Limbs - j] = 0 + } + + a := approximate(&x, x.BitLen()) + aRef := approximateRef(&x) + + if a != aRef { + t.Error("Approximation mismatch") + } + } +} + +func Test{{.ElementName}}InversionCorrectionFactorFormula(t *testing.T) { + const kLimbs = k * Limbs + const power = kLimbs*6 + invIterationsN*(kLimbs-k+1) + factorInt := big.NewInt(1) + factorInt.Lsh(factorInt, power) + factorInt.Mod(factorInt, Modulus()) + + var refFactorInt big.Int + inversionCorrectionFactor := {{.ElementName}}{ + {{- range $i := .NbWordsIndexesFull }} + inversionCorrectionFactorWord{{$i}}, + {{- end}} + } + inversionCorrectionFactor.ToBigInt(&refFactorInt) + + if refFactorInt.Cmp(factorInt) != 0 { + t.Error("mismatch") + } +} + +func Test{{.ElementName}}LinearComb(t *testing.T) { + var x {{.ElementName}} + var y {{.ElementName}} + + for i := 0; i < 1000; i++ { + x.SetRandom() + y.SetRandom() + testLinearComb(t, &x, mrand.Int63(), &y, mrand.Int63()) + } +} + +// Probably unnecessary post-dev. In case the output of inv is wrong, this checks whether it's only off by a constant factor. +func Test{{.ElementName}}InversionCorrectionFactor(t *testing.T) { + + // (1/x)/inv(x) = (1/1)/inv(1) ⇔ inv(1) = x inv(x) + + var one {{.ElementName}} + var oneInv {{.ElementName}} + one.SetOne() + oneInv.Inverse(&one) + + for i := 0; i < 100; i++ { + var x {{.ElementName}} + var xInv {{.ElementName}} + x.SetRandom() + xInv.Inverse(&x) + + x.Mul(&x, &xInv) + if !x.Equal(&oneInv) { + t.Error("Correction factor is inconsistent") + } + } + + if !oneInv.Equal(&one) { + var i big.Int + oneInv.ToBigIntRegular(&i) // no montgomery + i.ModInverse(&i, Modulus()) + var fac {{.ElementName}} + fac.setBigInt(&i) // back to montgomery + + var facTimesFac {{.ElementName}} + facTimesFac.Mul(&fac, &{{.ElementName}}{ + {{- range $i := .NbWordsIndexesFull }} + inversionCorrectionFactorWord{{$i}}, + {{- end}} + }) + + t.Error("Correction factor is consistently off by", fac, "Should be", facTimesFac) + } +} + +func Test{{.ElementName}}BigNumNeg(t *testing.T) { + var a {{.ElementName}} + aHi := a.neg(&a, 0) + if !a.IsZero() || aHi != 0 { + t.Error("-0 != 0") + } +} + +func Test{{.ElementName}}BigNumWMul(t *testing.T) { + var x {{.ElementName}} + + for i := 0; i < 1000; i++ { + x.SetRandom() + w := mrand.Int63() + testBigNumWMul(t, &x, w) + } +} + +func Test{{.ElementName}}VeryBigIntConversion(t *testing.T) { + xHi := mrand.Uint64() + var x {{.ElementName}} + x.SetRandom() + var xInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + x.assertMatchVeryBigInt(t, xHi, &xInt) +} + +func Test{{.ElementName}}MontReducePos(t *testing.T) { + var x {{.ElementName}} + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64() & ^signBitSelector) + } +} + +func Test{{.ElementName}}MontReduceNeg(t *testing.T) { + var x {{.ElementName}} + + for i := 0; i < 1000; i++ { + x.SetRandom() + testMontReduceSigned(t, &x, mrand.Uint64() | signBitSelector) + } +} + +func Test{{.ElementName}}MontNegMultipleOfR(t *testing.T) { + var zero {{.ElementName}} + + for i := 0; i < 1000; i++ { + testMontReduceSigned(t, &zero, mrand.Uint64() | signBitSelector) + } +} + +//TODO: Tests like this (update factor related) are common to all fields. Move them to somewhere non-autogen +func TestUpdateFactorSubtraction(t *testing.T) { + for i := 0; i < 1000; i++ { + + f0, g0 := randomizeUpdateFactors() + f1, g1 := randomizeUpdateFactors() + + for f0-f1 > 1<<31 || f0-f1 <= -1<<31 { + f1 /= 2 + } + + for g0-g1 > 1<<31 || g0-g1 <= -1<<31 { + g1 /= 2 + } + + c0 := updateFactorsCompose(f0, g0) + c1 := updateFactorsCompose(f1, g1) + + cRes := c0 - c1 + fRes, gRes := updateFactorsDecompose(cRes) + + if fRes != f0-f1 || gRes != g0-g1 { + t.Error(i) + } + } +} + +func TestUpdateFactorsDouble(t *testing.T) { + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f > 1<<30 || f < (-1<<31+1)/2 { + f /= 2 + if g <= 1<<29 && g >= (-1<<31+1)/4 { + g *= 2 //g was kept small on f's account. Now that we're halving f, we can double g + } + } + + if g > 1<<30 || g < (-1<<31+1)/2 { + g /= 2 + + if f <= 1<<29 && f >= (-1<<31+1)/4 { + f *= 2 //f was kept small on g's account. Now that we're halving g, we can double f + } + } + + c := updateFactorsCompose(f, g) + cD := c * 2 + fD, gD := updateFactorsDecompose(cD) + + if fD != 2*f || gD != 2*g { + t.Error(i) + } + } +} + +func TestUpdateFactorsNeg(t *testing.T) { + var fMistake bool + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + + if f == 0x80000000 || g == 0x80000000 { + // Update factors this large can only have been obtained after 31 iterations and will therefore never be negated + // We don't have capacity to store -2³¹ + // Repeat this iteration + i-- + continue + } + + c := updateFactorsCompose(f, g) + nc := -c + nf, ng := updateFactorsDecompose(nc) + fMistake = fMistake || nf != -f + if nf != -f || ng != -g { + t.Errorf("Mismatch iteration #%d:\n%d, %d ->\n %d -> %d ->\n %d, %d\n Inputs in hex: %X, %X", + i, f, g, c, nc, nf, ng, f, g) + } + } + if fMistake { + t.Error("Mistake with f detected") + } else { + t.Log("All good with f") + } +} + +func TestUpdateFactorsNeg0(t *testing.T) { + c := updateFactorsCompose(0, 0) + t.Logf("c(0,0) = %X", c) + cn := -c + + if c != cn { + t.Error("Negation of zero update factors should yield the same result.") + } +} + +func TestUpdateFactorDecomposition(t *testing.T) { + var negSeen bool + + for i := 0; i < 1000; i++ { + + f, g := randomizeUpdateFactors() + + if f <= -(1<<31) || f > 1<<31 { + t.Fatal("f out of range") + } + + negSeen = negSeen || f < 0 + + c := updateFactorsCompose(f, g) + + fBack, gBack := updateFactorsDecompose(c) + + if f != fBack || g != gBack { + t.Errorf("(%d, %d) -> %d -> (%d, %d)\n", f, g, c, fBack, gBack) + } + } + + if !negSeen { + t.Fatal("No negative f factors") + } +} + +func TestUpdateFactorInitialValues(t *testing.T) { + + f0, g0 := updateFactorsDecompose(updateFactorIdentityMatrixRow0) + f1, g1 := updateFactorsDecompose(updateFactorIdentityMatrixRow1) + + if f0 != 1 || g0 != 0 || f1 != 0 || g1 != 1 { + t.Error("Update factor initial value constants are incorrect") + } +} + +func TestUpdateFactorsRandomization(t *testing.T) { + var maxLen int + + //t.Log("|f| + |g| is not to exceed", 1 << 31) + for i := 0; i < 1000; i++ { + f, g := randomizeUpdateFactors() + lf, lg := abs64T32(f), abs64T32(g) + absSum := lf + lg + if absSum >= 1<<31 { + + if absSum == 1<<31 { + maxLen++ + } else { + t.Error(i, "Sum of absolute values too large, f =", f, ",g =", g, ",|f| + |g| =", absSum) + } + } + } + + if maxLen == 0 { + t.Error("max len not observed") + } else { + t.Log(maxLen, "maxLens observed") + } +} + +func randomizeUpdateFactor(absLimit uint32) int64 { + const maxSizeLikelihood = 10 + maxSize := mrand.Intn(maxSizeLikelihood) + + absLimit64 := int64(absLimit) + var f int64 + switch maxSize { + case 0: + f = absLimit64 + case 1: + f = -absLimit64 + default: + f = int64(mrand.Uint64()%(2*uint64(absLimit64)+1)) - absLimit64 + } + + if f > 1<<31 { + return 1 << 31 + } else if f < -1<<31+1 { + return -1<<31 + 1 + } + + return f +} + +func abs64T32(f int64) uint32 { + if f >= 1<<32 || f < -1<<32 { + panic("f out of range") + } + + if f < 0 { + return uint32(-f) + } + return uint32(f) +} + +func randomizeUpdateFactors() (int64, int64) { + var f [2]int64 + b := mrand.Int() % 2 + + f[b] = randomizeUpdateFactor(1 << 31) + + //As per the paper, |f| + |g| \le 2³¹. + f[1-b] = randomizeUpdateFactor(1<<31 - abs64T32(f[b])) + + //Patching another edge case + if f[0]+f[1] == -1<<31 { + b = mrand.Int() % 2 + f[b]++ + } + + return f[0], f[1] +} + +func testLinearComb(t *testing.T, x *{{.ElementName}}, xC int64, y *{{.ElementName}}, yC int64) { + + var p1 big.Int + x.ToBigInt(&p1) + p1.Mul(&p1, big.NewInt(xC)) + + var p2 big.Int + y.ToBigInt(&p2) + p2.Mul(&p2, big.NewInt(yC)) + + p1.Add(&p1, &p2) + p1.Mod(&p1, Modulus()) + montReduce(&p1, &p1) + + var z {{.ElementName}} + z.linearCombSosSigned(x, xC, y, yC) + z.assertMatchVeryBigInt(t, 0, &p1) +} + + +func testBigNumWMul(t *testing.T, a *{{.ElementName}}, c int64) { + var aHi uint64 + var aTimes {{.ElementName}} + aHi = aTimes.mulWRegular(a, c) + + assertMulProduct(t, a, c, &aTimes, aHi) +} + +func testMontReduceSigned(t *testing.T, x *{{.ElementName}}, xHi uint64) { + var res {{.ElementName}} + var xInt big.Int + var resInt big.Int + x.toVeryBigIntSigned(&xInt, xHi) + res.montReduceSigned(x, xHi) + montReduce(&resInt, &xInt) + res.assertMatchVeryBigInt(t, 0, &resInt) +} + +func updateFactorsCompose(f int64, g int64) int64 { + return f + g<<32 +} + +var rInv big.Int +func montReduce(res *big.Int, x *big.Int) { + if rInv.BitLen() == 0 { // initialization + rInv.SetUint64(1) + rInv.Lsh(&rInv, Limbs * 64) + rInv.ModInverse(&rInv, Modulus()) + } + res.Mul(x, &rInv) + res.Mod(res, Modulus()) +} + +func (z *{{.ElementName}}) toVeryBigIntUnsigned(i *big.Int, xHi uint64) { + z.ToBigInt(i) + var upperWord big.Int + upperWord.SetUint64(xHi) + upperWord.Lsh(&upperWord, Limbs*64) + i.Add(&upperWord, i) +} + +func (z *{{.ElementName}}) toVeryBigIntSigned(i *big.Int, xHi uint64) { + z.toVeryBigIntUnsigned(i, xHi) + if signBitSelector&xHi != 0 { + twosCompModulus := big.NewInt(1) + twosCompModulus.Lsh(twosCompModulus, (Limbs+1)*64) + i.Sub(i, twosCompModulus) + } +} + +func assertMulProduct(t *testing.T, x *{{.ElementName}}, c int64, result *{{.ElementName}}, resultHi uint64) big.Int { + var xInt big.Int + x.ToBigInt(&xInt) + + xInt.Mul(&xInt, big.NewInt(c)) + + result.assertMatchVeryBigInt(t, resultHi, &xInt) + return xInt +} + +func assertMatch(t *testing.T, w []big.Word, a uint64, index int) { + + var wI big.Word + + if index < len(w) { + wI = w[index] + } + + const filter uint64 = 0xFFFFFFFFFFFFFFFF >> (64 - bits.UintSize) + + a = a >> ((index * bits.UintSize) % 64) + a &= filter + + if uint64(wI) != a { + t.Error("Bignum mismatch: disagreement on word", index) + } +} + +func (z *{{.ElementName}}) assertMatchVeryBigInt(t *testing.T, aHi uint64, aInt *big.Int) { + + var modulus big.Int + var aIntMod big.Int + modulus.SetInt64(1) + modulus.Lsh(&modulus, (Limbs+1)*64) + aIntMod.Mod(aInt, &modulus) + + words := aIntMod.Bits() + + const steps = 64 / bits.UintSize + for i := 0; i < Limbs*steps; i++ { + assertMatch(t, words, z[i/steps], i) + } + + for i := 0; i < steps; i++ { + assertMatch(t, words, aHi, Limbs*steps+i) + } +} + +func approximateRef(x *{{.ElementName}}) uint64 { + + var asInt big.Int + x.ToBigInt(&asInt) + n := x.BitLen() + + if n <= 64 { + return asInt.Uint64() + } + + modulus := big.NewInt(1 << 31) + var lo big.Int + lo.Mod(&asInt, modulus) + + modulus.Lsh(modulus, uint(n-64)) + var hi big.Int + hi.Div(&asInt, modulus) + hi.Lsh(&hi, 31) + + hi.Add(&hi, &lo) + return hi.Uint64() +} +{{- end}} +` diff --git a/field/internal/templates/element/mul_nocarry.go b/field/internal/templates/element/mul_nocarry.go index 53dea7d308..a492c6fc2a 100644 --- a/field/internal/templates/element/mul_nocarry.go +++ b/field/internal/templates/element/mul_nocarry.go @@ -49,4 +49,51 @@ var c [3]uint64 } {{- end}} {{ end }} + + + + +{{ define "mul_nocarry_v2" }} +var t [{{.all.NbWords}}]uint64 + +{{- range $j := .all.NbWordsIndexesFull}} +{ + // round {{$j}} + + {{- if eq $j 0}} + c1, c0 := bits.Mul64(y, {{$.V2}}[0]) + m := c0 * {{index $.all.QInverse 0}} + c2 := madd0(m, {{index $.all.Q 0}}, c0) + {{- range $i := $.all.NbWordsIndexesNoZero}} + c1, c0 = madd1(y, {{$.V2}}[{{$i}}], c1) + {{- if eq $i $.all.NbWordsLastIndex}} + t[{{sub $.all.NbWords 1}}], t[{{sub $i 1}}] = madd3(m, {{index $.all.Q $i}}, c0, c2, c1) + {{- else}} + c2, t[{{sub $i 1}}] = madd2(m, {{index $.all.Q $i}}, c2, c0) + {{- end}} + {{- end}} + {{- else if eq $j $.all.NbWordsLastIndex}} + m := t[0] * {{index $.all.QInverse 0}} + c2 := madd0(m, {{index $.all.Q 0}}, t[0]) + {{- range $i := $.all.NbWordsIndexesNoZero}} + {{- if eq $i $.all.NbWordsLastIndex}} + z[{{sub $.all.NbWords 1}}], z[{{sub $i 1}}] = madd2(m, {{index $.all.Q $i}}, t[{{$i}}], c2) + {{- else}} + c2, z[{{sub $i 1}}] = madd2(m, {{index $.all.Q $i}}, c2, t[{{$i}}]) + {{- end}} + {{- end}} + {{- else}} + m := t[0] * {{index $.all.QInverse 0}} + c2 := madd0(m, {{index $.all.Q 0}}, t[0]) + {{- range $i := $.all.NbWordsIndexesNoZero}} + {{- if eq $i $.all.NbWordsLastIndex}} + t[{{sub $.all.NbWords 1}}], t[{{sub $i 1}}] = madd2(m, {{index $.all.Q $i}}, t[{{$i}}], c2) + {{- else}} + c2, t[{{sub $i 1}}] = madd2(m, {{index $.all.Q $i}}, c2, t[{{$i}}]) + {{- end}} + {{- end}} + {{- end }} +} +{{- end}} +{{ end }} ` diff --git a/field/internal/templates/element/reduce.go b/field/internal/templates/element/reduce.go index f1c9a75251..5d5d1bd0a9 100644 --- a/field/internal/templates/element/reduce.go +++ b/field/internal/templates/element/reduce.go @@ -2,7 +2,7 @@ package element const Reduce = ` {{ define "reduce" }} -// if z > q --> z -= q +// if z > q → z -= q // note: this is NOT constant time if !({{- range $i := reverse .NbWordsIndexesNoZero}} z[{{$i}}] < {{index $.Q $i}} || ( z[{{$i}}] == {{index $.Q $i}} && ( {{- end}}z[0] < {{index $.Q 0}} {{- range $i := .NbWordsIndexesNoZero}} )) {{- end}} ){ diff --git a/field/internal/templates/element/tests.go b/field/internal/templates/element/tests.go index 8c6a603eca..8b83e3a75d 100644 --- a/field/internal/templates/element/tests.go +++ b/field/internal/templates/element/tests.go @@ -7,6 +7,7 @@ import ( "encoding/json" "math/big" "math/bits" + {{if .NoCarry}} mrand "math/rand" {{end}} "testing" {{if .UseAddChain}} "fmt" {{ end }} @@ -268,7 +269,7 @@ var staticTestValues []{{.ElementName}} func init() { staticTestValues = append(staticTestValues, {{.ElementName}}{}) // zero staticTestValues = append(staticTestValues, One()) // one - staticTestValues = append(staticTestValues, rSquare) // r^2 + staticTestValues = append(staticTestValues, rSquare) // r² var e, one {{.ElementName}} one.SetOne() e.Sub(&q{{.ElementName}}, &one)