Skip to content

Commit

Permalink
Rollup merge of #73757 - oli-obk:const_prop_hardening, r=wesleywiser
Browse files Browse the repository at this point in the history
Const prop: erase all block-only locals at the end of every block

I messed up this erasure in #73656 (comment). I think it is too fragile to have the previous scheme. Let's benchmark the new scheme and see what happens.

r? @wesleywiser

cc @felix91gr
  • Loading branch information
Manishearth authored Jun 28, 2020
2 parents ec48989 + b9f4e0d commit ccc1bf7
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 16 deletions.
7 changes: 7 additions & 0 deletions src/librustc_mir/interpret/eval_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ pub enum LocalValue<Tag = ()> {
}

impl<'tcx, Tag: Copy + 'static> LocalState<'tcx, Tag> {
/// Read the local's value or error if the local is not yet live or not live anymore.
///
/// Note: This may only be invoked from the `Machine::access_local` hook and not from
/// anywhere else. You may be invalidating machine invariants if you do!
pub fn access(&self) -> InterpResult<'tcx, Operand<Tag>> {
match self.value {
LocalValue::Dead => throw_ub!(DeadLocal),
Expand All @@ -144,6 +148,9 @@ impl<'tcx, Tag: Copy + 'static> LocalState<'tcx, Tag> {

/// Overwrite the local. If the local can be overwritten in place, return a reference
/// to do so; otherwise return the `MemPlace` to consult instead.
///
/// Note: This may only be invoked from the `Machine::access_local_mut` hook and not from
/// anywhere else. You may be invalidating machine invariants if you do!
pub fn access_mut(
&mut self,
) -> InterpResult<'tcx, Result<&mut LocalValue<Tag>, MemPlace<Tag>>> {
Expand Down
19 changes: 18 additions & 1 deletion src/librustc_mir/interpret/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use rustc_span::def_id::DefId;

use super::{
AllocId, Allocation, AllocationExtra, CheckInAllocMsg, Frame, ImmTy, InterpCx, InterpResult,
Memory, MemoryKind, OpTy, Operand, PlaceTy, Pointer, Scalar,
LocalValue, MemPlace, Memory, MemoryKind, OpTy, Operand, PlaceTy, Pointer, Scalar,
};

/// Data returned by Machine::stack_pop,
Expand Down Expand Up @@ -192,6 +192,8 @@ pub trait Machine<'mir, 'tcx>: Sized {
) -> InterpResult<'tcx>;

/// Called to read the specified `local` from the `frame`.
/// Since reading a ZST is not actually accessing memory or locals, this is never invoked
/// for ZST reads.
#[inline]
fn access_local(
_ecx: &InterpCx<'mir, 'tcx, Self>,
Expand All @@ -201,6 +203,21 @@ pub trait Machine<'mir, 'tcx>: Sized {
frame.locals[local].access()
}

/// Called to write the specified `local` from the `frame`.
/// Since writing a ZST is not actually accessing memory or locals, this is never invoked
/// for ZST reads.
#[inline]
fn access_local_mut<'a>(
ecx: &'a mut InterpCx<'mir, 'tcx, Self>,
frame: usize,
local: mir::Local,
) -> InterpResult<'tcx, Result<&'a mut LocalValue<Self::PointerTag>, MemPlace<Self::PointerTag>>>
where
'tcx: 'mir,
{
ecx.stack_mut()[frame].locals[local].access_mut()
}

/// Called before a basic block terminator is executed.
/// You can use this to detect endlessly running programs.
#[inline]
Expand Down
6 changes: 5 additions & 1 deletion src/librustc_mir/interpret/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,11 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
})
}

