Skip to content

Commit

Permalink
fix(ssa refactor): function inlining orphans calls (#1747)
Browse files Browse the repository at this point in the history
* chore(ssa refactor): inling regression test

* chore(ssa refactor): rename for clarity

* chore(ssa refactor): skip seen blocks

* fix(ssa refactor): fix function inlining

* chore(ssa refactor): fix factorial inlining test

* chore(ssa refactor): typo

* chore(ssa refactor): clearer doc comments
  • Loading branch information
joss-aztec authored Jun 21, 2023
1 parent a5e631a commit f30a90f
Showing 1 changed file with 116 additions and 9 deletions.
125 changes: 116 additions & 9 deletions crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,13 @@ struct PerFunctionContext<'function> {
/// argument values.
values: HashMap<ValueId, ValueId>,

/// Maps BasicBlockIds in the function being inlined to the new BasicBlockIds to use in the
/// function being inlined into.
/// Maps blocks in the source function to blocks in the function being inlined into, where
/// each mapping is from the start of a source block to an inlined block in which the
/// analogous program point occurs.
///
/// Note that the starts of multiple source blocks can map into a single inlined block.
/// Conversely the whole of a source block is not guaranteed to map into a single inlined
/// block.
blocks: HashMap<BasicBlockId, BasicBlockId>,

/// Maps InstructionIds from the function being inlined to the function being inlined into.
Expand Down Expand Up @@ -222,7 +227,9 @@ impl<'function> PerFunctionContext<'function> {
new_value
}

/// Translate a block id from the source function to one of the target function.
/// Translates the program point representing the start of the given `source_block` to the
/// inlined block in which the analogous program point occurs. (Once inlined, the source
/// block's analogous program region may span multiple inlined blocks.)
///
/// If the block isn't already known, this will insert a new block into the target function
/// with the same parameter types as the source block.
Expand Down Expand Up @@ -282,11 +289,14 @@ impl<'function> PerFunctionContext<'function> {
let mut function_returns = vec![];

while let Some(source_block_id) = block_queue.pop() {
if seen_blocks.contains(&source_block_id) {
continue;
}
let translated_block_id = self.translate_block(source_block_id, &mut block_queue);
self.context.builder.switch_to_block(translated_block_id);

seen_blocks.insert(source_block_id);
self.inline_block(ssa, source_block_id);
self.inline_block_instructions(ssa, source_block_id);

if let Some((block, values)) =
self.handle_terminator_instruction(source_block_id, &mut block_queue)
Expand Down Expand Up @@ -331,7 +341,7 @@ impl<'function> PerFunctionContext<'function> {

/// Inline each instruction in the given block into the function being inlined into.
/// This may recurse if it finds another function to inline if a call instruction is within this block.
fn inline_block(&mut self, ssa: &Ssa, block_id: BasicBlockId) {
fn inline_block_instructions(&mut self, ssa: &Ssa, block_id: BasicBlockId) {
let block = &self.source_function.dfg[block_id];
for id in block.instructions() {
match &self.source_function.dfg[*id] {
Expand Down Expand Up @@ -448,7 +458,11 @@ impl<'function> PerFunctionContext<'function> {
if self.inlining_main {
self.context.builder.terminate_with_return(return_values.clone());
}
let block_id = self.translate_block(block_id, block_queue);
// Note that `translate_block` would take us back to the point at which the
// inlining of this source block began. Since additional blocks may have been
// inlined since, we are interested in the block representing the current program
// point, obtained via `current_block`.
let block_id = self.context.builder.current_block();
Some((block_id, return_values))
}
}
Expand All @@ -457,10 +471,13 @@ impl<'function> PerFunctionContext<'function> {

#[cfg(test)]
mod test {
use acvm::FieldElement;

use crate::ssa_refactor::{
ir::{
basic_block::BasicBlockId,
function::RuntimeType,
instruction::{BinaryOp, TerminatorInstruction},
instruction::{BinaryOp, Intrinsic, TerminatorInstruction},
map::Id,
types::Type,
},
Expand Down Expand Up @@ -622,15 +639,26 @@ mod test {
// b0():
// jmp b1()
// b1():
// jmp b2()
// b2():
// jmp b3()
// b3():
// jmp b4()
// b4():
// jmp b5()
// b5():
// jmp b6()
// b6():
// return Field 120
// }
let inlined = ssa.inline_functions();
assert_eq!(inlined.functions.len(), 1);

let main = inlined.main();
let b1 = &main.dfg[b1];
let b6_id: BasicBlockId = Id::test_new(6);
let b6 = &main.dfg[b6_id];

match b1.terminator() {
match b6.terminator() {
Some(TerminatorInstruction::Return { return_values }) => {
assert_eq!(return_values.len(), 1);
let value = main
Expand All @@ -643,4 +671,83 @@ mod test {
other => unreachable!("Unexpected terminator {other:?}"),
}
}

#[test]
fn displaced_return_mapping() {
// This test is designed specifically to catch a regression in which the ids of blocks
// terminated by returns are badly tracked. As a result, the continuation of a source
// block after a call instruction could but inlined into a block that's already been
// terminated, producing an incorrect order and orphaning successors.

// fn main f0 {
// b0(v0: u1):
// v2 = call f1(v0)
// call println(v2)
// return
// }
// fn inner1 f1 {
// b0(v0: u1):
// v2 = call f2(v0)
// return v2
// }
// fn inner2 f2 {
// b0(v0: u1):
// jmpif v0 then: b1, else: b2
// b1():
// jmp b3(Field 1)
// b3(v3: Field):
// return v3
// b2():
// jmp b3(Field 2)
// }
let main_id = Id::test_new(0);
let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir);

let main_cond = builder.add_parameter(Type::bool());
let inner1_id = Id::test_new(1);
let inner1 = builder.import_function(inner1_id);
let main_v2 = builder.insert_call(inner1, vec![main_cond], vec![Type::field()])[0];
let println = builder.import_intrinsic_id(Intrinsic::Println);
builder.insert_call(println, vec![main_v2], vec![]);
builder.terminate_with_return(vec![]);

builder.new_function("inner1".into(), inner1_id);
let inner1_cond = builder.add_parameter(Type::bool());
let inner2_id = Id::test_new(2);
let inner2 = builder.import_function(inner2_id);
let inner1_v2 = builder.insert_call(inner2, vec![inner1_cond], vec![Type::field()])[0];
builder.terminate_with_return(vec![inner1_v2]);

builder.new_function("inner2".into(), inner2_id);
let inner2_cond = builder.add_parameter(Type::bool());
let then_block = builder.insert_block();
let else_block = builder.insert_block();
let join_block = builder.insert_block();
builder.terminate_with_jmpif(inner2_cond, then_block, else_block);
builder.switch_to_block(then_block);
let one = builder.numeric_constant(FieldElement::one(), Type::field());
builder.terminate_with_jmp(join_block, vec![one]);
builder.switch_to_block(else_block);
let two = builder.numeric_constant(FieldElement::from(2_u128), Type::field());
builder.terminate_with_jmp(join_block, vec![two]);
let join_param = builder.add_block_parameter(join_block, Type::field());
builder.switch_to_block(join_block);
builder.terminate_with_return(vec![join_param]);

let ssa = builder.finish().inline_functions();
// Expected result:
// fn main f3 {
// b0(v0: u1):
// jmpif v0 then: b1, else: b2
// b1():
// jmp b3(Field 1)
// b3(v3: Field):
// call println(v3)
// return
// b2():
// jmp b3(Field 2)
// }
let main = ssa.main();
assert_eq!(main.reachable_blocks().len(), 4);
}
}

0 comments on commit f30a90f

Please sign in to comment.