Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Insert alignment checks for pointer dereferences when debug assertions are enabled #98112

Merged
merged 8 commits into from
Mar 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions compiler/rustc_codegen_cranelift/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,18 @@ fn codegen_fn_body(fx: &mut FunctionCx<'_, '_, '_>, start_block: Block) {
source_info.span,
);
}
AssertKind::MisalignedPointerDereference { ref required, ref found } => {
let required = codegen_operand(fx, required).load_scalar(fx);
let found = codegen_operand(fx, found).load_scalar(fx);
let location = fx.get_caller_location(source_info).load_scalar(fx);

codegen_panic_inner(
fx,
rustc_hir::LangItem::PanicBoundsCheck,
&[required, found, location],
source_info.span,
);
}
_ => {
let msg_str = msg.description();
codegen_panic(fx, msg_str, source_info);
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_codegen_ssa/src/mir/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,13 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
// and `#[track_caller]` adds an implicit third argument.
(LangItem::PanicBoundsCheck, vec![index, len, location])
}
AssertKind::MisalignedPointerDereference { ref required, ref found } => {
let required = self.codegen_operand(bx, required).immediate();
let found = self.codegen_operand(bx, found).immediate();
// It's `fn panic_bounds_check(index: usize, len: usize)`,
// and `#[track_caller]` adds an implicit third argument.
(LangItem::PanicMisalignedPointerDereference, vec![required, found, location])
}
_ => {
let msg = bx.const_str(msg.description());
// It's `pub fn panic(expr: &str)`, with the wide reference being passed
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_const_eval/src/const_eval/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,12 @@ impl<'mir, 'tcx> interpret::Machine<'mir, 'tcx> for CompileTimeInterpreter<'mir,
RemainderByZero(op) => RemainderByZero(eval_to_int(op)?),
ResumedAfterReturn(generator_kind) => ResumedAfterReturn(*generator_kind),
ResumedAfterPanic(generator_kind) => ResumedAfterPanic(*generator_kind),
MisalignedPointerDereference { ref required, ref found } => {
MisalignedPointerDereference {
required: eval_to_int(required)?,
found: eval_to_int(found)?,
}
}
};
Err(ConstEvalErrKind::AssertFailure(err).into())
}
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_hir/src/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ language_item_table! {
PanicDisplay, sym::panic_display, panic_display, Target::Fn, GenericRequirement::None;
ConstPanicFmt, sym::const_panic_fmt, const_panic_fmt, Target::Fn, GenericRequirement::None;
PanicBoundsCheck, sym::panic_bounds_check, panic_bounds_check_fn, Target::Fn, GenericRequirement::Exact(0);
PanicMisalignedPointerDereference, sym::panic_misaligned_pointer_dereference, panic_misaligned_pointer_dereference_fn, Target::Fn, GenericRequirement::Exact(0);
PanicInfo, sym::panic_info, panic_info, Target::Struct, GenericRequirement::None;
PanicLocation, sym::panic_location, panic_location, Target::Struct, GenericRequirement::None;
PanicImpl, sym::panic_impl, panic_impl, Target::Fn, GenericRequirement::None;
Expand Down
20 changes: 18 additions & 2 deletions compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1277,7 +1277,7 @@ impl<O> AssertKind<O> {

/// Getting a description does not require `O` to be printable, and does not
/// require allocation.
/// The caller is expected to handle `BoundsCheck` separately.
/// The caller is expected to handle `BoundsCheck` and `MisalignedPointerDereference` separately.
pub fn description(&self) -> &'static str {
use AssertKind::*;
match self {
Expand All @@ -1296,7 +1296,9 @@ impl<O> AssertKind<O> {
ResumedAfterReturn(GeneratorKind::Async(_)) => "`async fn` resumed after completion",
ResumedAfterPanic(GeneratorKind::Gen) => "generator resumed after panicking",
ResumedAfterPanic(GeneratorKind::Async(_)) => "`async fn` resumed after panicking",
BoundsCheck { .. } => bug!("Unexpected AssertKind"),
BoundsCheck { .. } | MisalignedPointerDereference { .. } => {
bug!("Unexpected AssertKind")
}
}
}

Expand Down Expand Up @@ -1353,6 +1355,13 @@ impl<O> AssertKind<O> {
Overflow(BinOp::Shl, _, r) => {
write!(f, "\"attempt to shift left by `{{}}`, which would overflow\", {:?}", r)
}
MisalignedPointerDereference { required, found } => {
write!(
f,
"\"misaligned pointer dereference: address must be a multiple of {{}} but is {{}}\", {:?}, {:?}",
required, found
)
}
_ => write!(f, "\"{}\"", self.description()),
}
}
Expand Down Expand Up @@ -1397,6 +1406,13 @@ impl<O: fmt::Debug> fmt::Debug for AssertKind<O> {
Overflow(BinOp::Shl, _, r) => {
write!(f, "attempt to shift left by `{:#?}`, which would overflow", r)
}
MisalignedPointerDereference { required, found } => {
write!(
f,
"misaligned pointer dereference: address must be a multiple of {:?} but is {:?}",
required, found
)
}
_ => write!(f, "{}", self.description()),
}
}
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/mir/syntax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,7 @@ pub enum AssertKind<O> {
RemainderByZero(O),
ResumedAfterReturn(GeneratorKind),
ResumedAfterPanic(GeneratorKind),
MisalignedPointerDereference { required: O, found: O },
}

