Skip to content

Commit

Permalink
separate bounds-check from alignment check
Browse files Browse the repository at this point in the history
  • Loading branch information
RalfJung committed Oct 15, 2023
1 parent e24835c commit b131fc1
Show file tree
Hide file tree
Showing 13 changed files with 89 additions and 136 deletions.
4 changes: 2 additions & 2 deletions compiler/rustc_const_eval/src/interpret/intern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,15 @@ impl<'rt, 'mir, 'tcx: 'mir, M: CompileTimeMachine<'mir, 'tcx, const_eval::Memory
// to avoid could be expensive: on the potentially larger types, arrays and slices,
// rather than on all aggregates unconditionally.
if matches!(mplace.layout.ty.kind(), ty::Array(..) | ty::Slice(..)) {
let Some((size, align)) = self.ecx.size_and_align_of_mplace(&mplace)? else {
let Some((size, _align)) = self.ecx.size_and_align_of_mplace(&mplace)? else {
// We do the walk if we can't determine the size of the mplace: we may be
// dealing with extern types here in the future.
return Ok(true);
};

// If there is no provenance in this allocation, it does not contain references
// that point to another allocation, and we can avoid the interning walk.
if let Some(alloc) = self.ecx.get_ptr_alloc(mplace.ptr(), size, align)? {
if let Some(alloc) = self.ecx.get_ptr_alloc(mplace.ptr(), size)? {
if !alloc.has_provenance() {
return Ok(false);
}
Expand Down
15 changes: 8 additions & 7 deletions compiler/rustc_const_eval/src/interpret/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use rustc_middle::ty::layout::{LayoutOf as _, ValidityRequirement};
use rustc_middle::ty::GenericArgsRef;
use rustc_middle::ty::{Ty, TyCtxt};
use rustc_span::symbol::{sym, Symbol};
use rustc_target::abi::{Abi, Align, Primitive, Size};
use rustc_target::abi::{Abi, Primitive, Size};

use super::{
util::ensure_monomorphic_enough, CheckInAllocMsg, ImmTy, InterpCx, Machine, OpTy, PlaceTy,
Expand Down Expand Up @@ -349,10 +349,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
// Check that the range between them is dereferenceable ("in-bounds or one past the
// end of the same allocation"). This is like the check in ptr_offset_inbounds.
let min_ptr = if dist >= 0 { b } else { a };
self.check_ptr_access_align(
self.check_ptr_access(
min_ptr,
Size::from_bytes(dist.unsigned_abs()),
Align::ONE,
CheckInAllocMsg::OffsetFromTest,
)?;

Expand Down Expand Up @@ -581,10 +580,9 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
// pointers to be properly aligned (unlike a read/write operation).
let min_ptr = if offset_bytes >= 0 { ptr } else { offset_ptr };
// This call handles checking for integer/null pointers.
self.check_ptr_access_align(
self.check_ptr_access(
min_ptr,
Size::from_bytes(offset_bytes.unsigned_abs()),
Align::ONE,
CheckInAllocMsg::PointerArithmeticTest,
)?;
Ok(offset_ptr)
Expand Down Expand Up @@ -613,7 +611,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
let src = self.read_pointer(src)?;
let dst = self.read_pointer(dst)?;

self.mem_copy(src, align, dst, align, size, nonoverlapping)
self.check_ptr_align(src, align)?;
self.check_ptr_align(dst, align)?;

self.mem_copy(src, dst, size, nonoverlapping)
}

pub(crate) fn write_bytes_intrinsic(
Expand Down Expand Up @@ -669,7 +670,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
size|
-> InterpResult<'tcx, &[u8]> {
let ptr = this.read_pointer(op)?;
let Some(alloc_ref) = self.get_ptr_alloc(ptr, size, Align::ONE)? else {
let Some(alloc_ref) = self.get_ptr_alloc(ptr, size)? else {
// zero-sized access
return Ok(&[]);
};
Expand Down
129 changes: 48 additions & 81 deletions compiler/rustc_const_eval/src/interpret/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,14 +258,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
None => self.get_alloc_raw(alloc_id)?.size(),
};
// This will also call the access hooks.
self.mem_copy(
ptr,
Align::ONE,
new_ptr.into(),
Align::ONE,
old_size.min(new_size),
/*nonoverlapping*/ true,
)?;
self.mem_copy(ptr, new_ptr.into(), old_size.min(new_size), /*nonoverlapping*/ true)?;
self.deallocate_ptr(ptr, old_size_and_align, kind)?;

Ok(new_ptr)
Expand Down Expand Up @@ -367,12 +360,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
&self,
ptr: Pointer<Option<M::Provenance>>,
size: Size,
align: Align,
) -> InterpResult<'tcx, Option<(AllocId, Size, M::ProvenanceExtra)>> {
self.check_and_deref_ptr(
ptr,
size,
align,
CheckInAllocMsg::MemoryAccessTest,
|alloc_id, offset, prov| {
let (size, align) = self
Expand All @@ -382,17 +373,16 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
)
}

/// Check if the given pointer points to live memory of given `size` and `align`.
/// Check if the given pointer points to live memory of the given `size`.
/// The caller can control the error message for the out-of-bounds case.
#[inline(always)]
pub fn check_ptr_access_align(
pub fn check_ptr_access(
&self,
ptr: Pointer<Option<M::Provenance>>,
size: Size,
align: Align,
msg: CheckInAllocMsg,
) -> InterpResult<'tcx> {
self.check_and_deref_ptr(ptr, size, align, msg, |alloc_id, _, _| {
self.check_and_deref_ptr(ptr, size, msg, |alloc_id, _, _| {
let (size, align) = self.get_live_alloc_size_and_align(alloc_id, msg)?;
Ok((size, align, ()))
})?;
Expand All @@ -408,7 +398,6 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
&self,
ptr: Pointer<Option<M::Provenance>>,
size: Size,
align: Align,
msg: CheckInAllocMsg,
alloc_size: impl FnOnce(
AllocId,
Expand All @@ -423,17 +412,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
if size.bytes() > 0 || addr == 0 {
throw_ub!(DanglingIntPointer(addr, msg));
}
// Must be aligned.
if M::enforce_alignment(self) && align.bytes() > 1 {
self.check_misalign(
Self::offset_misalignment(addr, align),
CheckAlignMsg::AccessedPtr,
)?;
}
None
}
Ok((alloc_id, offset, prov)) => {
let (alloc_size, alloc_align, ret_val) = alloc_size(alloc_id, offset, prov)?;
let (alloc_size, _alloc_align, ret_val) = alloc_size(alloc_id, offset, prov)?;
// Test bounds. This also ensures non-null.
// It is sufficient to check this for the end pointer. Also check for overflow!
if offset.checked_add(size, &self.tcx).map_or(true, |end| end > alloc_size) {
Expand All @@ -449,14 +431,6 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
if M::Provenance::OFFSET_IS_ADDR {
assert_ne!(ptr.addr(), Size::ZERO);
}
// Test align. Check this last; if both bounds and alignment are violated
// we want the error to be about the bounds.
if M::enforce_alignment(self) && align.bytes() > 1 {
self.check_misalign(
self.alloc_misalignment(ptr, offset, align, alloc_align),
CheckAlignMsg::AccessedPtr,
)?;
}

// We can still be zero-sized in this branch, in which case we have to
// return `None`.
Expand All @@ -465,7 +439,6 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
})
}

#[inline(always)]
pub(super) fn check_misalign(
&self,
misaligned: Option<Misalignment>,
Expand All @@ -477,54 +450,55 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
Ok(())
}

#[must_use]
fn offset_misalignment(offset: u64, align: Align) -> Option<Misalignment> {
if offset % align.bytes() == 0 {
None
} else {
// The biggest power of two through which `offset` is divisible.
let offset_pow2 = 1 << offset.trailing_zeros();
Some(Misalignment { has: Align::from_bytes(offset_pow2).unwrap(), required: align })
}
}

#[must_use]
fn alloc_misalignment(
pub(super) fn is_ptr_misaligned(
&self,
ptr: Pointer<Option<M::Provenance>>,
offset: Size,
align: Align,
alloc_align: Align,
) -> Option<Misalignment> {
if M::use_addr_for_alignment_check(self) {
// `use_addr_for_alignment_check` can only be true if `OFFSET_IS_ADDR` is true.
Self::offset_misalignment(ptr.addr().bytes(), align)
} else {
// Check allocation alignment and offset alignment.
if alloc_align.bytes() < align.bytes() {
Some(Misalignment { has: alloc_align, required: align })
if !M::enforce_alignment(self) || align.bytes() == 1 {
return None;
}

#[inline]
fn offset_misalignment(offset: u64, align: Align) -> Option<Misalignment> {
if offset % align.bytes() == 0 {
None
} else {
Self::offset_misalignment(offset.bytes(), align)
// The biggest power of two through which `offset` is divisible.
let offset_pow2 = 1 << offset.trailing_zeros();
Some(Misalignment { has: Align::from_bytes(offset_pow2).unwrap(), required: align })
}
}
}

pub(super) fn is_ptr_misaligned(
&self,
ptr: Pointer<Option<M::Provenance>>,
align: Align,
) -> Option<Misalignment> {
if !M::enforce_alignment(self) {
return None;
}
match self.ptr_try_get_alloc_id(ptr) {
Err(addr) => Self::offset_misalignment(addr, align),
Err(addr) => offset_misalignment(addr, align),
Ok((alloc_id, offset, _prov)) => {
let (_size, alloc_align, _kind) = self.get_alloc_info(alloc_id);
self.alloc_misalignment(ptr, offset, align, alloc_align)
if M::use_addr_for_alignment_check(self) {
// `use_addr_for_alignment_check` can only be true if `OFFSET_IS_ADDR` is true.
offset_misalignment(ptr.addr().bytes(), align)
} else {
// Check allocation alignment and offset alignment.
if alloc_align.bytes() < align.bytes() {
Some(Misalignment { has: alloc_align, required: align })
} else {
offset_misalignment(offset.bytes(), align)
}
}
}
}
}

/// Checks a pointer for misalignment.
///
/// The error assumes this is checking the pointer used directly for an access.
pub fn check_ptr_align(
&self,
ptr: Pointer<Option<M::Provenance>>,
align: Align,
) -> InterpResult<'tcx> {
self.check_misalign(self.is_ptr_misaligned(ptr, align), CheckAlignMsg::AccessedPtr)
}
}

/// Allocation accessors
Expand Down Expand Up @@ -629,18 +603,16 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
}
}

/// "Safe" (bounds and align-checked) allocation access.
/// Bounds-checked *but not align-checked* allocation access.
pub fn get_ptr_alloc<'a>(
&'a self,
ptr: Pointer<Option<M::Provenance>>,
size: Size,
align: Align,
) -> InterpResult<'tcx, Option<AllocRef<'a, 'tcx, M::Provenance, M::AllocExtra, M::Bytes>>>
{
let ptr_and_alloc = self.check_and_deref_ptr(
ptr,
size,
align,
CheckInAllocMsg::MemoryAccessTest,
|alloc_id, offset, prov| {
let alloc = self.get_alloc_raw(alloc_id)?;
Expand Down Expand Up @@ -701,15 +673,14 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
Ok((alloc, &mut self.machine))
}

/// "Safe" (bounds and align-checked) allocation access.
/// Bounds-checked *but not align-checked* allocation access.
pub fn get_ptr_alloc_mut<'a>(
&'a mut self,
ptr: Pointer<Option<M::Provenance>>,
size: Size,
align: Align,
) -> InterpResult<'tcx, Option<AllocRefMut<'a, 'tcx, M::Provenance, M::AllocExtra, M::Bytes>>>
{
let parts = self.get_ptr_access(ptr, size, align)?;
let parts = self.get_ptr_access(ptr, size)?;
if let Some((alloc_id, offset, prov)) = parts {
let tcx = *self.tcx;
// FIXME: can we somehow avoid looking up the allocation twice here?
Expand Down Expand Up @@ -1066,7 +1037,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
ptr: Pointer<Option<M::Provenance>>,
size: Size,
) -> InterpResult<'tcx, &[u8]> {
let Some(alloc_ref) = self.get_ptr_alloc(ptr, size, Align::ONE)? else {
let Some(alloc_ref) = self.get_ptr_alloc(ptr, size)? else {
// zero-sized access
return Ok(&[]);
};
Expand All @@ -1092,7 +1063,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
assert_eq!(lower, len, "can only write iterators with a precise length");

let size = Size::from_bytes(len);
let Some(alloc_ref) = self.get_ptr_alloc_mut(ptr, size, Align::ONE)? else {
let Some(alloc_ref) = self.get_ptr_alloc_mut(ptr, size)? else {
// zero-sized access
assert_matches!(src.next(), None, "iterator said it was empty but returned an element");
return Ok(());
Expand All @@ -1117,29 +1088,25 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
pub fn mem_copy(
&mut self,
src: Pointer<Option<M::Provenance>>,
src_align: Align,
dest: Pointer<Option<M::Provenance>>,
dest_align: Align,
size: Size,
nonoverlapping: bool,
) -> InterpResult<'tcx> {
self.mem_copy_repeatedly(src, src_align, dest, dest_align, size, 1, nonoverlapping)
self.mem_copy_repeatedly(src, dest, size, 1, nonoverlapping)
}

pub fn mem_copy_repeatedly(
&mut self,
src: Pointer<Option<M::Provenance>>,
src_align: Align,
dest: Pointer<Option<M::Provenance>>,
dest_align: Align,
size: Size,
num_copies: u64,
nonoverlapping: bool,
) -> InterpResult<'tcx> {
let tcx = self.tcx;
// We need to do our own bounds-checks.
let src_parts = self.get_ptr_access(src, size, src_align)?;
let dest_parts = self.get_ptr_access(dest, size * num_copies, dest_align)?; // `Size` multiplication
let src_parts = self.get_ptr_access(src, size)?;
let dest_parts = self.get_ptr_access(dest, size * num_copies)?; // `Size` multiplication

// FIXME: we look up both allocations twice here, once before for the `check_ptr_access`
// and once below to get the underlying `&[mut] Allocation`.
Expand Down
13 changes: 3 additions & 10 deletions compiler/rustc_const_eval/src/interpret/place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ where
.unwrap_or((mplace.layout.size, mplace.layout.align.abi));
// We check alignment separately, and *after* checking everything else.
// If an access is both OOB and misaligned, we want to see the bounds error.
let a = self.get_ptr_alloc(mplace.ptr(), size, Align::ONE)?;
let a = self.get_ptr_alloc(mplace.ptr(), size)?;
self.check_misalign(mplace.mplace.misaligned, CheckAlignMsg::BasedOn)?;
Ok(a)
}
Expand All @@ -478,7 +478,7 @@ where
// If an access is both OOB and misaligned, we want to see the bounds error.
// However we have to call `check_misalign` first to make the borrow checker happy.
let misalign_err = self.check_misalign(mplace.mplace.misaligned, CheckAlignMsg::BasedOn);
let a = self.get_ptr_alloc_mut(mplace.ptr(), size, Align::ONE)?;
let a = self.get_ptr_alloc_mut(mplace.ptr(), size)?;
misalign_err?;
Ok(a)
}
Expand Down Expand Up @@ -873,14 +873,7 @@ where
// non-overlapping.)
// We check alignment separately, and *after* checking everything else.
// If an access is both OOB and misaligned, we want to see the bounds error.
self.mem_copy(
src.ptr(),
Align::ONE,
dest.ptr(),
Align::ONE,
dest_size,
/*nonoverlapping*/ true,
)?;
self.mem_copy(src.ptr(), dest.ptr(), dest_size, /*nonoverlapping*/ true)?;
self.check_misalign(src.mplace.misaligned, CheckAlignMsg::BasedOn)?;
self.check_misalign(dest.mplace.misaligned, CheckAlignMsg::BasedOn)?;
Ok(())
Expand Down
Loading

0 comments on commit b131fc1

Please sign in to comment.