Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restrict promotion of const fn calls #121557

Merged
merged 5 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compiler/rustc_const_eval/src/const_eval/dummy_machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ impl<'mir, 'tcx: 'mir> interpret::Machine<'mir, 'tcx> for DummyMachine {
type MemoryKind = !;
const PANIC_ON_ALLOC_FAIL: bool = true;

// We want to just eval random consts in the program, so `eval_mir_const` can fail.
const ALL_CONSTS_ARE_PRECHECKED: bool = false;

#[inline(always)]
fn enforce_alignment(_ecx: &InterpCx<'mir, 'tcx, Self>) -> bool {
false // no reason to enforce alignment
Expand Down
22 changes: 11 additions & 11 deletions compiler/rustc_const_eval/src/interpret/eval_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -822,15 +822,13 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
self.stack_mut().push(frame);

// Make sure all the constants required by this frame evaluate successfully (post-monomorphization check).
if M::POST_MONO_CHECKS {
for &const_ in &body.required_consts {
let c = self
.instantiate_from_current_frame_and_normalize_erasing_regions(const_.const_)?;
c.eval(*self.tcx, self.param_env, const_.span).map_err(|err| {
err.emit_note(*self.tcx);
err
})?;
}
for &const_ in &body.required_consts {
let c =
self.instantiate_from_current_frame_and_normalize_erasing_regions(const_.const_)?;
c.eval(*self.tcx, self.param_env, const_.span).map_err(|err| {
err.emit_note(*self.tcx);
err
})?;
}

// done
Expand Down Expand Up @@ -1181,8 +1179,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
) -> InterpResult<'tcx, OpTy<'tcx, M::Provenance>> {
M::eval_mir_constant(self, *val, span, layout, |ecx, val, span, layout| {
let const_val = val.eval(*ecx.tcx, ecx.param_env, span).map_err(|err| {
// FIXME: somehow this is reachable even when POST_MONO_CHECKS is on.
// Are we not always populating `required_consts`?
oli-obk marked this conversation as resolved.
Show resolved Hide resolved
if M::ALL_CONSTS_ARE_PRECHECKED && !matches!(err, ErrorHandled::TooGeneric(..)) {
// Looks like the const is not captued by `required_consts`, that's bad.
bug!("interpret const eval failure of {val:?} which is not in required_consts");
}
err.emit_note(*ecx.tcx);
err
})?;
Expand Down
5 changes: 3 additions & 2 deletions compiler/rustc_const_eval/src/interpret/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,9 @@ pub trait Machine<'mir, 'tcx: 'mir>: Sized {
/// Should the machine panic on allocation failures?
const PANIC_ON_ALLOC_FAIL: bool;

/// Should post-monomorphization checks be run when a stack frame is pushed?
const POST_MONO_CHECKS: bool = true;
/// Determines whether `eval_mir_constant` can never fail because all required consts have
/// already been checked before.
const ALL_CONSTS_ARE_PRECHECKED: bool = true;

/// Whether memory accesses should be alignment-checked.
fn enforce_alignment(ecx: &InterpCx<'mir, 'tcx, Self>) -> bool;
Expand Down
14 changes: 14 additions & 0 deletions compiler/rustc_middle/src/mir/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,20 @@ impl<'tcx> Const<'tcx> {
}
}

/// Determines whether we need to add this const to `required_consts`. This is the case if and
/// only if evaluating it may error.
#[inline]
pub fn is_required_const(&self) -> bool {
match self {
Const::Ty(c) => match c.kind() {
ty::ConstKind::Value(_) => false, // already a value, cannot error
_ => true,
},
Const::Val(..) => false, // already a value, cannot error
Const::Unevaluated(..) => true,
}
}

#[inline]
pub fn try_to_scalar(self) -> Option<Scalar> {
match self {
Expand Down
23 changes: 9 additions & 14 deletions compiler/rustc_mir_transform/src/inline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -720,18 +720,12 @@ impl<'tcx> Inliner<'tcx> {
kind: TerminatorKind::Goto { target: integrator.map_block(START_BLOCK) },
});

// Copy only unevaluated constants from the callee_body into the caller_body.
// Although we are only pushing `ConstKind::Unevaluated` consts to
// `required_consts`, here we may not only have `ConstKind::Unevaluated`
// because we are calling `instantiate_and_normalize_erasing_regions`.
caller_body.required_consts.extend(callee_body.required_consts.iter().copied().filter(
|&ct| match ct.const_ {
Const::Ty(_) => {
bug!("should never encounter ty::UnevaluatedConst in `required_consts`")
RalfJung marked this conversation as resolved.
Show resolved Hide resolved
}
Const::Val(..) | Const::Unevaluated(..) => true,
oli-obk marked this conversation as resolved.
Show resolved Hide resolved
},
));
// Copy required constants from the callee_body into the caller_body. Although we are only
// pushing unevaluated consts to `required_consts`, here they may have been evaluated
// because we are calling `instantiate_and_normalize_erasing_regions` -- so we filter again.
caller_body.required_consts.extend(
callee_body.required_consts.into_iter().filter(|ct| ct.const_.is_required_const()),
);
// Now that we incorporated the callee's `required_consts`, we can remove the callee from
// `mentioned_items` -- but we have to take their `mentioned_items` in return. This does
// some extra work here to save the monomorphization collector work later. It helps a lot,
Expand All @@ -747,8 +741,9 @@ impl<'tcx> Inliner<'tcx> {
caller_body.mentioned_items.remove(idx);
caller_body.mentioned_items.extend(callee_body.mentioned_items);
} else {
// If we can't find the callee, there's no point in adding its items.
// Probably it already got removed by being inlined elsewhere in the same function.
// If we can't find the callee, there's no point in adding its items. Probably it
// already got removed by being inlined elsewhere in the same function, so we already
// took its items.
}
}

Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ fn mir_promoted(
body.tainted_by_errors = Some(error_reported);
}

// Collect `required_consts` *before* promotion, so if there are any consts being promoted
// we still add them to the list in the outer MIR body.
let mut required_consts = Vec::new();
let mut required_consts_visitor = RequiredConstsVisitor::new(&mut required_consts);
for (bb, bb_data) in traversal::reverse_postorder(&body) {
Expand Down
151 changes: 117 additions & 34 deletions compiler/rustc_mir_transform/src/promote_consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//! move analysis runs after promotion on broken MIR.

use either::{Left, Right};
use rustc_data_structures::fx::FxHashSet;
use rustc_hir as hir;
use rustc_middle::mir;
use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
Expand Down Expand Up @@ -175,6 +176,12 @@ fn collect_temps_and_candidates<'tcx>(
struct Validator<'a, 'tcx> {
ccx: &'a ConstCx<'a, 'tcx>,
temps: &'a mut IndexSlice<Local, TempState>,
/// For backwards compatibility, we are promoting function calls in `const`/`static`
/// initializers. But we want to avoid evaluating code that might panic and that otherwise would
/// not have been evaluated, so we only promote such calls in basic blocks that are guaranteed
/// to execute. In other words, we only promote such calls in basic blocks that are definitely
/// not dead code. Here we cache the result of computing that set of basic blocks.
promotion_safe_blocks: Option<FxHashSet<BasicBlock>>,
}

impl<'a, 'tcx> std::ops::Deref for Validator<'a, 'tcx> {
Expand Down Expand Up @@ -260,7 +267,9 @@ impl<'tcx> Validator<'_, 'tcx> {
self.validate_rvalue(rhs)
}
Right(terminator) => match &terminator.kind {
TerminatorKind::Call { func, args, .. } => self.validate_call(func, args),
TerminatorKind::Call { func, args, .. } => {
self.validate_call(func, args, loc.block)
}
TerminatorKind::Yield { .. } => Err(Unpromotable),
kind => {
span_bug!(terminator.source_info.span, "{:?} not promotable", kind);
Expand Down Expand Up @@ -588,53 +597,103 @@ impl<'tcx> Validator<'_, 'tcx> {
Ok(())
}

/// Computes the sets of blocks of this MIR that are definitely going to be executed
/// if the function returns successfully. That makes it safe to promote calls in them
/// that might fail.
fn promotion_safe_blocks(body: &mir::Body<'tcx>) -> FxHashSet<BasicBlock> {
let mut safe_blocks = FxHashSet::default();
let mut safe_block = START_BLOCK;
loop {
safe_blocks.insert(safe_block);
// Let's see if we can find another safe block.
safe_block = match body.basic_blocks[safe_block].terminator().kind {
TerminatorKind::Goto { target } => target,
TerminatorKind::Call { target: Some(target), .. }
| TerminatorKind::Drop { target, .. } => {
// This calls a function or the destructor. `target` does not get executed if
// the callee loops or panics. But in both cases the const already fails to
// evaluate, so we are fine considering `target` a safe block for promotion.
target
}
TerminatorKind::Assert { target, .. } => {
// Similar to above, we only consider successful execution.
target
}
_ => {
// No next safe block.
break;
}
};
}
safe_blocks
}

/// Returns whether the block is "safe" for promotion, which means it cannot be dead code.
/// We use this to avoid promoting operations that can fail in dead code.
fn is_promotion_safe_block(&mut self, block: BasicBlock) -> bool {
let body = self.body;
let safe_blocks =
self.promotion_safe_blocks.get_or_insert_with(|| Self::promotion_safe_blocks(body));
safe_blocks.contains(&block)
}

fn validate_call(
&mut self,
callee: &Operand<'tcx>,
args: &[Spanned<Operand<'tcx>>],
block: BasicBlock,
) -> Result<(), Unpromotable> {
// Validate the operands. If they fail, there's no question -- we cannot promote.
self.validate_operand(callee)?;
for arg in args {
self.validate_operand(&arg.node)?;
}

// Functions marked `#[rustc_promotable]` are explicitly allowed to be promoted, so we can
// accept them at this point.
let fn_ty = callee.ty(self.body, self.tcx);
if let ty::FnDef(def_id, _) = *fn_ty.kind() {
if self.tcx.is_promotable_const_fn(def_id) {
return Ok(());
}
}

// Inside const/static items, we promote all (eligible) function calls.
// Everywhere else, we require `#[rustc_promotable]` on the callee.
let promote_all_const_fn = matches!(
// Ideally, we'd stop here and reject the rest.
// But for backward compatibility, we have to accept some promotion in const/static
// initializers. Inline consts are explicitly excluded, they are more recent so we have no
// backwards compatibility reason to allow more promotion inside of them.
let promote_all_fn = matches!(
self.const_kind,
Some(hir::ConstContext::Static(_) | hir::ConstContext::Const { inline: false })
);
if !promote_all_const_fn {
if let ty::FnDef(def_id, _) = *fn_ty.kind() {
// Never promote runtime `const fn` calls of
// functions without `#[rustc_promotable]`.
if !self.tcx.is_promotable_const_fn(def_id) {
return Err(Unpromotable);
}
}
if !promote_all_fn {
return Err(Unpromotable);
}

// Make sure the callee is a `const fn`.
let is_const_fn = match *fn_ty.kind() {
ty::FnDef(def_id, _) => self.tcx.is_const_fn_raw(def_id),
_ => false,
};
if !is_const_fn {
return Err(Unpromotable);
}

self.validate_operand(callee)?;
for arg in args {
self.validate_operand(&arg.node)?;
// The problem is, this may promote calls to functions that panic.
// We don't want to introduce compilation errors if there's a panic in a call in dead code.
// So we ensure that this is not dead code.
if !self.is_promotion_safe_block(block) {
return Err(Unpromotable);
}

// This passed all checks, so let's accept.
Ok(())
}
}

// FIXME(eddyb) remove the differences for promotability in `static`, `const`, `const fn`.
fn validate_candidates(
ccx: &ConstCx<'_, '_>,
temps: &mut IndexSlice<Local, TempState>,
candidates: &[Candidate],
) -> Vec<Candidate> {
let mut validator = Validator { ccx, temps };
let mut validator = Validator { ccx, temps, promotion_safe_blocks: None };

candidates
.iter()
Expand All @@ -653,6 +712,10 @@ struct Promoter<'a, 'tcx> {
/// If true, all nested temps are also kept in the
/// source MIR, not moved to the promoted MIR.
keep_original: bool,

/// If true, add the new const (the promoted) to the required_consts of the parent MIR.
/// This is initially false and then set by the visitor when it encounters a `Call` terminator.
add_to_required: bool,
}

impl<'a, 'tcx> Promoter<'a, 'tcx> {
Expand Down Expand Up @@ -755,6 +818,10 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
TerminatorKind::Call {
mut func, mut args, call_source: desugar, fn_span, ..
} => {
// This promoted involves a function call, so it may fail to evaluate.
// Let's make sure it is added to `required_consts` so that that failure cannot get lost.
self.add_to_required = true;

self.visit_operand(&mut func, loc);
for arg in &mut args {
self.visit_operand(&mut arg.node, loc);
Expand Down Expand Up @@ -789,7 +856,7 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {

fn promote_candidate(mut self, candidate: Candidate, next_promoted_id: usize) -> Body<'tcx> {
let def = self.source.source.def_id();
let mut rvalue = {
let (mut rvalue, promoted_op) = {
let promoted = &mut self.promoted;
let promoted_id = Promoted::new(next_promoted_id);
let tcx = self.tcx;
Expand All @@ -799,11 +866,7 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
let args = tcx.erase_regions(GenericArgs::identity_for_item(tcx, def));
let uneval = mir::UnevaluatedConst { def, args, promoted: Some(promoted_id) };

Operand::Constant(Box::new(ConstOperand {
span,
user_ty: None,
const_: Const::Unevaluated(uneval, ty),
}))
ConstOperand { span, user_ty: None, const_: Const::Unevaluated(uneval, ty) }
};

let blocks = self.source.basic_blocks.as_mut();
Expand Down Expand Up @@ -836,22 +899,26 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {
let promoted_ref = local_decls.push(promoted_ref);
assert_eq!(self.temps.push(TempState::Unpromotable), promoted_ref);

let promoted_operand = promoted_operand(ref_ty, span);
let promoted_ref_statement = Statement {
source_info: statement.source_info,
kind: StatementKind::Assign(Box::new((
Place::from(promoted_ref),
Rvalue::Use(promoted_operand(ref_ty, span)),
Rvalue::Use(Operand::Constant(Box::new(promoted_operand))),
))),
};
self.extra_statements.push((loc, promoted_ref_statement));

Rvalue::Ref(
tcx.lifetimes.re_erased,
*borrow_kind,
Place {
local: mem::replace(&mut place.local, promoted_ref),
projection: List::empty(),
},
(
Rvalue::Ref(
tcx.lifetimes.re_erased,
*borrow_kind,
Place {
local: mem::replace(&mut place.local, promoted_ref),
projection: List::empty(),
},
),
promoted_operand,
)
};

Expand All @@ -863,6 +930,12 @@ impl<'a, 'tcx> Promoter<'a, 'tcx> {

let span = self.promoted.span;
self.assign(RETURN_PLACE, rvalue, span);

// Now that we did promotion, we know whether we'll want to add this to `required_consts`.
if self.add_to_required {
self.source.required_consts.push(promoted_op);
}

self.promoted
}
}
Expand All @@ -878,6 +951,14 @@ impl<'a, 'tcx> MutVisitor<'tcx> for Promoter<'a, 'tcx> {
*local = self.promote_temp(*local);
}
}

fn visit_constant(&mut self, constant: &mut ConstOperand<'tcx>, _location: Location) {
if constant.const_.is_required_const() {
self.promoted.required_consts.push(*constant);
}

// Skipping `super_constant` as the visitor is otherwise only looking for locals.
}
}

fn promote_candidates<'tcx>(
Expand Down Expand Up @@ -931,8 +1012,10 @@ fn promote_candidates<'tcx>(
temps: &mut temps,
extra_statements: &mut extra_statements,
keep_original: false,
add_to_required: false,
};

// `required_consts` of the promoted itself gets filled while building the MIR body.
let mut promoted = promoter.promote_candidate(candidate, promotions.len());
promoted.source.promoted = Some(promotions.next_index());
promotions.push(promoted);
Expand Down
Loading
Loading