diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dom.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/dom.rs index 93c3567fa9a..4763ffffbd1 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/dom.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/dom.rs @@ -44,6 +44,9 @@ pub(crate) struct DominatorTree { /// After dominator tree computation has complete, this will contain a node for every /// reachable block, and no nodes for unreachable blocks. nodes: HashMap, + + /// Subsequent calls to `dominates` are cached to speed up access + cache: HashMap<(BasicBlockId, BasicBlockId), bool>, } /// Methods for querying the dominator tree. @@ -83,7 +86,21 @@ impl DominatorTree { /// This function panics if either of the blocks are unreachable. /// /// An instruction is considered to dominate itself. - pub(crate) fn dominates(&self, block_a_id: BasicBlockId, mut block_b_id: BasicBlockId) -> bool { + pub(crate) fn dominates(&mut self, block_a_id: BasicBlockId, block_b_id: BasicBlockId) -> bool { + if let Some(res) = self.cache.get(&(block_a_id, block_b_id)) { + return *res; + } + + let result = self.dominates_helper(block_a_id, block_b_id); + self.cache.insert((block_a_id, block_b_id), result); + result + } + + pub(crate) fn dominates_helper( + &self, + block_a_id: BasicBlockId, + mut block_b_id: BasicBlockId, + ) -> bool { // Walk up the dominator tree from "b" until we encounter or pass "a". Doing the // comparison on the reverse post-order may allows to test whether we have passed "a" // without waiting until we reach the root of the tree. @@ -104,7 +121,7 @@ impl DominatorTree { /// Allocate and compute a dominator tree from a pre-computed control flow graph and /// post-order counterpart. pub(crate) fn with_cfg_and_post_order(cfg: &ControlFlowGraph, post_order: &PostOrder) -> Self { - let mut dom_tree = DominatorTree { nodes: HashMap::new() }; + let mut dom_tree = DominatorTree { nodes: HashMap::new(), cache: HashMap::new() }; dom_tree.compute_dominator_tree(cfg, post_order); dom_tree } @@ -249,7 +266,7 @@ mod tests { block0_id, TerminatorInstruction::Return { return_values: vec![] }, ); - let dom_tree = DominatorTree::with_function(&func); + let mut dom_tree = DominatorTree::with_function(&func); assert!(dom_tree.dominates(block0_id, block0_id)); } @@ -308,7 +325,7 @@ mod tests { // unreachable, performing this query indicates an internal compiler error. #[test] fn unreachable_node_asserts() { - let (dt, b0, _b1, b2, b3) = unreachable_node_setup(); + let (mut dt, b0, _b1, b2, b3) = unreachable_node_setup(); assert!(dt.dominates(b0, b0)); assert!(dt.dominates(b0, b2)); @@ -326,42 +343,42 @@ mod tests { #[test] #[should_panic] fn unreachable_node_panic_b0_b1() { - let (dt, b0, b1, _b2, _b3) = unreachable_node_setup(); + let (mut dt, b0, b1, _b2, _b3) = unreachable_node_setup(); dt.dominates(b0, b1); } #[test] #[should_panic] fn unreachable_node_panic_b1_b0() { - let (dt, b0, b1, _b2, _b3) = unreachable_node_setup(); + let (mut dt, b0, b1, _b2, _b3) = unreachable_node_setup(); dt.dominates(b1, b0); } #[test] #[should_panic] fn unreachable_node_panic_b1_b1() { - let (dt, _b0, b1, _b2, _b3) = unreachable_node_setup(); + let (mut dt, _b0, b1, _b2, _b3) = unreachable_node_setup(); dt.dominates(b1, b1); } #[test] #[should_panic] fn unreachable_node_panic_b1_b2() { - let (dt, _b0, b1, b2, _b3) = unreachable_node_setup(); + let (mut dt, _b0, b1, b2, _b3) = unreachable_node_setup(); dt.dominates(b1, b2); } #[test] #[should_panic] fn unreachable_node_panic_b1_b3() { - let (dt, _b0, b1, _b2, b3) = unreachable_node_setup(); + let (mut dt, _b0, b1, _b2, b3) = unreachable_node_setup(); dt.dominates(b1, b3); } #[test] #[should_panic] fn unreachable_node_panic_b3_b1() { - let (dt, _b0, b1, b2, _b3) = unreachable_node_setup(); + let (mut dt, _b0, b1, b2, _b3) = unreachable_node_setup(); dt.dominates(b2, b1); } @@ -390,7 +407,7 @@ mod tests { let func = ssa.main(); let block0_id = func.entry_block(); - let dt = DominatorTree::with_function(func); + let mut dt = DominatorTree::with_function(func); // Expected dominance tree: // block0 { diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg/branch_analysis.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg/branch_analysis.rs index c0e45f9c1e8..bed0686e45b 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg/branch_analysis.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg/branch_analysis.rs @@ -1,9 +1,28 @@ //! This is an algorithm for identifying branch starts and ends. -use std::collections::{HashMap, HashSet}; +//! +//! The algorithm is split into two parts: +//! 1. The outer part: +//! A. An (unrolled) CFG can be though of as a linear sequence of blocks where some nodes split +//! off, but eventually rejoin to a new node and continue the linear sequence. +//! B. Follow this sequence in order, and whenever a split is found call +//! `find_join_point_of_branches` and then recur from the join point it returns until the +//! return instruction is found. +//! +//! 2. The inner part defined by `find_join_point_of_branches`: +//! A. For each of the two branches in a jmpif block: +//! - Check if either has multiple predecessors. If so, it is a join point. +//! - If not, continue to search the linear sequence of successor blocks from that block. +//! - If another split point is found, recur in `find_join_point_of_branches` +//! - If a block with multiple predecessors is found, return it. +//! - After, we should have identified a join point for both branches. This is expected to be +//! the same block for both and can be returned from here to continue iteration. +//! +//! This algorithm will remember each join point found in `find_join_point_of_branches` and +//! the resulting map from each split block to each join block is returned. +use std::collections::HashMap; use crate::ssa_refactor::ir::{ - basic_block::BasicBlockId, cfg::ControlFlowGraph, dom::DominatorTree, function::Function, - post_order::PostOrder, + basic_block::BasicBlockId, cfg::ControlFlowGraph, function::Function, }; /// Returns a `HashMap` mapping blocks that start a branch (i.e. blocks terminated with jmpif) to @@ -16,121 +35,78 @@ pub(super) fn find_branch_ends( function: &Function, cfg: &ControlFlowGraph, ) -> HashMap { - let post_order = PostOrder::with_function(function); - let dom_tree = DominatorTree::with_cfg_and_post_order(cfg, &post_order); - let mut stepper = Stepper::new(function.entry_block()); - // This outer `visited` set is inconsequential, and simply here to satisfy the recursive - // stepper interface. - let mut visited = HashSet::new(); - let mut branch_ends = HashMap::new(); - while !stepper.finished { - stepper.step(cfg, &dom_tree, &mut visited, &mut branch_ends); - } - branch_ends -} + let mut block = function.entry_block(); + let mut context = Context::new(cfg); -/// Returns the block at which `left` and `right` converge, at the same time identifying branch -/// ends in any sub branches. -/// -/// This function is called by `Stepper::step` and is thus recursive. -fn step_until_rejoin( - cfg: &ControlFlowGraph, - dom_tree: &DominatorTree, - branch_ends: &mut HashMap, - left: BasicBlockId, - right: BasicBlockId, -) -> BasicBlockId { - let mut visited = HashSet::new(); - let mut left_stepper = Stepper::new(left); - let mut right_stepper = Stepper::new(right); + loop { + let mut successors = cfg.successors(block); - while !left_stepper.finished || !right_stepper.finished { - left_stepper.step(cfg, dom_tree, &mut visited, branch_ends); - right_stepper.step(cfg, dom_tree, &mut visited, branch_ends); + if successors.len() == 2 { + block = context.find_join_point_of_branches(block, successors); + } else if successors.len() == 1 { + block = successors.next().unwrap(); + } else if successors.len() == 0 { + // return encountered. We have nothing to join, so we're done + break; + } else { + unreachable!("A block can only have 0, 1, or 2 successors"); + } } - let collision = match (left_stepper.collision, right_stepper.collision) { - (Some(collision), None) | (None, Some(collision)) => collision, - (Some(_),Some(_))=> unreachable!("A collision on both branches indicates a loop"), - _ => unreachable!( - "Until we support multiple returns, branches always re-converge. Once supported this case should return `None`" - ), - }; - collision + + context.branch_ends } -/// Tracks traversal along the arm of a branch. Steppers are progressed in pairs, such that the -/// re-convergence point of two arms is discovered as soon as possible. The exceptional case is -/// that of the top level stepper, which conveniently steps the whole CFG as if it were a single -/// arm. -struct Stepper { - /// The block that will be interrogated when calling `step` - current_block: BasicBlockId, - /// Indicates that the stepper has no more block successors to process, either because it has - /// reached the end of the CFG, or because it encountered a block already visited by its - /// sibling stepper. - finished: bool, - /// Once finished this option indicates whether a collision was encountered before reaching - /// the end of the CFG. - collision: Option, +struct Context<'cfg> { + branch_ends: HashMap, + cfg: &'cfg ControlFlowGraph, } -impl Stepper { - /// Creates a fresh stepper instance - fn new(current_block: BasicBlockId) -> Self { - Stepper { current_block, finished: false, collision: None } +impl<'cfg> Context<'cfg> { + fn new(cfg: &'cfg ControlFlowGraph) -> Self { + Self { cfg, branch_ends: HashMap::new() } } - /// Checks the current block to see if it has already been visited and if so marks it as a - /// collision. If a sub-branch is encountered `step_until_rejoin` is called to start a pair - /// of child steppers stepping along its arms. - /// - /// It is safe to call this even when the stepper has reached its end. - fn step( + fn find_join_point_of_branches( &mut self, - cfg: &ControlFlowGraph, - dom_tree: &DominatorTree, - visited: &mut HashSet, - branch_ends: &mut HashMap, - ) { - if self.finished { - // The caller still needs to progress the other stepper, while this one sits idle. - return; - } - if visited.contains(&self.current_block) { - // The other stepper has already visited this block - thus this block is the - // re.-convergence point. - self.collision = Some(self.current_block); - self.finished = true; + start: BasicBlockId, + mut successors: impl Iterator, + ) -> BasicBlockId { + let left = successors.next().unwrap(); + let right = successors.next().unwrap(); + + let left_join = self.find_join_point(left); + let right_join = self.find_join_point(right); + + assert_eq!(left_join, right_join, "Expected two blocks to join to the same block"); + self.branch_ends.insert(start, left_join); + + left_join + } + + fn find_join_point(&mut self, block: BasicBlockId) -> BasicBlockId { + let predecessors = self.cfg.predecessors(block); + if predecessors.len() > 1 { + return block; } - visited.insert(self.current_block); + // The join point is not this block, so continue on + self.skip_then_find_join_point(block) + } - let mut successors = cfg.successors(self.current_block); - match successors.len() { - 0 => { - // Reached the end of the CFG without a collision - this will happen in the other - // stepper assuming the CFG contains no early returns. - self.finished = true; - } - 1 => { - // This block doesn't describe any branch starts or ends - move on. - self.current_block = successors.next().unwrap(); - } - 2 => { - // Sub-branch start encountered - recurse to find the end of the sub branch - let left = successors.next().unwrap(); - let right = successors.next().unwrap(); - let sub_branch_end = step_until_rejoin(cfg, dom_tree, branch_ends, left, right); - for collision_predecessor in cfg.predecessors(sub_branch_end) { - assert!(dom_tree.dominates(self.current_block, collision_predecessor)); - } - branch_ends.insert(self.current_block, sub_branch_end); + fn skip_then_find_join_point(&mut self, block: BasicBlockId) -> BasicBlockId { + let mut successors = self.cfg.successors(block); - // Resume stepping though the current arm fro where the sub-branch left off - self.current_block = sub_branch_end; - } - _ => { - unreachable!("Basic blocks never have more than 2 successors") - } + if successors.len() == 2 { + let join = self.find_join_point_of_branches(block, successors); + // Note that we call skip_then_find_join_point here instead of find_join_point. + // We already know this `join` is a join point, but it cannot be for the current block + // since we already know it is the join point of the successors of the current block. + self.skip_then_find_join_point(join) + } else if successors.len() == 1 { + self.find_join_point(successors.next().unwrap()) + } else if successors.len() == 0 { + unreachable!("return encountered before a join point was found. This can only happen if early-return was added to the language without implementing it by jmping to a join block first") + } else { + unreachable!("A block can only have 0, 1, or 2 successors"); } } } diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs index 96120999206..e5d7d6f0d5c 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs @@ -63,7 +63,7 @@ struct Loops { fn find_all_loops(function: &Function) -> Loops { let cfg = ControlFlowGraph::with_function(function); let post_order = PostOrder::with_function(function); - let dom_tree = DominatorTree::with_cfg_and_post_order(&cfg, &post_order); + let mut dom_tree = DominatorTree::with_cfg_and_post_order(&cfg, &post_order); let mut loops = vec![];