From 8669a6a139112edd9a17aed87930ea7519c414be Mon Sep 17 00:00:00 2001 From: Gautam Botrel Date: Fri, 3 Nov 2023 10:24:27 -0500 Subject: [PATCH] perf: lookup blueprint compile time improvement (#899) * test: add reference benchmark for many queries on lookup * perf: cache maxLevel in blueprint * style: cleanup * refactor: level_builder to instruction_tree for clarity * test: restore TestLookup * feat: remove a debug.Debug check --- constraint/bls12-377/r1cs_test.go | 1 - constraint/bls12-381/r1cs_test.go | 1 - constraint/bls24-315/r1cs_test.go | 1 - constraint/bls24-317/r1cs_test.go | 1 - constraint/blueprint.go | 7 +- constraint/blueprint_hint.go | 45 +++++++---- constraint/blueprint_logderivlookup.go | 77 +++++++++++-------- constraint/blueprint_r1cs.go | 44 +++++++---- constraint/blueprint_scs.go | 55 +++++++------ constraint/bn254/r1cs_test.go | 1 - constraint/bw6-633/r1cs_test.go | 1 - constraint/bw6-761/r1cs_test.go | 1 - constraint/core.go | 21 +++-- constraint/instruction_tree.go | 64 +++++++++++++++ constraint/level_builder.go | 55 ------------- constraint/tinyfield/r1cs_test.go | 1 - .../representations/tests/r1cs.go.tmpl | 1 - .../logderivlookup/logderivlookup_test.go | 54 +++++++++---- 18 files changed, 264 insertions(+), 167 deletions(-) create mode 100644 constraint/instruction_tree.go delete mode 100644 constraint/level_builder.go diff --git a/constraint/bls12-377/r1cs_test.go b/constraint/bls12-377/r1cs_test.go index 044c55d31a..9883601a4a 100644 --- a/constraint/bls12-377/r1cs_test.go +++ b/constraint/bls12-377/r1cs_test.go @@ -82,7 +82,6 @@ func TestSerialization(t *testing.T) { "System.lbWireLevel", "System.genericHint", "System.SymbolTable", - "System.lbOutputs", "System.bitLen")); diff != "" { t.Fatalf("round trip mismatch (-want +got):\n%s", diff) } diff --git a/constraint/bls12-381/r1cs_test.go b/constraint/bls12-381/r1cs_test.go index 28f77b956b..5a1e5a7c31 100644 --- a/constraint/bls12-381/r1cs_test.go +++ b/constraint/bls12-381/r1cs_test.go @@ -82,7 +82,6 @@ func TestSerialization(t *testing.T) { "System.lbWireLevel", "System.genericHint", "System.SymbolTable", - "System.lbOutputs", "System.bitLen")); diff != "" { t.Fatalf("round trip mismatch (-want +got):\n%s", diff) } diff --git a/constraint/bls24-315/r1cs_test.go b/constraint/bls24-315/r1cs_test.go index 4c42f78ee5..6a48dbc3f5 100644 --- a/constraint/bls24-315/r1cs_test.go +++ b/constraint/bls24-315/r1cs_test.go @@ -82,7 +82,6 @@ func TestSerialization(t *testing.T) { "System.lbWireLevel", "System.genericHint", "System.SymbolTable", - "System.lbOutputs", "System.bitLen")); diff != "" { t.Fatalf("round trip mismatch (-want +got):\n%s", diff) } diff --git a/constraint/bls24-317/r1cs_test.go b/constraint/bls24-317/r1cs_test.go index 40bb573a0b..6c4bc3deb0 100644 --- a/constraint/bls24-317/r1cs_test.go +++ b/constraint/bls24-317/r1cs_test.go @@ -82,7 +82,6 @@ func TestSerialization(t *testing.T) { "System.lbWireLevel", "System.genericHint", "System.SymbolTable", - "System.lbOutputs", "System.bitLen")); diff != "" { t.Fatalf("round trip mismatch (-want +got):\n%s", diff) } diff --git a/constraint/blueprint.go b/constraint/blueprint.go index 2949f0665f..34516326e7 100644 --- a/constraint/blueprint.go +++ b/constraint/blueprint.go @@ -18,9 +18,10 @@ type Blueprint interface { // NbOutputs return the number of output wires this blueprint creates. NbOutputs(inst Instruction) int - // WireWalker returns a function that walks the wires appearing in the blueprint. - // This is used by the level builder to build a dependency graph between instructions. - WireWalker(inst Instruction) func(cb func(wire uint32)) + // UpdateInstructionTree updates the instruction tree; + // since the blue print knows which wires it references, it updates + // the instruction tree with the level of the (new) wires. + UpdateInstructionTree(inst Instruction, tree InstructionTree) Level } // Solver represents the state of a constraint system solver at runtime. Blueprint can interact diff --git a/constraint/blueprint_hint.go b/constraint/blueprint_hint.go index ac96413ef0..1144a5e101 100644 --- a/constraint/blueprint_hint.go +++ b/constraint/blueprint_hint.go @@ -2,6 +2,7 @@ package constraint import ( "github.com/consensys/gnark/constraint/solver" + "github.com/consensys/gnark/debug" ) type BlueprintGenericHint struct{} @@ -71,24 +72,36 @@ func (b *BlueprintGenericHint) NbOutputs(inst Instruction) int { return 0 } -func (b *BlueprintGenericHint) WireWalker(inst Instruction) func(cb func(wire uint32)) { - return func(cb func(wire uint32)) { - lenInputs := int(inst.Calldata[2]) - j := 3 - for i := 0; i < lenInputs; i++ { - n := int(inst.Calldata[j]) // len of linear expr - j++ +func (b *BlueprintGenericHint) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level { + // BlueprintGenericHint knows the input and output to the instruction + maxLevel := LevelUnset - for k := 0; k < n; k++ { - t := Term{CID: inst.Calldata[j], VID: inst.Calldata[j+1]} - if !t.IsConstant() { - cb(t.VID) - } - j += 2 + // iterate over the inputs and find the max level + lenInputs := int(inst.Calldata[2]) + j := 3 + for i := 0; i < lenInputs; i++ { + n := int(inst.Calldata[j]) // len of linear expr + j++ + + for k := 0; k < n; k++ { + wireID := inst.Calldata[j+1] + j += 2 + if !tree.HasWire(wireID) { + continue + } + if level := tree.GetWireLevel(wireID); level > maxLevel { + maxLevel = level + } + if debug.Debug && tree.GetWireLevel(wireID) == LevelUnset { + panic("wire we depend on is not in the instruction tree") } } - for k := inst.Calldata[j]; k < inst.Calldata[j+1]; k++ { - cb(k) - } } + + // iterate over the outputs and insert them at maxLevel + 1 + outputLevel := maxLevel + 1 + for k := inst.Calldata[j]; k < inst.Calldata[j+1]; k++ { + tree.InsertWire(k, outputLevel) + } + return outputLevel } diff --git a/constraint/blueprint_logderivlookup.go b/constraint/blueprint_logderivlookup.go index b46f6d8f01..40fbe19ef3 100644 --- a/constraint/blueprint_logderivlookup.go +++ b/constraint/blueprint_logderivlookup.go @@ -11,6 +11,11 @@ import ( // It is essentially a hint to the solver, but enables storing the table entries only once. type BlueprintLookupHint struct { EntriesCalldata []uint32 + + // stores the maxLevel of the entries computed by WireWalker + maxLevel Level + maxLevelPosition int + maxLevelOffset int } // ensures BlueprintLookupHint implements the BlueprintSolvable interface @@ -65,47 +70,59 @@ func (b *BlueprintLookupHint) NbOutputs(inst Instruction) int { return int(inst.Calldata[2]) } -// Wires returns a function that walks the wires appearing in the blueprint. -// This is used by the level builder to build a dependency graph between instructions. -func (b *BlueprintLookupHint) WireWalker(inst Instruction) func(cb func(wire uint32)) { - return func(cb func(wire uint32)) { - // depend on the table UP to the number of entries at time of instruction creation. - nbEntries := int(inst.Calldata[1]) +func (b *BlueprintLookupHint) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level { + // depend on the table UP to the number of entries at time of instruction creation. + nbEntries := int(inst.Calldata[1]) + + // check if we already cached the max level + if b.maxLevelPosition-1 < nbEntries { // adjust for default value of b.maxLevelPosition (0) - // invoke the callback on each wire appearing in the table - j := 0 - for i := 0; i < nbEntries; i++ { + j := b.maxLevelOffset // skip the entries we already processed + for i := b.maxLevelPosition; i < nbEntries; i++ { // first we have the length of the linear expression n := int(b.EntriesCalldata[j]) j++ for k := 0; k < n; k++ { - t := Term{CID: b.EntriesCalldata[j], VID: b.EntriesCalldata[j+1]} - if !t.IsConstant() { - cb(t.VID) - } + wireID := b.EntriesCalldata[j+1] j += 2 + if !tree.HasWire(wireID) { + continue + } + if level := tree.GetWireLevel(wireID); (level + 1) > b.maxLevel { + b.maxLevel = level + 1 + } } } + b.maxLevelOffset = j + b.maxLevelPosition = nbEntries + } - // invoke the callback on each wire appearing in the inputs - nbInputs := int(inst.Calldata[2]) - j = 3 - for i := 0; i < nbInputs; i++ { - // first we have the length of the linear expression - n := int(inst.Calldata[j]) - j++ - for k := 0; k < n; k++ { - t := Term{CID: inst.Calldata[j], VID: inst.Calldata[j+1]} - if !t.IsConstant() { - cb(t.VID) - } - j += 2 + maxLevel := b.maxLevel - 1 // offset for default value. + + // update the max level with the lookup query inputs wires + nbInputs := int(inst.Calldata[2]) + j := 3 + for i := 0; i < nbInputs; i++ { + // first we have the length of the linear expression + n := int(inst.Calldata[j]) + j++ + for k := 0; k < n; k++ { + wireID := inst.Calldata[j+1] + j += 2 + if !tree.HasWire(wireID) { + continue + } + if level := tree.GetWireLevel(wireID); level > maxLevel { + maxLevel = level } } + } - // finally we have the outputs - for i := 0; i < nbInputs; i++ { - cb(uint32(i + int(inst.WireOffset))) - } + // finally we have the outputs + maxLevel++ + for i := 0; i < nbInputs; i++ { + tree.InsertWire(uint32(i+int(inst.WireOffset)), maxLevel) } + + return maxLevel } diff --git a/constraint/blueprint_r1cs.go b/constraint/blueprint_r1cs.go index b231eda067..e081c1c8b7 100644 --- a/constraint/blueprint_r1cs.go +++ b/constraint/blueprint_r1cs.go @@ -58,23 +58,39 @@ func (b *BlueprintGenericR1C) DecompressR1C(c *R1C, inst Instruction) { copySlice(&c.O, lenO, offset+2*(lenL+lenR)) } -func (b *BlueprintGenericR1C) WireWalker(inst Instruction) func(cb func(wire uint32)) { - return func(cb func(wire uint32)) { - lenL := int(inst.Calldata[1]) - lenR := int(inst.Calldata[2]) - lenO := int(inst.Calldata[3]) +func (b *BlueprintGenericR1C) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level { + // a R1C doesn't know which wires are input and which are outputs + lenL := int(inst.Calldata[1]) + lenR := int(inst.Calldata[2]) + lenO := int(inst.Calldata[3]) - appendWires := func(expectedLen, idx int) { - for k := 0; k < expectedLen; k++ { - idx++ - cb(inst.Calldata[idx]) - idx++ + outputWires := make([]uint32, 0) + maxLevel := LevelUnset + walkWires := func(n, idx int) { + for k := 0; k < n; k++ { + wireID := inst.Calldata[idx+1] + idx += 2 // advance the offset (coeffID + wireID) + if !tree.HasWire(wireID) { + continue + } + if level := tree.GetWireLevel(wireID); level == LevelUnset { + outputWires = append(outputWires, wireID) + } else if level > maxLevel { + maxLevel = level } } + } + + const offset = 4 + walkWires(lenL, offset) + walkWires(lenR, offset+2*lenL) + walkWires(lenO, offset+2*(lenL+lenR)) - const offset = 4 - appendWires(lenL, offset) - appendWires(lenR, offset+2*lenL) - appendWires(lenO, offset+2*(lenL+lenR)) + // insert the new wires. + maxLevel++ + for _, wireID := range outputWires { + tree.InsertWire(wireID, maxLevel) } + + return maxLevel } diff --git a/constraint/blueprint_scs.go b/constraint/blueprint_scs.go index 3f0973e234..69dad20bc5 100644 --- a/constraint/blueprint_scs.go +++ b/constraint/blueprint_scs.go @@ -28,12 +28,8 @@ func (b *BlueprintGenericSparseR1C) NbOutputs(inst Instruction) int { return 0 } -func (b *BlueprintGenericSparseR1C) WireWalker(inst Instruction) func(cb func(wire uint32)) { - return func(cb func(wire uint32)) { - cb(inst.Calldata[0]) // xa - cb(inst.Calldata[1]) // xb - cb(inst.Calldata[2]) // xc - } +func (b *BlueprintGenericSparseR1C) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level { + return updateInstructionTree(inst.Calldata[0:3], tree) } func (b *BlueprintGenericSparseR1C) CompressSparseR1C(c *SparseR1C, to *[]uint32) { @@ -172,12 +168,8 @@ func (b *BlueprintSparseR1CMul) NbOutputs(inst Instruction) int { return 0 } -func (b *BlueprintSparseR1CMul) WireWalker(inst Instruction) func(cb func(wire uint32)) { - return func(cb func(wire uint32)) { - cb(inst.Calldata[0]) // xa - cb(inst.Calldata[1]) // xb - cb(inst.Calldata[2]) // xc - } +func (b *BlueprintSparseR1CMul) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level { + return updateInstructionTree(inst.Calldata[0:3], tree) } func (b *BlueprintSparseR1CMul) CompressSparseR1C(c *SparseR1C, to *[]uint32) { @@ -220,12 +212,8 @@ func (b *BlueprintSparseR1CAdd) NbOutputs(inst Instruction) int { return 0 } -func (b *BlueprintSparseR1CAdd) WireWalker(inst Instruction) func(cb func(wire uint32)) { - return func(cb func(wire uint32)) { - cb(inst.Calldata[0]) // xa - cb(inst.Calldata[1]) // xb - cb(inst.Calldata[2]) // xc - } +func (b *BlueprintSparseR1CAdd) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level { + return updateInstructionTree(inst.Calldata[0:3], tree) } func (b *BlueprintSparseR1CAdd) CompressSparseR1C(c *SparseR1C, to *[]uint32) { @@ -273,10 +261,8 @@ func (b *BlueprintSparseR1CBool) NbOutputs(inst Instruction) int { return 0 } -func (b *BlueprintSparseR1CBool) WireWalker(inst Instruction) func(cb func(wire uint32)) { - return func(cb func(wire uint32)) { - cb(inst.Calldata[0]) // xa - } +func (b *BlueprintSparseR1CBool) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level { + return updateInstructionTree(inst.Calldata[0:1], tree) } func (b *BlueprintSparseR1CBool) CompressSparseR1C(c *SparseR1C, to *[]uint32) { @@ -303,3 +289,28 @@ func (b *BlueprintSparseR1CBool) DecompressSparseR1C(c *SparseR1C, inst Instruct c.QL = inst.Calldata[1] c.QM = inst.Calldata[2] } + +func updateInstructionTree(wires []uint32, tree InstructionTree) Level { + // constraint has at most one unsolved wire. + var outputWire uint32 + found := false + maxLevel := LevelUnset + for _, wireID := range wires { + if !tree.HasWire(wireID) { + continue + } + if level := tree.GetWireLevel(wireID); level == LevelUnset { + outputWire = wireID + found = true + } else if level > maxLevel { + maxLevel = level + } + } + + maxLevel++ + if found { + tree.InsertWire(outputWire, maxLevel) + } + + return maxLevel +} diff --git a/constraint/bn254/r1cs_test.go b/constraint/bn254/r1cs_test.go index d295603ebf..98b1786581 100644 --- a/constraint/bn254/r1cs_test.go +++ b/constraint/bn254/r1cs_test.go @@ -82,7 +82,6 @@ func TestSerialization(t *testing.T) { "System.lbWireLevel", "System.genericHint", "System.SymbolTable", - "System.lbOutputs", "System.bitLen")); diff != "" { t.Fatalf("round trip mismatch (-want +got):\n%s", diff) } diff --git a/constraint/bw6-633/r1cs_test.go b/constraint/bw6-633/r1cs_test.go index 7111c25c77..1fc0ce531a 100644 --- a/constraint/bw6-633/r1cs_test.go +++ b/constraint/bw6-633/r1cs_test.go @@ -82,7 +82,6 @@ func TestSerialization(t *testing.T) { "System.lbWireLevel", "System.genericHint", "System.SymbolTable", - "System.lbOutputs", "System.bitLen")); diff != "" { t.Fatalf("round trip mismatch (-want +got):\n%s", diff) } diff --git a/constraint/bw6-761/r1cs_test.go b/constraint/bw6-761/r1cs_test.go index d8f3b48039..45e3e0b0ad 100644 --- a/constraint/bw6-761/r1cs_test.go +++ b/constraint/bw6-761/r1cs_test.go @@ -85,7 +85,6 @@ func TestSerialization(t *testing.T) { "System.lbWireLevel", "System.genericHint", "System.SymbolTable", - "System.lbOutputs", "System.bitLen")); diff != "" { t.Fatalf("round trip mismatch (-want +got):\n%s", diff) } diff --git a/constraint/core.go b/constraint/core.go index cd23e468d9..c37fc4fd10 100644 --- a/constraint/core.go +++ b/constraint/core.go @@ -121,8 +121,7 @@ type System struct { bitLen int `cbor:"-"` // level builder - lbWireLevel []int `cbor:"-"` // at which level we solve a wire. init at -1. - lbOutputs []uint32 `cbor:"-"` // wire outputs for current constraint. + lbWireLevel []Level `cbor:"-"` // at which level we solve a wire. init at -1. CommitmentInfo Commitments GkrInfo GkrInfo @@ -143,8 +142,7 @@ func NewSystem(scalarField *big.Int, capacity int, t SystemType) System { bitLen: scalarField.BitLen(), Instructions: make([]PackedInstruction, 0, capacity), CallData: make([]uint32, 0, capacity*8), - lbOutputs: make([]uint32, 0, 256), - lbWireLevel: make([]int, 0, capacity), + lbWireLevel: make([]Level, 0, capacity), Levels: make([][]int, 0, capacity/2), CommitmentInfo: NewCommitments(t), } @@ -229,6 +227,11 @@ func (system *System) FieldBitLen() int { func (system *System) AddInternalVariable() (idx int) { idx = system.NbInternalVariables + system.GetNbPublicVariables() + system.GetNbSecretVariables() system.NbInternalVariables++ + // also grow the level slice + system.lbWireLevel = append(system.lbWireLevel, LevelUnset) + if debug.Debug && len(system.lbWireLevel) != system.NbInternalVariables { + panic("internal error") + } return idx } @@ -405,7 +408,15 @@ func (cs *System) AddInstruction(bID BlueprintID, calldata []uint32) []uint32 { cs.Instructions = append(cs.Instructions, pi) // update the instruction dependency tree - cs.updateLevel(len(cs.Instructions)-1, blueprint.WireWalker(inst)) + level := blueprint.UpdateInstructionTree(inst, cs) + iID := len(cs.Instructions) - 1 + + // we can't skip levels, so appending is fine. + if int(level) >= len(cs.Levels) { + cs.Levels = append(cs.Levels, []int{iID}) + } else { + cs.Levels[level] = append(cs.Levels[level], iID) + } return wires } diff --git a/constraint/instruction_tree.go b/constraint/instruction_tree.go new file mode 100644 index 0000000000..fe816df8e6 --- /dev/null +++ b/constraint/instruction_tree.go @@ -0,0 +1,64 @@ +package constraint + +import ( + "github.com/consensys/gnark/debug" +) + +type Level int + +const ( + LevelUnset Level = -1 +) + +type InstructionTree interface { + // InsertWire inserts a wire in the instruction tree at the given level. + // If the wire is already in the instruction tree, it panics. + InsertWire(wire uint32, level Level) + + // HasWire returns true if the wire is in the instruction tree. + // False if it's a constant or an input. + HasWire(wire uint32) bool + + // GetWireLevel returns the level of the wire in the instruction tree. + // If HasWire(wire) returns false, behavior is undefined. + GetWireLevel(wire uint32) Level +} + +// the instruction tree is a simple array of levels. +// it's morally a map[uint32 (wireID)]Level, but we use an array for performance reasons. + +func (system *System) HasWire(wireID uint32) bool { + offset := system.internalWireOffset() + if wireID < offset { + // it's a input. + return false + } + // if wireID == maxUint32, it's a constant. + return (wireID - offset) < uint32(len(system.lbWireLevel)) +} + +func (system *System) GetWireLevel(wireID uint32) Level { + return system.lbWireLevel[wireID-system.internalWireOffset()] +} + +func (system *System) InsertWire(wireID uint32, level Level) { + if debug.Debug { + if level < 0 { + panic("level must be >= 0") + } + if wireID < system.internalWireOffset() { + panic("cannot insert input wire in instruction tree") + } + } + wireID -= system.internalWireOffset() + if system.lbWireLevel[wireID] != LevelUnset { + panic("wire already exist in instruction tree") + } + + system.lbWireLevel[wireID] = level +} + +// internalWireOffset returns the position of the first internal wire in the wireIDs. +func (system *System) internalWireOffset() uint32 { + return uint32(system.GetNbPublicVariables() + system.GetNbSecretVariables()) +} diff --git a/constraint/level_builder.go b/constraint/level_builder.go deleted file mode 100644 index e6cc47cbe6..0000000000 --- a/constraint/level_builder.go +++ /dev/null @@ -1,55 +0,0 @@ -package constraint - -// The main idea here is to find a naive clustering of independent constraints that can be solved in parallel. -// -// We know that at each constraint, we will have at most one unsolved wire. -// (a constraint may have no unsolved wire in which case it is a plain check that the constraint hold, -// or it may additionally have some wires that will be solved by solver hints) -// -// We build a graph of dependency; we say that a wire is solved at a level l -// --> l = max(level_of_dependencies(wire)) + 1 -func (system *System) updateLevel(iID int, walkWires func(cb func(wire uint32))) { - level := -1 - - // process all wires of the instruction - walkWires(func(wire uint32) { - system.processWire(wire, &level) - }) - - // level = max(dependencies) + 1 - level++ - - // mark output wire with level - for _, wireID := range system.lbOutputs { - system.lbWireLevel[wireID] = level - } - - // we can't skip levels, so appending is fine. - if level >= len(system.Levels) { - system.Levels = append(system.Levels, []int{iID}) - } else { - system.Levels[level] = append(system.Levels[level], iID) - } - // clean the table. NB! Do not remove or move, this is required to make the - // compilation deterministic. - system.lbOutputs = system.lbOutputs[:0] -} - -func (system *System) processWire(wireID uint32, maxLevel *int) { - if wireID < uint32(system.GetNbPublicVariables()+system.GetNbSecretVariables()) { - return // ignore inputs - } - for int(wireID) >= len(system.lbWireLevel) { - // we didn't encounter this wire yet, we need to grow b.wireLevels - system.lbWireLevel = append(system.lbWireLevel, -1) - } - if system.lbWireLevel[wireID] != -1 { - // we know how to solve this wire, it's a dependency - if system.lbWireLevel[wireID] > *maxLevel { - *maxLevel = system.lbWireLevel[wireID] - } - return - } - // this wire is an output to the instruction - system.lbOutputs = append(system.lbOutputs, wireID) -} diff --git a/constraint/tinyfield/r1cs_test.go b/constraint/tinyfield/r1cs_test.go index 5015c06972..425f6e3f2b 100644 --- a/constraint/tinyfield/r1cs_test.go +++ b/constraint/tinyfield/r1cs_test.go @@ -85,7 +85,6 @@ func TestSerialization(t *testing.T) { "System.lbWireLevel", "System.genericHint", "System.SymbolTable", - "System.lbOutputs", "System.bitLen")); diff != "" { t.Fatalf("round trip mismatch (-want +got):\n%s", diff) } diff --git a/internal/generator/backend/template/representations/tests/r1cs.go.tmpl b/internal/generator/backend/template/representations/tests/r1cs.go.tmpl index 81eebde6bc..af4480ee01 100644 --- a/internal/generator/backend/template/representations/tests/r1cs.go.tmpl +++ b/internal/generator/backend/template/representations/tests/r1cs.go.tmpl @@ -73,7 +73,6 @@ func TestSerialization(t *testing.T) { "System.lbWireLevel", "System.genericHint", "System.SymbolTable", - "System.lbOutputs", "System.bitLen")); diff != "" { t.Fatalf("round trip mismatch (-want +got):\n%s", diff) } diff --git a/std/lookup/logderivlookup/logderivlookup_test.go b/std/lookup/logderivlookup/logderivlookup_test.go index 0070f18352..a3cdf039d7 100644 --- a/std/lookup/logderivlookup/logderivlookup_test.go +++ b/std/lookup/logderivlookup/logderivlookup_test.go @@ -7,8 +7,8 @@ import ( "testing" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/test" ) @@ -47,19 +47,47 @@ func TestLookup(t *testing.T) { witness.Expected[i] = new(big.Int).Set(witness.Entries[q.Int64()].(*big.Int)) } - ccs, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &LookupCircuit{}) - assert.NoError(err) - - w, err := frontend.NewWitness(&witness, ecc.BN254.ScalarField()) - assert.NoError(err) + assert.CheckCircuit(&LookupCircuit{}, test.WithValidAssignment(&witness)) +} - _, err = ccs.Solve(w) - assert.NoError(err) +type LookupCircuitLarge struct { + Entries [32000 * 2]frontend.Variable + Queries, Expected [32000 * 2]frontend.Variable +} - err = test.IsSolved(&LookupCircuit{}, &witness, ecc.BN254.ScalarField()) - assert.NoError(err) +func (c *LookupCircuitLarge) Define(api frontend.API) error { + t := New(api) + for i := range c.Entries { + t.Insert(c.Entries[i]) + } + results := make([]frontend.Variable, len(c.Queries)) + for i := range c.Queries { + results[i] = t.Lookup(c.Queries[i])[0] + } + if len(results) != len(c.Expected) { + return fmt.Errorf("length mismatch") + } + for i := range results { + api.AssertIsEqual(results[i], c.Expected[i]) + } + return nil +} - assert.ProverSucceeded(&LookupCircuit{}, &witness, - test.WithCurves(ecc.BN254), - test.WithBackends(backend.GROTH16, backend.PLONK)) +func BenchmarkCompileManyLookup(b *testing.B) { + b.Run("scs", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := frontend.Compile(ecc.BN254.ScalarField(), scs.NewBuilder, &LookupCircuitLarge{}) + if err != nil { + b.Fatal(err) + } + } + }) + b.Run("r1cs", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, err := frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &LookupCircuitLarge{}) + if err != nil { + b.Fatal(err) + } + } + }) }