Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Use helper functions for getting values of AcirVars #2194

Merged
merged 3 commits into from
Aug 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 37 additions & 70 deletions crates/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,29 @@ impl AcirContext {
self.add_data(var_data)
}

pub(crate) fn get_location(&mut self) -> Option<Location> {
pub(crate) fn get_location(&self) -> Option<Location> {
self.acir_ir.current_location
}

pub(crate) fn set_location(&mut self, location: Option<Location>) {
self.acir_ir.current_location = location;
}

/// Converts an [`AcirVar`] to a [`Witness`]
fn var_to_witness(&mut self, var: AcirVar) -> Result<Witness, InternalError> {
let expression = self.var_to_expression(var)?;
Ok(self.acir_ir.get_or_create_witness(&expression))
}

/// Converts an [`AcirVar`] to an [`Expression`]
fn var_to_expression(&self, var: AcirVar) -> Result<Expression, InternalError> {
let var_data = match self.vars.get(&var) {
Some(var_data) => var_data,
None => return Err(InternalError::UndeclaredAcirVar { location: self.get_location() }),
};
Ok(var_data.to_expression().into_owned())
}

/// True if the given AcirVar refers to a constant one value
pub(crate) fn is_constant_one(&self, var: &AcirVar) -> bool {
match self.vars[var] {
Expand Down Expand Up @@ -246,11 +261,8 @@ impl AcirContext {
/// Returns an `AcirVar` that is `1` if `lhs` equals `rhs` and
/// 0 otherwise.
pub(crate) fn eq_var(&mut self, lhs: AcirVar, rhs: AcirVar) -> Result<AcirVar, RuntimeError> {
let lhs_data = &self.vars[&lhs];
let rhs_data = &self.vars[&rhs];

let lhs_expr = lhs_data.to_expression();
let rhs_expr = rhs_data.to_expression();
let lhs_expr = self.var_to_expression(lhs)?;
let rhs_expr = self.var_to_expression(rhs)?;

let is_equal_witness = self.acir_ir.is_equal(&lhs_expr, &rhs_expr);
let result_var = self.add_data(AcirVarData::Witness(is_equal_witness));
Expand Down Expand Up @@ -479,13 +491,9 @@ impl AcirContext {
bit_size: u32,
predicate: AcirVar,
) -> Result<(AcirVar, AcirVar), RuntimeError> {
let lhs_data = &self.vars[&lhs];
let rhs_data = &self.vars[&rhs];
let predicate_data = &self.vars[&predicate];

let lhs_expr = lhs_data.to_expression();
let rhs_expr = rhs_data.to_expression();
let predicate_expr = predicate_data.to_expression();
let lhs_expr = self.var_to_expression(lhs)?;
let rhs_expr = self.var_to_expression(rhs)?;
let predicate_expr = self.var_to_expression(predicate)?;

let (quotient, remainder) =
self.acir_ir.euclidean_division(&lhs_expr, &rhs_expr, bit_size, &predicate_expr)?;
Expand All @@ -500,24 +508,15 @@ impl AcirContext {
/// and |remainder| < |rhs|
/// and remainder has the same sign than lhs
/// Note that this is not the euclidian division, where we have instead remainder < |rhs|
///
///
///
///

fn signed_division_var(
&mut self,
lhs: AcirVar,
rhs: AcirVar,
bit_size: u32,
) -> Result<(AcirVar, AcirVar), RuntimeError> {
let lhs_data = &self.vars[&lhs].clone();
let rhs_data = &self.vars[&rhs].clone();
let l_witness = self.var_to_witness(lhs)?;
let r_witness = self.var_to_witness(rhs)?;

let lhs_expr = lhs_data.to_expression();
let rhs_expr = rhs_data.to_expression();
let l_witness = self.acir_ir.get_or_create_witness(&lhs_expr);
let r_witness = self.acir_ir.get_or_create_witness(&rhs_expr);
assert_ne!(bit_size, 0, "signed integer should have at least one bit");
let (q, r) =
self.acir_ir.signed_division(&l_witness.into(), &r_witness.into(), bit_size)?;
Expand Down Expand Up @@ -571,18 +570,7 @@ impl AcirContext {
/// Converts the `AcirVar` to a `Witness` if it hasn't been already, and appends it to the
/// `GeneratedAcir`'s return witnesses.
pub(crate) fn return_var(&mut self, acir_var: AcirVar) -> Result<(), InternalError> {
let acir_var_data = match self.vars.get(&acir_var) {
Some(acir_var_data) => acir_var_data,
None => return Err(InternalError::UndeclaredAcirVar { location: self.get_location() }),
};
// TODO: Add caching to prevent expressions from being needlessly duplicated
let witness = match acir_var_data {
AcirVarData::Const(constant) => {
self.acir_ir.get_or_create_witness(&Expression::from(*constant))
}
AcirVarData::Expr(expr) => self.acir_ir.get_or_create_witness(expr),
AcirVarData::Witness(witness) => *witness,
};
let witness = self.var_to_witness(acir_var)?;
self.acir_ir.push_return_witness(witness);
Ok(())
}
Expand All @@ -593,11 +581,9 @@ impl AcirContext {
variable: AcirVar,
numeric_type: &NumericType,
) -> Result<AcirVar, RuntimeError> {
let data = &self.vars[&variable];
match numeric_type {
NumericType::Signed { bit_size } | NumericType::Unsigned { bit_size } => {
let data_expr = data.to_expression();
let witness = self.acir_ir.get_or_create_witness(&data_expr);
let witness = self.var_to_witness(variable)?;
self.acir_ir.range_constraint(witness, *bit_size)?;
}
NumericType::NativeField => {
Expand All @@ -616,8 +602,7 @@ impl AcirContext {
rhs: u32,
max_bit_size: u32,
) -> Result<AcirVar, RuntimeError> {
let lhs_data = &self.vars[&lhs];
let lhs_expr = lhs_data.to_expression();
let lhs_expr = self.var_to_expression(lhs)?;

// 2^{rhs}
let divisor = FieldElement::from(2_i128).pow(&FieldElement::from(rhs as i128));
Expand All @@ -641,17 +626,12 @@ impl AcirContext {
bit_size: u32,
predicate: AcirVar,
) -> Result<AcirVar, RuntimeError> {
let lhs_data = &self.vars[&lhs];
let rhs_data = &self.vars[&rhs];

let lhs_expr = lhs_data.to_expression();
let rhs_expr = rhs_data.to_expression();

let predicate_data = &self.vars[&predicate];
let predicate = predicate_data.to_expression().into_owned();
let lhs_expr = self.var_to_expression(lhs)?;
let rhs_expr = self.var_to_expression(rhs)?;
let predicate_expr = self.var_to_expression(predicate)?;

let is_greater_than_eq =
self.acir_ir.more_than_eq_comparison(&lhs_expr, &rhs_expr, bit_size, predicate)?;
self.acir_ir.more_than_eq_comparison(&lhs_expr, &rhs_expr, bit_size, predicate_expr)?;

Ok(self.add_data(AcirVarData::Witness(is_greater_than_eq)))
}
Expand Down Expand Up @@ -736,13 +716,10 @@ impl AcirContext {
for input in inputs {
let mut single_val_witnesses = Vec::new();
for (input, typ) in input.flatten() {
let var_data = &self.vars[&input];

// Intrinsics only accept Witnesses. This is not a limitation of the
// intrinsics, its just how we have defined things. Ideally, we allow
// constants too.
let expr = var_data.to_expression();
let witness = self.acir_ir.get_or_create_witness(&expr);
let witness = self.var_to_witness(input)?;
let num_bits = typ.bit_size();
single_val_witnesses.push(FunctionInput { witness, num_bits });
}
Expand Down Expand Up @@ -785,10 +762,10 @@ impl AcirContext {
}
};

let input_expr = &self.vars[&input_var].to_expression();
let input_expr = self.var_to_expression(input_var)?;

let bit_size = u32::BITS - (radix - 1).leading_zeros();
let limbs = self.acir_ir.radix_le_decompose(input_expr, radix, limb_count, bit_size)?;
let limbs = self.acir_ir.radix_le_decompose(&input_expr, radix, limb_count, bit_size)?;

let mut limb_vars = vecmap(limbs, |witness| {
let witness = self.add_data(AcirVarData::Witness(witness));
Expand Down Expand Up @@ -873,9 +850,7 @@ impl AcirContext {
outputs: Vec<AcirType>,
) -> Result<Vec<AcirValue>, InternalError> {
let b_inputs = try_vecmap(inputs, |i| match i {
AcirValue::Var(var, _) => {
Ok(BrilligInputs::Single(self.vars[&var].to_expression().into_owned()))
}
AcirValue::Var(var, _) => Ok(BrilligInputs::Single(self.var_to_expression(var)?)),
AcirValue::Array(vars) => {
let mut var_expressions: Vec<Expression> = Vec::new();
for var in vars {
Expand Down Expand Up @@ -904,7 +879,7 @@ impl AcirContext {
acir_value
}
});
let predicate = self.vars[&predicate].to_expression().into_owned();
let predicate = self.var_to_expression(predicate)?;
self.acir_ir.brillig(Some(predicate), code, b_inputs, b_outputs);

Ok(outputs_var)
Expand All @@ -917,7 +892,7 @@ impl AcirContext {
) -> Result<(), InternalError> {
match input {
AcirValue::Var(var, _) => {
var_expressions.push(self.vars[&var].to_expression().into_owned());
var_expressions.push(self.var_to_expression(var)?);
}
AcirValue::Array(vars) => {
for var in vars {
Expand Down Expand Up @@ -988,7 +963,7 @@ impl AcirContext {
) -> Result<Vec<AcirVar>, RuntimeError> {
let len = inputs.len();
// Convert the inputs into expressions
let inputs_expr = vecmap(inputs, |input| self.vars[&input].to_expression().into_owned());
let inputs_expr = try_vecmap(inputs, |input| self.var_to_expression(input))?;
// Generate output witnesses
let outputs_witness = vecmap(0..len, |_| self.acir_ir.next_witness_index());
let output_expr =
Expand All @@ -1007,14 +982,6 @@ impl AcirContext {

Ok(outputs_var)
}
/// Converts an AcirVar to a Witness
fn var_to_witness(&mut self, var: AcirVar) -> Result<Witness, InternalError> {
let var_data = match self.vars.get(&var) {
Some(var_data) => var_data,
None => return Err(InternalError::UndeclaredAcirVar { location: self.get_location() }),
};
Ok(self.acir_ir.get_or_create_witness(&var_data.to_expression()))
}

/// Constrain lhs to be less than rhs
fn less_than_constrain(
Expand Down