From 82b38bd33016776fc3a84b692adee033a4f2023d Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Thu, 5 Dec 2024 17:17:59 +0000 Subject: [PATCH] fix: Make `nargo::ops::transform_program` idempotent (#6695) --- Cargo.lock | 1 + acvm-repo/acvm/Cargo.toml | 2 +- acvm-repo/acvm/src/compiler/mod.rs | 41 ++- acvm-repo/acvm/src/compiler/optimizers/mod.rs | 2 +- .../acvm/src/compiler/transformers/mod.rs | 332 +++++++++++++++++- tooling/nargo_cli/src/cli/compile_cmd.rs | 6 + 6 files changed, 366 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2f19ed704b2..f414a126495 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -48,6 +48,7 @@ dependencies = [ "ark-bn254", "bn254_blackbox_solver", "brillig_vm", + "fxhash", "indexmap 1.9.3", "num-bigint", "proptest", diff --git a/acvm-repo/acvm/Cargo.toml b/acvm-repo/acvm/Cargo.toml index e513ae4e727..ba01ac8ec16 100644 --- a/acvm-repo/acvm/Cargo.toml +++ b/acvm-repo/acvm/Cargo.toml @@ -17,7 +17,7 @@ workspace = true thiserror.workspace = true tracing.workspace = true serde.workspace = true - +fxhash.workspace = true acir.workspace = true brillig_vm.workspace = true acvm_blackbox_solver.workspace = true diff --git a/acvm-repo/acvm/src/compiler/mod.rs b/acvm-repo/acvm/src/compiler/mod.rs index 8829f77e50b..e32c0665c0f 100644 --- a/acvm-repo/acvm/src/compiler/mod.rs +++ b/acvm-repo/acvm/src/compiler/mod.rs @@ -16,6 +16,10 @@ pub use simulator::CircuitSimulator; use transformers::transform_internal; pub use transformers::{transform, MIN_EXPRESSION_WIDTH}; +/// We need multiple passes to stabilize the output. +/// The value was determined by running tests. +const MAX_OPTIMIZER_PASSES: usize = 3; + /// This module moves and decomposes acir opcodes. The transformation map allows consumers of this module to map /// metadata they had about the opcodes to the new opcode structure generated after the transformation. #[derive(Debug)] @@ -28,9 +32,9 @@ impl AcirTransformationMap { /// Builds a map from a vector of pointers to the old acir opcodes. /// The index of the vector is the new opcode index. /// The value of the vector is the old opcode index pointed. - fn new(acir_opcode_positions: Vec) -> Self { + fn new(acir_opcode_positions: &[usize]) -> Self { let mut old_indices_to_new_indices = HashMap::with_capacity(acir_opcode_positions.len()); - for (new_index, old_index) in acir_opcode_positions.into_iter().enumerate() { + for (new_index, old_index) in acir_opcode_positions.iter().copied().enumerate() { old_indices_to_new_indices.entry(old_index).or_insert_with(Vec::new).push(new_index); } AcirTransformationMap { old_indices_to_new_indices } @@ -72,17 +76,42 @@ fn transform_assert_messages( } /// Applies [`ProofSystemCompiler`][crate::ProofSystemCompiler] specific optimizations to a [`Circuit`]. +/// +/// Runs multiple passes until the output stabilizes. pub fn compile( acir: Circuit, expression_width: ExpressionWidth, ) -> (Circuit, AcirTransformationMap) { - let (acir, acir_opcode_positions) = optimize_internal(acir); + let mut pass = 0; + let mut prev_opcodes_hash = fxhash::hash64(&acir.opcodes); + let mut prev_acir = acir; + + // For most test programs it would be enough to only loop `transform_internal`, + // but some of them don't stabilize unless we also repeat the backend agnostic optimizations. + let (mut acir, acir_opcode_positions) = loop { + let (acir, acir_opcode_positions) = optimize_internal(prev_acir); + + // Stop if we have already done at least one transform and an extra optimization changed nothing. + if pass > 0 && prev_opcodes_hash == fxhash::hash64(&acir.opcodes) { + break (acir, acir_opcode_positions); + } - let (mut acir, acir_opcode_positions) = - transform_internal(acir, expression_width, acir_opcode_positions); + let (acir, acir_opcode_positions) = + transform_internal(acir, expression_width, acir_opcode_positions); + + let opcodes_hash = fxhash::hash64(&acir.opcodes); + + // Stop if the output hasn't change in this loop or we went too long. + if pass == MAX_OPTIMIZER_PASSES - 1 || prev_opcodes_hash == opcodes_hash { + break (acir, acir_opcode_positions); + } - let transformation_map = AcirTransformationMap::new(acir_opcode_positions); + pass += 1; + prev_acir = acir; + prev_opcodes_hash = opcodes_hash; + }; + let transformation_map = AcirTransformationMap::new(&acir_opcode_positions); acir.assert_messages = transform_assert_messages(acir.assert_messages, &transformation_map); (acir, transformation_map) diff --git a/acvm-repo/acvm/src/compiler/optimizers/mod.rs b/acvm-repo/acvm/src/compiler/optimizers/mod.rs index 1947a80dc35..f86bf500998 100644 --- a/acvm-repo/acvm/src/compiler/optimizers/mod.rs +++ b/acvm-repo/acvm/src/compiler/optimizers/mod.rs @@ -23,7 +23,7 @@ use super::{transform_assert_messages, AcirTransformationMap}; pub fn optimize(acir: Circuit) -> (Circuit, AcirTransformationMap) { let (mut acir, new_opcode_positions) = optimize_internal(acir); - let transformation_map = AcirTransformationMap::new(new_opcode_positions); + let transformation_map = AcirTransformationMap::new(&new_opcode_positions); acir.assert_messages = transform_assert_messages(acir.assert_messages, &transformation_map); diff --git a/acvm-repo/acvm/src/compiler/transformers/mod.rs b/acvm-repo/acvm/src/compiler/transformers/mod.rs index c9ce4ac7895..e7e8f885710 100644 --- a/acvm-repo/acvm/src/compiler/transformers/mod.rs +++ b/acvm-repo/acvm/src/compiler/transformers/mod.rs @@ -1,5 +1,10 @@ use acir::{ - circuit::{brillig::BrilligOutputs, Circuit, ExpressionWidth, Opcode}, + circuit::{ + self, + brillig::{BrilligInputs, BrilligOutputs}, + opcodes::{BlackBoxFuncCall, FunctionInput, MemOp}, + Circuit, ExpressionWidth, Opcode, + }, native_types::{Expression, Witness}, AcirField, }; @@ -12,6 +17,7 @@ pub use csat::MIN_EXPRESSION_WIDTH; use super::{ optimizers::MergeExpressionsOptimizer, transform_assert_messages, AcirTransformationMap, + MAX_OPTIMIZER_PASSES, }; /// Applies [`ProofSystemCompiler`][crate::ProofSystemCompiler] specific optimizations to a [`Circuit`]. @@ -26,7 +32,7 @@ pub fn transform( let (mut acir, acir_opcode_positions) = transform_internal(acir, expression_width, acir_opcode_positions); - let transformation_map = AcirTransformationMap::new(acir_opcode_positions); + let transformation_map = AcirTransformationMap::new(&acir_opcode_positions); acir.assert_messages = transform_assert_messages(acir.assert_messages, &transformation_map); @@ -36,9 +42,52 @@ pub fn transform( /// Applies [`ProofSystemCompiler`][crate::ProofSystemCompiler] specific optimizations to a [`Circuit`]. /// /// Accepts an injected `acir_opcode_positions` to allow transformations to be applied directly after optimizations. +/// +/// Does multiple passes until the output stabilizes. #[tracing::instrument(level = "trace", name = "transform_acir", skip(acir, acir_opcode_positions))] pub(super) fn transform_internal( - acir: Circuit, + mut acir: Circuit, + expression_width: ExpressionWidth, + mut acir_opcode_positions: Vec, +) -> (Circuit, Vec) { + // Allow multiple passes until we have stable output. + let mut prev_opcodes_hash = fxhash::hash64(&acir.opcodes); + + // For most test programs it would be enough to loop here, but some of them + // don't stabilize unless we also repeat the backend agnostic optimizations. + for _ in 0..MAX_OPTIMIZER_PASSES { + let (new_acir, new_acir_opcode_positions) = + transform_internal_once(acir, expression_width, acir_opcode_positions); + + acir = new_acir; + acir_opcode_positions = new_acir_opcode_positions; + + let new_opcodes_hash = fxhash::hash64(&acir.opcodes); + + if new_opcodes_hash == prev_opcodes_hash { + break; + } + prev_opcodes_hash = new_opcodes_hash; + } + // After the elimination of intermediate variables the `current_witness_index` is potentially higher than it needs to be, + // which would cause gaps if we ran the optimization a second time, making it look like new variables were added. + acir.current_witness_index = max_witness(&acir).witness_index(); + + (acir, acir_opcode_positions) +} + +/// Applies [`ProofSystemCompiler`][crate::ProofSystemCompiler] specific optimizations to a [`Circuit`]. +/// +/// Accepts an injected `acir_opcode_positions` to allow transformations to be applied directly after optimizations. +/// +/// Does a single optimization pass. +#[tracing::instrument( + level = "trace", + name = "transform_acir_once", + skip(acir, acir_opcode_positions) +)] +fn transform_internal_once( + mut acir: Circuit, expression_width: ExpressionWidth, acir_opcode_positions: Vec, ) -> (Circuit, Vec) { @@ -79,8 +128,6 @@ pub(super) fn transform_internal( &mut next_witness_index, ); - // Update next_witness counter - next_witness_index += (intermediate_variables.len() - len) as u32; let mut new_opcodes = Vec::new(); for (g, (norm, w)) in intermediate_variables.iter().skip(len) { // de-normalize @@ -150,23 +197,288 @@ pub(super) fn transform_internal( let current_witness_index = next_witness_index - 1; - let acir = Circuit { + acir = Circuit { current_witness_index, expression_width, opcodes: transformed_opcodes, // The transformer does not add new public inputs ..acir }; + let mut merge_optimizer = MergeExpressionsOptimizer::new(); + let (opcodes, new_acir_opcode_positions) = merge_optimizer.eliminate_intermediate_variable(&acir, new_acir_opcode_positions); - // n.b. we do not update current_witness_index after the eliminate_intermediate_variable pass, the real index could be less. - let acir = Circuit { - current_witness_index, - expression_width, + + // n.b. if we do not update current_witness_index after the eliminate_intermediate_variable pass, the real index could be less. + acir = Circuit { opcodes, // The optimizer does not add new public inputs ..acir }; + (acir, new_acir_opcode_positions) } + +/// Find the witness with the highest ID in the circuit. +fn max_witness(circuit: &Circuit) -> Witness { + let mut witnesses = WitnessFolder::new(Witness::default(), |state, witness| { + *state = witness.max(*state); + }); + witnesses.fold_circuit(circuit); + witnesses.into_state() +} + +/// Fold all witnesses in a circuit. +struct WitnessFolder { + state: S, + accumulate: A, +} + +impl WitnessFolder +where + A: Fn(&mut S, Witness), +{ + /// Create the folder with some initial state and an accumulator function. + fn new(init: S, accumulate: A) -> Self { + Self { state: init, accumulate } + } + + /// Take the accumulated state. + fn into_state(self) -> S { + self.state + } + + /// Add all witnesses from the circuit. + fn fold_circuit(&mut self, circuit: &Circuit) { + self.fold_many(circuit.private_parameters.iter()); + self.fold_many(circuit.public_parameters.0.iter()); + self.fold_many(circuit.return_values.0.iter()); + for opcode in &circuit.opcodes { + self.fold_opcode(opcode); + } + } + + /// Fold a witness into the state. + fn fold(&mut self, witness: Witness) { + (self.accumulate)(&mut self.state, witness); + } + + /// Fold many witnesses into the state. + fn fold_many<'w, I: Iterator>(&mut self, witnesses: I) { + for w in witnesses { + self.fold(*w); + } + } + + /// Add witnesses from the opcode. + fn fold_opcode(&mut self, opcode: &Opcode) { + match opcode { + Opcode::AssertZero(expr) => { + self.fold_expr(expr); + } + Opcode::BlackBoxFuncCall(call) => self.fold_blackbox(call), + Opcode::MemoryOp { block_id: _, op, predicate } => { + let MemOp { operation, index, value } = op; + self.fold_expr(operation); + self.fold_expr(index); + self.fold_expr(value); + if let Some(pred) = predicate { + self.fold_expr(pred); + } + } + Opcode::MemoryInit { block_id: _, init, block_type: _ } => { + for w in init { + self.fold(*w); + } + } + // We keep the display for a BrilligCall and circuit Call separate as they + // are distinct in their functionality and we should maintain this separation for debugging. + Opcode::BrilligCall { id: _, inputs, outputs, predicate } => { + if let Some(pred) = predicate { + self.fold_expr(pred); + } + self.fold_brillig_inputs(inputs); + self.fold_brillig_outputs(outputs); + } + Opcode::Call { id: _, inputs, outputs, predicate } => { + if let Some(pred) = predicate { + self.fold_expr(pred); + } + self.fold_many(inputs.iter()); + self.fold_many(outputs.iter()); + } + } + } + + fn fold_expr(&mut self, expr: &Expression) { + for i in &expr.mul_terms { + self.fold(i.1); + self.fold(i.2); + } + for i in &expr.linear_combinations { + self.fold(i.1); + } + } + + fn fold_brillig_inputs(&mut self, inputs: &[BrilligInputs]) { + for input in inputs { + match input { + BrilligInputs::Single(expr) => { + self.fold_expr(expr); + } + BrilligInputs::Array(exprs) => { + for expr in exprs { + self.fold_expr(expr); + } + } + BrilligInputs::MemoryArray(_) => {} + } + } + } + + fn fold_brillig_outputs(&mut self, outputs: &[BrilligOutputs]) { + for output in outputs { + match output { + BrilligOutputs::Simple(w) => { + self.fold(*w); + } + BrilligOutputs::Array(ws) => self.fold_many(ws.iter()), + } + } + } + + fn fold_blackbox(&mut self, call: &BlackBoxFuncCall) { + match call { + BlackBoxFuncCall::AES128Encrypt { inputs, iv, key, outputs } => { + self.fold_function_inputs(inputs.as_slice()); + self.fold_function_inputs(iv.as_slice()); + self.fold_function_inputs(key.as_slice()); + self.fold_many(outputs.iter()); + } + BlackBoxFuncCall::AND { lhs, rhs, output } => { + self.fold_function_input(lhs); + self.fold_function_input(rhs); + self.fold(*output); + } + BlackBoxFuncCall::XOR { lhs, rhs, output } => { + self.fold_function_input(lhs); + self.fold_function_input(rhs); + self.fold(*output); + } + BlackBoxFuncCall::RANGE { input } => { + self.fold_function_input(input); + } + BlackBoxFuncCall::Blake2s { inputs, outputs } => { + self.fold_function_inputs(inputs.as_slice()); + self.fold_many(outputs.iter()); + } + BlackBoxFuncCall::Blake3 { inputs, outputs } => { + self.fold_function_inputs(inputs.as_slice()); + self.fold_many(outputs.iter()); + } + BlackBoxFuncCall::SchnorrVerify { + public_key_x, + public_key_y, + signature, + message, + output, + } => { + self.fold_function_input(public_key_x); + self.fold_function_input(public_key_y); + self.fold_function_inputs(signature.as_slice()); + self.fold_function_inputs(message.as_slice()); + self.fold(*output); + } + BlackBoxFuncCall::EcdsaSecp256k1 { + public_key_x, + public_key_y, + signature, + hashed_message, + output, + } => { + self.fold_function_inputs(public_key_x.as_slice()); + self.fold_function_inputs(public_key_y.as_slice()); + self.fold_function_inputs(signature.as_slice()); + self.fold_function_inputs(hashed_message.as_slice()); + self.fold(*output); + } + BlackBoxFuncCall::EcdsaSecp256r1 { + public_key_x, + public_key_y, + signature, + hashed_message, + output, + } => { + self.fold_function_inputs(public_key_x.as_slice()); + self.fold_function_inputs(public_key_y.as_slice()); + self.fold_function_inputs(signature.as_slice()); + self.fold_function_inputs(hashed_message.as_slice()); + self.fold(*output); + } + BlackBoxFuncCall::MultiScalarMul { points, scalars, outputs } => { + self.fold_function_inputs(points.as_slice()); + self.fold_function_inputs(scalars.as_slice()); + let (x, y, i) = outputs; + self.fold(*x); + self.fold(*y); + self.fold(*i); + } + BlackBoxFuncCall::EmbeddedCurveAdd { input1, input2, outputs } => { + self.fold_function_inputs(input1.as_slice()); + self.fold_function_inputs(input2.as_slice()); + let (x, y, i) = outputs; + self.fold(*x); + self.fold(*y); + self.fold(*i); + } + BlackBoxFuncCall::Keccakf1600 { inputs, outputs } => { + self.fold_function_inputs(inputs.as_slice()); + self.fold_many(outputs.iter()); + } + BlackBoxFuncCall::RecursiveAggregation { + verification_key, + proof, + public_inputs, + key_hash, + proof_type: _, + } => { + self.fold_function_inputs(verification_key.as_slice()); + self.fold_function_inputs(proof.as_slice()); + self.fold_function_inputs(public_inputs.as_slice()); + self.fold_function_input(key_hash); + } + BlackBoxFuncCall::BigIntAdd { .. } + | BlackBoxFuncCall::BigIntSub { .. } + | BlackBoxFuncCall::BigIntMul { .. } + | BlackBoxFuncCall::BigIntDiv { .. } => {} + BlackBoxFuncCall::BigIntFromLeBytes { inputs, modulus: _, output: _ } => { + self.fold_function_inputs(inputs.as_slice()); + } + BlackBoxFuncCall::BigIntToLeBytes { input: _, outputs } => { + self.fold_many(outputs.iter()); + } + BlackBoxFuncCall::Poseidon2Permutation { inputs, outputs, len: _ } => { + self.fold_function_inputs(inputs.as_slice()); + self.fold_many(outputs.iter()); + } + BlackBoxFuncCall::Sha256Compression { inputs, hash_values, outputs } => { + self.fold_function_inputs(inputs.as_slice()); + self.fold_function_inputs(hash_values.as_slice()); + self.fold_many(outputs.iter()); + } + } + } + + fn fold_function_input(&mut self, input: &FunctionInput) { + if let circuit::opcodes::ConstantOrWitnessEnum::Witness(witness) = input.input() { + self.fold(witness); + } + } + + fn fold_function_inputs(&mut self, inputs: &[FunctionInput]) { + for input in inputs { + self.fold_function_input(input); + } + } +} diff --git a/tooling/nargo_cli/src/cli/compile_cmd.rs b/tooling/nargo_cli/src/cli/compile_cmd.rs index 3317cd34e85..f134374f89e 100644 --- a/tooling/nargo_cli/src/cli/compile_cmd.rs +++ b/tooling/nargo_cli/src/cli/compile_cmd.rs @@ -418,6 +418,12 @@ mod tests { if verbose { // Compare where the most likely difference is. + similar_asserts::assert_eq!( + format!("{}", program_1.program), + format!("{}", program_2.program), + "optimization not idempotent for test program '{}'", + package.name + ); assert_eq!( program_1.program, program_2.program, "optimization not idempotent for test program '{}'",