Skip to content

Commit

Permalink
feat: add assertions for ACVM FunctionInput bit_size (#5864)
Browse files Browse the repository at this point in the history
# Description

- [x] Convert `FunctionInput::constant/witness` assertions to solver
errors

## Problem\*

- Resolves #5793

## Summary\*

- Adds assertions 

## Additional Context



## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: Tom French <tom@tomfren.ch>
Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 11, 2024
1 parent d4832ec commit 8712f4c
Show file tree
Hide file tree
Showing 25 changed files with 342 additions and 174 deletions.
6 changes: 4 additions & 2 deletions acvm-repo/acir/src/circuit/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ use serde::{Deserialize, Serialize};
mod black_box_function_call;
mod memory_operation;

pub use black_box_function_call::{BlackBoxFuncCall, ConstantOrWitnessEnum, FunctionInput};
pub use black_box_function_call::{
BlackBoxFuncCall, ConstantOrWitnessEnum, FunctionInput, InvalidInputBitSize,
};
pub use memory_operation::{BlockId, MemOp};

#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
Expand All @@ -40,7 +42,7 @@ pub enum Opcode<F> {
/// values which define the opcode.
///
/// A general expression of assert-zero opcode is the following:
/// ```
/// ```text
/// \sum_{i,j} {q_M}_{i,j}w_iw_j + \sum_i q_iw_i +q_c = 0
/// ```
///
Expand Down
36 changes: 31 additions & 5 deletions acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::native_types::Witness;
use crate::BlackBoxFunc;
use crate::{AcirField, BlackBoxFunc};

use serde::{Deserialize, Deserializer, Serialize, Serializer};
use thiserror::Error;

// Note: Some functions will not use all of the witness
// So we need to supply how many bits of the witness is needed
Expand All @@ -13,8 +15,8 @@ pub enum ConstantOrWitnessEnum<F> {

#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct FunctionInput<F> {
pub input: ConstantOrWitnessEnum<F>,
pub num_bits: u32,
input: ConstantOrWitnessEnum<F>,
num_bits: u32,
}

impl<F> FunctionInput<F> {
Expand All @@ -25,16 +27,40 @@ impl<F> FunctionInput<F> {
}
}

pub fn input(self) -> ConstantOrWitnessEnum<F> {
self.input
}

pub fn input_ref(&self) -> &ConstantOrWitnessEnum<F> {
&self.input
}

pub fn num_bits(&self) -> u32 {
self.num_bits
}

pub fn witness(witness: Witness, num_bits: u32) -> FunctionInput<F> {
FunctionInput { input: ConstantOrWitnessEnum::Witness(witness), num_bits }
}
}

#[derive(Clone, PartialEq, Eq, Debug, Error)]
#[error("FunctionInput value has too many bits: value: {value}, {value_num_bits} >= {max_bits}")]
pub struct InvalidInputBitSize {
pub value: String,
pub value_num_bits: u32,
pub max_bits: u32,
}

pub fn constant(value: F, num_bits: u32) -> FunctionInput<F> {
FunctionInput { input: ConstantOrWitnessEnum::Constant(value), num_bits }
impl<F: AcirField> FunctionInput<F> {
pub fn constant(value: F, max_bits: u32) -> Result<FunctionInput<F>, InvalidInputBitSize> {
if value.num_bits() <= max_bits {
Ok(FunctionInput { input: ConstantOrWitnessEnum::Constant(value), num_bits: max_bits })
} else {
let value_num_bits = value.num_bits();
let value = format!("{}", value);
Err(InvalidInputBitSize { value, value_num_bits, max_bits })
}
}
}

Expand Down
1 change: 1 addition & 0 deletions acvm-repo/acir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub use acir_field;
pub use acir_field::{AcirField, FieldElement};
pub use brillig;
pub use circuit::black_box_functions::BlackBoxFunc;
pub use circuit::opcodes::InvalidInputBitSize;

#[cfg(test)]
mod reflection {
Expand Down
16 changes: 8 additions & 8 deletions acvm-repo/acir/tests/test_program_serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ fn multi_scalar_mul_circuit() {
let multi_scalar_mul: Opcode<FieldElement> =
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::MultiScalarMul {
points: vec![
FunctionInput::witness(Witness(1), 128),
FunctionInput::witness(Witness(2), 128),
FunctionInput::witness(Witness(1), FieldElement::max_num_bits()),
FunctionInput::witness(Witness(2), FieldElement::max_num_bits()),
FunctionInput::witness(Witness(3), 1),
],
scalars: vec![
FunctionInput::witness(Witness(4), 128),
FunctionInput::witness(Witness(5), 128),
FunctionInput::witness(Witness(4), FieldElement::max_num_bits()),
FunctionInput::witness(Witness(5), FieldElement::max_num_bits()),
],
outputs: (Witness(6), Witness(7), Witness(8)),
});
Expand All @@ -91,10 +91,10 @@ fn multi_scalar_mul_circuit() {
let bytes = Program::serialize_program(&program);

let expected_serialization: Vec<u8> = vec![
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 93, 141, 11, 10, 0, 32, 8, 67, 43, 181, 15, 116, 232,
142, 158, 210, 130, 149, 240, 112, 234, 212, 156, 78, 12, 39, 67, 71, 158, 142, 80, 29, 44,
228, 66, 90, 168, 119, 189, 74, 115, 131, 174, 78, 115, 58, 124, 70, 254, 130, 59, 74, 253,
68, 255, 255, 221, 39, 54, 221, 93, 91, 132, 193, 0, 0, 0,
31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 93, 141, 11, 10, 0, 32, 8, 67, 43, 181, 15, 116, 255,
227, 70, 74, 11, 86, 194, 195, 169, 83, 115, 58, 49, 156, 12, 29, 121, 58, 66, 117, 176,
144, 11, 105, 161, 222, 245, 42, 205, 13, 186, 58, 205, 233, 240, 25, 249, 11, 238, 40,
245, 19, 253, 255, 119, 159, 216, 103, 157, 249, 169, 193, 0, 0, 0,
];

assert_eq!(bytes, expected_serialization)
Expand Down
42 changes: 28 additions & 14 deletions acvm-repo/acvm/src/compiler/optimizers/redundant_range.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use acir::{
circuit::{
opcodes::{BlackBoxFuncCall, ConstantOrWitnessEnum, FunctionInput},
opcodes::{BlackBoxFuncCall, ConstantOrWitnessEnum},
Circuit, Opcode,
},
native_types::Witness,
Expand Down Expand Up @@ -73,10 +73,13 @@ impl<F: AcirField> RangeOptimizer<F> {
}
}

Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE {
input:
FunctionInput { input: ConstantOrWitnessEnum::Witness(witness), num_bits },
}) => Some((*witness, *num_bits)),
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { input }) => {
if let ConstantOrWitnessEnum::Witness(witness) = input.input() {
Some((witness, input.num_bits()))
} else {
None
}
}

_ => None,
}) else {
Expand Down Expand Up @@ -106,17 +109,28 @@ impl<F: AcirField> RangeOptimizer<F> {
let mut new_order_list = Vec::with_capacity(order_list.len());
let mut optimized_opcodes = Vec::with_capacity(self.circuit.opcodes.len());
for (idx, opcode) in self.circuit.opcodes.into_iter().enumerate() {
let (witness, num_bits) = match opcode {
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE {
input:
FunctionInput { input: ConstantOrWitnessEnum::Witness(w), num_bits: bits },
}) => (w, bits),
_ => {
// If its not the range opcode, add it to the opcode
// list and continue;
let (witness, num_bits) = {
// If its not the range opcode, add it to the opcode
// list and continue;
let mut push_non_range_opcode = || {
optimized_opcodes.push(opcode.clone());
new_order_list.push(order_list[idx]);
continue;
};

match opcode {
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { input }) => {
match input.input() {
ConstantOrWitnessEnum::Witness(witness) => (witness, input.num_bits()),
_ => {
push_non_range_opcode();
continue;
}
}
}
_ => {
push_non_range_opcode();
continue;
}
}
};
// If we've already applied the range constraint for this witness then skip this opcode.
Expand Down
2 changes: 1 addition & 1 deletion acvm-repo/acvm/src/pwg/blackbox/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl AcvmBigIntSolver {
) -> Result<(), OpcodeResolutionError<F>> {
let bytes = inputs
.iter()
.map(|input| input_to_value(initial_witness, *input).unwrap().to_u128() as u8)
.map(|input| input_to_value(initial_witness, *input, false).unwrap().to_u128() as u8)
.collect::<Vec<u8>>();
self.bigint_solver.bigint_from_bytes(&bytes, modulus, output)?;
Ok(())
Expand Down
16 changes: 8 additions & 8 deletions acvm-repo/acvm/src/pwg/blackbox/embedded_curve_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ pub(super) fn multi_scalar_mul<F: AcirField>(
outputs: (Witness, Witness, Witness),
) -> Result<(), OpcodeResolutionError<F>> {
let points: Result<Vec<_>, _> =
points.iter().map(|input| input_to_value(initial_witness, *input)).collect();
points.iter().map(|input| input_to_value(initial_witness, *input, false)).collect();
let points: Vec<_> = points?.into_iter().collect();

let scalars: Result<Vec<_>, _> =
scalars.iter().map(|input| input_to_value(initial_witness, *input)).collect();
scalars.iter().map(|input| input_to_value(initial_witness, *input, false)).collect();
let mut scalars_lo = Vec::new();
let mut scalars_hi = Vec::new();
for (i, scalar) in scalars?.into_iter().enumerate() {
Expand Down Expand Up @@ -47,12 +47,12 @@ pub(super) fn embedded_curve_add<F: AcirField>(
input2: [FunctionInput<F>; 3],
outputs: (Witness, Witness, Witness),
) -> Result<(), OpcodeResolutionError<F>> {
let input1_x = input_to_value(initial_witness, input1[0])?;
let input1_y = input_to_value(initial_witness, input1[1])?;
let input1_infinite = input_to_value(initial_witness, input1[2])?;
let input2_x = input_to_value(initial_witness, input2[0])?;
let input2_y = input_to_value(initial_witness, input2[1])?;
let input2_infinite = input_to_value(initial_witness, input2[2])?;
let input1_x = input_to_value(initial_witness, input1[0], false)?;
let input1_y = input_to_value(initial_witness, input1[1], false)?;
let input1_infinite = input_to_value(initial_witness, input1[2], false)?;
let input2_x = input_to_value(initial_witness, input2[0], false)?;
let input2_y = input_to_value(initial_witness, input2[1], false)?;
let input2_infinite = input_to_value(initial_witness, input2[2], false)?;
let (res_x, res_y, res_infinite) = backend.ec_add(
&input1_x,
&input1_y,
Expand Down
9 changes: 5 additions & 4 deletions acvm-repo/acvm/src/pwg/blackbox/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,16 @@ fn get_hash_input<F: AcirField>(
for input in inputs.iter() {
let num_bits = input.num_bits() as usize;

let witness_assignment = input_to_value(initial_witness, *input)?;
let witness_assignment = input_to_value(initial_witness, *input, false)?;
let bytes = witness_assignment.fetch_nearest_bytes(num_bits);
message_input.extend(bytes);
}

// Truncate the message if there is a `message_size` parameter given
match message_size {
Some(input) => {
let num_bytes_to_take = input_to_value(initial_witness, *input)?.to_u128() as usize;
let num_bytes_to_take =
input_to_value(initial_witness, *input, false)?.to_u128() as usize;

// If the number of bytes to take is more than the amount of bytes available
// in the message, then we error.
Expand Down Expand Up @@ -78,7 +79,7 @@ fn to_u32_array<const N: usize, F: AcirField>(
) -> Result<[u32; N], OpcodeResolutionError<F>> {
let mut result = [0; N];
for (it, input) in result.iter_mut().zip(inputs) {
let witness_value = input_to_value(initial_witness, *input)?;
let witness_value = input_to_value(initial_witness, *input, false)?;
*it = witness_value.to_u128() as u32;
}
Ok(result)
Expand Down Expand Up @@ -133,7 +134,7 @@ pub(crate) fn solve_poseidon2_permutation_opcode<F: AcirField>(
// Read witness assignments
let mut state = Vec::new();
for input in inputs.iter() {
let witness_assignment = input_to_value(initial_witness, *input)?;
let witness_assignment = input_to_value(initial_witness, *input, false)?;
state.push(witness_assignment);
}

Expand Down
7 changes: 5 additions & 2 deletions acvm-repo/acvm/src/pwg/blackbox/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,11 @@ fn solve_logic_opcode<F: AcirField>(
result: Witness,
logic_op: impl Fn(F, F) -> F,
) -> Result<(), OpcodeResolutionError<F>> {
let w_l_value = input_to_value(initial_witness, *a)?;
let w_r_value = input_to_value(initial_witness, *b)?;
// TODO(https://github.com/noir-lang/noir/issues/5985): re-enable these once we figure out how to combine these with existing
// noirc_frontend/noirc_evaluator overflow error messages
let skip_bitsize_checks = true;
let w_l_value = input_to_value(initial_witness, *a, skip_bitsize_checks)?;
let w_r_value = input_to_value(initial_witness, *b, skip_bitsize_checks)?;
let assignment = logic_op(w_l_value, w_r_value);

insert_value(&result, assignment, initial_witness)
Expand Down
8 changes: 4 additions & 4 deletions acvm-repo/acvm/src/pwg/blackbox/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ fn first_missing_assignment<F>(
inputs: &[FunctionInput<F>],
) -> Option<Witness> {
inputs.iter().find_map(|input| {
if let ConstantOrWitnessEnum::Witness(witness) = input.input {
if witness_assignments.contains_key(&witness) {
if let ConstantOrWitnessEnum::Witness(ref witness) = input.input_ref() {
if witness_assignments.contains_key(witness) {
None
} else {
Some(witness)
Some(*witness)
}
} else {
None
Expand Down Expand Up @@ -108,7 +108,7 @@ pub(crate) fn solve<F: AcirField>(
for (it, input) in state.iter_mut().zip(inputs.as_ref()) {
let num_bits = input.num_bits() as usize;
assert_eq!(num_bits, 64);
let witness_assignment = input_to_value(initial_witness, *input)?;
let witness_assignment = input_to_value(initial_witness, *input, false)?;
let lane = witness_assignment.try_to_u64();
*it = lane.unwrap();
}
Expand Down
4 changes: 2 additions & 2 deletions acvm-repo/acvm/src/pwg/blackbox/pedersen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub(super) fn pedersen<F: AcirField>(
outputs: (Witness, Witness),
) -> Result<(), OpcodeResolutionError<F>> {
let scalars: Result<Vec<_>, _> =
inputs.iter().map(|input| input_to_value(initial_witness, *input)).collect();
inputs.iter().map(|input| input_to_value(initial_witness, *input, false)).collect();
let scalars: Vec<_> = scalars?.into_iter().collect();

let (res_x, res_y) = backend.pedersen_commitment(&scalars, domain_separator)?;
Expand All @@ -36,7 +36,7 @@ pub(super) fn pedersen_hash<F: AcirField>(
output: Witness,
) -> Result<(), OpcodeResolutionError<F>> {
let scalars: Result<Vec<_>, _> =
inputs.iter().map(|input| input_to_value(initial_witness, *input)).collect();
inputs.iter().map(|input| input_to_value(initial_witness, *input, false)).collect();
let scalars: Vec<_> = scalars?.into_iter().collect();

let res = backend.pedersen_hash(&scalars, domain_separator)?;
Expand Down
5 changes: 4 additions & 1 deletion acvm-repo/acvm/src/pwg/blackbox/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ pub(crate) fn solve_range_opcode<F: AcirField>(
initial_witness: &WitnessMap<F>,
input: &FunctionInput<F>,
) -> Result<(), OpcodeResolutionError<F>> {
let w_value = input_to_value(initial_witness, *input)?;
// TODO(https://github.com/noir-lang/noir/issues/5985):
// re-enable bitsize checks
let skip_bitsize_checks = true;
let w_value = input_to_value(initial_witness, *input, skip_bitsize_checks)?;
if w_value.num_bits() > input.num_bits() {
return Err(OpcodeResolutionError::UnsatisfiedConstrain {
opcode_location: ErrorLocation::Unresolved,
Expand Down
4 changes: 2 additions & 2 deletions acvm-repo/acvm/src/pwg/blackbox/signature/schnorr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ pub(crate) fn schnorr_verify<F: AcirField>(
message: &[FunctionInput<F>],
output: Witness,
) -> Result<(), OpcodeResolutionError<F>> {
let public_key_x: &F = &input_to_value(initial_witness, public_key_x)?;
let public_key_y: &F = &input_to_value(initial_witness, public_key_y)?;
let public_key_x: &F = &input_to_value(initial_witness, public_key_x, false)?;
let public_key_y: &F = &input_to_value(initial_witness, public_key_y, false)?;

let signature = to_u8_array(initial_witness, signature)?;
let message = to_u8_vec(initial_witness, message)?;
Expand Down
4 changes: 2 additions & 2 deletions acvm-repo/acvm/src/pwg/blackbox/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub(crate) fn to_u8_array<const N: usize, F: AcirField>(
) -> Result<[u8; N], OpcodeResolutionError<F>> {
let mut result = [0; N];
for (it, input) in result.iter_mut().zip(inputs) {
let witness_value_bytes = input_to_value(initial_witness, *input)?.to_be_bytes();
let witness_value_bytes = input_to_value(initial_witness, *input, false)?.to_be_bytes();
let byte = witness_value_bytes
.last()
.expect("Field element must be represented by non-zero amount of bytes");
Expand All @@ -23,7 +23,7 @@ pub(crate) fn to_u8_vec<F: AcirField>(
) -> Result<Vec<u8>, OpcodeResolutionError<F>> {
let mut result = Vec::with_capacity(inputs.len());
for input in inputs {
let witness_value_bytes = input_to_value(initial_witness, *input)?.to_be_bytes();
let witness_value_bytes = input_to_value(initial_witness, *input, false)?.to_be_bytes();
let byte = witness_value_bytes
.last()
.expect("Field element must be represented by non-zero amount of bytes");
Expand Down
Loading

0 comments on commit 8712f4c

Please sign in to comment.