Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify switches on a statically known discriminant in MIR. #112688

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ mod check_alignment;
pub mod simplify;
mod simplify_branches;
mod simplify_comparison_integral;
mod simplify_static_switch;
mod sroa;
mod uninhabited_enum_branching;
mod unreachable_prop;
Expand Down Expand Up @@ -561,6 +562,8 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&simplify::SimplifyLocals::BeforeConstProp,
&copy_prop::CopyProp,
&ref_prop::ReferencePropagation,
// Remove switches on a statically known discriminant, which can happen as a result of inlining.
&simplify_static_switch::SimplifyStaticSwitch,
// Perform `SeparateConstSwitch` after SSA-based analyses, as cloning blocks may
// destroy the SSA property. It should still happen before const-propagation, so the
// latter pass will leverage the created opportunities.
Expand Down
320 changes: 320 additions & 0 deletions compiler/rustc_mir_transform/src/simplify_static_switch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,320 @@
use super::MirPass;

use rustc_data_structures::fx::FxHashMap;
use rustc_middle::mir::visit::{PlaceContext, Visitor};
use rustc_middle::mir::{
AggregateKind, BasicBlock, Body, Local, Location, Operand, Place, Rvalue, StatementKind,
TerminatorKind,
};
use rustc_middle::ty::TyCtxt;
use rustc_mir_dataflow::impls::MaybeBorrowedLocals;
use rustc_mir_dataflow::Analysis;
use rustc_session::Session;

use super::simplify;
use super::ssa::SsaLocals;

/// # Overview
///
/// This pass looks to optimize a pattern in MIR where variants of an aggregate
/// are constructed in one or more blocks with the same successor and then that
/// aggregate/discriminant is switched on in that successor block, in which case
/// we can remove the switch on the discriminant because we statically know
/// what target block will be taken for each variant.
///
/// Note that an aggregate which is returned from a function call or passed as
/// an argument is not viable for this optimization because we do not statically
/// know the discriminant/variant of the aggregate.
///
/// For example, the following CFG:
/// ```text
/// x = Foo::A(y); --- Foo::A ---> ...
/// / \ /
/// ... --> switch x
/// \ / \
/// x = Foo::B(y); --- Foo::B ---> ...
/// ```
/// would become:
/// ```text
/// x = Foo::A(y); --------- Foo::A ---> ...
/// /
/// ...
/// \
/// x = Foo::B(y); --------- Foo::B ---> ...
/// ```
///
/// # Soundness
///
/// - If the discriminant being switched on is not SSA, or if the aggregate is
/// mutated before the discriminant is assigned, the optimization cannot be
/// applied because we no longer statically know what variant the aggregate
/// could be, or what discriminant is being switched on.
///
/// - If the discriminant is borrowed before being switched on, or the aggregate
/// is borrowed before the discriminant is assigned, we also cannot optimize due
/// to the possibilty stated in the first paragraph.
///
/// - An aggregate being constructed has a known variant, and if it is not borrowed
/// or mutated before being switched on, then it does not actually need a runtime
/// switch on the discriminant (aka variant) of said aggregate.
///
pub struct SimplifyStaticSwitch;

impl<'tcx> MirPass<'tcx> for SimplifyStaticSwitch {
fn is_enabled(&self, sess: &Session) -> bool {
sess.mir_opt_level() >= 2
}

fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
debug!("Running SimplifyStaticSwitch on {:?}", body.source.def_id());

let ssa_locals = SsaLocals::new(body);
if simplify_static_switches(tcx, body, &ssa_locals) {
simplify::remove_dead_blocks(tcx, body);
}
}
}