#[derive(Clone, Debug, PartialEq, TyEncodable, TyDecodable, Hash, HashStable)]
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_middle/src/mir/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,10 @@ macro_rules! make_mir_visitor {
ResumedAfterReturn(_) | ResumedAfterPanic(_) => {
// Nothing to visit
}
MisalignedPointerDereference { required, found } => {
self.visit_operand(required, location);
self.visit_operand(found, location);
}
}
}

Expand Down
227 changes: 227 additions & 0 deletions compiler/rustc_mir_transform/src/check_alignment.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
use crate::MirPass;
use rustc_hir::def_id::DefId;
use rustc_index::vec::IndexVec;
use rustc_middle::mir::*;
use rustc_middle::mir::{
interpret::{ConstValue, Scalar},
visit::{PlaceContext, Visitor},
};
use rustc_middle::ty::{Ty, TyCtxt, TypeAndMut};
use rustc_session::Session;

pub struct CheckAlignment;

impl<'tcx> MirPass<'tcx> for CheckAlignment {
fn is_enabled(&self, sess: &Session) -> bool {
sess.opts.debug_assertions
}

fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let basic_blocks = body.basic_blocks.as_mut();
let local_decls = &mut body.local_decls;

for block in (0..basic_blocks.len()).rev() {
let block = block.into();
for statement_index in (0..basic_blocks[block].statements.len()).rev() {
let location = Location { block, statement_index };
let statement = &basic_blocks[block].statements[statement_index];
let source_info = statement.source_info;

let mut finder = PointerFinder {
local_decls,
tcx,
pointers: Vec::new(),
def_id: body.source.def_id(),
};
for (pointer, pointee_ty) in finder.find_pointers(statement) {
debug!("Inserting alignment check for {:?}", pointer.ty(&*local_decls, tcx).ty);

let new_block = split_block(basic_blocks, location);
insert_alignment_check(
tcx,
local_decls,
&mut basic_blocks[block],
pointer,
pointee_ty,
source_info,
new_block,
);
}
}
}
}
}

impl<'tcx, 'a> PointerFinder<'tcx, 'a> {
fn find_pointers(&mut self, statement: &Statement<'tcx>) -> Vec<(Place<'tcx>, Ty<'tcx>)> {
self.pointers.clear();
self.visit_statement(statement, Location::START);
core::mem::take(&mut self.pointers)
}
}

struct PointerFinder<'tcx, 'a> {
local_decls: &'a mut LocalDecls<'tcx>,
tcx: TyCtxt<'tcx>,
def_id: DefId,
pointers: Vec<(Place<'tcx>, Ty<'tcx>)>,
}

