diff --git a/constraint/bls12-377/gkr.go b/constraint/bls12-377/gkr.go index ef851cfb8f..adfe94b61a 100644 --- a/constraint/bls12-377/gkr.go +++ b/constraint/bls12-377/gkr.go @@ -28,6 +28,7 @@ import ( "github.com/consensys/gnark/std/utils/algo_utils" "hash" "math/big" + "sync" ) type GkrSolvingData struct { @@ -164,9 +165,12 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { return b[:] }) - hsh := HashBuilderRegistry[hashName]() + hsh, err := GetHashBuilder(hashName) + if err != nil { + return err + } - proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh(), insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) if err != nil { return err } @@ -193,4 +197,23 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { } // TODO: Move to gnark-crypto -var HashBuilderRegistry = make(map[string]func() hash.Hash) +var ( + hashBuilderRegistry = make(map[string]func() hash.Hash) + hasBuilderLock sync.RWMutex +) + +func RegisterHashBuilder(name string, builder func() hash.Hash) { + hasBuilderLock.Lock() + defer hasBuilderLock.Unlock() + hashBuilderRegistry[name] = builder +} + +func GetHashBuilder(name string) (func() hash.Hash, error) { + hasBuilderLock.RLock() + defer hasBuilderLock.RUnlock() + builder, ok := hashBuilderRegistry[name] + if !ok { + return nil, fmt.Errorf("hash function not found") + } + return builder, nil +} diff --git a/constraint/bls12-381/gkr.go b/constraint/bls12-381/gkr.go index f70fb34a48..811d8d8f7f 100644 --- a/constraint/bls12-381/gkr.go +++ b/constraint/bls12-381/gkr.go @@ -28,6 +28,7 @@ import ( "github.com/consensys/gnark/std/utils/algo_utils" "hash" "math/big" + "sync" ) type GkrSolvingData struct { @@ -164,9 +165,12 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { return b[:] }) - hsh := HashBuilderRegistry[hashName]() + hsh, err := GetHashBuilder(hashName) + if err != nil { + return err + } - proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh(), insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) if err != nil { return err } @@ -193,4 +197,23 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { } // TODO: Move to gnark-crypto -var HashBuilderRegistry = make(map[string]func() hash.Hash) +var ( + hashBuilderRegistry = make(map[string]func() hash.Hash) + hasBuilderLock sync.RWMutex +) + +func RegisterHashBuilder(name string, builder func() hash.Hash) { + hasBuilderLock.Lock() + defer hasBuilderLock.Unlock() + hashBuilderRegistry[name] = builder +} + +func GetHashBuilder(name string) (func() hash.Hash, error) { + hasBuilderLock.RLock() + defer hasBuilderLock.RUnlock() + builder, ok := hashBuilderRegistry[name] + if !ok { + return nil, fmt.Errorf("hash function not found") + } + return builder, nil +} diff --git a/constraint/bls24-315/gkr.go b/constraint/bls24-315/gkr.go index c7c2c9ed2a..21fee9bcbd 100644 --- a/constraint/bls24-315/gkr.go +++ b/constraint/bls24-315/gkr.go @@ -28,6 +28,7 @@ import ( "github.com/consensys/gnark/std/utils/algo_utils" "hash" "math/big" + "sync" ) type GkrSolvingData struct { @@ -164,9 +165,12 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { return b[:] }) - hsh := HashBuilderRegistry[hashName]() + hsh, err := GetHashBuilder(hashName) + if err != nil { + return err + } - proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh(), insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) if err != nil { return err } @@ -193,4 +197,23 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { } // TODO: Move to gnark-crypto -var HashBuilderRegistry = make(map[string]func() hash.Hash) +var ( + hashBuilderRegistry = make(map[string]func() hash.Hash) + hasBuilderLock sync.RWMutex +) + +func RegisterHashBuilder(name string, builder func() hash.Hash) { + hasBuilderLock.Lock() + defer hasBuilderLock.Unlock() + hashBuilderRegistry[name] = builder +} + +func GetHashBuilder(name string) (func() hash.Hash, error) { + hasBuilderLock.RLock() + defer hasBuilderLock.RUnlock() + builder, ok := hashBuilderRegistry[name] + if !ok { + return nil, fmt.Errorf("hash function not found") + } + return builder, nil +} diff --git a/constraint/bls24-317/gkr.go b/constraint/bls24-317/gkr.go index f38ac45c92..516f92fcde 100644 --- a/constraint/bls24-317/gkr.go +++ b/constraint/bls24-317/gkr.go @@ -28,6 +28,7 @@ import ( "github.com/consensys/gnark/std/utils/algo_utils" "hash" "math/big" + "sync" ) type GkrSolvingData struct { @@ -164,9 +165,12 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { return b[:] }) - hsh := HashBuilderRegistry[hashName]() + hsh, err := GetHashBuilder(hashName) + if err != nil { + return err + } - proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh(), insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) if err != nil { return err } @@ -193,4 +197,23 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { } // TODO: Move to gnark-crypto -var HashBuilderRegistry = make(map[string]func() hash.Hash) +var ( + hashBuilderRegistry = make(map[string]func() hash.Hash) + hasBuilderLock sync.RWMutex +) + +func RegisterHashBuilder(name string, builder func() hash.Hash) { + hasBuilderLock.Lock() + defer hasBuilderLock.Unlock() + hashBuilderRegistry[name] = builder +} + +func GetHashBuilder(name string) (func() hash.Hash, error) { + hasBuilderLock.RLock() + defer hasBuilderLock.RUnlock() + builder, ok := hashBuilderRegistry[name] + if !ok { + return nil, fmt.Errorf("hash function not found") + } + return builder, nil +} diff --git a/constraint/bn254/gkr.go b/constraint/bn254/gkr.go index 8b3ea26755..9ce0a8a161 100644 --- a/constraint/bn254/gkr.go +++ b/constraint/bn254/gkr.go @@ -28,6 +28,7 @@ import ( "github.com/consensys/gnark/std/utils/algo_utils" "hash" "math/big" + "sync" ) type GkrSolvingData struct { @@ -164,9 +165,12 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { return b[:] }) - hsh := HashBuilderRegistry[hashName]() + hsh, err := GetHashBuilder(hashName) + if err != nil { + return err + } - proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh(), insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) if err != nil { return err } @@ -193,4 +197,23 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { } // TODO: Move to gnark-crypto -var HashBuilderRegistry = make(map[string]func() hash.Hash) +var ( + hashBuilderRegistry = make(map[string]func() hash.Hash) + hasBuilderLock sync.RWMutex +) + +func RegisterHashBuilder(name string, builder func() hash.Hash) { + hasBuilderLock.Lock() + defer hasBuilderLock.Unlock() + hashBuilderRegistry[name] = builder +} + +func GetHashBuilder(name string) (func() hash.Hash, error) { + hasBuilderLock.RLock() + defer hasBuilderLock.RUnlock() + builder, ok := hashBuilderRegistry[name] + if !ok { + return nil, fmt.Errorf("hash function not found") + } + return builder, nil +} diff --git a/constraint/bw6-633/gkr.go b/constraint/bw6-633/gkr.go index 8018e76d3d..862843f646 100644 --- a/constraint/bw6-633/gkr.go +++ b/constraint/bw6-633/gkr.go @@ -28,6 +28,7 @@ import ( "github.com/consensys/gnark/std/utils/algo_utils" "hash" "math/big" + "sync" ) type GkrSolvingData struct { @@ -164,9 +165,12 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { return b[:] }) - hsh := HashBuilderRegistry[hashName]() + hsh, err := GetHashBuilder(hashName) + if err != nil { + return err + } - proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh(), insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) if err != nil { return err } @@ -193,4 +197,23 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { } // TODO: Move to gnark-crypto -var HashBuilderRegistry = make(map[string]func() hash.Hash) +var ( + hashBuilderRegistry = make(map[string]func() hash.Hash) + hasBuilderLock sync.RWMutex +) + +func RegisterHashBuilder(name string, builder func() hash.Hash) { + hasBuilderLock.Lock() + defer hasBuilderLock.Unlock() + hashBuilderRegistry[name] = builder +} + +func GetHashBuilder(name string) (func() hash.Hash, error) { + hasBuilderLock.RLock() + defer hasBuilderLock.RUnlock() + builder, ok := hashBuilderRegistry[name] + if !ok { + return nil, fmt.Errorf("hash function not found") + } + return builder, nil +} diff --git a/constraint/bw6-761/gkr.go b/constraint/bw6-761/gkr.go index d69ee82d31..a2b159cdff 100644 --- a/constraint/bw6-761/gkr.go +++ b/constraint/bw6-761/gkr.go @@ -28,6 +28,7 @@ import ( "github.com/consensys/gnark/std/utils/algo_utils" "hash" "math/big" + "sync" ) type GkrSolvingData struct { @@ -164,9 +165,12 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { return b[:] }) - hsh := HashBuilderRegistry[hashName]() + hsh, err := GetHashBuilder(hashName) + if err != nil { + return err + } - proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh(), insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) if err != nil { return err } @@ -193,4 +197,23 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { } // TODO: Move to gnark-crypto -var HashBuilderRegistry = make(map[string]func() hash.Hash) +var ( + hashBuilderRegistry = make(map[string]func() hash.Hash) + hasBuilderLock sync.RWMutex +) + +func RegisterHashBuilder(name string, builder func() hash.Hash) { + hasBuilderLock.Lock() + defer hasBuilderLock.Unlock() + hashBuilderRegistry[name] = builder +} + +func GetHashBuilder(name string) (func() hash.Hash, error) { + hasBuilderLock.RLock() + defer hasBuilderLock.RUnlock() + builder, ok := hashBuilderRegistry[name] + if !ok { + return nil, fmt.Errorf("hash function not found") + } + return builder, nil +} diff --git a/internal/generator/backend/template/representations/gkr.go.tmpl b/internal/generator/backend/template/representations/gkr.go.tmpl index 110276de60..1049da66bb 100644 --- a/internal/generator/backend/template/representations/gkr.go.tmpl +++ b/internal/generator/backend/template/representations/gkr.go.tmpl @@ -10,6 +10,7 @@ import ( "github.com/consensys/gnark/std/utils/algo_utils" "hash" "math/big" + "sync" ) type GkrSolvingData struct { @@ -146,9 +147,12 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { return b[:] }) - hsh := HashBuilderRegistry[hashName]() + hsh, err := GetHashBuilder(hashName) + if err != nil { + return err + } - proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh, insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) + proof, err := gkr.Prove(data.circuit, data.assignments, fiatshamir.WithHash(hsh(), insBytes...), gkr.WithPool(&data.memoryPool), gkr.WithWorkers(data.workers)) if err != nil { return err } @@ -175,4 +179,23 @@ func GkrProveHint(hashName string, data *GkrSolvingData) hint.Hint { } // TODO: Move to gnark-crypto -var HashBuilderRegistry = make(map[string]func() hash.Hash) \ No newline at end of file +var ( + hashBuilderRegistry = make(map[string]func() hash.Hash) + hasBuilderLock sync.RWMutex +) + +func RegisterHashBuilder(name string, builder func() hash.Hash) { + hasBuilderLock.Lock() + defer hasBuilderLock.Unlock() + hashBuilderRegistry[name] = builder +} + +func GetHashBuilder(name string) (func() hash.Hash, error) { + hasBuilderLock.RLock() + defer hasBuilderLock.RUnlock() + builder, ok := hashBuilderRegistry[name] + if !ok { + return nil, fmt.Errorf("hash function not found") + } + return builder, nil +} \ No newline at end of file diff --git a/std/gkr/api_test.go b/std/gkr/api_test.go index d39b103962..a0d1d4edbb 100644 --- a/std/gkr/api_test.go +++ b/std/gkr/api_test.go @@ -2,17 +2,18 @@ package gkr import ( "fmt" - "github.com/consensys/gnark-crypto/kzg" - "github.com/consensys/gnark/backend/plonk" - bn254r1cs "github.com/consensys/gnark/constraint/bn254" - "github.com/consensys/gnark/test" - "github.com/stretchr/testify/require" "hash" "math/rand" "strconv" "testing" "time" + "github.com/consensys/gnark-crypto/kzg" + "github.com/consensys/gnark/backend/plonk" + bn254r1cs "github.com/consensys/gnark/constraint/bn254" + "github.com/consensys/gnark/test" + "github.com/stretchr/testify/require" + "github.com/consensys/gnark-crypto/ecc" "github.com/consensys/gnark-crypto/ecc/bn254/fr" "github.com/consensys/gnark-crypto/ecc/bn254/fr/gkr" @@ -426,21 +427,21 @@ func testPlonk(t *testing.T, circuit, assignment frontend.Circuit) { } func registerMiMC() { - bn254r1cs.HashBuilderRegistry["mimc"] = bn254MiMC.NewMiMC - stdHash.BuilderRegistry["mimc"] = func(api frontend.API) (stdHash.FieldHasher, error) { + bn254r1cs.RegisterHashBuilder("mimc", bn254MiMC.NewMiMC) + stdHash.Register("mimc", func(api frontend.API) (stdHash.FieldHasher, error) { m, err := mimc.NewMiMC(api) return &m, err - } + }) } func registerConstant(c int) { name := strconv.Itoa(c) - bn254r1cs.HashBuilderRegistry[name] = func() hash.Hash { + bn254r1cs.RegisterHashBuilder(name, func() hash.Hash { return constHashBn254(c) - } - stdHash.BuilderRegistry[name] = func(frontend.API) (stdHash.FieldHasher, error) { + }) + stdHash.Register(name, func(frontend.API) (stdHash.FieldHasher, error) { return constHash(c), nil - } + }) } func init() { diff --git a/std/gkr/compile.go b/std/gkr/compile.go index 7f8dc97925..d219257667 100644 --- a/std/gkr/compile.go +++ b/std/gkr/compile.go @@ -2,14 +2,15 @@ package gkr import ( "fmt" + "math/big" + "math/bits" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/constraint/solver" "github.com/consensys/gnark/frontend" fiatshamir "github.com/consensys/gnark/std/fiat-shamir" "github.com/consensys/gnark/std/hash" "github.com/consensys/gnark/std/utils/algo_utils" - "math/big" - "math/bits" ) type circuitDataForSnark struct { @@ -188,7 +189,7 @@ func (s Solution) Verify(hashName string, initialChallenges ...frontend.Variable } var hsh hash.FieldHasher - if hsh, err = hash.BuilderRegistry[hashName](s.parentApi); err != nil { + if hsh, err = hash.GetFieldHasher(hashName, s.parentApi); err != nil { return err } s.toStore.HashName = hashName diff --git a/std/hash/hash.go b/std/hash/hash.go index 67e0d91ce1..2df30fa1ef 100644 --- a/std/hash/hash.go +++ b/std/hash/hash.go @@ -18,6 +18,9 @@ limitations under the License. package hash import ( + "errors" + "sync" + "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/math/uints" ) @@ -37,7 +40,26 @@ type FieldHasher interface { Reset() } -var BuilderRegistry = make(map[string]func(api frontend.API) (FieldHasher, error)) +var ( + builderRegistry = make(map[string]func(api frontend.API) (FieldHasher, error)) + lock sync.RWMutex +) + +func Register(name string, builder func(api frontend.API) (FieldHasher, error)) { + lock.Lock() + defer lock.Unlock() + builderRegistry[name] = builder +} + +func GetFieldHasher(name string, api frontend.API) (FieldHasher, error) { + lock.RLock() + defer lock.RUnlock() + builder, ok := builderRegistry[name] + if !ok { + return nil, errors.New("hash function not found") + } + return builder(api) +} // BinaryHasher hashes inputs into a short digest. It takes as inputs bytes and // outputs byte array whose length depends on the underlying hash function. For