#[instrument(level = "debug", skip_all, ret)]
fn simplify_static_switches<'tcx>(
tcx: TyCtxt<'tcx>,
body: &mut Body<'tcx>,
ssa_locals: &SsaLocals,
) -> bool {
let dominators = body.basic_blocks.dominators();
let predecessors = body.basic_blocks.predecessors();
let mut discriminants = FxHashMap::default();
let mut static_switches = FxHashMap::default();
Comment on lines +86 to +87
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you document what is store in each map?

let mut borrowed_locals =
MaybeBorrowedLocals.into_engine(tcx, body).iterate_to_fixpoint().into_results_cursor(body);
for (switched, rvalue, location) in ssa_locals.assignments(body) {
let Rvalue::Discriminant(discr) = rvalue else {
continue
};

borrowed_locals.seek_after_primary_effect(location);
// If `discr` was borrowed before its discriminant was assigned to `switched`,
// or if it was borrowed in the assignment, we cannot optimize.
if borrowed_locals.contains(discr.local) {
debug!("The aggregate: {discr:?} was borrowed before its discriminant was read");
continue;
}

let Location { block, statement_index } = location;
let mut finder = MutatedLocalFinder { local: discr.local, mutated: false };
for (statement_index, statement) in body.basic_blocks[block]
.statements
.iter()
.enumerate()
.take_while(|&(index, _)| index != statement_index)
Comment on lines +106 to +109
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.statements
.iter()
.enumerate()
.take_while(|&(index, _)| index != statement_index)
.statements[..statement_index]
.iter()
.enumerate()

{
finder.visit_statement(statement, Location { block, statement_index });
}

if finder.mutated {
debug!("The aggregate: {discr:?} was mutated before its discriminant was read");
continue;
}

// If `switched` is borrowed by the time we actually switch on it, we also cannot optimize.
borrowed_locals.seek_to_block_end(block);
if borrowed_locals.contains(switched) {
cjgillot marked this conversation as resolved.
Show resolved Hide resolved
debug!("The local: {switched:?} was borrowed before being switched on");
continue;
}

discriminants.insert(
switched,
Discriminant {
block,
discr: discr.local,
exclude: if ssa_locals.num_direct_uses(switched) == 1 {
// If there is only one direct use of `switched` we do not need to keep
// it around because the only use is in the switch.
Some(statement_index)
} else {
None
},
},
);
}

if discriminants.is_empty() {
debug!("No SSA locals were assigned a discriminant");
return false;
}

for (switched, Discriminant { discr, block, exclude }) in discriminants {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discriminants is a hashmap. Iteration order is unstable. Should discriminants be a Vec or pairs?

let data = &body.basic_blocks[block];
if data.is_cleanup {
continue;
}

let predecessors = &predecessors[block];
if predecessors.is_empty() {
continue;
}

if predecessors.iter().any(|&pred| {
// If we find a backedge from: `pred -> block`, this indicates that
// `block` is a loop header. To avoid creating irreducible CFGs we do
// not thread through loop headers.
dominators.dominates(block, pred)
}) {
debug!("Unable to thread through loop header: {block:?}");
continue;
}

let terminator = data.terminator();
let TerminatorKind::SwitchInt {
discr: Operand::Copy(place) | Operand::Move(place),
targets
} = &terminator.kind else {
continue
};
cjgillot marked this conversation as resolved.
Show resolved Hide resolved

if place.local != switched {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

local has necessarily an integer type. Should we assert that place.projection is empty?

continue;
}

let mut finder = MutatedLocalFinder { local: discr, mutated: false };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be created inside the loop, to ensure we don't forget to reset mutated.

'preds: for &pred in predecessors {
let data = &body.basic_blocks[pred];
let terminator = data.terminator();
let TerminatorKind::Goto { .. } = terminator.kind else {
continue
};

for (statement_index, statement) in data.statements.iter().enumerate().rev() {
match statement.kind {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this can be an if-let.

StatementKind::SetDiscriminant { box place, variant_index: variant }
| StatementKind::Assign(box (
place,
Rvalue::Aggregate(box AggregateKind::Adt(_, variant, ..), ..),
)) if place.local == discr => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assert place.projection is empty here too.

if finder.mutated {
debug!(
"The discriminant: {discr:?} was mutated in predecessor: {pred:?}"
);
// We can't optimize this predecessor, so try the next one.
finder.mutated = false;

continue 'preds;
}

let discr_ty = body.local_decls[discr].ty;
if let Some(discr) = discr_ty.discriminant_for_variant(tcx, variant) {
debug!(
"{pred:?}: {place:?} = {discr_ty:?}::{variant:?}; goto -> {block:?}",
);

let target = targets.target_for_value(discr.val);
static_switches
.entry(block)
.and_modify(|static_switches: &mut &mut [StaticSwitch]| {
if static_switches.iter_mut().all(|switch| {
if switch.pred == pred {
switch.target = target;
false
} else {
true
}
}) {
*static_switches =
tcx.arena.alloc_from_iter(
static_switches.iter().copied().chain([
StaticSwitch { pred, target, exclude },
]),
);
}
})
.or_insert_with(|| {
tcx.arena.alloc([StaticSwitch { pred, target, exclude }])
});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand how the worklist is stored and why.
Could you add an explanation?
Why isn't a vector of edges (pred, block, target), meaning "the pred->block edge becomes pred->target" enough?

}

continue 'preds;
}
_ if finder.mutated => {
debug!("The discriminant: {discr:?} was mutated in predecessor: {pred:?}");
// Note that the discriminant could have been mutated in one predecessor
// but not the others, in which case only the predecessors which did not mutate
// the discriminant can be optimized.
finder.mutated = false;

continue 'preds;
}
_ => finder.visit_statement(statement, Location { block, statement_index }),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we create finder, visit, and check mutated in this arm, instead of the other arms?

}
}
}
}

