From b3f34f148cbe05e6cd74ef36144275bb03d015f7 Mon Sep 17 00:00:00 2001
From: Jake Fecher <jake@aztecprotocol.com>
Date: Wed, 21 Jun 2023 15:48:55 -0500
Subject: [PATCH 1/2] Speedup find-branch-ends

---
 .../src/ssa_refactor/ir/dom.rs                |  39 +++-
 .../opt/flatten_cfg/branch_analysis.rs        | 194 ++++++++----------
 .../src/ssa_refactor/opt/unrolling.rs         |   2 +-
 3 files changed, 118 insertions(+), 117 deletions(-)

diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/dom.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/dom.rs
index 93c3567fa9a..4763ffffbd1 100644
--- a/crates/noirc_evaluator/src/ssa_refactor/ir/dom.rs
+++ b/crates/noirc_evaluator/src/ssa_refactor/ir/dom.rs
@@ -44,6 +44,9 @@ pub(crate) struct DominatorTree {
     /// After dominator tree computation has complete, this will contain a node for every
     /// reachable block, and no nodes for unreachable blocks.
     nodes: HashMap<BasicBlockId, DominatorTreeNode>,
+
+    /// Subsequent calls to `dominates` are cached to speed up access
+    cache: HashMap<(BasicBlockId, BasicBlockId), bool>,
 }
 
 /// Methods for querying the dominator tree.
