Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: optimise older opcodes in reverse order #6476

Merged
merged 16 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 89 additions & 63 deletions acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,36 @@ use acir::{

use crate::compiler::CircuitSimulator;

pub(crate) struct MergeExpressionsOptimizer {
pub(crate) struct MergeExpressionsOptimizer<F> {
resolved_blocks: HashMap<BlockId, BTreeSet<Witness>>,
modified_gates: HashMap<usize, Opcode<F>>,
deleted_gates: BTreeSet<usize>,
}

impl MergeExpressionsOptimizer {
impl<F: AcirField> MergeExpressionsOptimizer<F> {
pub(crate) fn new() -> Self {
MergeExpressionsOptimizer { resolved_blocks: HashMap::new() }
MergeExpressionsOptimizer {
resolved_blocks: HashMap::new(),
modified_gates: HashMap::new(),
deleted_gates: BTreeSet::new(),
}
}
/// This pass analyzes the circuit and identifies intermediate variables that are
/// only used in two gates. It then merges the gate that produces the
/// intermediate variable into the second one that uses it
/// Note: This pass is only relevant for backends that can handle unlimited width
pub(crate) fn eliminate_intermediate_variable<F: AcirField>(
pub(crate) fn eliminate_intermediate_variable(
&mut self,
circuit: &Circuit<F>,
acir_opcode_positions: Vec<usize>,
) -> (Vec<Opcode<F>>, Vec<usize>) {
// Initialization
self.modified_gates.clear();
self.deleted_gates.clear();
self.resolved_blocks.clear();

// Keep track, for each witness, of the gates that use it
let circuit_inputs = circuit.circuit_arguments();
self.resolved_blocks = HashMap::new();
let mut used_witness: BTreeMap<Witness, BTreeSet<usize>> = BTreeMap::new();
for (i, opcode) in circuit.opcodes.iter().enumerate() {
let witnesses = self.witness_inputs(opcode);
Expand All @@ -46,80 +56,89 @@ impl MergeExpressionsOptimizer {
}
}

let mut modified_gates: HashMap<usize, Opcode<F>> = HashMap::new();
let mut new_circuit = Vec::new();
let mut new_acir_opcode_positions = Vec::new();
// For each opcode, try to get a target opcode to merge with
for (i, (opcode, opcode_position)) in
circuit.opcodes.iter().zip(acir_opcode_positions).enumerate()
{
for (i, opcode) in circuit.opcodes.iter().enumerate() {
if !matches!(opcode, Opcode::AssertZero(_)) {
new_circuit.push(opcode.clone());
new_acir_opcode_positions.push(opcode_position);
continue;
}
let opcode = modified_gates.get(&i).unwrap_or(opcode).clone();
let mut to_keep = true;
let input_witnesses = self.witness_inputs(&opcode);
for w in input_witnesses {
let Some(gates_using_w) = used_witness.get(&w) else {
continue;
};
// We only consider witness which are used in exactly two arithmetic gates
if gates_using_w.len() == 2 {
let first = *gates_using_w.first().expect("gates_using_w.len == 2");
let second = *gates_using_w.last().expect("gates_using_w.len == 2");
let b = if second == i {
first
} else {
// sanity check
assert!(i == first);
second
if let Some(opcode) = self.get_opcode(i, circuit) {
let input_witnesses = self.witness_inputs(&opcode);
for w in input_witnesses {
let Some(gates_using_w) = used_witness.get(&w) else {
continue;
};

let second_gate = modified_gates.get(&b).unwrap_or(&circuit.opcodes[b]);
if let (Opcode::AssertZero(expr_define), Opcode::AssertZero(expr_use)) =
(&opcode, second_gate)
{
// We cannot merge an expression into an earlier opcode, because this
// would break the 'execution ordering' of the opcodes
// This case can happen because a previous merge would change an opcode
// and eliminate a witness from it, giving new opportunities for this
// witness to be used in only two expressions
// TODO: the missed optimization for the i>b case can be handled by
// - doing this pass again until there is no change, or
// - merging 'b' into 'i' instead
if i < b {
if let Some(expr) = Self::merge(expr_use, expr_define, w) {
modified_gates.insert(b, Opcode::AssertZero(expr));
to_keep = false;
// Update the 'used_witness' map to account for the merge.
for w2 in CircuitSimulator::expr_wit(expr_define) {
if !circuit_inputs.contains(&w2) {
let v = used_witness.entry(w2).or_default();
v.insert(b);
v.remove(&i);
// We only consider witness which are used in exactly two arithmetic gates
if gates_using_w.len() == 2 {
let first = *gates_using_w.first().expect("gates_using_w.len == 2");
let second = *gates_using_w.last().expect("gates_using_w.len == 2");
let b = if second == i {
first
} else {
// sanity check
assert!(i == first);
second
};
// Merge the opcode with smaller index into the other one
// by updating modified_gates/deleted_gates/used_witness
// returns false if it could not merge them
let mut merge_opcodes = |op1, op2| -> bool {
if op1 == op2 {
return false;
}
let (source, target) = if op1 < op2 { (op1, op2) } else { (op2, op1) };
let source_opcode = self.get_opcode(source, circuit);
let target_opcode = self.get_opcode(target, circuit);
if let (
Some(Opcode::AssertZero(expr_use)),
Some(Opcode::AssertZero(expr_define)),
) = (target_opcode, source_opcode)
{
if let Some(expr) =
Self::merge_expression(&expr_use, &expr_define, w)
{
self.modified_gates.insert(target, Opcode::AssertZero(expr));
self.deleted_gates.insert(source);
// Update the 'used_witness' map to account for the merge.
let mut witness_list = CircuitSimulator::expr_wit(&expr_use);
aakoshh marked this conversation as resolved.
Show resolved Hide resolved
witness_list.extend(CircuitSimulator::expr_wit(&expr_define));
for w2 in witness_list {
if !circuit_inputs.contains(&w2) {
used_witness.entry(w2).and_modify(|v| {
v.insert(target);
v.remove(&source);
});
}
}
return true;
}
// We need to stop here and continue with the next opcode
// because the merge invalidates the current opcode.
break;
}
false
};

if merge_opcodes(b, i) {
// We need to stop here and continue with the next opcode
// because the merge invalidates the current opcode.
break;
}
}
}
}
}

// Construct the new circuit from modified/deleted gates
let mut new_circuit = Vec::new();
let mut new_acir_opcode_positions = Vec::new();

if to_keep {
let opcode = modified_gates.get(&i).cloned().unwrap_or(opcode);
new_circuit.push(opcode);
new_acir_opcode_positions.push(opcode_position);
for (i, opcode_position) in acir_opcode_positions.iter().enumerate() {
if let Some(op) = self.get_opcode(i, circuit) {
new_circuit.push(op);
new_acir_opcode_positions.push(*opcode_position);
}
}
(new_circuit, new_acir_opcode_positions)
}

fn brillig_input_wit<F>(&self, input: &BrilligInputs<F>) -> BTreeSet<Witness> {
fn brillig_input_wit(&self, input: &BrilligInputs<F>) -> BTreeSet<Witness> {
aakoshh marked this conversation as resolved.
Show resolved Hide resolved
let mut result = BTreeSet::new();
match input {
BrilligInputs::Single(expr) => {
Expand Down Expand Up @@ -152,7 +171,7 @@ impl MergeExpressionsOptimizer {
}

// Returns the input witnesses used by the opcode
fn witness_inputs<F: AcirField>(&self, opcode: &Opcode<F>) -> BTreeSet<Witness> {
fn witness_inputs(&self, opcode: &Opcode<F>) -> BTreeSet<Witness> {
aakoshh marked this conversation as resolved.
Show resolved Hide resolved
match opcode {
Opcode::AssertZero(expr) => CircuitSimulator::expr_wit(expr),
Opcode::BlackBoxFuncCall(bb_func) => {
Expand Down Expand Up @@ -198,7 +217,7 @@ impl MergeExpressionsOptimizer {

// Merge 'expr' into 'target' via Gaussian elimination on 'w'
// Returns None if the expressions cannot be merged
fn merge<F: AcirField>(
fn merge_expression(
target: &Expression<F>,
expr: &Expression<F>,
w: Witness,
Expand Down Expand Up @@ -226,6 +245,13 @@ impl MergeExpressionsOptimizer {
}
None
}

fn get_opcode(&self, g: usize, circuit: &Circuit<F>) -> Option<Opcode<F>> {
if self.deleted_gates.contains(&g) {
return None;
}
self.modified_gates.get(&g).or(circuit.opcodes.get(g)).cloned()
}
}

#[cfg(test)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ fn main(x: Field) {
value += term2;
value.assert_max_bit_size::<1>();

// Regression test for Aztec Packages issue #6451
// Regression test for #6447 (Aztec Packages issue #9488)
let y = unsafe { empty(x + 1) };
let z = y + x + 1;
let z1 = z + y;
Expand Down
Loading