Skip to content

Commit

Permalink
Auto merge of #98332 - oli-obk:assume, r=wesleywiser
Browse files Browse the repository at this point in the history
Lower the assume intrinsic to a MIR statement

This makes #96862 (comment) easier and will generally allow us to cheaply insert assume intrinsic calls in mir building.

r? rust-lang/wg-mir-opt
  • Loading branch information
bors committed Sep 7, 2022
2 parents 0568b0a + a0130e6 commit e7c7aa7
Show file tree
Hide file tree
Showing 37 changed files with 338 additions and 570 deletions.
2 changes: 1 addition & 1 deletion compiler/rustc_borrowck/src/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ impl<'tcx> rustc_mir_dataflow::GenKillAnalysis<'tcx> for Borrows<'_, 'tcx> {
| mir::StatementKind::Retag { .. }
| mir::StatementKind::AscribeUserType(..)
| mir::StatementKind::Coverage(..)
| mir::StatementKind::CopyNonOverlapping(..)
| mir::StatementKind::Intrinsic(..)
| mir::StatementKind::Nop => {}
}
}
Expand Down
26 changes: 15 additions & 11 deletions compiler/rustc_borrowck/src/invalidation.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use rustc_data_structures::graph::dominators::Dominators;
use rustc_middle::mir::visit::Visitor;
use rustc_middle::mir::{BasicBlock, Body, Location, Place, Rvalue};
use rustc_middle::mir::{self, BasicBlock, Body, Location, NonDivergingIntrinsic, Place, Rvalue};
use rustc_middle::mir::{BorrowKind, Mutability, Operand};
use rustc_middle::mir::{InlineAsmOperand, Terminator, TerminatorKind};
use rustc_middle::mir::{Statement, StatementKind};
Expand Down Expand Up @@ -63,23 +63,24 @@ impl<'cx, 'tcx> Visitor<'tcx> for InvalidationGenerator<'cx, 'tcx> {
StatementKind::FakeRead(box (_, _)) => {
// Only relevant for initialized/liveness/safety checks.
}
StatementKind::CopyNonOverlapping(box rustc_middle::mir::CopyNonOverlapping {
StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(op)) => {
self.consume_operand(location, op);
}
StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(mir::CopyNonOverlapping {
ref src,
ref dst,
ref count,
}) => {
})) => {
self.consume_operand(location, src);
self.consume_operand(location, dst);
self.consume_operand(location, count);
}
StatementKind::Nop
// Only relevant for mir typeck
StatementKind::AscribeUserType(..)
// Doesn't have any language semantics
| StatementKind::Coverage(..)
| StatementKind::AscribeUserType(..)
| StatementKind::Retag { .. }
| StatementKind::StorageLive(..) => {
// `Nop`, `AscribeUserType`, `Retag`, and `StorageLive` are irrelevant
// to borrow check.
}
// Does not actually affect borrowck
| StatementKind::StorageLive(..) => {}
StatementKind::StorageDead(local) => {
self.access_place(
location,
Expand All @@ -88,7 +89,10 @@ impl<'cx, 'tcx> Visitor<'tcx> for InvalidationGenerator<'cx, 'tcx> {
LocalMutationIsAllowed::Yes,
);
}
StatementKind::Deinit(..) | StatementKind::SetDiscriminant { .. } => {
StatementKind::Nop
| StatementKind::Retag { .. }
| StatementKind::Deinit(..)
| StatementKind::SetDiscriminant { .. } => {
bug!("Statement not allowed in this MIR phase")
}
}
Expand Down
28 changes: 14 additions & 14 deletions compiler/rustc_borrowck/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ use rustc_index::bit_set::ChunkedBitSet;
use rustc_index::vec::IndexVec;
use rustc_infer::infer::{DefiningAnchor, InferCtxt, TyCtxtInferExt};
use rustc_middle::mir::{
traversal, Body, ClearCrossCrate, Local, Location, Mutability, Operand, Place, PlaceElem,
PlaceRef, VarDebugInfoContents,
traversal, Body, ClearCrossCrate, Local, Location, Mutability, NonDivergingIntrinsic, Operand,
Place, PlaceElem, PlaceRef, VarDebugInfoContents,
};
use rustc_middle::mir::{AggregateKind, BasicBlock, BorrowCheckResult, BorrowKind};
use rustc_middle::mir::{Field, ProjectionElem, Promoted, Rvalue, Statement, StatementKind};
Expand Down Expand Up @@ -591,22 +591,19 @@ impl<'cx, 'tcx> rustc_mir_dataflow::ResultsVisitor<'cx, 'tcx> for MirBorrowckCtx
flow_state,
);
}
StatementKind::CopyNonOverlapping(box rustc_middle::mir::CopyNonOverlapping {
..
}) => {
span_bug!(
StatementKind::Intrinsic(box ref kind) => match kind {
NonDivergingIntrinsic::Assume(op) => self.consume_operand(location, (op, span), flow_state),
NonDivergingIntrinsic::CopyNonOverlapping(..) => span_bug!(
span,
"Unexpected CopyNonOverlapping, should only appear after lower_intrinsics",
)
}
StatementKind::Nop
// Only relevant for mir typeck
StatementKind::AscribeUserType(..)
// Doesn't have any language semantics
| StatementKind::Coverage(..)
| StatementKind::AscribeUserType(..)
| StatementKind::Retag { .. }
| StatementKind::StorageLive(..) => {
// `Nop`, `AscribeUserType`, `Retag`, and `StorageLive` are irrelevant
// to borrow check.
}
// Does not actually affect borrowck
| StatementKind::StorageLive(..) => {}
StatementKind::StorageDead(local) => {
self.access_place(
location,
Expand All @@ -616,7 +613,10 @@ impl<'cx, 'tcx> rustc_mir_dataflow::ResultsVisitor<'cx, 'tcx> for MirBorrowckCtx
flow_state,
);
}
StatementKind::Deinit(..) | StatementKind::SetDiscriminant { .. } => {
StatementKind::Nop
| StatementKind::Retag { .. }
| StatementKind::Deinit(..)
| StatementKind::SetDiscriminant { .. } => {
bug!("Statement not allowed in this MIR phase")
}
}
Expand Down
13 changes: 7 additions & 6 deletions compiler/rustc_borrowck/src/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1302,12 +1302,13 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
);
}
}
StatementKind::CopyNonOverlapping(box rustc_middle::mir::CopyNonOverlapping {
..
}) => span_bug!(
stmt.source_info.span,
"Unexpected StatementKind::CopyNonOverlapping, should only appear after lowering_intrinsics",
),
StatementKind::Intrinsic(box ref kind) => match kind {
NonDivergingIntrinsic::Assume(op) => self.check_operand(op, location),
NonDivergingIntrinsic::CopyNonOverlapping(..) => span_bug!(
stmt.source_info.span,
"Unexpected NonDivergingIntrinsic::CopyNonOverlapping, should only appear after lowering_intrinsics",
),
},
StatementKind::FakeRead(..)
| StatementKind::StorageLive(..)
| StatementKind::StorageDead(..)
Expand Down
39 changes: 25 additions & 14 deletions compiler/rustc_codegen_cranelift/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -794,20 +794,31 @@ fn codegen_stmt<'tcx>(
| StatementKind::AscribeUserType(..) => {}

