Skip to content

Commit

Permalink
conditional constant propagation (#6498)
Browse files Browse the repository at this point in the history
Given this Sway code

```
script;

pub fn main() -> u64 {
    foo(10) + foo(11)
}

fn foo(p: u64) -> u64 {
   if p == 10 {
      p + 1
   } else {
      p - 1
   }
}
```

This optimization (after const-folding) produces

```
fn foo_2(p !116: u64) -> u64, !119 {
        entry(p: u64):
        v0 = const u64 10, !120
        v1 = cmp eq p v0, !123
        v2 = const u64 11, !126
        cbr v1, block2(v2), block1(), !121

        block1():
        v3 = const u64 1, !127
        v4 = sub p, v3, !130
        br block2(v4)

        block2(v5: u64):
        ret u64 v5
    }
```

The `p+1` has been optimized into `11`.

Co-authored-by: Joshua Batty <joshpbatty@gmail.com>
  • Loading branch information
vaivaswatha and JoshuaBatty authored Sep 9, 2024
1 parent 31a1d6f commit 8232d42
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 17 deletions.
2 changes: 2 additions & 0 deletions sway-ir/src/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ pub mod const_demotion;
pub use const_demotion::*;
pub mod constants;
pub use constants::*;
pub mod conditional_constprop;
pub use conditional_constprop::*;
pub mod dce;
pub use dce::*;
pub mod inline;
Expand Down
104 changes: 104 additions & 0 deletions sway-ir/src/optimize/conditional_constprop.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
//! When a value is guaranteed to have a constant value in a region of the CFG,
//! this optimization replaces uses of that value with the constant in that region.
use rustc_hash::FxHashMap;

use crate::{
AnalysisResults, Context, DomTree, Function, InstOp, Instruction, IrError, Pass,
PassMutability, Predicate, ScopedPass, DOMINATORS_NAME,
};

pub const CCP_NAME: &str = "ccp";

pub fn create_ccp_pass() -> Pass {
Pass {
name: CCP_NAME,
descr: "Conditional constant proparagion",
deps: vec![DOMINATORS_NAME],
runner: ScopedPass::FunctionPass(PassMutability::Transform(ccp)),
}
}

pub fn ccp(
context: &mut Context,
analyses: &AnalysisResults,
function: Function,
) -> Result<bool, IrError> {
let dom_tree: &DomTree = analyses.get_analysis_result(function);

// In the set of blocks dominated by `key`, replace all uses of `val.0` with `val.1`.
let mut dom_region_replacements = FxHashMap::default();

for block in function.block_iter(context) {
let term = block
.get_terminator(context)
.expect("Malformed block: no terminator");
if let InstOp::ConditionalBranch {
cond_value,
true_block,
false_block: _,
} = &term.op
{
if let Some(Instruction {
parent: _,
op: InstOp::Cmp(pred, v1, v2),
}) = cond_value.get_instruction(context)
{
if matches!(pred, Predicate::Equal)
&& (v1.is_constant(context) ^ v2.is_constant(context)
&& true_block.block.num_predecessors(context) == 1)
{
if v1.is_constant(context) {
dom_region_replacements.insert(true_block.block, (*v2, *v1));
} else {
dom_region_replacements.insert(true_block.block, (*v1, *v2));
}
}
}
}
}

// lets walk the dominator tree from the root.
let Some((root_block, _root_node)) =
dom_tree.iter().find(|(_block, node)| node.parent.is_none())
else {
panic!("Dominator tree without root");
};

if dom_region_replacements.is_empty() {
return Ok(false);
}

let mut stack = vec![(*root_block, 0)];
let mut replacements = FxHashMap::default();
while let Some((block, next_child)) = stack.last().cloned() {
let cur_replacement_opt = dom_region_replacements.get(&block);

if next_child == 0 {
// Preorder processing
if let Some(cur_replacement) = cur_replacement_opt {
replacements.insert(cur_replacement.0, cur_replacement.1);
}
// walk the current block.
block.replace_values(context, &replacements);
}

let block_node = &dom_tree[&block];

// walk children.
if let Some(child) = block_node.children.get(next_child) {
// When we arrive back at "block" next time, we should process the next child.
stack.last_mut().unwrap().1 = next_child + 1;
// Go on to process the child.
stack.push((*child, 0));
} else {
// No children left to process. Start postorder processing.
if let Some(cur_replacement) = cur_replacement_opt {
replacements.remove(&cur_replacement.0);
}
stack.pop();
}
}

Ok(true)
}
22 changes: 12 additions & 10 deletions sway-ir/src/pass_manager.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use crate::{
create_arg_demotion_pass, create_const_demotion_pass, create_const_folding_pass,
create_dce_pass, create_dom_fronts_pass, create_dominators_pass, create_escaped_symbols_pass,
create_fn_dce_pass, create_fn_dedup_debug_profile_pass, create_fn_dedup_release_profile_pass,
create_fn_inline_pass, create_mem2reg_pass, create_memcpyopt_pass, create_misc_demotion_pass,
create_module_printer_pass, create_module_verifier_pass, create_postorder_pass,
create_ret_demotion_pass, create_simplify_cfg_pass, create_sroa_pass, Context, Function,
IrError, Module, ARG_DEMOTION_NAME, CONST_DEMOTION_NAME, CONST_FOLDING_NAME, DCE_NAME,
FN_DCE_NAME, FN_DEDUP_DEBUG_PROFILE_NAME, FN_DEDUP_RELEASE_PROFILE_NAME, FN_INLINE_NAME,
MEM2REG_NAME, MEMCPYOPT_NAME, MISC_DEMOTION_NAME, RET_DEMOTION_NAME, SIMPLIFY_CFG_NAME,
SROA_NAME,
create_arg_demotion_pass, create_ccp_pass, create_const_demotion_pass,
create_const_folding_pass, create_dce_pass, create_dom_fronts_pass, create_dominators_pass,
create_escaped_symbols_pass, create_fn_dce_pass, create_fn_dedup_debug_profile_pass,
create_fn_dedup_release_profile_pass, create_fn_inline_pass, create_mem2reg_pass,
create_memcpyopt_pass, create_misc_demotion_pass, create_module_printer_pass,
create_module_verifier_pass, create_postorder_pass, create_ret_demotion_pass,
create_simplify_cfg_pass, create_sroa_pass, Context, Function, IrError, Module,
ARG_DEMOTION_NAME, CCP_NAME, CONST_DEMOTION_NAME, CONST_FOLDING_NAME, DCE_NAME, FN_DCE_NAME,
FN_DEDUP_DEBUG_PROFILE_NAME, FN_DEDUP_RELEASE_PROFILE_NAME, FN_INLINE_NAME, MEM2REG_NAME,
MEMCPYOPT_NAME, MISC_DEMOTION_NAME, RET_DEMOTION_NAME, SIMPLIFY_CFG_NAME, SROA_NAME,
};
use downcast_rs::{impl_downcast, Downcast};
use rustc_hash::FxHashMap;
Expand Down Expand Up @@ -396,6 +396,7 @@ pub fn register_known_passes(pm: &mut PassManager) {
pm.register(create_sroa_pass());
pm.register(create_fn_inline_pass());
pm.register(create_const_folding_pass());
pm.register(create_ccp_pass());
pm.register(create_simplify_cfg_pass());
pm.register(create_fn_dce_pass());
pm.register(create_dce_pass());
Expand All @@ -416,6 +417,7 @@ pub fn create_o1_pass_group() -> PassGroup {
o1.append_pass(SIMPLIFY_CFG_NAME);
o1.append_pass(FN_DCE_NAME);
o1.append_pass(FN_INLINE_NAME);
o1.append_pass(CCP_NAME);
o1.append_pass(CONST_FOLDING_NAME);
o1.append_pass(SIMPLIFY_CFG_NAME);
o1.append_pass(CONST_FOLDING_NAME);
Expand Down
29 changes: 29 additions & 0 deletions sway-ir/tests/ccp/ccp1.ir
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// regex: ID=[[:alpha:]_0-9]+

script {
// check: fn $ID($(param=$ID): u64) -> u64
fn test1(p: u64) -> u64 {
entry(p: u64):
// check: $(c=$ID) = const u64 100
v0 = const u64 100
// check: cmp eq $c $param
v2 = cmp eq v0 p
// check: $ID, $(true_block=$ID)(), $(false_block=$ID)()
cbr v2, get_0_block0(), get_0_block1()

// check: $true_block():
get_0_block0():
// check: $(one=$ID) = const u64 1
v5 = const u64 1
// check: add $one, $c
v6 = add v5, p
ret u64 v6

// check: $false_block():
get_0_block1():
v7 = const u64 111
// check: add $ID, $param
v8 = add v7, p
ret u64 v8
}
}
30 changes: 23 additions & 7 deletions sway-ir/tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use std::path::PathBuf;

use itertools::Itertools;
use sway_ir::{
create_arg_demotion_pass, create_const_demotion_pass, create_const_folding_pass,
create_dce_pass, create_dom_fronts_pass, create_dominators_pass, create_escaped_symbols_pass,
create_mem2reg_pass, create_memcpyopt_pass, create_misc_demotion_pass, create_postorder_pass,
create_ret_demotion_pass, create_simplify_cfg_pass, metadata_to_inline, optimize as opt,
register_known_passes, Context, ExperimentalFlags, Function, IrError, PassGroup, PassManager,
Value, DCE_NAME, FN_DCE_NAME, FN_DEDUP_DEBUG_PROFILE_NAME, FN_DEDUP_RELEASE_PROFILE_NAME,
MEM2REG_NAME, SROA_NAME,
create_arg_demotion_pass, create_ccp_pass, create_const_demotion_pass,
create_const_folding_pass, create_dce_pass, create_dom_fronts_pass, create_dominators_pass,
create_escaped_symbols_pass, create_mem2reg_pass, create_memcpyopt_pass,
create_misc_demotion_pass, create_postorder_pass, create_ret_demotion_pass,
create_simplify_cfg_pass, metadata_to_inline, optimize as opt, register_known_passes, Context,
ExperimentalFlags, Function, IrError, PassGroup, PassManager, Value, DCE_NAME, FN_DCE_NAME,
FN_DEDUP_DEBUG_PROFILE_NAME, FN_DEDUP_RELEASE_PROFILE_NAME, MEM2REG_NAME, SROA_NAME,
};
use sway_types::SourceEngine;

Expand Down Expand Up @@ -237,6 +237,22 @@ fn constants() {

// -------------------------------------------------------------------------------------------------

#[allow(clippy::needless_collect)]
#[test]
fn ccp() {
run_tests("ccp", |_first_line, ir: &mut Context| {
let mut pass_mgr = PassManager::default();
let mut pass_group = PassGroup::default();
pass_mgr.register(create_postorder_pass());
pass_mgr.register(create_dominators_pass());
let pass = pass_mgr.register(create_ccp_pass());
pass_group.append_pass(pass);
pass_mgr.run(ir, &pass_group).unwrap()
})
}

// -------------------------------------------------------------------------------------------------

#[allow(clippy::needless_collect)]
#[test]
fn simplify_cfg() {
Expand Down

0 comments on commit 8232d42

Please sign in to comment.