impl<'tcx, 'a> Visitor<'tcx> for PointerFinder<'tcx, 'a> {
fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) {
if let PlaceContext::NonUse(_) = context {
return;
}
if !place.is_indirect() {
return;
}

let pointer = Place::from(place.local);
let pointer_ty = pointer.ty(&*self.local_decls, self.tcx).ty;

// We only want to check unsafe pointers
if !pointer_ty.is_unsafe_ptr() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this do the wrong thing for something like **x with x: &*const i32? Here x is safe but we still deref a raw ptr that needs checking.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at this stage there are no places like **x. They have all been converted to

tmp = *x;
*tmp

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wow, that is a very subtle non-local invariant. At the very least the code needs a detailed comment explaining that. Also we better make sure some check (MIR validation?) ICEs if that invariant is ever violated.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we already have a MIR validation that only the first projection may be a Deref. I have tripped it often working on MIR opts:

&& place.projection[1..].contains(&ProjectionElem::Deref)

trace!("Indirect, but not an unsafe ptr, not checking {:?}", pointer_ty);
return;
}

let Some(pointee) = pointer_ty.builtin_deref(true) else {
debug!("Indirect but no builtin deref: {:?}", pointer_ty);
return;
};
let mut pointee_ty = pointee.ty;
if pointee_ty.is_array() || pointee_ty.is_slice() || pointee_ty.is_str() {
oli-obk marked this conversation as resolved.
Show resolved Hide resolved
pointee_ty = pointee_ty.sequence_element_type(self.tcx);
}

if !pointee_ty.is_sized(self.tcx, self.tcx.param_env_reveal_all_normalized(self.def_id)) {
debug!("Unsafe pointer, but unsized: {:?}", pointer_ty);
return;
}

if [self.tcx.types.bool, self.tcx.types.i8, self.tcx.types.u8, self.tcx.types.str_]
.contains(&pointee_ty)
{
debug!("Trivially aligned pointee type: {:?}", pointer_ty);
return;
}

self.pointers.push((pointer, pointee_ty))
}
}

fn split_block(
basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>,
location: Location,
) -> BasicBlock {
let block_data = &mut basic_blocks[location.block];

// Drain every statement after this one and move the current terminator to a new basic block
let new_block = BasicBlockData {
statements: block_data.statements.split_off(location.statement_index),
terminator: block_data.terminator.take(),
is_cleanup: block_data.is_cleanup,
};

basic_blocks.push(new_block)
}

fn insert_alignment_check<'tcx>(
tcx: TyCtxt<'tcx>,
local_decls: &mut LocalDecls<'tcx>,
block_data: &mut BasicBlockData<'tcx>,
pointer: Place<'tcx>,
pointee_ty: Ty<'tcx>,
source_info: SourceInfo,
new_block: BasicBlock,
) {
// Cast the pointer to a *const ()
let const_raw_ptr = tcx.mk_ptr(TypeAndMut { ty: tcx.types.unit, mutbl: Mutability::Not });
let rvalue = Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(pointer), const_raw_ptr);
let thin_ptr = local_decls.push(LocalDecl::with_source_info(const_raw_ptr, source_info)).into();
block_data
.statements
.push(Statement { source_info, kind: StatementKind::Assign(Box::new((thin_ptr, rvalue))) });

// Transmute the pointer to a usize (equivalent to `ptr.addr()`)
let rvalue = Rvalue::Cast(CastKind::Transmute, Operand::Copy(thin_ptr), tcx.types.usize);
let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
block_data
.statements
.push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) });

// Get the alignment of the pointee
let alignment =
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
let rvalue = Rvalue::NullaryOp(NullOp::AlignOf, pointee_ty);
block_data.statements.push(Statement {
source_info,
kind: StatementKind::Assign(Box::new((alignment, rvalue))),
});

// Subtract 1 from the alignment to get the alignment mask
let alignment_mask =
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
let one = Operand::Constant(Box::new(Constant {
span: source_info.span,
user_ty: None,
literal: ConstantKind::Val(
ConstValue::Scalar(Scalar::from_target_usize(1, &tcx)),
tcx.types.usize,
),
}));
block_data.statements.push(Statement {
source_info,
kind: StatementKind::Assign(Box::new((
alignment_mask,
Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(alignment), one))),
))),
});