StatementKind::Coverage { .. } => fx.tcx.sess.fatal("-Zcoverage is unimplemented"),
StatementKind::CopyNonOverlapping(inner) => {
let dst = codegen_operand(fx, &inner.dst);
let pointee = dst
.layout()
.pointee_info_at(fx, rustc_target::abi::Size::ZERO)
.expect("Expected pointer");
let dst = dst.load_scalar(fx);
let src = codegen_operand(fx, &inner.src).load_scalar(fx);
let count = codegen_operand(fx, &inner.count).load_scalar(fx);
let elem_size: u64 = pointee.size.bytes();
let bytes =
if elem_size != 1 { fx.bcx.ins().imul_imm(count, elem_size as i64) } else { count };
fx.bcx.call_memcpy(fx.target_config, dst, src, bytes);
}
StatementKind::Intrinsic(ref intrinsic) => match &**intrinsic {
// We ignore `assume` intrinsics, they are only useful for optimizations
NonDivergingIntrinsic::Assume(_) => {}
NonDivergingIntrinsic::CopyNonOverlapping(mir::CopyNonOverlapping {
src,
dst,
count,
}) => {
let dst = codegen_operand(fx, dst);
let pointee = dst
.layout()
.pointee_info_at(fx, rustc_target::abi::Size::ZERO)
.expect("Expected pointer");
let dst = dst.load_scalar(fx);
let src = codegen_operand(fx, src).load_scalar(fx);
let count = codegen_operand(fx, count).load_scalar(fx);
let elem_size: u64 = pointee.size.bytes();
let bytes = if elem_size != 1 {
fx.bcx.ins().imul_imm(count, elem_size as i64)
} else {
count
};
fx.bcx.call_memcpy(fx.target_config, dst, src, bytes);
}
},
}
}

