diff --git a/src/librustc_mir/interpret/eval_context.rs b/src/librustc_mir/interpret/eval_context.rs index 95e193b625354..c828931808138 100644 --- a/src/librustc_mir/interpret/eval_context.rs +++ b/src/librustc_mir/interpret/eval_context.rs @@ -131,6 +131,10 @@ pub enum LocalValue { } impl<'tcx, Tag: Copy + 'static> LocalState<'tcx, Tag> { + /// Read the local's value or error if the local is not yet live or not live anymore. + /// + /// Note: This may only be invoked from the `Machine::access_local` hook and not from + /// anywhere else. You may be invalidating machine invariants if you do! pub fn access(&self) -> InterpResult<'tcx, Operand> { match self.value { LocalValue::Dead => throw_ub!(DeadLocal), @@ -143,6 +147,9 @@ impl<'tcx, Tag: Copy + 'static> LocalState<'tcx, Tag> { /// Overwrite the local. If the local can be overwritten in place, return a reference /// to do so; otherwise return the `MemPlace` to consult instead. + /// + /// Note: This may only be invoked from the `Machine::access_local_mut` hook and not from + /// anywhere else. You may be invalidating machine invariants if you do! pub fn access_mut( &mut self, ) -> InterpResult<'tcx, Result<&mut LocalValue, MemPlace>> { diff --git a/src/librustc_mir/interpret/machine.rs b/src/librustc_mir/interpret/machine.rs index b5dc40d955191..ec1c93c81657e 100644 --- a/src/librustc_mir/interpret/machine.rs +++ b/src/librustc_mir/interpret/machine.rs @@ -11,7 +11,7 @@ use rustc_span::def_id::DefId; use super::{ AllocId, Allocation, AllocationExtra, CheckInAllocMsg, Frame, ImmTy, InterpCx, InterpResult, - Memory, MemoryKind, OpTy, Operand, PlaceTy, Pointer, Scalar, + LocalValue, MemPlace, Memory, MemoryKind, OpTy, Operand, PlaceTy, Pointer, Scalar, }; /// Data returned by Machine::stack_pop, @@ -192,6 +192,8 @@ pub trait Machine<'mir, 'tcx>: Sized { ) -> InterpResult<'tcx>; /// Called to read the specified `local` from the `frame`. + /// Since reading a ZST is not actually accessing memory or locals, this is never invoked + /// for ZST reads. #[inline] fn access_local( _ecx: &InterpCx<'mir, 'tcx, Self>, @@ -201,6 +203,21 @@ pub trait Machine<'mir, 'tcx>: Sized { frame.locals[local].access() } + /// Called to write the specified `local` from the `frame`. + /// Since writing a ZST is not actually accessing memory or locals, this is never invoked + /// for ZST reads. + #[inline] + fn access_local_mut<'a>( + ecx: &'a mut InterpCx<'mir, 'tcx, Self>, + frame: usize, + local: mir::Local, + ) -> InterpResult<'tcx, Result<&'a mut LocalValue, MemPlace>> + where + 'tcx: 'mir, + { + ecx.stack_mut()[frame].locals[local].access_mut() + } + /// Called before a basic block terminator is executed. /// You can use this to detect endlessly running programs. #[inline] diff --git a/src/librustc_mir/interpret/operand.rs b/src/librustc_mir/interpret/operand.rs index dd746f5cfb409..402bfc93361cf 100644 --- a/src/librustc_mir/interpret/operand.rs +++ b/src/librustc_mir/interpret/operand.rs @@ -432,7 +432,11 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { }) } - /// This is used by [priroda](https://github.com/oli-obk/priroda) to get an OpTy from a local + /// Read from a local. Will not actually access the local if reading from a ZST. + /// Will not access memory, instead an indirect `Operand` is returned. + /// + /// This is public because it is used by [priroda](https://github.com/oli-obk/priroda) to get an + /// OpTy from a local pub fn access_local( &self, frame: &super::Frame<'mir, 'tcx, M::PointerTag, M::FrameExtra>, diff --git a/src/librustc_mir/interpret/place.rs b/src/librustc_mir/interpret/place.rs index 396aec0a8f89f..f9729c5ad2fee 100644 --- a/src/librustc_mir/interpret/place.rs +++ b/src/librustc_mir/interpret/place.rs @@ -740,7 +740,7 @@ where // but not factored as a separate function. let mplace = match dest.place { Place::Local { frame, local } => { - match self.stack_mut()[frame].locals[local].access_mut()? { + match M::access_local_mut(self, frame, local)? { Ok(local) => { // Local can be updated in-place. *local = LocalValue::Live(Operand::Immediate(src)); @@ -973,7 +973,7 @@ where ) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::PointerTag>, Option)> { let (mplace, size) = match place.place { Place::Local { frame, local } => { - match self.stack_mut()[frame].locals[local].access_mut()? { + match M::access_local_mut(self, frame, local)? { Ok(&mut local_val) => { // We need to make an allocation. @@ -997,7 +997,7 @@ where } // Now we can call `access_mut` again, asserting it goes well, // and actually overwrite things. - *self.stack_mut()[frame].locals[local].access_mut().unwrap().unwrap() = + *M::access_local_mut(self, frame, local).unwrap().unwrap() = LocalValue::Live(Operand::Indirect(mplace)); (mplace, Some(size)) } diff --git a/src/librustc_mir/transform/const_prop.rs b/src/librustc_mir/transform/const_prop.rs index e1c5a4f5b1885..5da36ac1cf66e 100644 --- a/src/librustc_mir/transform/const_prop.rs +++ b/src/librustc_mir/transform/const_prop.rs @@ -4,6 +4,7 @@ use std::cell::Cell; use rustc_ast::ast::Mutability; +use rustc_data_structures::fx::FxHashSet; use rustc_hir::def::DefKind; use rustc_hir::HirId; use rustc_index::bit_set::BitSet; @@ -28,7 +29,7 @@ use rustc_trait_selection::traits; use crate::const_eval::error_to_const_error; use crate::interpret::{ self, compile_time_machine, AllocId, Allocation, Frame, ImmTy, Immediate, InterpCx, LocalState, - LocalValue, Memory, MemoryKind, OpTy, Operand as InterpOperand, PlaceTy, Pointer, + LocalValue, MemPlace, Memory, MemoryKind, OpTy, Operand as InterpOperand, PlaceTy, Pointer, ScalarMaybeUninit, StackPopCleanup, }; use crate::transform::{MirPass, MirSource}; @@ -151,11 +152,19 @@ impl<'tcx> MirPass<'tcx> for ConstProp { struct ConstPropMachine<'mir, 'tcx> { /// The virtual call stack. stack: Vec>, + /// `OnlyInsideOwnBlock` locals that were written in the current block get erased at the end. + written_only_inside_own_block_locals: FxHashSet, + /// Locals that need to be cleared after every block terminates. + only_propagate_inside_block_locals: BitSet, } impl<'mir, 'tcx> ConstPropMachine<'mir, 'tcx> { - fn new() -> Self { - Self { stack: Vec::new() } + fn new(only_propagate_inside_block_locals: BitSet) -> Self { + Self { + stack: Vec::new(), + written_only_inside_own_block_locals: Default::default(), + only_propagate_inside_block_locals, + } } } @@ -227,6 +236,18 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for ConstPropMachine<'mir, 'tcx> l.access() } + fn access_local_mut<'a>( + ecx: &'a mut InterpCx<'mir, 'tcx, Self>, + frame: usize, + local: Local, + ) -> InterpResult<'tcx, Result<&'a mut LocalValue, MemPlace>> + { + if frame == 0 && ecx.machine.only_propagate_inside_block_locals.contains(local) { + ecx.machine.written_only_inside_own_block_locals.insert(local); + } + ecx.machine.stack[frame].locals[local].access_mut() + } + fn before_access_global( _memory_extra: &(), _alloc_id: AllocId, @@ -274,8 +295,6 @@ struct ConstPropagator<'mir, 'tcx> { // Because we have `MutVisitor` we can't obtain the `SourceInfo` from a `Location`. So we store // the last known `SourceInfo` here and just keep revisiting it. source_info: Option, - // Locals we need to forget at the end of the current block - locals_of_current_block: BitSet, } impl<'mir, 'tcx> LayoutOf for ConstPropagator<'mir, 'tcx> { @@ -313,8 +332,20 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { let param_env = tcx.param_env(def_id).with_reveal_all(); let span = tcx.def_span(def_id); - let mut ecx = InterpCx::new(tcx, span, param_env, ConstPropMachine::new(), ()); let can_const_prop = CanConstProp::check(body); + let mut only_propagate_inside_block_locals = BitSet::new_empty(can_const_prop.len()); + for (l, mode) in can_const_prop.iter_enumerated() { + if *mode == ConstPropMode::OnlyInsideOwnBlock { + only_propagate_inside_block_locals.insert(l); + } + } + let mut ecx = InterpCx::new( + tcx, + span, + param_env, + ConstPropMachine::new(only_propagate_inside_block_locals), + (), + ); let ret = ecx .layout_of(body.return_ty().subst(tcx, substs)) @@ -345,7 +376,6 @@ impl<'mir, 'tcx> ConstPropagator<'mir, 'tcx> { //FIXME(wesleywiser) we can't steal this because `Visitor::super_visit_body()` needs it local_decls: body.local_decls.clone(), source_info: None, - locals_of_current_block: BitSet::new_empty(body.local_decls.len()), } } @@ -899,7 +929,6 @@ impl<'mir, 'tcx> MutVisitor<'tcx> for ConstPropagator<'mir, 'tcx> { Will remove it from const-prop after block is finished. Local: {:?}", place.local ); - self.locals_of_current_block.insert(place.local); } ConstPropMode::OnlyPropagateInto | ConstPropMode::NoPropagation => { trace!("can't propagate into {:?}", place); @@ -1088,10 +1117,27 @@ impl<'mir, 'tcx> MutVisitor<'tcx> for ConstPropagator<'mir, 'tcx> { } } } - // We remove all Locals which are restricted in propagation to their containing blocks. - for local in self.locals_of_current_block.iter() { + + // We remove all Locals which are restricted in propagation to their containing blocks and + // which were modified in the current block. + // Take it out of the ecx so we can get a mutable reference to the ecx for `remove_const` + let mut locals = std::mem::take(&mut self.ecx.machine.written_only_inside_own_block_locals); + for &local in locals.iter() { Self::remove_const(&mut self.ecx, local); } - self.locals_of_current_block.clear(); + locals.clear(); + // Put it back so we reuse the heap of the storage + self.ecx.machine.written_only_inside_own_block_locals = locals; + if cfg!(debug_assertions) { + // Ensure we are correctly erasing locals with the non-debug-assert logic. + for local in self.ecx.machine.only_propagate_inside_block_locals.iter() { + assert!( + self.get_const(local.into()).is_none() + || self + .layout_of(self.local_decls[local].ty) + .map_or(true, |layout| layout.is_zst()) + ) + } + } } }