diff --git a/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs b/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs index fbe179d7c04..8bb9a680ea9 100644 --- a/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs +++ b/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs @@ -42,6 +42,10 @@ impl FunctionInput { pub fn witness(witness: Witness, num_bits: u32) -> FunctionInput { FunctionInput { input: ConstantOrWitnessEnum::Witness(witness), num_bits } } + + pub fn is_constant(&self) -> bool { + matches!(self.input, ConstantOrWitnessEnum::Constant(_)) + } } #[derive(Clone, PartialEq, Eq, Debug, Error)] diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs index b0f283eeaeb..1069416b7b8 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs @@ -8,7 +8,9 @@ use crate::ssa::ir::dfg::CallStack; use crate::ssa::ir::types::Type as SsaType; use crate::ssa::ir::{instruction::Endian, types::NumericType}; use acvm::acir::circuit::brillig::{BrilligFunctionId, BrilligInputs, BrilligOutputs}; -use acvm::acir::circuit::opcodes::{AcirFunctionId, BlockId, BlockType, MemOp}; +use acvm::acir::circuit::opcodes::{ + AcirFunctionId, BlockId, BlockType, ConstantOrWitnessEnum, MemOp, +}; use acvm::acir::circuit::{AssertionPayload, ExpressionOrMemory, ExpressionWidth, Opcode}; use acvm::brillig_vm::{MemoryValue, VMStatus, VM}; use acvm::{ @@ -1459,22 +1461,7 @@ impl AcirContext { } _ => (vec![], vec![]), }; - // Allow constant inputs for most blackbox - // EmbeddedCurveAdd needs to be fixed first in bb - // Poseidon2Permutation requires witness input - let allow_constant_inputs = matches!( - name, - BlackBoxFunc::MultiScalarMul - | BlackBoxFunc::Keccakf1600 - | BlackBoxFunc::Blake2s - | BlackBoxFunc::Blake3 - | BlackBoxFunc::AND - | BlackBoxFunc::XOR - | BlackBoxFunc::AES128Encrypt - | BlackBoxFunc::EmbeddedCurveAdd - ); - // Convert `AcirVar` to `FunctionInput` - let inputs = self.prepare_inputs_for_black_box_func_call(inputs, allow_constant_inputs)?; + let inputs = self.prepare_inputs_for_black_box_func(inputs, name)?; // Call Black box with `FunctionInput` let mut results = vecmap(&constant_outputs, |c| self.add_constant(*c)); let outputs = self.acir_ir.call_black_box( @@ -1496,6 +1483,34 @@ impl AcirContext { Ok(results) } + fn prepare_inputs_for_black_box_func( + &mut self, + inputs: Vec, + name: BlackBoxFunc, + ) -> Result>>, RuntimeError> { + // Allow constant inputs for most blackbox, but: + // - EmbeddedCurveAdd requires all-or-nothing constant inputs + // - Poseidon2Permutation requires witness input + let allow_constant_inputs = matches!( + name, + BlackBoxFunc::MultiScalarMul + | BlackBoxFunc::Keccakf1600 + | BlackBoxFunc::Blake2s + | BlackBoxFunc::Blake3 + | BlackBoxFunc::AND + | BlackBoxFunc::XOR + | BlackBoxFunc::AES128Encrypt + | BlackBoxFunc::EmbeddedCurveAdd + ); + // Convert `AcirVar` to `FunctionInput` + let mut inputs = + self.prepare_inputs_for_black_box_func_call(inputs, allow_constant_inputs)?; + if name == BlackBoxFunc::EmbeddedCurveAdd { + inputs = self.all_or_nothing_for_ec_add(inputs)?; + } + Ok(inputs) + } + /// Black box function calls expect their inputs to be in a specific data structure (FunctionInput). /// /// This function will convert `AcirVar` into `FunctionInput` for a blackbox function call. @@ -1536,6 +1551,41 @@ impl AcirContext { Ok(witnesses) } + /// EcAdd has 6 inputs representing the two points to add + /// Each point must be either all constant, or all witnesses + fn all_or_nothing_for_ec_add( + &mut self, + inputs: Vec>>, + ) -> Result>>, RuntimeError> { + let mut has_constant = false; + let mut has_witness = false; + let mut result = inputs.clone(); + for (i, input) in inputs.iter().enumerate() { + if input[0].is_constant() { + has_constant = true; + } else { + has_witness = true; + } + if i % 3 == 2 { + if has_constant && has_witness { + // Convert the constants to witness if mixed constant and witness, + for j in i - 2..i + 1 { + if let ConstantOrWitnessEnum::Constant(constant) = inputs[j][0].input() { + let constant = self.add_constant(constant); + let witness_var = self.get_or_create_witness_var(constant)?; + let witness = self.var_to_witness(witness_var)?; + result[j] = + vec![FunctionInput::witness(witness, inputs[j][0].num_bits())]; + } + } + } + has_constant = false; + has_witness = false; + } + } + Ok(result) + } + /// Returns a vector of `AcirVar`s constrained to be the decomposition of the given input /// over given radix. ///