Expand Down
8 changes: 5 additions & 3 deletions compiler/rustc_codegen_cranelift/src/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -536,9 +536,11 @@ pub(crate) fn mir_operand_get_const_val<'tcx>(
{
return None;
}
StatementKind::CopyNonOverlapping(_) => {
return None;
} // conservative handling
StatementKind::Intrinsic(ref intrinsic) => match **intrinsic {
NonDivergingIntrinsic::CopyNonOverlapping(..) => return None,
NonDivergingIntrinsic::Assume(..) => {}
},
// conservative handling
StatementKind::Assign(_)
| StatementKind::FakeRead(_)
| StatementKind::SetDiscriminant { .. }
Expand Down
3 changes: 0 additions & 3 deletions compiler/rustc_codegen_cranelift/src/intrinsics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,9 +357,6 @@ fn codegen_regular_intrinsic_call<'tcx>(
let usize_layout = fx.layout_of(fx.tcx.types.usize);

match intrinsic {
sym::assume => {
intrinsic_args!(fx, args => (_a); intrinsic);
}
sym::likely | sym::unlikely => {
intrinsic_args!(fx, args => (a); intrinsic);

Expand Down
4 changes: 0 additions & 4 deletions compiler/rustc_codegen_ssa/src/mir/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,6 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let result = PlaceRef::new_sized(llresult, fn_abi.ret.layout);

let llval = match name {
sym::assume => {
bx.assume(args[0].immediate());
return;
}
sym::abort => {
bx.abort();
return;
Expand Down
14 changes: 9 additions & 5 deletions compiler/rustc_codegen_ssa/src/mir/statement.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use rustc_middle::mir;
use rustc_middle::mir::NonDivergingIntrinsic;

use super::FunctionCx;
use super::LocalRef;
Expand Down Expand Up @@ -73,11 +74,14 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
self.codegen_coverage(&mut bx, coverage.clone(), statement.source_info.scope);
bx
}
mir::StatementKind::CopyNonOverlapping(box mir::CopyNonOverlapping {
ref src,
ref dst,
ref count,
}) => {
mir::StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(ref op)) => {
let op_val = self.codegen_operand(&mut bx, op);
bx.assume(op_val.immediate());
bx
}
mir::StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(
mir::CopyNonOverlapping { ref count, ref src, ref dst },
)) => {
let dst_val = self.codegen_operand(&mut bx, dst);
let src_val = self.codegen_operand(&mut bx, src);
let count = self.codegen_operand(&mut bx, count).immediate();
Expand Down
34 changes: 27 additions & 7 deletions compiler/rustc_const_eval/src/interpret/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use rustc_hir::def_id::DefId;
use rustc_middle::mir::{
self,
interpret::{ConstValue, GlobalId, InterpResult, PointerArithmetic, Scalar},
BinOp,
BinOp, NonDivergingIntrinsic,
};
use rustc_middle::ty;
use rustc_middle::ty::layout::LayoutOf as _;
Expand Down Expand Up @@ -506,12 +506,6 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
// These just return their argument
self.copy_op(&args[0], dest, /*allow_transmute*/ false)?;
}
sym::assume => {
let cond = self.read_scalar(&args[0])?.to_bool()?;
if !cond {
throw_ub_format!("`assume` intrinsic called with `false`");
}
}
sym::raw_eq => {
let result = self.raw_eq_intrinsic(&args[0], &args[1])?;
self.write_scalar(result, dest)?;
Expand All @@ -536,6 +530,32 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
Ok(true)
}

pub(super) fn emulate_nondiverging_intrinsic(
&mut self,
intrinsic: &NonDivergingIntrinsic<'tcx>,
) -> InterpResult<'tcx> {
match intrinsic {
NonDivergingIntrinsic::Assume(op) => {
let op = self.eval_operand(op, None)?;
let cond = self.read_scalar(&op)?.to_bool()?;
if !cond {
throw_ub_format!("`assume` called with `false`");
}
Ok(())
}
NonDivergingIntrinsic::CopyNonOverlapping(mir::CopyNonOverlapping {
count,
src,
dst,
}) => {
let src = self.eval_operand(src, None)?;
let dst = self.eval_operand(dst, None)?;
let count = self.eval_operand(count, None)?;
self.copy_intrinsic(&src, &dst, &count, /* nonoverlapping */ true)
}
}
}

pub fn exact_div(
&mut self,
a: &ImmTy<'tcx, M::Provenance>,
Expand Down
8 changes: 1 addition & 7 deletions compiler/rustc_const_eval/src/interpret/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
M::retag(self, *kind, &dest)?;
}

// Call CopyNonOverlapping
CopyNonOverlapping(box rustc_middle::mir::CopyNonOverlapping { src, dst, count }) => {
let src = self.eval_operand(src, None)?;
let dst = self.eval_operand(dst, None)?;
let count = self.eval_operand(count, None)?;
self.copy_intrinsic(&src, &dst, &count, /* nonoverlapping */ true)?;
}
Intrinsic(box ref intrinsic) => self.emulate_nondiverging_intrinsic(intrinsic)?,

// Statements we do not track.
AscribeUserType(..) => {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ impl<'tcx> Visitor<'tcx> for Checker<'_, 'tcx> {
| StatementKind::Retag { .. }
| StatementKind::AscribeUserType(..)
| StatementKind::Coverage(..)
| StatementKind::CopyNonOverlapping(..)
| StatementKind::Intrinsic(..)
| StatementKind::Nop => {}
}
}
Expand Down
24 changes: 16 additions & 8 deletions compiler/rustc_const_eval/src/transform/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ use rustc_middle::mir::interpret::Scalar;
use rustc_middle::mir::visit::NonUseContext::VarDebugInfo;
use rustc_middle::mir::visit::{PlaceContext, Visitor};
use rustc_middle::mir::{
traversal, AggregateKind, BasicBlock, BinOp, Body, BorrowKind, CastKind, Local, Location,
MirPass, MirPhase, Operand, Place, PlaceElem, PlaceRef, ProjectionElem, RuntimePhase, Rvalue,
SourceScope, Statement, StatementKind, Terminator, TerminatorKind, UnOp, START_BLOCK,
traversal, AggregateKind, BasicBlock, BinOp, Body, BorrowKind, CastKind, CopyNonOverlapping,
Local, Location, MirPass, MirPhase, NonDivergingIntrinsic, Operand, Place, PlaceElem, PlaceRef,
ProjectionElem, RuntimePhase, Rvalue, SourceScope, Statement, StatementKind, Terminator,
TerminatorKind, UnOp, START_BLOCK,
};
use rustc_middle::ty::fold::BottomUpFolder;
use rustc_middle::ty::subst::Subst;
Expand Down Expand Up @@ -636,11 +637,18 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> {
);
}
}
StatementKind::CopyNonOverlapping(box rustc_middle::mir::CopyNonOverlapping {
ref src,
ref dst,
ref count,
}) => {
StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(op)) => {
let ty = op.ty(&self.body.local_decls, self.tcx);
if !ty.is_bool() {
self.fail(
location,
format!("`assume` argument must be `bool`, but got: `{}`", ty),
);
}
}
StatementKind::Intrinsic(box NonDivergingIntrinsic::CopyNonOverlapping(
CopyNonOverlapping { src, dst, count },
)) => {
let src_ty = src.ty(&self.body.local_decls, self.tcx);
let op_src_ty = if let Some(src_deref) = src_ty.builtin_deref(true) {
src_deref.ty
Expand Down
8 changes: 1 addition & 7 deletions compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1362,13 +1362,7 @@ impl Debug for Statement<'_> {
write!(fmt, "Coverage::{:?} for {:?}", kind, rgn)
}
Coverage(box ref coverage) => write!(fmt, "Coverage::{:?}", coverage.kind),
CopyNonOverlapping(box crate::mir::CopyNonOverlapping {
ref src,
ref dst,
ref count,
}) => {
write!(fmt, "copy_nonoverlapping(src={:?}, dst={:?}, count={:?})", src, dst, count)
}
Intrinsic(box ref intrinsic) => write!(fmt, "{intrinsic}"),
Nop => write!(fmt, "nop"),
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/mir/spanview.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ pub fn statement_kind_name(statement: &Statement<'_>) -> &'static str {
Retag(..) => "Retag",
AscribeUserType(..) => "AscribeUserType",
Coverage(..) => "Coverage",
CopyNonOverlapping(..) => "CopyNonOverlapping",
Intrinsic(..) => "Intrinsic",
Nop => "Nop",
}
}
Expand Down
Loading

0 comments on commit e7c7aa7

Please sign in to comment.