Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Const prop: erase all block-only locals at the end of every block #73757

Merged
merged 1 commit into from
Jun 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/librustc_mir/interpret/eval_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ pub enum LocalValue<Tag = ()> {
}

impl<'tcx, Tag: Copy + 'static> LocalState<'tcx, Tag> {
/// Read the local's value or error if the local is not yet live or not live anymore.
///
/// Note: This may only be invoked from the `Machine::access_local` hook and not from
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Note: This may only be invoked from the `Machine::access_local` hook and not from
/// Note: This must only be invoked from the `Machine::access_local` hook and not from

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had must before 😆 I think may is correct here.

/// anywhere else. You may be invalidating machine invariants if you do!
pub fn access(&self) -> InterpResult<'tcx, Operand<Tag>> {
match self.value {
LocalValue::Dead => throw_ub!(DeadLocal),
Expand All @@ -143,6 +147,9 @@ impl<'tcx, Tag: Copy + 'static> LocalState<'tcx, Tag> {

/// Overwrite the local. If the local can be overwritten in place, return a reference
/// to do so; otherwise return the `MemPlace` to consult instead.
///
/// Note: This may only be invoked from the `Machine::access_local_mut` hook and not from
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Note: This may only be invoked from the `Machine::access_local_mut` hook and not from
/// Note: This must only be invoked from the `Machine::access_local_mut` hook and not from

/// anywhere else. You may be invalidating machine invariants if you do!
pub fn access_mut(
&mut self,
) -> InterpResult<'tcx, Result<&mut LocalValue<Tag>, MemPlace<Tag>>> {
Expand Down
19 changes: 18 additions & 1 deletion src/librustc_mir/interpret/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use rustc_span::def_id::DefId;

use super::{
AllocId, Allocation, AllocationExtra, CheckInAllocMsg, Frame, ImmTy, InterpCx, InterpResult,
Memory, MemoryKind, OpTy, Operand, PlaceTy, Pointer, Scalar,
LocalValue, MemPlace, Memory, MemoryKind, OpTy, Operand, PlaceTy, Pointer, Scalar,
};

/// Data returned by Machine::stack_pop,
Expand Down Expand Up @@ -192,6 +192,8 @@ pub trait Machine<'mir, 'tcx>: Sized {
) -> InterpResult<'tcx>;

/// Called to read the specified `local` from the `frame`.
/// Since reading a ZST is not actually accessing memory or locals, this is never invoked
/// for ZST reads.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, this was already the case before but now is documented better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

#[inline]
fn access_local(
_ecx: &InterpCx<'mir, 'tcx, Self>,
Expand All @@ -201,6 +203,21 @@ pub trait Machine<'mir, 'tcx>: Sized {
frame.locals[local].access()
}

/// Called to write the specified `local` from the `frame`.
/// Since writing a ZST is not actually accessing memory or locals, this is never invoked
/// for ZST reads.
#[inline]
fn access_local_mut<'a>(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @RalfJung some miri engine and machine extensions happen in this PR. They don't change anything, they just allow the const propagator to hook into accessing a local mutably and run some extra code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the heads-up!

ecx: &'a mut InterpCx<'mir, 'tcx, Self>,
frame: usize,
local: mir::Local,
) -> InterpResult<'tcx, Result<&'a mut LocalValue<Self::PointerTag>, MemPlace<Self::PointerTag>>>
where
'tcx: 'mir,
{
ecx.stack_mut()[frame].locals[local].access_mut()
}

/// Called before a basic block terminator is executed.
/// You can use this to detect endlessly running programs.
#[inline]
Expand Down
6 changes: 5 additions & 1 deletion src/librustc_mir/interpret/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,11 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
})
}

