Skip to content

Commit

Permalink
perf: lookup blueprint compile time improvement (#899)
Browse files Browse the repository at this point in the history
* test: add reference benchmark for many queries on lookup

* perf: cache maxLevel in blueprint

* style: cleanup

* refactor: level_builder to instruction_tree for clarity

* test: restore TestLookup

* feat: remove a debug.Debug check
  • Loading branch information
gbotrel authored Nov 3, 2023
1 parent 24d850c commit 8669a6a
Show file tree
Hide file tree
Showing 18 changed files with 264 additions and 167 deletions.
1 change: 0 additions & 1 deletion constraint/bls12-377/r1cs_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion constraint/bls12-381/r1cs_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion constraint/bls24-315/r1cs_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion constraint/bls24-317/r1cs_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions constraint/blueprint.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ type Blueprint interface {
// NbOutputs return the number of output wires this blueprint creates.
NbOutputs(inst Instruction) int

// WireWalker returns a function that walks the wires appearing in the blueprint.
// This is used by the level builder to build a dependency graph between instructions.
WireWalker(inst Instruction) func(cb func(wire uint32))
// UpdateInstructionTree updates the instruction tree;
// since the blue print knows which wires it references, it updates
// the instruction tree with the level of the (new) wires.
UpdateInstructionTree(inst Instruction, tree InstructionTree) Level
}

// Solver represents the state of a constraint system solver at runtime. Blueprint can interact
Expand Down
45 changes: 29 additions & 16 deletions constraint/blueprint_hint.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package constraint

import (
"github.com/consensys/gnark/constraint/solver"
"github.com/consensys/gnark/debug"
)

type BlueprintGenericHint struct{}
Expand Down Expand Up @@ -71,24 +72,36 @@ func (b *BlueprintGenericHint) NbOutputs(inst Instruction) int {
return 0
}

func (b *BlueprintGenericHint) WireWalker(inst Instruction) func(cb func(wire uint32)) {
return func(cb func(wire uint32)) {
lenInputs := int(inst.Calldata[2])
j := 3
for i := 0; i < lenInputs; i++ {
n := int(inst.Calldata[j]) // len of linear expr
j++
func (b *BlueprintGenericHint) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level {
// BlueprintGenericHint knows the input and output to the instruction
maxLevel := LevelUnset

for k := 0; k < n; k++ {
t := Term{CID: inst.Calldata[j], VID: inst.Calldata[j+1]}
if !t.IsConstant() {
cb(t.VID)
}
j += 2
// iterate over the inputs and find the max level
lenInputs := int(inst.Calldata[2])
j := 3
for i := 0; i < lenInputs; i++ {
n := int(inst.Calldata[j]) // len of linear expr
j++

for k := 0; k < n; k++ {
wireID := inst.Calldata[j+1]
j += 2
if !tree.HasWire(wireID) {
continue
}
if level := tree.GetWireLevel(wireID); level > maxLevel {
maxLevel = level
}
if debug.Debug && tree.GetWireLevel(wireID) == LevelUnset {
panic("wire we depend on is not in the instruction tree")
}
}
for k := inst.Calldata[j]; k < inst.Calldata[j+1]; k++ {
cb(k)
}
}

// iterate over the outputs and insert them at maxLevel + 1
outputLevel := maxLevel + 1
for k := inst.Calldata[j]; k < inst.Calldata[j+1]; k++ {
tree.InsertWire(k, outputLevel)
}
return outputLevel
}
77 changes: 47 additions & 30 deletions constraint/blueprint_logderivlookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ import (
// It is essentially a hint to the solver, but enables storing the table entries only once.
type BlueprintLookupHint struct {
EntriesCalldata []uint32

// stores the maxLevel of the entries computed by WireWalker
maxLevel Level
maxLevelPosition int
maxLevelOffset int
}

// ensures BlueprintLookupHint implements the BlueprintSolvable interface
Expand Down Expand Up @@ -65,47 +70,59 @@ func (b *BlueprintLookupHint) NbOutputs(inst Instruction) int {
return int(inst.Calldata[2])
}

// Wires returns a function that walks the wires appearing in the blueprint.
// This is used by the level builder to build a dependency graph between instructions.
func (b *BlueprintLookupHint) WireWalker(inst Instruction) func(cb func(wire uint32)) {
return func(cb func(wire uint32)) {
// depend on the table UP to the number of entries at time of instruction creation.
nbEntries := int(inst.Calldata[1])
func (b *BlueprintLookupHint) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level {
// depend on the table UP to the number of entries at time of instruction creation.
nbEntries := int(inst.Calldata[1])

// check if we already cached the max level
if b.maxLevelPosition-1 < nbEntries { // adjust for default value of b.maxLevelPosition (0)

// invoke the callback on each wire appearing in the table
j := 0
for i := 0; i < nbEntries; i++ {
j := b.maxLevelOffset // skip the entries we already processed
for i := b.maxLevelPosition; i < nbEntries; i++ {
// first we have the length of the linear expression
n := int(b.EntriesCalldata[j])
j++
for k := 0; k < n; k++ {
t := Term{CID: b.EntriesCalldata[j], VID: b.EntriesCalldata[j+1]}
if !t.IsConstant() {
cb(t.VID)
}
wireID := b.EntriesCalldata[j+1]
j += 2
if !tree.HasWire(wireID) {
continue
}
if level := tree.GetWireLevel(wireID); (level + 1) > b.maxLevel {
b.maxLevel = level + 1
}
}
}
b.maxLevelOffset = j
b.maxLevelPosition = nbEntries
}

// invoke the callback on each wire appearing in the inputs
nbInputs := int(inst.Calldata[2])
j = 3
for i := 0; i < nbInputs; i++ {
// first we have the length of the linear expression
n := int(inst.Calldata[j])
j++
for k := 0; k < n; k++ {
t := Term{CID: inst.Calldata[j], VID: inst.Calldata[j+1]}
if !t.IsConstant() {
cb(t.VID)
}
j += 2
maxLevel := b.maxLevel - 1 // offset for default value.

// update the max level with the lookup query inputs wires
nbInputs := int(inst.Calldata[2])
j := 3
for i := 0; i < nbInputs; i++ {
// first we have the length of the linear expression
n := int(inst.Calldata[j])
j++
for k := 0; k < n; k++ {
wireID := inst.Calldata[j+1]
j += 2
if !tree.HasWire(wireID) {
continue
}
if level := tree.GetWireLevel(wireID); level > maxLevel {
maxLevel = level
}
}
}

// finally we have the outputs
for i := 0; i < nbInputs; i++ {
cb(uint32(i + int(inst.WireOffset)))
}
// finally we have the outputs
maxLevel++
for i := 0; i < nbInputs; i++ {
tree.InsertWire(uint32(i+int(inst.WireOffset)), maxLevel)
}

return maxLevel
}
44 changes: 30 additions & 14 deletions constraint/blueprint_r1cs.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,39 @@ func (b *BlueprintGenericR1C) DecompressR1C(c *R1C, inst Instruction) {
copySlice(&c.O, lenO, offset+2*(lenL+lenR))
}

func (b *BlueprintGenericR1C) WireWalker(inst Instruction) func(cb func(wire uint32)) {
return func(cb func(wire uint32)) {
lenL := int(inst.Calldata[1])
lenR := int(inst.Calldata[2])
lenO := int(inst.Calldata[3])
func (b *BlueprintGenericR1C) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level {
// a R1C doesn't know which wires are input and which are outputs
lenL := int(inst.Calldata[1])
lenR := int(inst.Calldata[2])
lenO := int(inst.Calldata[3])

appendWires := func(expectedLen, idx int) {
for k := 0; k < expectedLen; k++ {
idx++
cb(inst.Calldata[idx])
idx++
outputWires := make([]uint32, 0)
maxLevel := LevelUnset
walkWires := func(n, idx int) {
for k := 0; k < n; k++ {
wireID := inst.Calldata[idx+1]
idx += 2 // advance the offset (coeffID + wireID)
if !tree.HasWire(wireID) {
continue
}
if level := tree.GetWireLevel(wireID); level == LevelUnset {
outputWires = append(outputWires, wireID)
} else if level > maxLevel {
maxLevel = level
}
}
}

const offset = 4
walkWires(lenL, offset)
walkWires(lenR, offset+2*lenL)
walkWires(lenO, offset+2*(lenL+lenR))

const offset = 4
appendWires(lenL, offset)
appendWires(lenR, offset+2*lenL)
appendWires(lenO, offset+2*(lenL+lenR))
// insert the new wires.
maxLevel++
for _, wireID := range outputWires {
tree.InsertWire(wireID, maxLevel)
}

return maxLevel
}
55 changes: 33 additions & 22 deletions constraint/blueprint_scs.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@ func (b *BlueprintGenericSparseR1C) NbOutputs(inst Instruction) int {
return 0
}

func (b *BlueprintGenericSparseR1C) WireWalker(inst Instruction) func(cb func(wire uint32)) {
return func(cb func(wire uint32)) {
cb(inst.Calldata[0]) // xa
cb(inst.Calldata[1]) // xb
cb(inst.Calldata[2]) // xc
}
func (b *BlueprintGenericSparseR1C) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level {
return updateInstructionTree(inst.Calldata[0:3], tree)
}

func (b *BlueprintGenericSparseR1C) CompressSparseR1C(c *SparseR1C, to *[]uint32) {
Expand Down Expand Up @@ -172,12 +168,8 @@ func (b *BlueprintSparseR1CMul) NbOutputs(inst Instruction) int {
return 0
}

func (b *BlueprintSparseR1CMul) WireWalker(inst Instruction) func(cb func(wire uint32)) {
return func(cb func(wire uint32)) {
cb(inst.Calldata[0]) // xa
cb(inst.Calldata[1]) // xb
cb(inst.Calldata[2]) // xc
}
func (b *BlueprintSparseR1CMul) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level {
return updateInstructionTree(inst.Calldata[0:3], tree)
}

func (b *BlueprintSparseR1CMul) CompressSparseR1C(c *SparseR1C, to *[]uint32) {
Expand Down Expand Up @@ -220,12 +212,8 @@ func (b *BlueprintSparseR1CAdd) NbOutputs(inst Instruction) int {
return 0
}

func (b *BlueprintSparseR1CAdd) WireWalker(inst Instruction) func(cb func(wire uint32)) {
return func(cb func(wire uint32)) {
cb(inst.Calldata[0]) // xa
cb(inst.Calldata[1]) // xb
cb(inst.Calldata[2]) // xc
}
func (b *BlueprintSparseR1CAdd) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level {
return updateInstructionTree(inst.Calldata[0:3], tree)
}

func (b *BlueprintSparseR1CAdd) CompressSparseR1C(c *SparseR1C, to *[]uint32) {
Expand Down Expand Up @@ -273,10 +261,8 @@ func (b *BlueprintSparseR1CBool) NbOutputs(inst Instruction) int {
return 0
}

func (b *BlueprintSparseR1CBool) WireWalker(inst Instruction) func(cb func(wire uint32)) {
return func(cb func(wire uint32)) {
cb(inst.Calldata[0]) // xa
}
func (b *BlueprintSparseR1CBool) UpdateInstructionTree(inst Instruction, tree InstructionTree) Level {
return updateInstructionTree(inst.Calldata[0:1], tree)
}

func (b *BlueprintSparseR1CBool) CompressSparseR1C(c *SparseR1C, to *[]uint32) {
Expand All @@ -303,3 +289,28 @@ func (b *BlueprintSparseR1CBool) DecompressSparseR1C(c *SparseR1C, inst Instruct
c.QL = inst.Calldata[1]
c.QM = inst.Calldata[2]
}

func updateInstructionTree(wires []uint32, tree InstructionTree) Level {
// constraint has at most one unsolved wire.
var outputWire uint32
found := false
maxLevel := LevelUnset
for _, wireID := range wires {
if !tree.HasWire(wireID) {
continue
}
if level := tree.GetWireLevel(wireID); level == LevelUnset {
outputWire = wireID
found = true
} else if level > maxLevel {
maxLevel = level
}
}

maxLevel++
if found {
tree.InsertWire(outputWire, maxLevel)
}

return maxLevel
}
1 change: 0 additions & 1 deletion constraint/bn254/r1cs_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion constraint/bw6-633/r1cs_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion constraint/bw6-761/r1cs_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 8669a6a

Please sign in to comment.