Skip to content

Commit

Permalink
feat(perf): Track last loads per block in mem2reg and remove them if …
Browse files Browse the repository at this point in the history
…possible (#6088)

Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com>
Co-authored-by: jfecher <jake@aztecprotocol.com>
  • Loading branch information
3 people authored Nov 26, 2024
1 parent 6491175 commit 624ae6c
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 3 deletions.
223 changes: 221 additions & 2 deletions compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! - A reference with 0 aliases means we were unable to find which reference this reference
//! refers to. If such a reference is stored to, we must conservatively invalidate every
//! reference in the current block.
//! - We also track the last load instruction to each address per block.
//!
//! From there, to figure out the value of each reference at the end of block, iterate each instruction:
//! - On `Instruction::Allocate`:
Expand All @@ -28,6 +29,13 @@
//! - Furthermore, if the result of the load is a reference, mark the result as an alias
//! of the reference it dereferences to (if known).
//! - If which reference it dereferences to is not known, this load result has no aliases.
//! - We also track the last instance of a load instruction to each address in a block.
//! If we see that the last load instruction was from the same address as the current load instruction,
//! we move to replace the result of the current load with the result of the previous load.
//! This removal requires a couple conditions:
//! - No store occurs to that address before the next load,
//! - The address is not used as an argument to a call
//! This optimization helps us remove repeated loads for which there are not known values.
//! - On `Instruction::Store { address, value }`:
//! - If the address of the store is known:
//! - If the address has exactly 1 alias:
Expand All @@ -40,11 +48,13 @@
//! - Conservatively mark every alias in the block to `Unknown`.
//! - Additionally, if there were no Loads to any alias of the address between this Store and
//! the previous Store to the same address, the previous store can be removed.
//! - Remove the instance of the last load instruction to the address and its aliases
//! - On `Instruction::Call { arguments }`:
//! - If any argument of the call is a reference, set the value of each alias of that
//! reference to `Unknown`
//! - Any builtin functions that may return aliases if their input also contains a
//! reference should be tracked. Examples: `slice_push_back`, `slice_insert`, `slice_remove`, etc.
//! - Remove the instance of the last load instruction for any reference arguments and their aliases
//!
//! On a terminator instruction:
//! - If the terminator is a `Jmp`:
Expand Down Expand Up @@ -274,6 +284,9 @@ impl<'f> PerFunctionContext<'f> {
if let Some(first_predecessor) = predecessors.next() {
let mut first = self.blocks.get(&first_predecessor).cloned().unwrap_or_default();
first.last_stores.clear();
// Last loads are tracked per block. During unification we are creating a new block from the current one,
// so we must clear the last loads of the current block before we return the new block.
first.last_loads.clear();

// Note that we have to start folding with the first block as the accumulator.
// If we started with an empty block, an empty block union'd with any other block
Expand Down Expand Up @@ -410,6 +423,28 @@ impl<'f> PerFunctionContext<'f> {

self.last_loads.insert(address, (instruction, block_id));
}

// Check whether the block has a repeat load from the same address (w/ no calls or stores in between the loads).
// If we do have a repeat load, we can remove the current load and map its result to the previous load's result.
if let Some(last_load) = references.last_loads.get(&address) {
let Instruction::Load { address: previous_address } =
&self.inserter.function.dfg[*last_load]
else {
panic!("Expected a Load instruction here");
};
let result = self.inserter.function.dfg.instruction_results(instruction)[0];
let previous_result =
self.inserter.function.dfg.instruction_results(*last_load)[0];
if *previous_address == address {
self.inserter.map_value(result, previous_result);
self.instructions_to_remove.insert(instruction);
}
}
// We want to set the load for every load even if the address has a known value
// and the previous load instruction was removed.
// We are safe to still remove a repeat load in this case as we are mapping from the current load's
// result to the previous load, which if it was removed should already have a mapping to the known value.
references.set_last_load(address, instruction);
}
Instruction::Store { address, value } => {
let address = self.inserter.function.dfg.resolve(*address);
Expand All @@ -435,6 +470,8 @@ impl<'f> PerFunctionContext<'f> {
}

references.set_known_value(address, value);
// If we see a store to an address, the last load to that address needs to remain.
references.keep_last_load_for(address, self.inserter.function);
references.last_stores.insert(address, instruction);
}
Instruction::Allocate => {
Expand Down Expand Up @@ -542,6 +579,9 @@ impl<'f> PerFunctionContext<'f> {
let value = self.inserter.function.dfg.resolve(*value);
references.set_unknown(value);
references.mark_value_used(value, self.inserter.function);

// If a reference is an argument to a call, the last load to that address and its aliases needs to remain.
references.keep_last_load_for(value, self.inserter.function);
}
}
}
Expand Down Expand Up @@ -572,6 +612,12 @@ impl<'f> PerFunctionContext<'f> {
let destination_parameters = self.inserter.function.dfg[*destination].parameters();
assert_eq!(destination_parameters.len(), arguments.len());

// If we have multiple parameters that alias that same argument value,
// then those parameters also alias each other.
// We save parameters with repeat arguments to later mark those
// parameters as aliasing one another.
let mut arg_set: HashMap<ValueId, BTreeSet<ValueId>> = HashMap::default();

// Add an alias for each reference parameter
for (parameter, argument) in destination_parameters.iter().zip(arguments) {
if self.inserter.function.dfg.value_is_reference(*parameter) {
Expand All @@ -581,10 +627,27 @@ impl<'f> PerFunctionContext<'f> {
if let Some(aliases) = references.aliases.get_mut(expression) {
// The argument reference is possibly aliased by this block parameter
aliases.insert(*parameter);

// Check if we have seen the same argument
let seen_parameters = arg_set.entry(argument).or_default();
// Add the current parameter to the parameters we have seen for this argument.
// The previous parameters and the current one alias one another.
seen_parameters.insert(*parameter);
}
}
}
}

// Set the aliases of the parameters
for (_, aliased_params) in arg_set {
for param in aliased_params.iter() {
self.set_aliases(
references,
*param,
AliasSet::known_multiple(aliased_params.clone()),
);
}
}
}
TerminatorInstruction::Return { return_values, .. } => {
// Removing all `last_stores` for each returned reference is more important here
Expand Down Expand Up @@ -900,7 +963,7 @@ mod tests {
// v10 = eq v9, Field 2
// constrain v9 == Field 2
// v11 = load v2
// v12 = load v10
// v12 = load v11
// v13 = eq v12, Field 2
// constrain v11 == Field 2
// return
Expand Down Expand Up @@ -959,7 +1022,7 @@ mod tests {
let main = ssa.main();
assert_eq!(main.reachable_blocks().len(), 4);

// The store from the original SSA should remain
// The stores from the original SSA should remain
assert_eq!(count_stores(main.entry_block(), &main.dfg), 2);
assert_eq!(count_stores(b2, &main.dfg), 1);

Expand Down Expand Up @@ -1006,4 +1069,160 @@ mod tests {
let main = ssa.main();
assert_eq!(count_loads(main.entry_block(), &main.dfg), 1);
}

#[test]
fn remove_repeat_loads() {
// This tests starts with two loads from the same unknown load.
// Specifically you should look for `load v2` in `b3`.
// We should be able to remove the second repeated load.
let src = "
acir(inline) fn main f0 {
b0():
v0 = allocate -> &mut Field
store Field 0 at v0
v2 = allocate -> &mut &mut Field
store v0 at v2
jmp b1(Field 0)
b1(v3: Field):
v4 = eq v3, Field 0
jmpif v4 then: b2, else: b3
b2():
v5 = load v2 -> &mut Field
store Field 2 at v5
v8 = add v3, Field 1
jmp b1(v8)
b3():
v9 = load v0 -> Field
v10 = eq v9, Field 2
constrain v9 == Field 2
v11 = load v2 -> &mut Field
v12 = load v2 -> &mut Field
v13 = load v12 -> Field
v14 = eq v13, Field 2
constrain v13 == Field 2
return
}
";

let ssa = Ssa::from_str(src).unwrap();

// The repeated load from v3 should be removed
// b3 should only have three loads now rather than four previously
//
// All stores are expected to remain.
let expected = "
acir(inline) fn main f0 {
b0():
v1 = allocate -> &mut Field
store Field 0 at v1
v3 = allocate -> &mut &mut Field
store v1 at v3
jmp b1(Field 0)
b1(v0: Field):
v4 = eq v0, Field 0
jmpif v4 then: b3, else: b2
b3():
v11 = load v3 -> &mut Field
store Field 2 at v11
v13 = add v0, Field 1
jmp b1(v13)
b2():
v5 = load v1 -> Field
v7 = eq v5, Field 2
constrain v5 == Field 2
v8 = load v3 -> &mut Field
v9 = load v8 -> Field
v10 = eq v9, Field 2
constrain v9 == Field 2
return
}
";

let ssa = ssa.mem2reg();
assert_normalized_ssa_equals(ssa, expected);
}

#[test]
fn keep_repeat_loads_passed_to_a_call() {
// The test is the exact same as `remove_repeat_loads` above except with the call
// to `f1` between the repeated loads.
let src = "
acir(inline) fn main f0 {
b0():
v1 = allocate -> &mut Field
store Field 0 at v1
v3 = allocate -> &mut &mut Field
store v1 at v3
jmp b1(Field 0)
b1(v0: Field):
v4 = eq v0, Field 0
jmpif v4 then: b3, else: b2
b3():
v13 = load v3 -> &mut Field
store Field 2 at v13
v15 = add v0, Field 1
jmp b1(v15)
b2():
v5 = load v1 -> Field
v7 = eq v5, Field 2
constrain v5 == Field 2
v8 = load v3 -> &mut Field
call f1(v3)
v10 = load v3 -> &mut Field
v11 = load v10 -> Field
v12 = eq v11, Field 2
constrain v11 == Field 2
return
}
acir(inline) fn foo f1 {
b0(v0: &mut Field):
return
}
";

let ssa = Ssa::from_str(src).unwrap();

let ssa = ssa.mem2reg();
// We expect the program to be unchanged
assert_normalized_ssa_equals(ssa, src);
}

#[test]
fn keep_repeat_loads_with_alias_store() {
// v7, v8, and v9 alias one another. We want to make sure that a repeat load to v7 with a store
// to its aliases in between the repeat loads does not remove those loads.
let src = "
acir(inline) fn main f0 {
b0(v0: u1):
jmpif v0 then: b2, else: b1
b2():
v6 = allocate -> &mut Field
store Field 0 at v6
jmp b3(v6, v6, v6)
b3(v1: &mut Field, v2: &mut Field, v3: &mut Field):
v8 = load v1 -> Field
store Field 2 at v2
v10 = load v1 -> Field
store Field 1 at v3
v11 = load v1 -> Field
store Field 3 at v3
v13 = load v1 -> Field
constrain v8 == Field 0
constrain v10 == Field 2
constrain v11 == Field 1
constrain v13 == Field 3
return
b1():
v4 = allocate -> &mut Field
store Field 1 at v4
jmp b3(v4, v4, v4)
}
";

let ssa = Ssa::from_str(src).unwrap();

let ssa = ssa.mem2reg();
// We expect the program to be unchanged
assert_normalized_ssa_equals(ssa, src);
}
}
4 changes: 4 additions & 0 deletions compiler/noirc_evaluator/src/ssa/opt/mem2reg/alias_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ impl AliasSet {
Self { aliases: Some(aliases) }
}

pub(super) fn known_multiple(values: BTreeSet<ValueId>) -> AliasSet {
Self { aliases: Some(values) }
}

/// In rare cases, such as when creating an empty array of references, the set of aliases for a
/// particular value will be known to be zero, which is distinct from being unknown and
/// possibly referring to any alias.
Expand Down
13 changes: 13 additions & 0 deletions compiler/noirc_evaluator/src/ssa/opt/mem2reg/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ pub(super) struct Block {

/// The last instance of a `Store` instruction to each address in this block
pub(super) last_stores: im::OrdMap<ValueId, InstructionId>,

// The last instance of a `Load` instruction to each address in this block
pub(super) last_loads: im::OrdMap<ValueId, InstructionId>,
}

/// An `Expression` here is used to represent a canonical key
Expand Down Expand Up @@ -237,4 +240,14 @@ impl Block {

Cow::Owned(AliasSet::unknown())
}

pub(super) fn set_last_load(&mut self, address: ValueId, instruction: InstructionId) {
self.last_loads.insert(address, instruction);
}

pub(super) fn keep_last_load_for(&mut self, address: ValueId, function: &Function) {
let address = function.dfg.resolve(address);
self.last_loads.remove(&address);
self.for_each_alias_of(address, |block, alias| block.last_loads.remove(&alias));
}
}
2 changes: 1 addition & 1 deletion tooling/debugger/tests/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ mod tests {
let nargo_bin =
cargo_bin("nargo").into_os_string().into_string().expect("Cannot parse nargo path");

let timeout_seconds = 25;
let timeout_seconds = 30;
let mut dbg_session =
spawn_bash(Some(timeout_seconds * 1000)).expect("Could not start bash session");

Expand Down

0 comments on commit 624ae6c

Please sign in to comment.