Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Fiat-Shamir transcript using a short hash #900

Merged
merged 12 commits into from
Nov 8, 2023
3 changes: 3 additions & 0 deletions std/algebra/emulated/sw_emulated/point.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ func (c *Curve[B, S]) MarshalG1(p AffinePoint[B]) []frontend.Variable {
res := make([]frontend.Variable, 2*nbBits)
copy(res, bx)
copy(res[len(bx):], by)
xZ := c.baseApi.IsZero(x)
yZ := c.baseApi.IsZero(y)
res[1] = c.api.Mul(xZ, yZ)
return res
}

Expand Down
53 changes: 33 additions & 20 deletions std/algebra/emulated/sw_emulated/point_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,27 +83,40 @@ func (c *MarshalG1Test[T, S]) Define(api frontend.API) error {

func TestMarshalG1(t *testing.T) {
assert := test.NewAssert(t)
_, _, g, _ := bw6761.Generators()
gBytes := g.Marshal()
nbBytes := 2 * fp_bw6761.Bytes
nbBits := nbBytes * 8
circuit := &MarshalG1Test[emulated.BW6761Fp, emulated.BW6761Fr]{
R: make([]frontend.Variable, nbBits),
}
witness := &MarshalG1Test[emulated.BW6761Fp, emulated.BW6761Fr]{
G: AffinePoint[emulated.BW6761Fp]{
X: emulated.ValueOf[emulated.BW6761Fp](g.X),
Y: emulated.ValueOf[emulated.BW6761Fp](g.Y),
},
R: make([]frontend.Variable, nbBits),
}
for i := 0; i < nbBytes; i++ {
for j := 0; j < 8; j++ {
witness.R[i*8+j] = (gBytes[i] >> (7 - j)) & 1
testFn := func(r fr_bw6761.Element) {
var P bw6761.G1Affine
P.ScalarMultiplicationBase(r.BigInt(new(big.Int)))
gBytes := P.Marshal()
nbBytes := 2 * fp_bw6761.Bytes
nbBits := nbBytes * 8
circuit := &MarshalG1Test[emulated.BW6761Fp, emulated.BW6761Fr]{
R: make([]frontend.Variable, nbBits),
}
}
err := test.IsSolved(circuit, witness, testCurve.ScalarField())
assert.NoError(err)
witness := &MarshalG1Test[emulated.BW6761Fp, emulated.BW6761Fr]{
G: AffinePoint[emulated.BW6761Fp]{
X: emulated.ValueOf[emulated.BW6761Fp](P.X),
Y: emulated.ValueOf[emulated.BW6761Fp](P.Y),
},
R: make([]frontend.Variable, nbBits),
}
for i := 0; i < nbBytes; i++ {
for j := 0; j < 8; j++ {
witness.R[i*8+j] = (gBytes[i] >> (7 - j)) & 1
}
}
err := test.IsSolved(circuit, witness, testCurve.ScalarField())
assert.NoError(err)
}
assert.Run(func(assert *test.Assert) {
var r fr_bw6761.Element
r.SetRandom()
testFn(r)
})
assert.Run(func(assert *test.Assert) {
var r fr_bw6761.Element
r.SetZero()
testFn(r)
})
}

type NegTest[T, S emulated.FieldParams] struct {
Expand Down
40 changes: 25 additions & 15 deletions std/algebra/native/sw_bls12377/g1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,23 +83,33 @@ func (c *MarshalG1Test) Define(api frontend.API) error {
func TestMarshalG1(t *testing.T) {
assert := test.NewAssert(t)

// sample a random point
var r fr.Element
r.SetRandom()
var br big.Int
r.BigInt(&br)
_, _, g, _ := bls12377.Generators()
g.ScalarMultiplication(&g, &br)
gBytes := g.Marshal()
var witness MarshalG1Test
witness.P.Assign(&g)
for i := 0; i < 96; i++ {
for j := 0; j < 8; j++ {
witness.R[i*8+j] = (gBytes[i] >> (7 - j)) & 1
testfn := func(r fr.Element) {
var br big.Int
r.BigInt(&br)
_, _, g, _ := bls12377.Generators()
g.ScalarMultiplication(&g, &br)
gBytes := g.Marshal()
var witness MarshalG1Test
witness.P.Assign(&g)
for i := 0; i < 96; i++ {
for j := 0; j < 8; j++ {
witness.R[i*8+j] = (gBytes[i] >> (7 - j)) & 1
}
}
var circuit MarshalG1Test
assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(ecc.BW6_761))
}
var circuit MarshalG1Test
assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(ecc.BW6_761))
assert.Run(func(assert *test.Assert) {
// sample a random point
var r fr.Element
r.SetRandom()
testfn(r)
})
assert.Run(func(assert *test.Assert) {
var r fr.Element
r.SetZero()
testfn(r)
})
}

// -------------------------------------------------------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions std/algebra/native/sw_bls12377/pairing2.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ func (c *Curve) MarshalG1(P G1Affine) []frontend.Variable {
res[i] = x[nbBits-1-i]
res[i+nbBits] = y[nbBits-1-i]
}
xZ := c.api.IsZero(P.X)
yZ := c.api.IsZero(P.Y)
res[1] = c.api.Mul(xZ, yZ)
return res
}

Expand Down
40 changes: 25 additions & 15 deletions std/algebra/native/sw_bls24315/g1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,23 +85,33 @@ func (c *MarshalG1Test) Define(api frontend.API) error {
func TestMarshalG1(t *testing.T) {
assert := test.NewAssert(t)

// sample a random point
var r fr.Element
r.SetRandom()
var br big.Int
r.BigInt(&br)
_, _, g, _ := bls24315.Generators()
g.ScalarMultiplication(&g, &br)
gBytes := g.Marshal()
var witness MarshalG1Test
witness.P.Assign(&g)
for i := 0; i < 80; i++ {
for j := 0; j < 8; j++ {
witness.R[i*8+j] = (gBytes[i] >> (7 - j)) & 1
testfn := func(r fr.Element) {
var br big.Int
r.BigInt(&br)
_, _, g, _ := bls24315.Generators()
g.ScalarMultiplication(&g, &br)
gBytes := g.Marshal()
var witness MarshalG1Test
witness.P.Assign(&g)
for i := 0; i < 80; i++ {
for j := 0; j < 8; j++ {
witness.R[i*8+j] = (gBytes[i] >> (7 - j)) & 1
}
}
var circuit MarshalG1Test
assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(ecc.BW6_633))
}
var circuit MarshalG1Test
assert.CheckCircuit(&circuit, test.WithValidAssignment(&witness), test.WithCurves(ecc.BW6_633))
assert.Run(func(assert *test.Assert) {
// sample a random point
var r fr.Element
r.SetRandom()
testfn(r)
})
assert.Run(func(assert *test.Assert) {
var r fr.Element
r.SetZero()
testfn(r)
})
}

// -------------------------------------------------------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions std/algebra/native/sw_bls24315/pairing2.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ func (c *Curve) MarshalG1(P G1Affine) []frontend.Variable {
res[i] = x[nbBits-1-i]
res[i+nbBits] = y[nbBits-1-i]
}
xZ := c.api.IsZero(P.X)
yZ := c.api.IsZero(P.Y)
res[1] = c.api.Mul(xZ, yZ)
return res
}

Expand Down
2 changes: 1 addition & 1 deletion std/commitments/fri/fri.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (s RadixTwoFri) verifyProofOfProximitySingleRound(api frontend.API, salt fr
xis[i] = paddNaming(fmt.Sprintf("x%d", i), frSize)
}
xis[s.nbSteps] = paddNaming("s0", frSize)
fs := fiatshamir.NewTranscript(api, s.h, xis...)
fs := fiatshamir.NewTranscript(api, s.h, xis, fiatshamir.WithDomainSeparation())
xi := make([]frontend.Variable, s.nbSteps)

// the salt is binded to the first challenge, to ensure the challenges
Expand Down
72 changes: 62 additions & 10 deletions std/fiat-shamir/transcript.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ package fiatshamir

import (
"errors"
"slices"

"github.com/consensys/gnark/constant"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/hash"
"github.com/consensys/gnark/std/math/bits"
)

// errChallengeNotFound is returned when a wrong challenge name is provided.
Expand All @@ -41,6 +43,8 @@ type Transcript struct {

// gnark API
api frontend.API

config transcriptConfig
}

type challenge struct {
Expand All @@ -53,19 +57,24 @@ type challenge struct {
// NewTranscript returns a new transcript.
// h is the hash function that is used to compute the challenges.
// challenges are the name of the challenges. The order is important.
func NewTranscript(api frontend.API, h hash.FieldHasher, challengesID ...string) Transcript {
func NewTranscript(api frontend.API, h hash.FieldHasher, challengesID []string, opts ...TranscriptOption) *Transcript {
cfg := transcriptConfig{}
for _, opt := range opts {
opt(&cfg)
}
n := len(challengesID)
t := Transcript{
challenges: make(map[string]challenge, n),
api: api,
h: h,
config: cfg,
}

for i := 0; i < n; i++ {
t.challenges[challengesID[i]] = challenge{position: i}
}

return t
return &t
}

// Bind binds the challenge to value. A challenge can be binded to an
Expand All @@ -92,10 +101,10 @@ func (t *Transcript) Bind(challengeID string, values []frontend.Variable) error

// ComputeChallenge computes the challenge corresponding to the given name.
// The resulting variable is:
// * H(name ∥ previous_challenge ∥ binded_values...) if the challenge is not the first one
// * H(name ∥ binded_values... ) if it's is the first challenge
// - H(name ∥ previous_challenge ∥ binded_values...) if the challenge is not the first one
// - H(name ∥ binded_values... ) if it's is the first challenge
func (t *Transcript) ComputeChallenge(challengeID string) (frontend.Variable, error) {

var err error
challenge, ok := t.challenges[challengeID]

if !ok {
Expand All @@ -110,19 +119,34 @@ func (t *Transcript) ComputeChallenge(challengeID string) (frontend.Variable, er
t.h.Reset()

// write the challenge name, the purpose is to have a domain separator
cChallenge := []byte(challengeID) // if we send a string, it is assumed to be a base10 number
if challengeName, err := constant.HashedBytes(t.api, cChallenge); err == nil {
t.h.Write(challengeName)
challengeInput := []byte(challengeID)
var challengeHashInput frontend.Variable = challengeInput
if t.config.withDomainSeparation {
ivokub marked this conversation as resolved.
Show resolved Hide resolved
challengeHashInput, err = constant.HashedBytes(t.api, []byte(challengeID))
if err != nil {
return nil, err
}
}
if t.config.tryBitmode > 0 {
challengeBits := bits.ToBinary(t.api, challengeInput, bits.WithNbDigits(8*len(challengeInput)))
slices.Reverse(challengeBits)
t.h.Write(challengeBits...)
} else {
return nil, err
t.h.Write(challengeHashInput)
}

// write the previous challenge if it's not the first challenge
if challenge.position != 0 {
if t.previous == nil || (t.previous.position != challenge.position-1) {
return nil, errPreviousChallengeNotComputed
}
t.h.Write(t.previous.value)
if t.config.tryBitmode > 0 {
prevBits := bits.ToBinary(t.api, t.previous.value, bits.WithNbDigits(t.config.tryBitmode))
slices.Reverse(prevBits)
t.h.Write(prevBits...)
} else {
t.h.Write(t.previous.value)
}
}

// write the binded values in the order they were added
Expand All @@ -140,3 +164,31 @@ func (t *Transcript) ComputeChallenge(challengeID string) (frontend.Variable, er
return challenge.value, nil

}

type transcriptConfig struct {
tryBitmode int
withDomainSeparation bool
}

// TranscriptOption allows modifying the [Transcript] operation.
type TranscriptOption func(tc *transcriptConfig)

// WithTryBitmode changes the [Transcript] to work on bits instead of field
// elements when writing input to the hasher. Requires that the hasher is also
// set to work in bitmode. This mode of operation is useful in cases where we
// work in mismatching fields and want to avoid overflows.
func WithTryBitmode(nbBits int) TranscriptOption {
return func(tc *transcriptConfig) {
tc.tryBitmode = nbBits
}
}

// WithDomainSeparation adds domain separation string `string:` as defined in
// RCF 9380. This mode of operation is beneficial when seeking for compatibility
// with native Transcript when initialized using gnark-crypto's MiMC
// implementation.
func WithDomainSeparation() TranscriptOption {
return func(tc *transcriptConfig) {
tc.withDomainSeparation = true
}
}
9 changes: 6 additions & 3 deletions std/fiat-shamir/transcript_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package fiatshamir

import (
"crypto/rand"
"math/big"
"testing"

Expand Down Expand Up @@ -45,7 +46,7 @@ func (circuit *FiatShamirCircuit) Define(api frontend.API) error {
}

// New transcript with 3 challenges to be derived
tsSnark := NewTranscript(api, &hSnark, "alpha", "beta", "gamma")
tsSnark := NewTranscript(api, &hSnark, []string{"alpha", "beta", "gamma"}, WithDomainSeparation())

// Bind challenges
if err := tsSnark.Bind("alpha", circuit.Bindings[0][:]); err != nil {
Expand Down Expand Up @@ -83,6 +84,7 @@ func (circuit *FiatShamirCircuit) Define(api frontend.API) error {
}

func TestFiatShamir(t *testing.T) {
var err error
assert := test.NewAssert(t)

testData := map[ecc.ID]hash.Hash{
Expand All @@ -101,10 +103,11 @@ func TestFiatShamir(t *testing.T) {
// instantiate the hash and the transcript in plain go
ts := fiatshamir.NewTranscript(h.New(), "alpha", "beta", "gamma")

var bindings [3][4]big.Int
var bindings [3][4]*big.Int
for i := 0; i < 3; i++ {
for j := 0; j < 4; j++ {
bindings[i][j].SetUint64(uint64(i * j))
bindings[i][j], err = rand.Int(rand.Reader, curveID.ScalarField())
assert.NoError(err)
}
}
frSize := utils.ByteLen(curveID.ScalarField())
Expand Down
7 changes: 3 additions & 4 deletions std/gkr/gkr.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package gkr

import (
"fmt"
"strconv"

"github.com/consensys/gnark/frontend"
fiatshamir "github.com/consensys/gnark/std/fiat-shamir"
"github.com/consensys/gnark/std/polynomial"
"github.com/consensys/gnark/std/sumcheck"
"strconv"
)

// @tabaie TODO: Contains many things copy-pasted from gnark-crypto. Generify somehow?
Expand Down Expand Up @@ -198,9 +199,7 @@ func setup(api frontend.API, c Circuit, assignment WireAssignment, transcriptSet

if transcriptSettings.Transcript == nil {
challengeNames := ChallengeNames(o.sorted, o.nbVars, transcriptSettings.Prefix)
transcript := fiatshamir.NewTranscript(
api, transcriptSettings.Hash, challengeNames...)
o.transcript = &transcript
o.transcript = fiatshamir.NewTranscript(api, transcriptSettings.Hash, challengeNames)
if err = o.transcript.Bind(challengeNames[0], transcriptSettings.BaseChallenges); err != nil {
return o, err
}
Expand Down
Loading