From 4b43bd87637511d47af5332418b952006fcffd4c Mon Sep 17 00:00:00 2001 From: dbanks12 Date: Thu, 24 Oct 2024 18:22:37 +0000 Subject: [PATCH] chore: use Brillig opcode when possible for less-than operations on fields --- .../acir/src/circuit/black_box_functions.rs | 5 +++ .../src/curve_specific_solver.rs | 4 +++ .../bn254_blackbox_solver/src/lib.rs | 11 ++++++ .../brillig/brillig_gen/brillig_black_box.rs | 13 +++++++ .../noirc_evaluator/src/brillig/brillig_ir.rs | 8 +++++ .../src/brillig/brillig_ir/instructions.rs | 20 +++++++++++ .../ssa/acir_gen/acir_ir/generated_acir.rs | 5 +++ .../src/ssa/ir/instruction/call.rs | 1 + noir/noir-repo/noir_stdlib/src/field/bn254.nr | 34 ++++++++++++------- noir/noir-repo/tooling/lsp/src/solver.rs | 8 +++++ 10 files changed, 96 insertions(+), 13 deletions(-) diff --git a/noir/noir-repo/acvm-repo/acir/src/circuit/black_box_functions.rs b/noir/noir-repo/acvm-repo/acir/src/circuit/black_box_functions.rs index 2e5a94f1c50b..7c4b2e72743e 100644 --- a/noir/noir-repo/acvm-repo/acir/src/circuit/black_box_functions.rs +++ b/noir/noir-repo/acvm-repo/acir/src/circuit/black_box_functions.rs @@ -186,6 +186,9 @@ pub enum BlackBoxFunc { /// - state: [(witness, 32); 8] /// - output: [(witness, 32); 8] Sha256Compression, + + // Less than comparison between two field elements (only usable in unconstrained) + FieldLessThan, } impl std::fmt::Display for BlackBoxFunc { @@ -218,6 +221,7 @@ impl BlackBoxFunc { BlackBoxFunc::BigIntToLeBytes => "bigint_to_le_bytes", BlackBoxFunc::Poseidon2Permutation => "poseidon2_permutation", BlackBoxFunc::Sha256Compression => "sha256_compression", + BlackBoxFunc::FieldLessThan => "field_less_than", } } @@ -244,6 +248,7 @@ impl BlackBoxFunc { "bigint_to_le_bytes" => Some(BlackBoxFunc::BigIntToLeBytes), "poseidon2_permutation" => Some(BlackBoxFunc::Poseidon2Permutation), "sha256_compression" => Some(BlackBoxFunc::Sha256Compression), + "field_less_than" => Some(BlackBoxFunc::FieldLessThan), _ => None, } } diff --git a/noir/noir-repo/acvm-repo/blackbox_solver/src/curve_specific_solver.rs b/noir/noir-repo/acvm-repo/blackbox_solver/src/curve_specific_solver.rs index 869017f52eec..52bb7608f0ff 100644 --- a/noir/noir-repo/acvm-repo/blackbox_solver/src/curve_specific_solver.rs +++ b/noir/noir-repo/acvm-repo/blackbox_solver/src/curve_specific_solver.rs @@ -34,6 +34,7 @@ pub trait BlackBoxFunctionSolver { _inputs: &[F], _len: u32, ) -> Result, BlackBoxResolutionError>; + fn field_less_than(&self, _input_x: &F, _input_y: &F) -> Result; } pub struct StubbedBlackBoxSolver; @@ -83,4 +84,7 @@ impl BlackBoxFunctionSolver for StubbedBlackBoxSolver { ) -> Result, BlackBoxResolutionError> { Err(Self::fail(BlackBoxFunc::Poseidon2Permutation)) } + fn field_less_than(&self, _input_x: &F, _input_y: &F) -> Result { + Err(Self::fail(BlackBoxFunc::FieldLessThan)) + } } diff --git a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/lib.rs b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/lib.rs index d74c17a52b50..5bf0bf25dfad 100644 --- a/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/lib.rs +++ b/noir/noir-repo/acvm-repo/bn254_blackbox_solver/src/lib.rs @@ -74,4 +74,15 @@ impl BlackBoxFunctionSolver for Bn254BlackBoxSolver { ) -> Result, BlackBoxResolutionError> { poseidon2_permutation(inputs, len) } + + fn field_less_than( + &self, + _input_x: &FieldElement, + _input_y: &FieldElement, + ) -> Result { + Err(BlackBoxResolutionError::Failed( + acir::BlackBoxFunc::FieldLessThan, + "field_less_than is not supported in acir, only in brillig/unconstrained".to_string(), + )) + } } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs index 10c0e8b8e8c4..3297cfb2d53f 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs @@ -412,6 +412,19 @@ pub(crate) fn convert_black_box_call { + if let ( + [BrilligVariable::SingleAddr(input_x), BrilligVariable::SingleAddr(input_y)], + [BrilligVariable::SingleAddr(output)], + ) = (function_arguments, function_results) + { + brillig_context.field_less_than_instruction(*input_x, *input_y, *output); + } else { + unreachable!( + "ICE: FieldLessThan expects two register arguments one register result" + ) + } + } } } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir.rs index 4964ff27f608..c27749518cd3 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir.rs @@ -201,6 +201,14 @@ pub(crate) mod tests { ) -> Result, BlackBoxResolutionError> { Ok(vec![0_u128.into(), 1_u128.into(), 2_u128.into(), 3_u128.into()]) } + + fn field_less_than( + &self, + _input_x: &FieldElement, + _input_y: &FieldElement, + ) -> Result { + Ok(true) + } } pub(crate) fn create_context(id: FunctionId) -> BrilligContext { diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs index 1ac672687f34..f972a4568967 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs @@ -38,6 +38,26 @@ impl BrilligContext< self.binary(lhs, rhs, result, operation); } + /// Procresses a field less than instruction. + /// + /// Just performs an assertion on bit size and then forwards to binary_instruction(). + pub(crate) fn field_less_than_instruction( + &mut self, + lhs: SingleAddrVariable, + rhs: SingleAddrVariable, + result: SingleAddrVariable, + ) { + let max_field_size = FieldElement::max_num_bits(); + assert!( + lhs.bit_size == max_field_size && rhs.bit_size == max_field_size, + "Expected bit sizes lhs and rhs to be {}, got {} and {} for 'field less than' operation", + lhs.bit_size, + rhs.bit_size, + max_field_size, + ); + self.binary_instruction(lhs, rhs, result, BrilligBinaryOp::LessThan); + } + /// Processes a not instruction. /// /// Not is computed using a subtraction operation as there is no native not instruction diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs index 01fcaef90425..51456004654c 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs @@ -343,6 +343,7 @@ impl GeneratedAcir { .expect("Compiler should generate correct size inputs"), outputs: outputs.try_into().expect("Compiler should generate correct size outputs"), }, + BlackBoxFunc::FieldLessThan => panic!("FieldLessThan is not supported in ACIR"), }; self.push_opcode(AcirOpcode::BlackBoxFuncCall(black_box_func_call)); @@ -666,6 +667,8 @@ fn black_box_func_expected_input_size(name: BlackBoxFunc) -> Option { // FromLeBytes takes a variable array of bytes as input BlackBoxFunc::BigIntFromLeBytes => None, + + BlackBoxFunc::FieldLessThan => panic!("FieldLessThan is not supported in ACIR"), } } @@ -714,6 +717,8 @@ fn black_box_expected_output_size(name: BlackBoxFunc) -> Option { // AES encryption returns a variable number of outputs BlackBoxFunc::AES128Encrypt => None, + + BlackBoxFunc::FieldLessThan => panic!("FieldLessThan is not supported in ACIR"), } } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs index 3c931f8cada1..dd7c8fbcd51d 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs @@ -577,6 +577,7 @@ fn simplify_black_box_func( } BlackBoxFunc::Sha256Compression => SimplifyResult::None, //TODO(Guillaume) BlackBoxFunc::AES128Encrypt => SimplifyResult::None, + BlackBoxFunc::FieldLessThan => panic!("FieldLessThan is not supported in ACIR"), } } diff --git a/noir/noir-repo/noir_stdlib/src/field/bn254.nr b/noir/noir-repo/noir_stdlib/src/field/bn254.nr index 9349e67aed38..ea36ca8c686d 100644 --- a/noir/noir-repo/noir_stdlib/src/field/bn254.nr +++ b/noir/noir-repo/noir_stdlib/src/field/bn254.nr @@ -1,4 +1,5 @@ use crate::runtime::is_unconstrained; +use crate::field::bn254::field_less_than; // The low and high decomposition of the field modulus global PLO: Field = 53438638232309528389504892708671455233; @@ -25,23 +26,30 @@ pub(crate) unconstrained fn decompose_hint(x: Field) -> (Field, Field) { compute_decomposition(x) } +#[foreign(field_less_than)] +pub fn field_less_than(_x: Field, _y: Field) -> bool {} + fn compute_lt(x: Field, y: Field, num_bytes: u32) -> bool { - let x_bytes: [u8; 32] = x.to_le_bytes(); - let y_bytes: [u8; 32] = y.to_le_bytes(); - let mut x_is_lt = false; - let mut done = false; - for i in 0..num_bytes { - if (!done) { - let x_byte = x_bytes[num_bytes - 1 - i]; - let y_byte = y_bytes[num_bytes - 1 - i]; - let bytes_match = x_byte == y_byte; - if !bytes_match { - x_is_lt = x_byte < y_byte; - done = true; + if is_unconstrained() { + field_less_than(x, y) + } else { + let x_bytes: [u8; 32] = x.to_le_bytes(); + let y_bytes: [u8; 32] = y.to_le_bytes(); + let mut x_is_lt = false; + let mut done = false; + for i in 0..num_bytes { + if (!done) { + let x_byte = x_bytes[num_bytes - 1 - i]; + let y_byte = y_bytes[num_bytes - 1 - i]; + let bytes_match = x_byte == y_byte; + if !bytes_match { + x_is_lt = x_byte < y_byte; + done = true; + } } } + x_is_lt } - x_is_lt } fn compute_lte(x: Field, y: Field, num_bytes: u32) -> bool { diff --git a/noir/noir-repo/tooling/lsp/src/solver.rs b/noir/noir-repo/tooling/lsp/src/solver.rs index 3c2d7499880f..060fb0c20a4a 100644 --- a/noir/noir-repo/tooling/lsp/src/solver.rs +++ b/noir/noir-repo/tooling/lsp/src/solver.rs @@ -50,4 +50,12 @@ impl BlackBoxFunctionSolver for WrapperSolver { ) -> Result, acvm::BlackBoxResolutionError> { self.0.poseidon2_permutation(inputs, len) } + + fn field_less_than( + &self, + _input_x: &acvm::FieldElement, + _input_y: &acvm::FieldElement, + ) -> Result { + self.0.field_less_than(_input_x, _input_y) + } }