// BitAnd the alignment mask with the pointer
scottmcm marked this conversation as resolved.
Show resolved Hide resolved
let alignment_bits =
local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into();
block_data.statements.push(Statement {
source_info,
kind: StatementKind::Assign(Box::new((
alignment_bits,
Rvalue::BinaryOp(
BinOp::BitAnd,
Box::new((Operand::Copy(addr), Operand::Copy(alignment_mask))),
),
))),
});

// Check if the alignment bits are all zero
let is_ok = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into();
let zero = Operand::Constant(Box::new(Constant {
span: source_info.span,
user_ty: None,
literal: ConstantKind::Val(
ConstValue::Scalar(Scalar::from_target_usize(0, &tcx)),
tcx.types.usize,
),
}));
block_data.statements.push(Statement {
source_info,
kind: StatementKind::Assign(Box::new((
is_ok,
Rvalue::BinaryOp(BinOp::Eq, Box::new((Operand::Copy(alignment_bits), zero.clone()))),
))),
});

// Set this block's terminator to our assert, continuing to new_block if we pass
block_data.terminator = Some(Terminator {
source_info,
kind: TerminatorKind::Assert {
cond: Operand::Copy(is_ok),
expected: true,
target: new_block,
msg: AssertKind::MisalignedPointerDereference {
required: Operand::Copy(alignment),
found: Operand::Copy(addr),
},
cleanup: None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discovered this while rebasing my PR #102906. I don't think None here is correct (which translate to UnwindAction::Continue after my PR). This will cause any destructors on the stack to be skipped if the pointer is misaligned.

I feel that this check should be inserted during MIR build, or, this pass needs to be run before drop elaboration.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ideal outcome is to produce an aborting panic here, because this check can add surprising new unwinding sites and make code that was unwind-safe not unwind-safe. Unwinding can be a significantly worse outcome than an unaligned read.

With your PR, is it possible to make this abort?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. I can just change it to UnwindAction::Terminate which will cause an abort.

},
});
}
2 changes: 2 additions & 0 deletions compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ mod separate_const_switch;
mod shim;
mod ssa;
// This pass is public to allow external drivers to perform MIR cleanup
mod check_alignment;
pub mod simplify;
mod simplify_branches;
mod simplify_comparison_integral;
Expand Down Expand Up @@ -546,6 +547,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
tcx,
body,
&[
&check_alignment::CheckAlignment,
&reveal_all::RevealAll, // has to be done before inlining, since inlined code is in RevealAll mode.
&lower_slice_len::LowerSliceLenCalls, // has to be done before inlining, otherwise actual call will be almost always inlined. Also simple, so can just do first
&unreachable_prop::UnreachablePropagation,
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,7 @@ symbols! {
panic_implementation,
panic_info,
panic_location,
panic_misaligned_pointer_dereference,
panic_nounwind,
panic_runtime,
panic_str,
Expand Down
14 changes: 14 additions & 0 deletions library/core/src/panicking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,20 @@ fn panic_bounds_check(index: usize, len: usize) -> ! {
panic!("index out of bounds: the len is {len} but the index is {index}")
}

#[cold]
#[cfg_attr(not(feature = "panic_immediate_abort"), inline(never))]
#[track_caller]
#[cfg_attr(not(bootstrap), lang = "panic_misaligned_pointer_dereference")] // needed by codegen for panic on misaligned pointer deref
fn panic_misaligned_pointer_dereference(required: usize, found: usize) -> ! {
oli-obk marked this conversation as resolved.
Show resolved Hide resolved
if cfg!(feature = "panic_immediate_abort") {
super::intrinsics::abort()
}

panic!(
"misaligned pointer dereference: address must be a multiple of {required:#x} but is {found:#x}"
)
}

/// Panic because we cannot unwind out of a function.
///
/// This function is called directly by the codegen backend, and must not have
Expand Down
Loading