Skip to content

Commit

Permalink
fix(ssa refactor): Reset condition value during flattening pass (#1811)
Browse files Browse the repository at this point in the history
* Fix flattening forgetting to reset condition value

* Simplify test a little

* Add other test for 1792

* Fix test

* Grammar
  • Loading branch information
jfecher authored Jun 23, 2023
1 parent ca84c8d commit 2e330e0
Showing 1 changed file with 153 additions and 16 deletions.
169 changes: 153 additions & 16 deletions crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,11 @@ impl<'f> Context<'f> {
let else_branch =
self.inline_branch(block, else_block, old_condition, else_condition, zero);

// We must remember to reset whether side effects are enabled when both branches
// end, in addition to resetting the value of old_condition since it is set to
// known to be true/false within the then/else branch respectively.
self.insert_current_side_effects_enabled();
self.inserter.map_value(old_condition, old_condition);

// While there is a condition on the stack we don't compile outside the condition
// until it is popped. This ensures we inline the full then and else branches
Expand Down Expand Up @@ -494,10 +498,16 @@ impl<'f> Context<'f> {
let old_stores = std::mem::take(&mut self.store_values);
let old_allocations = std::mem::take(&mut self.local_allocations);

// Remember the old condition value is now known to be true/false within this branch
let known_value =
self.inserter.function.dfg.make_constant(condition_value, Type::bool());
self.inserter.map_value(old_condition, known_value);
// Optimization: within the then branch we know the condition to be true, so replace
// any references of it within this branch with true. Likewise, do the same with false
// with the else branch. We must be careful not to replace the condition if it is a
// known constant, otherwise we can end up setting 1 = 0 or vice-versa.
if self.inserter.function.dfg.get_numeric_constant(old_condition).is_none() {
let known_value =
self.inserter.function.dfg.make_constant(condition_value, Type::bool());

self.inserter.map_value(old_condition, known_value);
}

let final_block = self.inline_block(destination, &[]);

Expand Down Expand Up @@ -670,11 +680,12 @@ impl<'f> Context<'f> {

#[cfg(test)]
mod test {
use std::rc::Rc;

use crate::ssa_refactor::{
ir::{
dfg::DataFlowGraph,
function::RuntimeType,
function::{Function, RuntimeType},
instruction::{BinaryOp, Instruction, Intrinsic, TerminatorInstruction},
map::Id,
types::Type,
Expand Down Expand Up @@ -837,12 +848,7 @@ mod test {
let main = ssa.main();
assert_eq!(main.reachable_blocks().len(), 1);

let store_count = main.dfg[main.entry_block()]
.instructions()
.iter()
.filter(|id| matches!(&main.dfg[**id], Instruction::Store { .. }))
.count();

let store_count = count_instruction(main, |ins| matches!(ins, Instruction::Store { .. }));
assert_eq!(store_count, 2);
}

Expand Down Expand Up @@ -921,13 +927,16 @@ mod test {
let main = ssa.main();
assert_eq!(main.reachable_blocks().len(), 1);

let store_count = main.dfg[main.entry_block()]
let store_count = count_instruction(main, |ins| matches!(ins, Instruction::Store { .. }));
assert_eq!(store_count, 4);
}

fn count_instruction(function: &Function, f: impl Fn(&Instruction) -> bool) -> usize {
function.dfg[function.entry_block()]
.instructions()
.iter()
.filter(|id| matches!(&main.dfg[**id], Instruction::Store { .. }))
.count();

assert_eq!(store_count, 4);
.filter(|id| f(&function.dfg[**id]))
.count()
}

#[test]
Expand Down Expand Up @@ -1196,4 +1205,132 @@ mod test {
_ => Vec::new(),
}
}

#[test]
fn should_not_merge_away_constraints() {
// Very simplified derived regression test for #1792
// Tests that it does not simplify to a true constraint an always-false constraint
// The original function is replaced by the following:
// fn main f1 {
// b0():
// jmpif u1 0 then: b1, else: b2
// b1():
// jmp b2()
// b2():
// constrain u1 0 // was incorrectly removed
// return
// }
let main_id = Id::test_new(1);
let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir);

builder.insert_block(); // entry

let b1 = builder.insert_block();
let b2 = builder.insert_block();
let v_false = builder.numeric_constant(0_u128, Type::bool());
builder.terminate_with_jmpif(v_false, b1, b2);

builder.switch_to_block(b1);
builder.terminate_with_jmp(b2, vec![]);

builder.switch_to_block(b2);
builder.insert_constrain(v_false); // should not be removed
builder.terminate_with_return(vec![]);

let ssa = builder.finish().flatten_cfg();
let main = ssa.main();

// Assert we have not incorrectly removed a constraint:
use Instruction::Constrain;
let constrain_count = count_instruction(main, |ins| matches!(ins, Constrain(_)));
assert_eq!(constrain_count, 1);
}

#[test]
fn should_not_merge_incorrectly_to_false() {
// Regression test for #1792
// Tests that it does not simplify a true constraint an always-false constraint
// fn main f1 {
// b0():
// v4 = call pedersen([Field 0], u32 0)
// v5 = array_get v4, index Field 0
// v6 = cast v5 as u32
// v8 = mod v6, u32 2
// v9 = cast v8 as u1
// v10 = allocate
// store Field 0 at v10
// jmpif v9 then: b1, else: b2
// b1():
// v14 = add v5, Field 1
// store v14 at v10
// jmp b3()
// b3():
// v12 = eq v9, u1 1
// constrain v12
// return
// b2():
// store Field 0 at v10
// jmp b3()
// }
let main_id = Id::test_new(1);
let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir);

builder.insert_block(); // b0
let b1 = builder.insert_block();
let b2 = builder.insert_block();
let b3 = builder.insert_block();

let element_type = Rc::new(vec![Type::field()]);
let zero = builder.field_constant(0_u128);
let zero_array = builder.array_constant(im::Vector::unit(zero), element_type.clone());
let i_zero = builder.numeric_constant(0_u128, Type::unsigned(32));
let pedersen =
builder.import_intrinsic_id(Intrinsic::BlackBox(acvm::acir::BlackBoxFunc::Pedersen));
let v4 = builder.insert_call(
pedersen,
vec![zero_array, i_zero],
vec![Type::Array(element_type, 2)],
)[0];
let v5 = builder.insert_array_get(v4, zero, Type::field());
let v6 = builder.insert_cast(v5, Type::unsigned(32));
let i_two = builder.numeric_constant(2_u128, Type::unsigned(32));
let v8 = builder.insert_binary(v6, BinaryOp::Mod, i_two);
let v9 = builder.insert_cast(v8, Type::bool());

let v10 = builder.insert_allocate();
builder.insert_store(v10, zero);

builder.terminate_with_jmpif(v9, b1, b2);

builder.switch_to_block(b1);
let one = builder.field_constant(1_u128);
let v14 = builder.insert_binary(v5, BinaryOp::Add, one);
builder.insert_store(v10, v14);
builder.terminate_with_jmp(b3, vec![]);

builder.switch_to_block(b2);
builder.insert_store(v10, zero);
builder.terminate_with_jmp(b3, vec![]);

builder.switch_to_block(b3);
let b_true = builder.numeric_constant(1_u128, Type::unsigned(1));
let v12 = builder.insert_binary(v9, BinaryOp::Eq, b_true);
builder.insert_constrain(v12);
builder.terminate_with_return(vec![]);

let ssa = builder.finish().flatten_cfg();
let main = ssa.main();

// Now assert that there is not an always-false constraint after flattening:
let mut constrain_count = 0;
for instruction in main.dfg[main.entry_block()].instructions() {
if let Instruction::Constrain(value) = main.dfg[*instruction] {
if let Some(constant) = main.dfg.get_numeric_constant(value) {
assert!(constant.is_one());
}
constrain_count += 1;
}
}
assert_eq!(constrain_count, 1);
}
}

0 comments on commit 2e330e0

Please sign in to comment.