if static_switches.is_empty() {
debug!("No static switches were found in the current body");
return false;
}

let basic_blocks = body.basic_blocks.as_mut();
let num_switches: usize = static_switches.iter().map(|(_, switches)| switches.len()).sum();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Counting the number of opts just for tracing is not useful.

for (block, static_switches) in static_switches {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, static_switches is a hashmap, iteration order is unstable.

for switch in static_switches {
debug!("{block:?}: Removing static switch: {switch:?}");

// We use the SSA, to destroy the SSA.
let data = {
let (block, pred) = basic_blocks.pick2_mut(block, switch.pred);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining why block != switch.pred?

match switch.exclude {
Some(exclude) => {
pred.statements.extend(block.statements.iter().enumerate().filter_map(
|(index, statement)| {
if index == exclude { None } else { Some(statement.clone()) }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not necessary. We can run remove_unused_definitions together with remove_dead_blocks.

},
));
}
None => pred.statements.extend_from_slice(&block.statements),
}
pred
};
let terminator = data.terminator_mut();

// Make sure that we have not overwritten the terminator and it is still
// a `goto -> block`.
assert_eq!(terminator.kind, TerminatorKind::Goto { target: block });
// Something to be noted is that, this creates an edge from: `pred -> target`,
// and because we ensure that we do not thread through any loop headers, meaning
// it is not part of a loop, this edge will only ever appear once in the CFG.
terminator.kind = TerminatorKind::Goto { target: switch.target };
}
}

debug!("Removed {num_switches} static switches from: {:?}", body.source.def_id());
true
}

#[derive(Debug, Copy, Clone)]
struct StaticSwitch {
pred: BasicBlock,
target: BasicBlock,
exclude: Option<usize>,
}

#[derive(Debug, Copy, Clone)]
struct Discriminant {
discr: Local,
block: BasicBlock,
exclude: Option<usize>,
}

struct MutatedLocalFinder {
local: Local,
mutated: bool,
}

impl<'tcx> Visitor<'tcx> for MutatedLocalFinder {
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _: Location) {
if self.local == place.local && let PlaceContext::MutatingUse(..) = context {
self.mutated = true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we detect storage statements too?

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,51 +51,44 @@
StorageLive(_10);
StorageLive(_11);
_9 = discriminant(_1);
switchInt(move _9) -> [0: bb7, 1: bb5, otherwise: bb6];
switchInt(move _9) -> [0: bb5, 1: bb3, otherwise: bb4];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make this a unit test for SeparateConstSwitch?

}

bb1: {
StorageDead(_11);
StorageDead(_10);
_5 = discriminant(_3);
switchInt(move _5) -> [0: bb2, 1: bb4, otherwise: bb3];
}

bb2: {
_8 = ((_3 as Continue).0: i32);
_0 = Result::<i32, i32>::Ok(_8);
StorageDead(_3);
return;
}

bb3: {
unreachable;
}

bb4: {
bb2: {
_6 = ((_3 as Break).0: std::result::Result<std::convert::Infallible, i32>);
_13 = ((_6 as Err).0: i32);
_0 = Result::<i32, i32>::Err(move _13);
StorageDead(_3);
return;
}

bb5: {
bb3: {
_11 = ((_1 as Err).0: i32);
StorageLive(_12);
_12 = Result::<Infallible, i32>::Err(move _11);
_3 = ControlFlow::<Result<Infallible, i32>, i32>::Break(move _12);
StorageDead(_12);
goto -> bb1;
StorageDead(_11);
StorageDead(_10);
goto -> bb2;
}

bb6: {
bb4: {
unreachable;
}

bb7: {
bb5: {
_10 = ((_1 as Ok).0: i32);
_3 = ControlFlow::<Result<Infallible, i32>, i32>::Continue(move _10);
StorageDead(_11);
StorageDead(_10);
goto -> bb1;
}
}
Expand Down
Loading