Skip to content

Commit

Permalink
refactor: consolidate bw6-761 tower + fix GT exp tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yelhousni committed Oct 5, 2023
1 parent 996c82f commit 70715c0
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 249 deletions.
46 changes: 21 additions & 25 deletions std/algebra/emulated/fields_bw6761/e3.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"math/big"

bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/math/emulated"
)

Expand All @@ -15,12 +16,18 @@ type E3 struct {
}

type Ext3 struct {
fp *curveF
api frontend.API
fp *curveF
}

func NewExt3(baseEl *curveF) *Ext3 {
func NewExt3(api frontend.API) *Ext3 {
fp, err := emulated.NewField[emulated.BW6761Fp](api)
if err != nil {
panic(err)
}
return &Ext3{
fp: baseEl,
api: api,
fp: fp,
}
}

Expand All @@ -32,7 +39,6 @@ func (e Ext3) Reduce(x *E3) *E3 {
return &z
}

// SetZero sets an *E3 elmt to zero
func (e Ext3) Zero() *E3 {
zero := e.fp.Zero()
return &E3{
Expand All @@ -42,7 +48,6 @@ func (e Ext3) Zero() *E3 {
}
}

// One sets z to 1 in Montgomery form and returns z
func (e Ext3) One() *E3 {
one := e.fp.One()
zero := e.fp.Zero()
Expand All @@ -53,7 +58,6 @@ func (e Ext3) One() *E3 {
}
}

// Neg negates the *E3 number
func (e Ext3) Neg(x *E3) *E3 {
a0 := e.fp.Neg(&x.A0)
a1 := e.fp.Neg(&x.A1)
Expand All @@ -65,7 +69,6 @@ func (e Ext3) Neg(x *E3) *E3 {
}
}

// Add adds two elements of *E3
func (e Ext3) Add(x, y *E3) *E3 {
a0 := e.fp.Add(&x.A0, &y.A0)
a1 := e.fp.Add(&x.A1, &y.A1)
Expand All @@ -77,7 +80,6 @@ func (e Ext3) Add(x, y *E3) *E3 {
}
}

// Sub two elements of *E3
func (e Ext3) Sub(x, y *E3) *E3 {
a0 := e.fp.Sub(&x.A0, &y.A0)
a1 := e.fp.Sub(&x.A1, &y.A1)
Expand All @@ -89,7 +91,6 @@ func (e Ext3) Sub(x, y *E3) *E3 {
}
}

// Double doubles an element in *E3
func (e Ext3) Double(x *E3) *E3 {
two := big.NewInt(2)
a0 := e.fp.MulConst(&x.A0, two)
Expand All @@ -102,14 +103,13 @@ func (e Ext3) Double(x *E3) *E3 {
}
}

func MulByNonResidue(fp *curveF, x *baseEl) *baseEl {
func mulFpByNonResidue(fp *curveF, x *baseEl) *baseEl {

z := fp.Neg(x)
z = fp.MulConst(z, big.NewInt(4))
return z
}

// Conjugate conjugates an element in *E3
func (e Ext3) Conjugate(x *E3) *E3 {
a1 := e.fp.Neg(&x.A1)
return &E3{
Expand All @@ -119,7 +119,6 @@ func (e Ext3) Conjugate(x *E3) *E3 {
}
}

// MulByElement multiplies an element in *E3 by an element in fp
func (e Ext3) MulByElement(x *E3, y *baseEl) *E3 {
a0 := e.fp.Mul(&x.A0, y)
a1 := e.fp.Mul(&x.A1, y)
Expand Down Expand Up @@ -152,7 +151,7 @@ func (e Ext3) MulBy01(z *E3, c0, c1 *baseEl) *E3 {
tmp := e.fp.Add(&z.A1, &z.A2)
t0 := e.fp.Mul(c1, tmp)
t0 = e.fp.Sub(t0, b)
t0 = MulByNonResidue(e.fp, t0)
t0 = mulFpByNonResidue(e.fp, t0)
t0 = e.fp.Add(t0, a)

tmp = e.fp.Add(&z.A0, &z.A2)
Expand Down Expand Up @@ -181,7 +180,7 @@ func (e Ext3) MulBy1(z *E3, c1 baseEl) *E3 {
tmp := e.fp.Add(&z.A1, &z.A2)
t0 := e.fp.Mul(&c1, tmp)
t0 = e.fp.Sub(t0, b)
t0 = MulByNonResidue(e.fp, t0)
t0 = mulFpByNonResidue(e.fp, t0)

tmp = e.fp.Add(&z.A0, &z.A1)
t1 := e.fp.Mul(&c1, tmp)
Expand All @@ -203,11 +202,11 @@ func (e Ext3) MulBy12(x *E3, b1, b2 *baseEl) *E3 {
c0 = e.fp.Mul(c0, tmp)
c0 = e.fp.Sub(c0, t1)
c0 = e.fp.Sub(c0, t2)
c0 = MulByNonResidue(e.fp, c0)
c0 = mulFpByNonResidue(e.fp, c0)
c1 := e.fp.Add(&x.A0, &x.A1)
c1 = e.fp.Mul(c1, b1)
c1 = e.fp.Sub(c1, t1)
tmp = MulByNonResidue(e.fp, t2)
tmp = mulFpByNonResidue(e.fp, t2)
c1 = e.fp.Add(c1, tmp)
tmp = e.fp.Add(&x.A0, &x.A2)
c2 := e.fp.Mul(b2, tmp)
Expand Down Expand Up @@ -240,7 +239,7 @@ func (e Ext3) Mul01By01(c0, c1, d0, d1 *baseEl) *E3 {
b := e.fp.Mul(d1, c1)
t0 := e.fp.Mul(c1, d1)
t0 = e.fp.Sub(t0, b)
t0 = MulByNonResidue(e.fp, t0)
t0 = mulFpByNonResidue(e.fp, t0)
t0 = e.fp.Add(t0, a)
t2 := e.fp.Mul(c0, d0)
t2 = e.fp.Sub(t2, a)
Expand All @@ -257,7 +256,6 @@ func (e Ext3) Mul01By01(c0, c1, d0, d1 *baseEl) *E3 {
}
}

// Mul sets z to the *E3-product of x,y, returns z
func (e Ext3) Mul(x, y *E3) *E3 {
// Algorithm 13 from https://eprint.iacr.org/2010/354.pdf
t0 := e.fp.Mul(&x.A0, &y.A0)
Expand All @@ -269,7 +267,7 @@ func (e Ext3) Mul(x, y *E3) *E3 {
c0 = e.fp.Mul(c0, tmp)
c0 = e.fp.Sub(c0, t1)
c0 = e.fp.Sub(c0, t2)
c0 = MulByNonResidue(e.fp, c0)
c0 = mulFpByNonResidue(e.fp, c0)

tmp = e.fp.Add(&x.A0, &x.A2)
c2 := e.fp.Add(&y.A0, &y.A2)
Expand All @@ -282,7 +280,7 @@ func (e Ext3) Mul(x, y *E3) *E3 {
c1 = e.fp.Mul(c1, tmp)
c1 = e.fp.Sub(c1, t0)
c1 = e.fp.Sub(c1, t1)
t2 = MulByNonResidue(e.fp, t2)
t2 = mulFpByNonResidue(e.fp, t2)

a0 := e.fp.Add(c0, t0)
a1 := e.fp.Add(c1, t2)
Expand All @@ -295,15 +293,14 @@ func (e Ext3) Mul(x, y *E3) *E3 {
}
}

// Square sets z to the *E3-product of x,x, returns z
func (e Ext3) Square(x *E3) *E3 {

// Algorithm 16 from https://eprint.iacr.org/2010/354.pdf

c6 := e.fp.MulConst(&x.A1, big.NewInt(2))
c4 := e.fp.Mul(&x.A0, c6) // x.A0 * xA1 * 2
c5 := e.fp.Mul(&x.A2, &x.A2)
c1 := MulByNonResidue(e.fp, c5)
c1 := mulFpByNonResidue(e.fp, c5)
c1 = e.fp.Add(c1, c4)
c2 := e.fp.Sub(c4, c5)

Expand All @@ -312,7 +309,7 @@ func (e Ext3) Square(x *E3) *E3 {
c4 = e.fp.Add(c4, &x.A2)
c5 = e.fp.Mul(c6, &x.A2) // x.A1 * xA2 * 2
c4 = e.fp.Mul(c4, c4)
c0 := MulByNonResidue(e.fp, c5)
c0 := mulFpByNonResidue(e.fp, c5)
c4 = e.fp.Add(c4, c5)
c4 = e.fp.Sub(c4, c3)

Expand Down Expand Up @@ -377,11 +374,10 @@ func (e Ext3) MulByNonResidue(x *E3) *E3 {
A1: x.A0,
A2: x.A1,
}
z.A0 = *MulByNonResidue(e.fp, &z.A0)
z.A0 = *mulFpByNonResidue(e.fp, &z.A0)
return z
}

// AssertIsEqual constraint self to be equal to other into the given constraint system
func (e Ext3) AssertIsEqual(a, b *E3) {
e.fp.AssertIsEqual(&a.A0, &b.A0)
e.fp.AssertIsEqual(&a.A1, &b.A1)
Expand Down
72 changes: 12 additions & 60 deletions std/algebra/emulated/fields_bw6761/e3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@ type e3Add struct {
}

func (circuit *e3Add) Define(api frontend.API) error {
nfield, err := emulated.NewField[emulated.BW6761Fp](api)
if err != nil {
panic(err)
}
e := NewExt3(nfield)
e := NewExt3(api)
expected := e.Add(&circuit.A, &circuit.B)
e.AssertIsEqual(expected, &circuit.C)
return nil
Expand Down Expand Up @@ -49,11 +45,7 @@ type e3Sub struct {
}

func (circuit *e3Sub) Define(api frontend.API) error {
nfield, err := emulated.NewField[emulated.BW6761Fp](api)
if err != nil {
panic(err)
}
e := NewExt3(nfield)
e := NewExt3(api)
expected := e.Sub(&circuit.A, &circuit.B)
e.AssertIsEqual(expected, &circuit.C)
return nil
Expand Down Expand Up @@ -82,11 +74,7 @@ type e3Neg struct {
}

func (circuit *e3Neg) Define(api frontend.API) error {
nfield, err := emulated.NewField[emulated.BW6761Fp](api)
if err != nil {
panic(err)
}
e := NewExt3(nfield)
e := NewExt3(api)
expected := e.Neg(&circuit.A)
e.AssertIsEqual(expected, &circuit.B)
return nil
Expand All @@ -113,11 +101,7 @@ type e3Double struct {
}

func (circuit *e3Double) Define(api frontend.API) error {
nfield, err := emulated.NewField[emulated.BW6761Fp](api)
if err != nil {
panic(err)
}
e := NewExt3(nfield)
e := NewExt3(api)
expected := e.Double(&circuit.A)
e.AssertIsEqual(expected, &circuit.B)
return nil
Expand All @@ -144,11 +128,7 @@ type e3Mul struct {
}

func (circuit *e3Mul) Define(api frontend.API) error {
nfield, err := emulated.NewField[emulated.BW6761Fp](api)
if err != nil {
panic(err)
}
e := NewExt3(nfield)
e := NewExt3(api)
expected := e.Mul(&circuit.A, &circuit.B)
e.AssertIsEqual(expected, &circuit.C)
return nil
Expand Down Expand Up @@ -177,11 +157,7 @@ type e3MulByNonResidue struct {
}

func (circuit *e3MulByNonResidue) Define(api frontend.API) error {
nfield, err := emulated.NewField[emulated.BW6761Fp](api)
if err != nil {
panic(err)
}
e := NewExt3(nfield)
e := NewExt3(api)
expected := e.MulByNonResidue(&circuit.A)
e.AssertIsEqual(expected, &circuit.B)
return nil
Expand Down Expand Up @@ -211,11 +187,7 @@ type e3MulByElement struct {
}

func (circuit *e3MulByElement) Define(api frontend.API) error {
nfield, err := emulated.NewField[emulated.BW6761Fp](api)
if err != nil {
panic(err)
}
e := NewExt3(nfield)
e := NewExt3(api)
expected := e.MulByElement(&circuit.A, &circuit.Y)
e.AssertIsEqual(expected, &circuit.B)
return nil
Expand Down Expand Up @@ -248,11 +220,7 @@ type e3MulBy01 struct {
}

func (circuit *e3MulBy01) Define(api frontend.API) error {
nfield, err := emulated.NewField[emulated.BW6761Fp](api)
if err != nil {
panic(err)
}
e := NewExt3(nfield)
e := NewExt3(api)
expected := e.MulBy01(&circuit.A, &circuit.C0, &circuit.C1)
e.AssertIsEqual(expected, &circuit.B)
return nil
Expand Down Expand Up @@ -285,11 +253,7 @@ type e3Square struct {
}

func (circuit *e3Square) Define(api frontend.API) error {
nfield, err := emulated.NewField[emulated.BW6761Fp](api)
if err != nil {
panic(err)
}
e := NewExt3(nfield)
e := NewExt3(api)
expected := e.Square(&circuit.A)
e.AssertIsEqual(expected, &circuit.B)
return nil
Expand All @@ -316,11 +280,7 @@ type e3Inverse struct {
}

func (circuit *e3Inverse) Define(api frontend.API) error {
nfield, err := emulated.NewField[emulated.BW6761Fp](api)
if err != nil {
panic(err)
}
e := NewExt3(nfield)
e := NewExt3(api)
expected := e.Inverse(&circuit.A)
e.AssertIsEqual(expected, &circuit.B)
return nil
Expand Down Expand Up @@ -348,11 +308,7 @@ type e3Div struct {
}

func (circuit *e3Div) Define(api frontend.API) error {
nfield, err := emulated.NewField[emulated.BW6761Fp](api)
if err != nil {
panic(err)
}
e := NewExt3(nfield)
e := NewExt3(api)
expected := e.DivUnchecked(&circuit.A, &circuit.B)
e.AssertIsEqual(expected, &circuit.C)
return nil
Expand Down Expand Up @@ -384,11 +340,7 @@ type e3Conjugate struct {
}

func (circuit *e3Conjugate) Define(api frontend.API) error {
nfield, err := emulated.NewField[emulated.BW6761Fp](api)
if err != nil {
panic(err)
}
e := NewExt3(nfield)
e := NewExt3(api)
expected := e.Conjugate(&circuit.A)
e.AssertIsEqual(expected, &circuit.B)
return nil
Expand Down
Loading

0 comments on commit 70715c0

Please sign in to comment.