/// This is used by [priroda](https://github.com/oli-obk/priroda) to get an OpTy from a local
/// Read from a local. Will not actually access the local if reading from a ZST.
/// Will not access memory, instead an indirect `Operand` is returned.
///
/// This is public because it is used by [priroda](https://github.com/oli-obk/priroda) to get an
/// OpTy from a local
pub fn access_local(
&self,
frame: &super::Frame<'mir, 'tcx, M::PointerTag, M::FrameExtra>,
Expand Down
6 changes: 3 additions & 3 deletions src/librustc_mir/interpret/place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ where
// but not factored as a separate function.
let mplace = match dest.place {
Place::Local { frame, local } => {
match self.stack_mut()[frame].locals[local].access_mut()? {
match M::access_local_mut(self, frame, local)? {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a risk here of accidentally using the frame method instead of the hook. Is there anything we can do about that?

I think there should at least be a comment at the frame method saying to not call it (except when implementing the hook).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll investigate whether we can have something static and otherwise document it properly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I documented it. I don't think it's realistic to statically prevent it being called except by introducing some appropriately named token types.

Ok(local) => {
// Local can be updated in-place.
*local = LocalValue::Live(Operand::Immediate(src));
Expand Down Expand Up @@ -973,7 +973,7 @@ where
) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::PointerTag>, Option<Size>)> {
let (mplace, size) = match place.place {
Place::Local { frame, local } => {
match self.stack_mut()[frame].locals[local].access_mut()? {
match M::access_local_mut(self, frame, local)? {
Ok(&mut local_val) => {
// We need to make an allocation.

Expand All @@ -997,7 +997,7 @@ where
}
// Now we can call `access_mut` again, asserting it goes well,
// and actually overwrite things.
*self.stack_mut()[frame].locals[local].access_mut().unwrap().unwrap() =
*M::access_local_mut(self, frame, local).unwrap().unwrap() =
LocalValue::Live(Operand::Indirect(mplace));
(mplace, Some(size))
}
Expand Down
68 changes: 57 additions & 11 deletions src/librustc_mir/transform/const_prop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
use std::cell::Cell;

use rustc_ast::ast::Mutability;
use rustc_data_structures::fx::FxHashSet;
use rustc_hir::def::DefKind;
use rustc_hir::HirId;
use rustc_index::bit_set::BitSet;
Expand All @@ -28,7 +29,7 @@ use rustc_trait_selection::traits;
use crate::const_eval::error_to_const_error;
use crate::interpret::{
self, compile_time_machine, AllocId, Allocation, Frame, ImmTy, Immediate, InterpCx, LocalState,
LocalValue, Memory, MemoryKind, OpTy, Operand as InterpOperand, PlaceTy, Pointer,
LocalValue, MemPlace, Memory, MemoryKind, OpTy, Operand as InterpOperand, PlaceTy, Pointer,
ScalarMaybeUninit, StackPopCleanup,
};
use crate::transform::{MirPass, MirSource};
Expand Down Expand Up @@ -151,11 +152,19 @@ impl<'tcx> MirPass<'tcx> for ConstProp {
struct ConstPropMachine<'mir, 'tcx> {
/// The virtual call stack.
stack: Vec<Frame<'mir, 'tcx, (), ()>>,
/// `OnlyInsideOwnBlock` locals that were written in the current block get erased at the end.
written_only_inside_own_block_locals: FxHashSet<Local>,
wesleywiser marked this conversation as resolved.
Show resolved Hide resolved
/// Locals that need to be cleared after every block terminates.
only_propagate_inside_block_locals: BitSet<Local>,
}

impl<'mir, 'tcx> ConstPropMachine<'mir, 'tcx> {
fn new() -> Self {
Self { stack: Vec::new() }
fn new(only_propagate_inside_block_locals: BitSet<Local>) -> Self {
Self {
stack: Vec::new(),
written_only_inside_own_block_locals: Default::default(),
only_propagate_inside_block_locals,
}
}
}

Expand Down Expand Up @@ -227,6 +236,18 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx>
l.access()
}

fn access_local_mut<'a>(
ecx: &'a mut InterpCx<'mir, 'tcx, Self>,
frame: usize,
local: Local,
) -> InterpResult<'tcx, Result<&'a mut LocalValue<Self::PointerTag>, MemPlace<Self::PointerTag>>>
{
if frame == 0 && ecx.machine.only_propagate_inside_block_locals.contains(local) {
ecx.machine.written_only_inside_own_block_locals.insert(local);
}
ecx.machine.stack[frame].locals[local].access_mut()
}

fn before_access_global(
_memory_extra: &(),
_alloc_id: AllocId,
Expand Down Expand Up @@ -274,8 +295,6 @@ struct ConstPropagator<'mir, 'tcx> {
// Because we have `MutVisitor` we can't obtain the `SourceInfo` from a `Location`. So we store
// the last known `SourceInfo` here and just keep revisiting it.
source_info: Option<SourceInfo>,
// Locals we need to forget at the end of the current block
locals_of_current_block: BitSet<Local>,
}

impl<'mir, 'tcx> LayoutOf for ConstPropagator<'mir, 'tcx> {
Expand Down Expand Up @@ -313,8 +332,20 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
let param_env = tcx.param_env(def_id).with_reveal_all();

let span = tcx.def_span(def_id);
let mut ecx = InterpCx::new(tcx, span, param_env, ConstPropMachine::new(), ());
let can_const_prop = CanConstProp::check(body);
let mut only_propagate_inside_block_locals = BitSet::new_empty(can_const_prop.len());
for (l, mode) in can_const_prop.iter_enumerated() {
if *mode == ConstPropMode::OnlyInsideOwnBlock {
only_propagate_inside_block_locals.insert(l);
}
}
let mut ecx = InterpCx::new(
tcx,
span,
param_env,
ConstPropMachine::new(only_propagate_inside_block_locals),
(),
);

let ret = ecx
.layout_of(body.return_ty().subst(tcx, substs))
Expand Down Expand Up @@ -345,7 +376,6 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> {
//FIXME(wesleywiser) we can't steal this because `Visitor::super_visit_body()` needs it
local_decls: body.local_decls.clone(),
source_info: None,
locals_of_current_block: BitSet::new_empty(body.local_decls.len()),
}
}

Expand Down Expand Up @@ -899,7 +929,6 @@ impl<'mir, 'tcx> MutVisitor<'tcx> for ConstPropagator<'mir, 'tcx> {
Will remove it from const-prop after block is finished. Local: {:?}",
place.local
);
self.locals_of_current_block.insert(place.local);
}
ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => {
trace!("can't propagate into {:?}", place);
Expand Down Expand Up @@ -1088,10 +1117,27 @@ impl<'mir, 'tcx> MutVisitor<'tcx> for ConstPropagator<'mir, 'tcx> {
}
}
}
// We remove all Locals which are restricted in propagation to their containing blocks.
for local in self.locals_of_current_block.iter() {

// We remove all Locals which are restricted in propagation to their containing blocks and
// which were modified in the current block.
// Take it out of the ecx so we can get a mutable reference to the ecx for `remove_const`
let mut locals = std::mem::take(&mut self.ecx.machine.written_only_inside_own_block_locals);
for &local in locals.iter() {
Self::remove_const(&mut self.ecx, local);
}
self.locals_of_current_block.clear();
locals.clear();
// Put it back so we reuse the heap of the storage
self.ecx.machine.written_only_inside_own_block_locals = locals;
if cfg!(debug_assertions) {
// Ensure we are correctly erasing locals with the non-debug-assert logic.
for local in self.ecx.machine.only_propagate_inside_block_locals.iter() {
assert!(
self.get_const(local.into()).is_none()
|| self
.layout_of(self.local_decls[local].ty)
.map_or(true, |layout| layout.is_zst())
)
}
}
}
}