/// This is used by [priroda](https://github.com/oli-obk/priroda) to get an OpTy from a local
/// Read from a local. Will not actually access the local if reading from a ZST.
/// Will not access memory, instead an indirect `Operand` is returned.
///
/// This is public because it is used by [priroda](https://github.com/oli-obk/priroda) to get an
/// OpTy from a local
pub fn access_local(
&self,
frame: &super::Frame<'mir, 'tcx, M::PointerTag, M::FrameExtra>,
Expand Down
6 changes: 3 additions & 3 deletions src/librustc_mir/interpret/place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ where
// but not factored as a separate function.
let mplace = match dest.place {
Place::Local { frame, local } => {
match self.stack_mut()[frame].locals[local].access_mut()? {
match M::access_local_mut(self, frame, local)? {
Ok(local) => {
// Local can be updated in-place.
*local = LocalValue::Live(Operand::Immediate(src));
Expand Down Expand Up @@ -974,7 +974,7 @@ where
) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::PointerTag>, Option<Size>)> {
let (mplace, size) = match place.place {
Place::Local { frame, local } => {
match self.stack_mut()[frame].locals[local].access_mut()? {
match M::access_local_mut(self, frame, local)? {
Ok(&mut local_val) => {
// We need to make an allocation.

Expand All @@ -998,7 +998,7 @@ where
}
// Now we can call `access_mut` again, asserting it goes well,
// and actually overwrite things.
*self.stack_mut()[frame].locals[local].access_mut().unwrap().unwrap() =
*M::access_local_mut(self, frame, local).unwrap().unwrap() =
LocalValue::Live(Operand::Indirect(mplace));
(mplace, Some(size))
}
Expand Down
68 changes: 57 additions & 11 deletions src/librustc_mir/transform/const_prop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
use std::cell::Cell;

use rustc_ast::ast::Mutability;
use rustc_data_structures::fx::FxHashSet;
use rustc_hir::def::DefKind;
use rustc_hir::HirId;
use rustc_index::bit_set::BitSet;
Expand All @@ -28,7 +29,7 @@ use rustc_trait_selection::traits;
use crate::const_eval::error_to_const_error;
use crate::interpret::{
self, compile_time_machine, AllocId, Allocation, Frame, ImmTy, Immediate, InterpCx, LocalState,
LocalValue, Memory, MemoryKind, OpTy, Operand as InterpOperand, PlaceTy, Pointer,
LocalValue, MemPlace, Memory, MemoryKind, OpTy, Operand as InterpOperand, PlaceTy, Pointer,
ScalarMaybeUninit, StackPopCleanup,
};
use crate::transform::{MirPass, MirSource};
Expand Down Expand Up @@ -151,11 +152,19 @@ impl<'tcx> MirPass<'tcx> for ConstProp {
struct ConstPropMachine<'mir, 'tcx> {
/// The virtual call stack.
stack: Vec<Frame<'mir, 'tcx, (), ()>>,
/// `OnlyInsideOwnBlock` locals that were written in the current block get erased at the end.
written_only_inside_own_block_locals: FxHashSet<Local>,
/// Locals that need to be cleared after every block terminates.
only_propagate_inside_block_locals: BitSet<Local>,
}

impl<'mir, 'tcx> ConstPropMachine<'mir, 'tcx> {
fn new() -> Self {
Self { stack: Vec::new() }
fn new(only_propagate_inside_block_locals: BitSet<Local>) -> Self {
Self {
stack: Vec::new(),
written_only_inside_own_block_locals: Default::default(),
only_propagate_inside_block_locals,
}
}
}

Expand Down Expand Up @@ -227,6 +236,18 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx>
l.access()
}

fn access_local_mut<'a>(
ecx: &'a mut InterpCx<'mir, 'tcx, Self>,
frame: usize,
local: Local,
) -> InterpResult<'tcx, Result<&'a mut LocalValue<Self::PointerTag>, MemPlace<Self::PointerTag>>>
{
if frame == 0 && ecx.machine.only_propagate_inside_block_locals.contains(local) {
ecx.machine.written_only_inside_own_block_locals.insert(local);
}
ecx.machine.stack[frame].locals[local].access_mut()
}

fn before_access_global(
_memory_extra: &(),
_alloc_id: AllocId,
Expand Down Expand Up @@ -274,8 +295,6 @@ struct ConstPropagator<'mir, 'tcx> {
// Because we have `MutVisitor` we can't obtain the `SourceInfo` from a `Location`. So we store
// the last known `SourceInfo` here and just keep revisiting it.
source_info: Option<SourceInfo>,
// Locals we need to forget at the end of the current block
locals_of_current_block: BitSet<Local>,
}

impl<'mir, 'tcx> LayoutOf for ConstPropagator<'mir, 'tcx> {
Expand Down Expand Up @@ -313,8 +332,20 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
let param_env = tcx.param_env(def_id).with_reveal_all();

let span = tcx.def_span(def_id);
let mut ecx = InterpCx::new(tcx, span, param_env, ConstPropMachine::new(), ());
let can_const_prop = CanConstProp::check(body);
let mut only_propagate_inside_block_locals = BitSet::new_empty(can_const_prop.len());
for (l, mode) in can_const_prop.iter_enumerated() {
if *mode == ConstPropMode::OnlyInsideOwnBlock {
only_propagate_inside_block_locals.insert(l);
}
}
let mut ecx = InterpCx::new(
tcx,
span,
param_env,
ConstPropMachine::new(only_propagate_inside_block_locals),
(),
);

let ret = ecx
.layout_of(body.return_ty().subst(tcx, substs))
Expand Down Expand Up @@ -345,7 +376,6 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
//FIXME(wesleywiser) we can't steal this because `Visitor::super_visit_body()` needs it
local_decls: body.local_decls.clone(),
source_info: None,
locals_of_current_block: BitSet::new_empty(body.local_decls.len()),
}
}

Expand Down Expand Up @@ -900,7 +930,6 @@ impl<'mir, 'tcx> MutVisitor<'tcx> for ConstPropagator<'mir, 'tcx> {
Will remove it from const-prop after block is finished. Local: {:?}",
place.local
);
self.locals_of_current_block.insert(place.local);
}
ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => {
trace!("can't propagate into {:?}", place);
Expand Down Expand Up @@ -1089,10 +1118,27 @@ impl<'mir, 'tcx> MutVisitor<'tcx> for ConstPropagator<'mir, 'tcx> {
}
}
}
// We remove all Locals which are restricted in propagation to their containing blocks.
for local in self.locals_of_current_block.iter() {

// We remove all Locals which are restricted in propagation to their containing blocks and
// which were modified in the current block.
// Take it out of the ecx so we can get a mutable reference to the ecx for `remove_const`
let mut locals = std::mem::take(&mut self.ecx.machine.written_only_inside_own_block_locals);
for &local in locals.iter() {
Self::remove_const(&mut self.ecx, local);
}
self.locals_of_current_block.clear();
locals.clear();
// Put it back so we reuse the heap of the storage
self.ecx.machine.written_only_inside_own_block_locals = locals;
if cfg!(debug_assertions) {
// Ensure we are correctly erasing locals with the non-debug-assert logic.
for local in self.ecx.machine.only_propagate_inside_block_locals.iter() {
assert!(
self.get_const(local.into()).is_none()
|| self
.layout_of(self.local_decls[local].ty)
.map_or(true, |layout| layout.is_zst())
)
}
}
}
}

0 comments on commit ccc1bf7

Please sign in to comment.