diff --git a/Cargo.toml b/Cargo.toml index ff3b8b690..0f4030a9e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,6 @@ brillig = { version = "0.20.0", path = "brillig", default-features = false } blackbox_solver = { package = "acvm_blackbox_solver", version = "0.20.0", path = "blackbox_solver", default-features = false } bincode = "1.3.3" -rmp-serde = "1.1.0" num-bigint = "0.4" num-traits = "0.2" diff --git a/acir/Cargo.toml b/acir/Cargo.toml index c10b0b610..3f2a25ac5 100644 --- a/acir/Cargo.toml +++ b/acir/Cargo.toml @@ -15,7 +15,7 @@ acir_field.workspace = true brillig.workspace = true serde.workspace = true thiserror.workspace = true -rmp-serde = { version ="1.1.0", optional = true } +rmp-serde = { version = "1.1.0", optional = true } flate2 = "1.0.24" bincode.workspace = true diff --git a/acir/src/circuit/brillig.rs b/acir/src/circuit/brillig.rs index 8bf59dd1a..5b1ec9d03 100644 --- a/acir/src/circuit/brillig.rs +++ b/acir/src/circuit/brillig.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; /// Inputs for the Brillig VM. These are the initial inputs /// that the Brillig VM will use to start. -#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash)] +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)] pub enum BrilligInputs { Single(Expression), Array(Vec), @@ -13,7 +13,7 @@ pub enum BrilligInputs { /// Outputs for the Brillig VM. Once the VM has completed /// execution, this will be the object that is returned. -#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug, Hash)] +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)] pub enum BrilligOutputs { Simple(Witness), Array(Vec), diff --git a/acir/src/circuit/mod.rs b/acir/src/circuit/mod.rs index a9cd30b8f..2e75a3fe5 100644 --- a/acir/src/circuit/mod.rs +++ b/acir/src/circuit/mod.rs @@ -20,6 +20,8 @@ pub struct Circuit { pub current_witness_index: u32, pub opcodes: Vec, + /// The set of private inputs to the circuit. + pub private_parameters: BTreeSet, // ACIR distinguishes between the public inputs which are provided externally or calculated within the circuit and returned. // The elements of these sets may not be mutually exclusive, i.e. a parameter may be returned from the circuit. // All public inputs (parameters and return values) must be provided to the verifier at verification time. @@ -43,6 +45,11 @@ impl Circuit { self.current_witness_index + 1 } + /// Returns all witnesses which are required to execute the circuit successfully. + pub fn circuit_arguments(&self) -> BTreeSet { + self.private_parameters.union(&self.public_parameters.0).cloned().collect() + } + /// Returns all public inputs. This includes those provided as parameters to the circuit and those /// computed as return values. pub fn public_inputs(&self) -> PublicInputs { @@ -178,6 +185,7 @@ mod tests { let circuit = Circuit { current_witness_index: 5, opcodes: vec![and_opcode(), range_opcode(), directive_opcode()], + private_parameters: BTreeSet::new(), public_parameters: PublicInputs(BTreeSet::from_iter(vec![Witness(2), Witness(12)])), return_values: PublicInputs(BTreeSet::from_iter(vec![Witness(4), Witness(12)])), }; @@ -206,6 +214,7 @@ mod tests { range_opcode(), and_opcode(), ], + private_parameters: BTreeSet::new(), public_parameters: PublicInputs(BTreeSet::from_iter(vec![Witness(2)])), return_values: PublicInputs(BTreeSet::from_iter(vec![Witness(2)])), }; diff --git a/acvm/src/compiler/mod.rs b/acvm/src/compiler/mod.rs index bac25286d..b8111479a 100644 --- a/acvm/src/compiler/mod.rs +++ b/acvm/src/compiler/mod.rs @@ -1,5 +1,8 @@ use acir::{ - circuit::{opcodes::UnsupportedMemoryOpcode, Circuit, Opcode, OpcodeLabel}, + circuit::{ + brillig::BrilligOutputs, directives::Directive, opcodes::UnsupportedMemoryOpcode, Circuit, + Opcode, OpcodeLabel, + }, native_types::{Expression, Witness}, BlackBoxFunc, FieldElement, }; @@ -57,12 +60,18 @@ pub fn compile( let range_optimizer = RangeOptimizer::new(acir); let (acir, opcode_label) = range_optimizer.replace_redundant_ranges(opcode_label); - let transformer = match &np_language { + let mut transformer = match &np_language { crate::Language::R1CS => { let transformer = R1CSTransformer::new(acir); return Ok((transformer.transform(), opcode_label)); } - crate::Language::PLONKCSat { width } => CSatTransformer::new(*width), + crate::Language::PLONKCSat { width } => { + let mut csat = CSatTransformer::new(*width); + for value in acir.circuit_arguments() { + csat.mark_solvable(value); + } + csat + } }; // TODO: the code below is only for CSAT transformer @@ -106,20 +115,114 @@ pub fn compile( transformed_gates.push(Opcode::Arithmetic(gate)); } } - other_gate => { + Opcode::BlackBoxFuncCall(func) => { + match func { + acir::circuit::opcodes::BlackBoxFuncCall::AND { output, .. } + | acir::circuit::opcodes::BlackBoxFuncCall::XOR { output, .. } => { + transformer.mark_solvable(*output) + } + acir::circuit::opcodes::BlackBoxFuncCall::RANGE { .. } => (), + acir::circuit::opcodes::BlackBoxFuncCall::SHA256 { outputs, .. } + | acir::circuit::opcodes::BlackBoxFuncCall::Keccak256 { outputs, .. } + | acir::circuit::opcodes::BlackBoxFuncCall::Keccak256VariableLength { + outputs, + .. + } + | acir::circuit::opcodes::BlackBoxFuncCall::RecursiveAggregation { + output_aggregation_object: outputs, + .. + } + | acir::circuit::opcodes::BlackBoxFuncCall::Blake2s { outputs, .. } => { + for witness in outputs { + transformer.mark_solvable(*witness); + } + } + acir::circuit::opcodes::BlackBoxFuncCall::FixedBaseScalarMul { + outputs, + .. + } + | acir::circuit::opcodes::BlackBoxFuncCall::Pedersen { outputs, .. } => { + transformer.mark_solvable(outputs.0); + transformer.mark_solvable(outputs.1) + } + acir::circuit::opcodes::BlackBoxFuncCall::HashToField128Security { + output, + .. + } + | acir::circuit::opcodes::BlackBoxFuncCall::EcdsaSecp256k1 { output, .. } + | acir::circuit::opcodes::BlackBoxFuncCall::EcdsaSecp256r1 { output, .. } + | acir::circuit::opcodes::BlackBoxFuncCall::SchnorrVerify { output, .. } => { + transformer.mark_solvable(*output) + } + } + new_opcode_labels.push(opcode_label[index]); - transformed_gates.push(other_gate.clone()) + transformed_gates.push(opcode.clone()); + } + Opcode::Directive(directive) => { + match directive { + Directive::Invert { result, .. } => { + transformer.mark_solvable(*result); + } + Directive::Quotient(quotient_directive) => { + transformer.mark_solvable(quotient_directive.q); + transformer.mark_solvable(quotient_directive.r); + } + Directive::ToLeRadix { b, .. } => { + for w in b { + transformer.mark_solvable(*w); + } + } + Directive::PermutationSort { bits, .. } => { + for w in bits { + transformer.mark_solvable(*w); + } + } + Directive::Log(_) => (), + } + new_opcode_labels.push(opcode_label[index]); + transformed_gates.push(opcode.clone()); + } + Opcode::MemoryInit { .. } => { + // `MemoryInit` does not write values to the `WitnessMap` + } + Opcode::MemoryOp { op, .. } => { + for (_, w1, w2) in &op.index.mul_terms { + transformer.mark_solvable(*w1); + transformer.mark_solvable(*w2); + } + for (_, w) in &op.index.linear_combinations { + transformer.mark_solvable(*w); + } + } + + Opcode::Block(_) | Opcode::ROM(_) | Opcode::RAM(_) => { + unimplemented!("Stepwise execution is not compatible with {}", opcode.name()) + } + Opcode::Brillig(brillig) => { + for output in &brillig.outputs { + match output { + BrilligOutputs::Simple(w) => transformer.mark_solvable(*w), + BrilligOutputs::Array(v) => { + for w in v { + transformer.mark_solvable(*w); + } + } + } + } + new_opcode_labels.push(opcode_label[index]); + transformed_gates.push(opcode.clone()); } } } let current_witness_index = next_witness_index - 1; - Ok(( Circuit { current_witness_index, opcodes: transformed_gates, // The optimizer does not add new public inputs + private_parameters: acir.private_parameters, public_parameters: acir.public_parameters, return_values: acir.return_values, }, diff --git a/acvm/src/compiler/optimizers/redundant_range.rs b/acvm/src/compiler/optimizers/redundant_range.rs index 62f149990..c4feb8dac 100644 --- a/acvm/src/compiler/optimizers/redundant_range.rs +++ b/acvm/src/compiler/optimizers/redundant_range.rs @@ -112,6 +112,7 @@ impl RangeOptimizer { Circuit { current_witness_index: self.circuit.current_witness_index, opcodes: optimized_opcodes, + private_parameters: self.circuit.private_parameters, public_parameters: self.circuit.public_parameters, return_values: self.circuit.return_values, }, @@ -139,6 +140,8 @@ fn extract_range_opcode(opcode: &Opcode) -> Option<(Witness, u32)> { #[cfg(test)] mod tests { + use std::collections::BTreeSet; + use crate::compiler::optimizers::redundant_range::{extract_range_opcode, RangeOptimizer}; use acir::{ circuit::{ @@ -163,6 +166,7 @@ mod tests { Circuit { current_witness_index: 1, opcodes, + private_parameters: BTreeSet::new(), public_parameters: PublicInputs::default(), return_values: PublicInputs::default(), } diff --git a/acvm/src/compiler/transformers/csat.rs b/acvm/src/compiler/transformers/csat.rs index 3fd4d037c..b520056cd 100644 --- a/acvm/src/compiler/transformers/csat.rs +++ b/acvm/src/compiler/transformers/csat.rs @@ -1,4 +1,4 @@ -use std::cmp::Ordering; +use std::{cmp::Ordering, collections::HashSet}; use acir::{ native_types::{Expression, Witness}, @@ -17,6 +17,8 @@ use indexmap::IndexMap; // Have a single transformer that you instantiate with a width, then pass many gates through pub(crate) struct CSatTransformer { width: usize, + /// Track the witness that can be solved + solvable_witness: HashSet, } impl CSatTransformer { @@ -24,14 +26,46 @@ impl CSatTransformer { pub(crate) fn new(width: usize) -> CSatTransformer { assert!(width > 2); - CSatTransformer { width } + CSatTransformer { width, solvable_witness: HashSet::new() } + } + + /// Check if the equation 'expression=0' can be solved, and if yes, add the solved witness to set of solvable witness + fn try_solve(&mut self, gate: &Expression) { + let mut unresolved = Vec::new(); + for (_, w1, w2) in &gate.mul_terms { + if !self.solvable_witness.contains(w1) { + unresolved.push(w1); + if !self.solvable_witness.contains(w2) { + return; + } + } + if !self.solvable_witness.contains(w2) { + unresolved.push(w2); + if !self.solvable_witness.contains(w1) { + return; + } + } + } + for (_, w) in &gate.linear_combinations { + if !self.solvable_witness.contains(w) { + unresolved.push(w); + } + } + if unresolved.len() == 1 { + self.mark_solvable(*unresolved[0]); + } + } + + /// Adds the witness to set of solvable witness + pub(crate) fn mark_solvable(&mut self, witness: Witness) { + self.solvable_witness.insert(witness); } // Still missing dead witness optimization. // To do this, we will need the whole set of arithmetic gates // I think it can also be done before the local optimization seen here, as dead variables will come from the user pub(crate) fn transform( - &self, + &mut self, gate: Expression, intermediate_variables: &mut IndexMap, num_witness: &mut u32, @@ -45,6 +79,7 @@ impl CSatTransformer { let mut gate = self.partial_gate_scan_optimization(gate, intermediate_variables, num_witness); gate.sort(); + self.try_solve(&gate); gate } @@ -72,7 +107,7 @@ impl CSatTransformer { // We can no longer extract another full gate, hence the algorithm terminates. Creating two intermediate variables t and t2. // This stage of preprocessing does not guarantee that all polynomials can fit into a gate. It only guarantees that all full gates have been extracted from each polynomial fn full_gate_scan_optimization( - &self, + &mut self, mut gate: Expression, intermediate_variables: &mut IndexMap, num_witness: &mut u32, @@ -98,9 +133,15 @@ impl CSatTransformer { // This will be our new gate which will be equal to `self` except we will have intermediate variables that will be constrained to any // subset of the terms that can be represented as full gates let mut new_gate = Expression::default(); - - while !gate.mul_terms.is_empty() { - let pair = gate.mul_terms[0]; + let mut remaining_mul_terms = Vec::with_capacity(gate.mul_terms.len()); + for pair in gate.mul_terms { + // We want to layout solvable intermediate variable, if we cannot solve one of the witness + // that means the intermediate gate will not be immediately solvable + if !self.solvable_witness.contains(&pair.1) || !self.solvable_witness.contains(&pair.2) + { + remaining_mul_terms.push(pair); + continue; + } // Check if this pair is present in the simplified fan-in // We are assuming that the fan-in/fan-out has been simplified. @@ -153,17 +194,23 @@ impl CSatTransformer { } // Now we have used up 2 spaces in our arithmetic gate. The width now dictates, how many more we can add - let remaining_space = self.width - 2 - 1; // We minus 1 because we need an extra space to contain the intermediate variable - // Keep adding terms until we have no more left, or we reach the width - for _ in 0..remaining_space { + let mut remaining_space = self.width - 2 - 1; // We minus 1 because we need an extra space to contain the intermediate variable + // Keep adding terms until we have no more left, or we reach the width + let mut remaining_linear_terms = + Vec::with_capacity(gate.linear_combinations.len()); + while remaining_space > 0 { if let Some(wire_term) = gate.linear_combinations.pop() { // Add this element into the new gate - intermediate_gate.linear_combinations.push(wire_term); + if self.solvable_witness.contains(&wire_term.1) { + intermediate_gate.linear_combinations.push(wire_term); + remaining_space -= 1; + } else { + remaining_linear_terms.push(wire_term); + } } else { - // No more elements left in the old gate, we could stop the whole function - // We could alternative let it keep going, as it will never reach this branch again since there are no more elements left - // XXX: Future optimization - // no_more_left = true + // No more usable elements left in the old gate + gate.linear_combinations = remaining_linear_terms; + break; } } // Constraint this intermediate_gate to be equal to the temp variable by adding it into the IndexMap @@ -179,12 +226,12 @@ impl CSatTransformer { ); // Add intermediate variable to the new gate instead of the full gate + self.mark_solvable(inter_var.1); new_gate.linear_combinations.push(inter_var); } }; - // Remove this term as we are finished processing it - gate.mul_terms.remove(0); } + gate.mul_terms = remaining_mul_terms; // Add the rest of the elements back into the new_gate new_gate.mul_terms.extend(gate.mul_terms.clone()); @@ -268,7 +315,7 @@ impl CSatTransformer { // // Cases, a lot of mul terms, a lot of fan-in terms, 50/50 fn partial_gate_scan_optimization( - &self, + &mut self, mut gate: Expression, intermediate_variables: &mut IndexMap, num_witness: &mut u32, @@ -283,24 +330,32 @@ impl CSatTransformer { } // 2. Create Intermediate variables for the multiplication gates + let mut remaining_mul_terms = Vec::with_capacity(gate.mul_terms.len()); for mul_term in gate.mul_terms.clone().into_iter() { - let mut intermediate_gate = Expression::default(); - - // Push mul term into the gate - intermediate_gate.mul_terms.push(mul_term); - // Get an intermediate variable which squashes the multiplication term - let inter_var = Self::get_or_create_intermediate_vars( - intermediate_variables, - intermediate_gate, - num_witness, - ); - - // Add intermediate variable as a part of the fan-in for the original gate - gate.linear_combinations.push(inter_var); + if self.solvable_witness.contains(&mul_term.1) + && self.solvable_witness.contains(&mul_term.2) + { + let mut intermediate_gate = Expression::default(); + + // Push mul term into the gate + intermediate_gate.mul_terms.push(mul_term); + // Get an intermediate variable which squashes the multiplication term + let inter_var = Self::get_or_create_intermediate_vars( + intermediate_variables, + intermediate_gate, + num_witness, + ); + + // Add intermediate variable as a part of the fan-in for the original gate + gate.linear_combinations.push(inter_var); + self.mark_solvable(inter_var.1); + } else { + remaining_mul_terms.push(mul_term); + } } // Remove all of the mul terms as we have intermediate variables to represent them now - gate.mul_terms.clear(); + gate.mul_terms = remaining_mul_terms; // We now only have a polynomial with only fan-in/fan-out terms i.e. terms of the form Ax + By + Cd + ... // Lets create intermediate variables if all of them cannot fit into the width @@ -318,29 +373,37 @@ impl CSatTransformer { // Collect as many terms up to the given width-1 and constrain them to an intermediate variable let mut intermediate_gate = Expression::default(); - for _ in 0..(self.width - 1) { - match gate.linear_combinations.pop() { - Some(term) => { - intermediate_gate.linear_combinations.push(term); - } - None => { - break; // We can also do nothing here - } - }; - } - let inter_var = Self::get_or_create_intermediate_vars( - intermediate_variables, - intermediate_gate, - num_witness, - ); + let mut remaining_linear_terms = Vec::with_capacity(gate.linear_combinations.len()); - added.push(inter_var); + for term in gate.linear_combinations { + if self.solvable_witness.contains(&term.1) + && intermediate_gate.linear_combinations.len() < self.width - 1 + { + intermediate_gate.linear_combinations.push(term); + } else { + remaining_linear_terms.push(term); + } + } + gate.linear_combinations = remaining_linear_terms; + let not_full = intermediate_gate.linear_combinations.len() < self.width - 1; + if intermediate_gate.linear_combinations.len() > 1 { + let inter_var = Self::get_or_create_intermediate_vars( + intermediate_variables, + intermediate_gate, + num_witness, + ); + self.mark_solvable(inter_var.1); + added.push(inter_var); + } + // The intermediate gate is not full, but the gate still has too many terms + if not_full && gate.linear_combinations.len() > self.width { + unreachable!("Could not reduce the expression"); + } } // Add back the intermediate variables to // keep consistency with the original equation. gate.linear_combinations.extend(added); - self.partial_gate_scan_optimization(gate, intermediate_variables, num_witness) } } @@ -368,14 +431,17 @@ fn simple_reduction_smoke_test() { let mut num_witness = 4; - let optimizer = CSatTransformer::new(3); + let mut optimizer = CSatTransformer::new(3); + optimizer.mark_solvable(b); + optimizer.mark_solvable(c); + optimizer.mark_solvable(d); let got_optimized_gate_a = optimizer.transform(gate_a, &mut intermediate_variables, &mut num_witness); // a = b + c + d => a - b - c - d = 0 // For width3, the result becomes: - // a - b + e = 0 - // - c - d - e = 0 + // a - d + e = 0 + // - c - b - e = 0 // // a - b + e = 0 let e = Witness(4); @@ -383,7 +449,7 @@ fn simple_reduction_smoke_test() { mul_terms: vec![], linear_combinations: vec![ (FieldElement::one(), a), - (-FieldElement::one(), b), + (-FieldElement::one(), d), (FieldElement::one(), e), ], q_c: FieldElement::zero(), @@ -392,13 +458,52 @@ fn simple_reduction_smoke_test() { assert_eq!(intermediate_variables.len(), 1); - // e = - c - d + // e = - c - b let expected_intermediate_gate = Expression { mul_terms: vec![], - linear_combinations: vec![(-FieldElement::one(), d), (-FieldElement::one(), c)], + linear_combinations: vec![(-FieldElement::one(), c), (-FieldElement::one(), b)], q_c: FieldElement::zero(), }; let (_, normalized_gate) = CSatTransformer::normalize(expected_intermediate_gate); assert!(intermediate_variables.contains_key(&normalized_gate)); assert_eq!(intermediate_variables[&normalized_gate].1, e); } + +#[test] +fn stepwise_reduction_test() { + let a = Witness(0); + let b = Witness(1); + let c = Witness(2); + let d = Witness(3); + let e = Witness(4); + + // a = b + c + d + e; + let gate_a = Expression { + mul_terms: vec![], + linear_combinations: vec![ + (-FieldElement::one(), a), + (FieldElement::one(), b), + (FieldElement::one(), c), + (FieldElement::one(), d), + (FieldElement::one(), e), + ], + q_c: FieldElement::zero(), + }; + + let mut intermediate_variables: IndexMap = IndexMap::new(); + + let mut num_witness = 4; + + let mut optimizer = CSatTransformer::new(3); + optimizer.mark_solvable(a); + optimizer.mark_solvable(c); + optimizer.mark_solvable(d); + optimizer.mark_solvable(e); + let got_optimized_gate_a = + optimizer.transform(gate_a, &mut intermediate_variables, &mut num_witness); + + let witnesses: Vec = + got_optimized_gate_a.linear_combinations.iter().map(|(_, w)| *w).collect(); + // Since b is not known, it cannot be put inside intermediate gates, so it must belong to the transformed gate. + assert!(witnesses.contains(&b)); +} diff --git a/acvm/src/compiler/transformers/fallback.rs b/acvm/src/compiler/transformers/fallback.rs index 79805eb40..747f36ecd 100644 --- a/acvm/src/compiler/transformers/fallback.rs +++ b/acvm/src/compiler/transformers/fallback.rs @@ -67,6 +67,7 @@ impl FallbackTransformer { Circuit { current_witness_index: witness_idx, opcodes: acir_supported_opcodes, + private_parameters: acir.private_parameters, public_parameters: acir.public_parameters, return_values: acir.return_values, },