@@ -83,7 +86,21 @@ impl DominatorTree {
     /// This function panics if either of the blocks are unreachable.
     ///
     /// An instruction is considered to dominate itself.
-    pub(crate) fn dominates(&self, block_a_id: BasicBlockId, mut block_b_id: BasicBlockId) -> bool {
+    pub(crate) fn dominates(&mut self, block_a_id: BasicBlockId, block_b_id: BasicBlockId) -> bool {
+        if let Some(res) = self.cache.get(&(block_a_id, block_b_id)) {
+            return *res;
+        }
+
+        let result = self.dominates_helper(block_a_id, block_b_id);
+        self.cache.insert((block_a_id, block_b_id), result);
+        result
+    }
+
+    pub(crate) fn dominates_helper(
+        &self,
+        block_a_id: BasicBlockId,
+        mut block_b_id: BasicBlockId,
+    ) -> bool {
         // Walk up the dominator tree from "b" until we encounter or pass "a". Doing the
         // comparison on the reverse post-order may allows to test whether we have passed "a"
         // without waiting until we reach the root of the tree.
@@ -104,7 +121,7 @@ impl DominatorTree {
     /// Allocate and compute a dominator tree from a pre-computed control flow graph and
     /// post-order counterpart.
     pub(crate) fn with_cfg_and_post_order(cfg: &ControlFlowGraph, post_order: &PostOrder) -> Self {
-        let mut dom_tree = DominatorTree { nodes: HashMap::new() };
+        let mut dom_tree = DominatorTree { nodes: HashMap::new(), cache: HashMap::new() };
         dom_tree.compute_dominator_tree(cfg, post_order);
         dom_tree
     }
@@ -249,7 +266,7 @@ mod tests {
             block0_id,
             TerminatorInstruction::Return { return_values: vec![] },
         );
-        let dom_tree = DominatorTree::with_function(&func);
+        let mut dom_tree = DominatorTree::with_function(&func);
         assert!(dom_tree.dominates(block0_id, block0_id));
     }
 
@@ -308,7 +325,7 @@ mod tests {
     // unreachable, performing this query indicates an internal compiler error.
     #[test]
     fn unreachable_node_asserts() {
-        let (dt, b0, _b1, b2, b3) = unreachable_node_setup();
+        let (mut dt, b0, _b1, b2, b3) = unreachable_node_setup();
 
         assert!(dt.dominates(b0, b0));
         assert!(dt.dominates(b0, b2));
@@ -326,42 +343,42 @@ mod tests {
     #[test]
     #[should_panic]
     fn unreachable_node_panic_b0_b1() {
-        let (dt, b0, b1, _b2, _b3) = unreachable_node_setup();
+        let (mut dt, b0, b1, _b2, _b3) = unreachable_node_setup();
         dt.dominates(b0, b1);
     }
 
     #[test]
     #[should_panic]
     fn unreachable_node_panic_b1_b0() {
-        let (dt, b0, b1, _b2, _b3) = unreachable_node_setup();
+        let (mut dt, b0, b1, _b2, _b3) = unreachable_node_setup();
         dt.dominates(b1, b0);
     }
 
     #[test]
     #[should_panic]
     fn unreachable_node_panic_b1_b1() {
-        let (dt, _b0, b1, _b2, _b3) = unreachable_node_setup();
+        let (mut dt, _b0, b1, _b2, _b3) = unreachable_node_setup();
         dt.dominates(b1, b1);
     }
 
     #[test]
     #[should_panic]
     fn unreachable_node_panic_b1_b2() {
-        let (dt, _b0, b1, b2, _b3) = unreachable_node_setup();
+        let (mut dt, _b0, b1, b2, _b3) = unreachable_node_setup();
         dt.dominates(b1, b2);
     }
 
     #[test]
     #[should_panic]
     fn unreachable_node_panic_b1_b3() {
-        let (dt, _b0, b1, _b2, b3) = unreachable_node_setup();
+        let (mut dt, _b0, b1, _b2, b3) = unreachable_node_setup();
         dt.dominates(b1, b3);
     }
 
     #[test]
     #[should_panic]
     fn unreachable_node_panic_b3_b1() {
-        let (dt, _b0, b1, b2, _b3) = unreachable_node_setup();
+        let (mut dt, _b0, b1, b2, _b3) = unreachable_node_setup();
         dt.dominates(b2, b1);
     }
 
@@ -390,7 +407,7 @@ mod tests {
         let func = ssa.main();
         let block0_id = func.entry_block();
 
-        let dt = DominatorTree::with_function(func);
+        let mut dt = DominatorTree::with_function(func);
 
         // Expected dominance tree:
         // block0 {
diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg/branch_analysis.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg/branch_analysis.rs
index c0e45f9c1e8..aa65c75a484 100644
--- a/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg/branch_analysis.rs
+++ b/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg/branch_analysis.rs
@@ -1,9 +1,28 @@
 //! This is an algorithm for identifying branch starts and ends.
-use std::collections::{HashMap, HashSet};
+//!
+//! The algorithm is split into two parts:
+//! 1. The outer part:
+//!   A. An (unrolled) CFG can be though of as a linear sequence of blocks where some nodes split
+//!   off, but eventually rejoin to a new node and continue the linear sequence.
+//!   B. Follow this sequence in order, and whenever a split is found call
+//!   `find_join_point_of_branches` and then recur from the join point it returns until the
+//!   return instruction is found.
+//!
+//! 2. The inner part defined by `find_join_point_of_branches`:
+//!   A. For each of the two branches in a jmpif block:
+//!     - Check if either has multiple predecessors. If so, it is a join point.
+//!     - If not, continue to search the linear sequence of successor blocks from that block.
+//!       - If another split point is found, recur in `find_join_point_of_branches`
+//!       - If a block with multiple predecessors is found, return it.
+//!     - After, we should have identified a join point for both branches. This is expected to be
+//!       the same block for both and can be returned from here to continue iteration.
+//!
+//! This algorithm will remember each join point found in `find_join_point_of_branches` and
+//! the resulting map from each split block to each join block is returned.
+use std::collections::HashMap;
 
 use crate::ssa_refactor::ir::{
-    basic_block::BasicBlockId, cfg::ControlFlowGraph, dom::DominatorTree, function::Function,
-    post_order::PostOrder,
+    basic_block::BasicBlockId, cfg::ControlFlowGraph, function::Function,
 };
 
 /// Returns a `HashMap` mapping blocks that start a branch (i.e. blocks terminated with jmpif) to
@@ -16,121 +35,86 @@ pub(super) fn find_branch_ends(
     function: &Function,
     cfg: &ControlFlowGraph,
 ) -> HashMap<BasicBlockId, BasicBlockId> {
-    let post_order = PostOrder::with_function(function);
-    let dom_tree = DominatorTree::with_cfg_and_post_order(cfg, &post_order);
-    let mut stepper = Stepper::new(function.entry_block());
-    // This outer `visited` set is inconsequential, and simply here to satisfy the recursive
-    // stepper interface.
-    let mut visited = HashSet::new();
-    let mut branch_ends = HashMap::new();
-    while !stepper.finished {
-        stepper.step(cfg, &dom_tree, &mut visited, &mut branch_ends);
-    }
-    branch_ends
-}
+    let mut block = function.entry_block();
+    let mut context = Context::new(cfg);
 
-/// Returns the block at which `left` and `right` converge, at the same time identifying branch
-/// ends in any sub branches.
-///
-/// This function is called by `Stepper::step` and is thus recursive.
-fn step_until_rejoin(
-    cfg: &ControlFlowGraph,
-    dom_tree: &DominatorTree,
-    branch_ends: &mut HashMap<BasicBlockId, BasicBlockId>,
-    left: BasicBlockId,
-    right: BasicBlockId,
-) -> BasicBlockId {
-    let mut visited = HashSet::new();
-    let mut left_stepper = Stepper::new(left);
-    let mut right_stepper = Stepper::new(right);
+    println!("Finding branch ends");
+
+    loop {
+        let mut successors = cfg.successors(block);
 
-    while !left_stepper.finished || !right_stepper.finished {
-        left_stepper.step(cfg, dom_tree, &mut visited, branch_ends);
-        right_stepper.step(cfg, dom_tree, &mut visited, branch_ends);
+        println!("On top-level block {}", block);
+
+        if successors.len() == 2 {
+            block = context.find_join_point_of_branches(block, successors);
+        } else if successors.len() == 1 {
+            block = successors.next().unwrap();
+        } else if successors.len() == 0 {
+            // return encountered. We have nothing to join, so we're done
+            break;
+        } else {
+            unreachable!("A block can only have 0, 1, or 2 successors");
+        }
     }
-    let collision = match (left_stepper.collision, right_stepper.collision) {
-        (Some(collision), None) | (None, Some(collision)) => collision,
-        (Some(_),Some(_))=> unreachable!("A collision on both branches indicates a loop"), 
-        _ => unreachable!(
-            "Until we support multiple returns, branches always re-converge. Once supported this case should return `None`"
-        ),
-    };
-    collision
+
+    println!("Done finding branch ends");
+
+    context.branch_ends
 }
 
-/// Tracks traversal along the arm of a branch. Steppers are progressed in pairs, such that the
-/// re-convergence point of two arms is discovered as soon as possible. The exceptional case is
-/// that of the top level stepper, which conveniently steps the whole CFG as if it were a single
-/// arm.
-struct Stepper {
-    /// The block that will be interrogated when calling `step`
-    current_block: BasicBlockId,
-    /// Indicates that the stepper has no more block successors to process, either because it has
-    /// reached the end of the CFG, or because it encountered a block already visited by its
-    /// sibling stepper.
-    finished: bool,
-    /// Once finished this option indicates whether a collision was encountered before reaching
-    /// the end of the CFG.
-    collision: Option<BasicBlockId>,
+struct Context<'cfg> {
+    branch_ends: HashMap<BasicBlockId, BasicBlockId>,
+    cfg: &'cfg ControlFlowGraph,
 }
 
-impl Stepper {
-    /// Creates a fresh stepper instance
-    fn new(current_block: BasicBlockId) -> Self {
-        Stepper { current_block, finished: false, collision: None }
+impl<'cfg> Context<'cfg> {
+    fn new(cfg: &'cfg ControlFlowGraph) -> Self {
+        Self { cfg, branch_ends: HashMap::new() }
     }
 
-    /// Checks the current block to see if it has already been visited and if so marks it as a
-    /// collision. If a sub-branch is encountered `step_until_rejoin` is called to start a pair
-    /// of child steppers stepping along its arms.
-    ///
-    /// It is safe to call this even when the stepper has reached its end.
-    fn step(
+    fn find_join_point_of_branches(
         &mut self,
-        cfg: &ControlFlowGraph,
-        dom_tree: &DominatorTree,
-        visited: &mut HashSet<BasicBlockId>,
-        branch_ends: &mut HashMap<BasicBlockId, BasicBlockId>,
-    ) {
-        if self.finished {
-            // The caller still needs to progress the other stepper, while this one sits idle.
-            return;
-        }
-        if visited.contains(&self.current_block) {
-            // The other stepper has already visited this block - thus this block is the
-            // re.-convergence point.
-            self.collision = Some(self.current_block);
-            self.finished = true;
+        start: BasicBlockId,
+        mut successors: impl Iterator<Item = BasicBlockId>,
+    ) -> BasicBlockId {
+        let left = successors.next().unwrap();
+        let right = successors.next().unwrap();
+
+        println!("On jmpif block {}", start);
+
+        let left_join = self.find_join_point(left);
+        let right_join = self.find_join_point(right);
+
+        assert_eq!(left_join, right_join, "Expected two blocks to join to the same block");
+        self.branch_ends.insert(start, left_join);
+
+        left_join
+    }
+
+    fn find_join_point(&mut self, block: BasicBlockId) -> BasicBlockId {
+        let predecessors = self.cfg.predecessors(block);
+        if predecessors.len() > 1 {
+            return block;
         }
-        visited.insert(self.current_block);
+        // The join point is not this block, so continue on
+        self.skip_then_find_join_point(block)
+    }
 
-        let mut successors = cfg.successors(self.current_block);
-        match successors.len() {
-            0 => {
-                // Reached the end of the CFG without a collision - this will happen in the other
-                // stepper assuming the CFG contains no early returns.
-                self.finished = true;
-            }
-            1 => {
-                // This block doesn't describe any branch starts or ends - move on.
-                self.current_block = successors.next().unwrap();
-            }
-            2 => {
-                // Sub-branch start encountered - recurse to find the end of the sub branch
-                let left = successors.next().unwrap();
-                let right = successors.next().unwrap();
-                let sub_branch_end = step_until_rejoin(cfg, dom_tree, branch_ends, left, right);
-                for collision_predecessor in cfg.predecessors(sub_branch_end) {
-                    assert!(dom_tree.dominates(self.current_block, collision_predecessor));
-                }
-                branch_ends.insert(self.current_block, sub_branch_end);
+    fn skip_then_find_join_point(&mut self, block: BasicBlockId) -> BasicBlockId {
+        let mut successors = self.cfg.successors(block);
 
-                // Resume stepping though the current arm fro where the sub-branch left off
-                self.current_block = sub_branch_end;
-            }
-            _ => {
-                unreachable!("Basic blocks never have more than 2 successors")
-            }
+        if successors.len() == 2 {
+            let join = self.find_join_point_of_branches(block, successors);
+            // Note that we call skip_then_find_join_point here instead of find_join_point.
+            // We already know this `join` is a join point, but it cannot be for the current block
+            // since we already know it is the join point of the successors of the current block.
+            self.skip_then_find_join_point(join)
+        } else if successors.len() == 1 {
+            self.find_join_point(successors.next().unwrap())
+        } else if successors.len() == 0 {
+            unreachable!("return encountered before a join point was found. This can only happen if early-return was added to the language without implementing it by jmping to a join block first")
+        } else {
+            unreachable!("A block can only have 0, 1, or 2 successors");
         }
     }
 }
diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs
index 96120999206..e5d7d6f0d5c 100644
--- a/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs
+++ b/crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs
@@ -63,7 +63,7 @@ struct Loops {
 fn find_all_loops(function: &Function) -> Loops {
     let cfg = ControlFlowGraph::with_function(function);
     let post_order = PostOrder::with_function(function);
-    let dom_tree = DominatorTree::with_cfg_and_post_order(&cfg, &post_order);
+    let mut dom_tree = DominatorTree::with_cfg_and_post_order(&cfg, &post_order);
 
     let mut loops = vec![];
 

From 6e860ced6661ac10c3917f97ce851c24c2a766f6 Mon Sep 17 00:00:00 2001
From: Jake Fecher <jake@aztecprotocol.com>
Date: Wed, 21 Jun 2023 15:55:57 -0500
Subject: [PATCH 2/2] Remove timing printlns

---
 .../src/ssa_refactor/opt/flatten_cfg/branch_analysis.rs   | 8 --------
 1 file changed, 8 deletions(-)

diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg/branch_analysis.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg/branch_analysis.rs
index aa65c75a484..bed0686e45b 100644
--- a/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg/branch_analysis.rs
+++ b/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg/branch_analysis.rs
@@ -38,13 +38,9 @@ pub(super) fn find_branch_ends(
     let mut block = function.entry_block();
     let mut context = Context::new(cfg);
 
-    println!("Finding branch ends");
-
     loop {
         let mut successors = cfg.successors(block);
 
-        println!("On top-level block {}", block);
-
         if successors.len() == 2 {
             block = context.find_join_point_of_branches(block, successors);
         } else if successors.len() == 1 {
@@ -57,8 +53,6 @@ pub(super) fn find_branch_ends(
         }
     }
 
-    println!("Done finding branch ends");
-
     context.branch_ends
 }
 
@@ -80,8 +74,6 @@ impl<'cfg> Context<'cfg> {
         let left = successors.next().unwrap();
         let right = successors.next().unwrap();
 
-        println!("On jmpif block {}", start);
-
         let left_join = self.find_join_point(left);
         let right_join = self.find_join_point(right);