Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable MIR inlining #91743

Merged
merged 12 commits into from
Jul 2, 2022
1 change: 1 addition & 0 deletions compiler/rustc_codegen_cranelift/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ fn codegen_stmt<'tcx>(
substs,
ty::ClosureKind::FnOnce,
)
.expect("failed to normalize and resolve closure during codegen")
.polymorphize(fx.tcx);
let func_ref = fx.get_function_ref(instance);
let func_addr = fx.bcx.ins().func_addr(fx.pointer_type, func_ref);
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_ssa/src/mir/rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
substs,
ty::ClosureKind::FnOnce,
)
.expect("failed to normalize and resolve closure during codegen")
.polymorphize(bx.cx().tcx());
OperandValue::Immediate(bx.cx().get_fn_addr(instance))
}
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_const_eval/src/interpret/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
def_id,
substs,
ty::ClosureKind::FnOnce,
);
)
.ok_or_else(|| err_inval!(TooGeneric))?;
let fn_ptr = self.create_fn_alloc_ptr(FnVal::Instance(instance));
self.write_pointer(fn_ptr, dest)?;
}
Expand Down
20 changes: 4 additions & 16 deletions compiler/rustc_const_eval/src/interpret/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,10 @@ where
let is_used = unused_params.contains(index).map_or(true, |unused| !unused);
// Only recurse when generic parameters in fns, closures and generators
// are used and require substitution.
match (is_used, subst.needs_subst()) {
// Just in case there are closures or generators within this subst,
// recurse.
(true, true) => return subst.visit_with(self),
// Confirm that polymorphization replaced the parameter with
// `ty::Param`/`ty::ConstKind::Param`.
(false, true) if cfg!(debug_assertions) => match subst.unpack() {
ty::subst::GenericArgKind::Type(ty) => {
assert!(matches!(ty.kind(), ty::Param(_)))
}
ty::subst::GenericArgKind::Const(ct) => {
assert!(matches!(ct.kind(), ty::ConstKind::Param(_)))
}
ty::subst::GenericArgKind::Lifetime(..) => (),
},
_ => {}
// Just in case there are closures or generators within this subst,
// recurse.
if is_used && subst.needs_subst() {
return subst.visit_with(self);
}
}
ControlFlow::CONTINUE
Expand Down
11 changes: 6 additions & 5 deletions compiler/rustc_middle/src/ty/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,12 +496,12 @@ impl<'tcx> Instance<'tcx> {
def_id: DefId,
substs: ty::SubstsRef<'tcx>,
requested_kind: ty::ClosureKind,
) -> Instance<'tcx> {
) -> Option<Instance<'tcx>> {
let actual_kind = substs.as_closure().kind();

match needs_fn_once_adapter_shim(actual_kind, requested_kind) {
Ok(true) => Instance::fn_once_adapter_instance(tcx, def_id, substs),
_ => Instance::new(def_id, substs),
_ => Some(Instance::new(def_id, substs)),
}
}

Expand All @@ -515,7 +515,7 @@ impl<'tcx> Instance<'tcx> {
tcx: TyCtxt<'tcx>,
closure_did: DefId,
substs: ty::SubstsRef<'tcx>,
) -> Instance<'tcx> {
) -> Option<Instance<'tcx>> {
debug!("fn_once_adapter_shim({:?}, {:?})", closure_did, substs);
let fn_once = tcx.require_lang_item(LangItem::FnOnce, None);
let call_once = tcx
Expand All @@ -531,12 +531,13 @@ impl<'tcx> Instance<'tcx> {
let self_ty = tcx.mk_closure(closure_did, substs);

let sig = substs.as_closure().sig();
let sig = tcx.normalize_erasing_late_bound_regions(ty::ParamEnv::reveal_all(), sig);
let sig =
tcx.try_normalize_erasing_late_bound_regions(ty::ParamEnv::reveal_all(), sig).ok()?;
assert_eq!(sig.inputs().len(), 1);
let substs = tcx.mk_substs_trait(self_ty, &[sig.inputs()[0].into()]);

debug!("fn_once_adapter_shim: self_ty={:?} sig={:?}", self_ty, sig);
Instance { def, substs }
Some(Instance { def, substs })
}

/// Depending on the kind of `InstanceDef`, the MIR body associated with an
Expand Down
20 changes: 20 additions & 0 deletions compiler/rustc_middle/src/ty/normalize_erasing_regions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,26 @@ impl<'tcx> TyCtxt<'tcx> {
self.normalize_erasing_regions(param_env, value)
}

/// If you have a `Binder<'tcx, T>`, you can do this to strip out the
/// late-bound regions and then normalize the result, yielding up
/// a `T` (with regions erased). This is appropriate when the
/// binder is being instantiated at the call site.
///
/// N.B., currently, higher-ranked type bounds inhibit
/// normalization. Therefore, each time we erase them in
/// codegen, we need to normalize the contents.
pub fn try_normalize_erasing_late_bound_regions<T>(
self,
param_env: ty::ParamEnv<'tcx>,
value: ty::Binder<'tcx, T>,
) -> Result<T, NormalizationError<'tcx>>
where
T: TypeFoldable<'tcx>,
{
let value = self.erase_late_bound_regions(value);
self.try_normalize_erasing_regions(param_env, value)
}

/// Monomorphizes a type from the AST by first applying the
/// in-scope substitutions and then normalizing any associated
/// types.
Expand Down
100 changes: 68 additions & 32 deletions compiler/rustc_mir_transform/src/inline.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
//! Inlining pass for MIR functions
use crate::deref_separator::deref_finder;
use rustc_attr::InlineAttr;
use rustc_const_eval::transform::validate::equal_up_to_regions;
use rustc_index::bit_set::BitSet;
use rustc_index::vec::Idx;
use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs};
use rustc_middle::mir::visit::*;
use rustc_middle::mir::*;
use rustc_middle::traits::ObligationCause;
use rustc_middle::ty::subst::Subst;
use rustc_middle::ty::{self, ConstKind, Instance, InstanceDef, ParamEnv, Ty, TyCtxt};
use rustc_session::config::OptLevel;
use rustc_span::{hygiene::ExpnKind, ExpnData, LocalExpnId, Span};
use rustc_target::spec::abi::Abi;

Expand Down Expand Up @@ -43,7 +44,15 @@ impl<'tcx> MirPass<'tcx> for Inline {
return enabled;
}

sess.opts.mir_opt_level() >= 3
match sess.mir_opt_level() {
0 | 1 => false,
2 => {
(sess.opts.optimize == OptLevel::Default
|| sess.opts.optimize == OptLevel::Aggressive)
&& sess.opts.incremental == None
}
_ => true,
}
}

fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
Expand Down Expand Up @@ -76,13 +85,6 @@ fn inline<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> bool {
}

let param_env = tcx.param_env_reveal_all_normalized(def_id);
let hir_id = tcx.hir().local_def_id_to_hir_id(def_id);
let param_env = rustc_trait_selection::traits::normalize_param_env_or_error(
tcx,
def_id.to_def_id(),
param_env,
ObligationCause::misc(body.span, hir_id),
);

let mut this = Inliner {
tcx,
Expand Down Expand Up @@ -166,6 +168,45 @@ impl<'tcx> Inliner<'tcx> {
return Err("failed to normalize callee body");
};

// Check call signature compatibility.
// Normally, this shouldn't be required, but trait normalization failure can create a
// validation ICE.
let terminator = caller_body[callsite.block].terminator.as_ref().unwrap();
let TerminatorKind::Call { args, destination, .. } = &terminator.kind else { bug!() };
let destination_ty = destination.ty(&caller_body.local_decls, self.tcx).ty;
let output_type = callee_body.return_ty();
if !equal_up_to_regions(self.tcx, self.param_env, output_type, destination_ty) {
trace!(?output_type, ?destination_ty);
return Err("failed to normalize return type");
}
if callsite.fn_sig.abi() == Abi::RustCall {
let mut args = args.into_iter();
let _ = args.next(); // Skip `self` argument.
let arg_tuple_ty = args.next().unwrap().ty(&caller_body.local_decls, self.tcx);
assert!(args.next().is_none());

let ty::Tuple(arg_tuple_tys) = arg_tuple_ty.kind() else {
bug!("Closure arguments are not passed as a tuple");
};

for (arg_ty, input) in arg_tuple_tys.iter().zip(callee_body.args_iter().skip(1)) {
let input_type = callee_body.local_decls[input].ty;
if !equal_up_to_regions(self.tcx, self.param_env, arg_ty, input_type) {
trace!(?arg_ty, ?input_type);
return Err("failed to normalize tuple argument type");
}
}
} else {
for (arg, input) in args.iter().zip(callee_body.args_iter()) {
let input_type = callee_body.local_decls[input].ty;
let arg_ty = arg.ty(&caller_body.local_decls, self.tcx);
if !equal_up_to_regions(self.tcx, self.param_env, arg_ty, input_type) {
trace!(?arg_ty, ?input_type);
return Err("failed to normalize argument type");
}
}
}

let old_blocks = caller_body.basic_blocks().next_index();
self.inline_call(caller_body, &callsite, callee_body);
let new_blocks = old_blocks..caller_body.basic_blocks().next_index();
Expand Down Expand Up @@ -263,6 +304,10 @@ impl<'tcx> Inliner<'tcx> {
return None;
}

if self.history.contains(&callee) {
return None;
}

let fn_sig = self.tcx.bound_fn_sig(def_id).subst(self.tcx, substs);

return Some(CallSite {
Expand All @@ -285,8 +330,14 @@ impl<'tcx> Inliner<'tcx> {
callsite: &CallSite<'tcx>,
callee_attrs: &CodegenFnAttrs,
) -> Result<(), &'static str> {
if let InlineAttr::Never = callee_attrs.inline {
return Err("never inline hint");
match callee_attrs.inline {
InlineAttr::Never => return Err("never inline hint"),
InlineAttr::Always | InlineAttr::Hint => {}
InlineAttr::None => {
if self.tcx.sess.mir_opt_level() <= 2 {
return Err("at mir-opt-level=2, only #[inline] is inlined");
}
}
}

// Only inline local functions if they would be eligible for cross-crate
Expand Down Expand Up @@ -407,22 +458,9 @@ impl<'tcx> Inliner<'tcx> {
}

TerminatorKind::Call { func: Operand::Constant(ref f), cleanup, .. } => {
if let ty::FnDef(def_id, substs) =
if let ty::FnDef(def_id, _) =
*callsite.callee.subst_mir(self.tcx, &f.literal.ty()).kind()
{
if let Ok(substs) =
self.tcx.try_normalize_erasing_regions(self.param_env, substs)
{
if let Ok(Some(instance)) =
Instance::resolve(self.tcx, self.param_env, def_id, substs)
{
if callsite.callee.def_id() == instance.def_id() {
return Err("self-recursion");
} else if self.history.contains(&instance) {
return Err("already inlined");
}
}
}
// Don't give intrinsics the extra penalty for calls
if tcx.is_intrinsic(def_id) {
cost += INSTR_COST;
Expand Down Expand Up @@ -482,14 +520,12 @@ impl<'tcx> Inliner<'tcx> {
if let InlineAttr::Always = callee_attrs.inline {
debug!("INLINING {:?} because inline(always) [cost={}]", callsite, cost);
Ok(())
} else if cost <= threshold {
debug!("INLINING {:?} [cost={} <= threshold={}]", callsite, cost, threshold);
Ok(())
} else {
if cost <= threshold {
debug!("INLINING {:?} [cost={} <= threshold={}]", callsite, cost, threshold);
Ok(())
} else {
debug!("NOT inlining {:?} [cost={} > threshold={}]", callsite, cost, threshold);
Err("cost above threshold")
}
debug!("NOT inlining {:?} [cost={} > threshold={}]", callsite, cost, threshold);
Err("cost above threshold")
}
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_mir_transform/src/inline/cycle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>(
trace!(?caller, ?param_env, ?substs, "cannot normalize, skipping");
continue;
};
let Some(callee) = ty::Instance::resolve(tcx, param_env, callee, substs).unwrap() else {
let Ok(Some(callee)) = ty::Instance::resolve(tcx, param_env, callee, substs) else {
trace!(?callee, "cannot resolve, skipping");
continue;
};
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_monomorphize/src/collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,8 @@ impl<'a, 'tcx> MirVisitor<'tcx> for MirNeighborCollector<'a, 'tcx> {
def_id,
substs,
ty::ClosureKind::FnOnce,
);
)
.expect("failed to normalize and resolve closure during codegen");
if should_codegen_locally(self.tcx, &instance) {
self.output.push(create_fn_mono_item(self.tcx, instance, span));
}
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_ty_utils/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,12 @@ fn resolve_associated_item<'tcx>(
}),
traits::ImplSource::Closure(closure_data) => {
let trait_closure_kind = tcx.fn_trait_kind_from_lang_item(trait_id).unwrap();
Some(Instance::resolve_closure(
Instance::resolve_closure(
tcx,
closure_data.closure_def_id,
closure_data.substs,
trait_closure_kind,
))
)
}
traits::ImplSource::FnPointer(ref data) => match data.fn_ty.kind() {
ty::FnDef(..) | ty::FnPtr(..) => Some(Instance {
Expand Down
4 changes: 2 additions & 2 deletions src/test/codegen/issue-37945.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub fn is_empty_1(xs: Iter<f32>) -> bool {
// CHECK-NEXT: start:
// CHECK-NEXT: [[A:%.*]] = icmp ne {{i32\*|ptr}} %xs.1, null
// CHECK-NEXT: tail call void @llvm.assume(i1 [[A]])
// CHECK-NEXT: [[B:%.*]] = icmp eq {{i32\*|ptr}} %xs.0, %xs.1
// CHECK-NEXT: [[B:%.*]] = icmp eq {{i32\*|ptr}} %xs.1, %xs.0
// CHECK-NEXT: ret i1 [[B:%.*]]
{xs}.next().is_none()
}
Expand All @@ -28,7 +28,7 @@ pub fn is_empty_2(xs: Iter<f32>) -> bool {
// CHECK-NEXT: start:
// CHECK-NEXT: [[C:%.*]] = icmp ne {{i32\*|ptr}} %xs.1, null
// CHECK-NEXT: tail call void @llvm.assume(i1 [[C]])
// CHECK-NEXT: [[D:%.*]] = icmp eq {{i32\*|ptr}} %xs.0, %xs.1
// CHECK-NEXT: [[D:%.*]] = icmp eq {{i32\*|ptr}} %xs.1, %xs.0
// CHECK-NEXT: ret i1 [[D:%.*]]
xs.map(|&x| x).next().is_none()
}
2 changes: 1 addition & 1 deletion src/test/codegen/issue-75659.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// This test checks that the call to memchr/slice_contains is optimized away
// when searching in small slices.

// compile-flags: -O
// compile-flags: -O -Zinline-mir=no
// only-x86_64

#![crate_type = "lib"]
Expand Down
12 changes: 5 additions & 7 deletions src/test/codegen/mem-replace-direct-memcpy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// may e.g. multiply `size_of::<T>()` with a variable "count" (which is only
// known to be `1` after inlining).

// compile-flags: -C no-prepopulate-passes
// compile-flags: -C no-prepopulate-passes -Zinline-mir=no

#![crate_type = "lib"]

Expand All @@ -12,14 +12,12 @@ pub fn replace_byte(dst: &mut u8, src: u8) -> u8 {
}

// NOTE(eddyb) the `CHECK-NOT`s ensure that the only calls of `@llvm.memcpy` in
// the entire output, are the two direct calls we want, from `ptr::{read,write}`.
// the entire output, are the two direct calls we want, from `ptr::replace`.

// CHECK-NOT: call void @llvm.memcpy
// CHECK: ; core::ptr::read
// CHECK: ; core::mem::replace
// CHECK-NOT: call void @llvm.memcpy
// CHECK: call void @llvm.memcpy.{{.+}}({{i8\*|ptr}} align 1 %{{.*}}, {{i8\*|ptr}} align 1 %src, i{{.*}} 1, i1 false)
// CHECK: call void @llvm.memcpy.{{.+}}({{i8\*|ptr}} align 1 %{{.*}}, {{i8\*|ptr}} align 1 %dest, i{{.*}} 1, i1 false)
Copy link
Member

@RalfJung RalfJung Jul 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change made the test fail in --stage 1 (the default for ./x.py test) on my system

/home/r/src/rust/rustc.3/build/x86_64-unknown-linux-gnu/test/codegen/mem-replace-direct-memcpy/mem-replace-direct-memcpy.ll:138:2: note: possible intended match here
 call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 1 %_4, i8* align 1 %src, i64 1, i1 false)
 ^

It should be src, not dest, it seems. Or it should be %{{.*}} I guess since the name seems to be unstable. The 2nd argument is the src though so the "dest" you added here makes no sense to me.

// CHECK-NOT: call void @llvm.memcpy
// CHECK: ; core::ptr::write
// CHECK-NOT: call void @llvm.memcpy
// CHECK: call void @llvm.memcpy.{{.+}}({{i8\*|ptr}} align 1 %dst, {{i8\*|ptr}} align 1 %src, i{{.*}} 1, i1 false)
// CHECK: call void @llvm.memcpy.{{.+}}({{i8\*|ptr}} align 1 %dest, {{i8\*|ptr}} align 1 %src{{.*}}, i{{.*}} 1, i1 false)
// CHECK-NOT: call void @llvm.memcpy
2 changes: 1 addition & 1 deletion src/test/codegen/remap_path_prefix/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// ignore-windows
//

// compile-flags: -g -C no-prepopulate-passes --remap-path-prefix={{cwd}}=/the/cwd --remap-path-prefix={{src-base}}=/the/src
// compile-flags: -g -C no-prepopulate-passes --remap-path-prefix={{cwd}}=/the/cwd --remap-path-prefix={{src-base}}=/the/src -Zinline-mir=no
// aux-build:remap_path_prefix_aux.rs

extern crate remap_path_prefix_aux;
Expand Down
7 changes: 4 additions & 3 deletions src/test/codegen/simd-wide-sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ pub fn wider_reduce_iter(x: Simd<u8, N>) -> u16 {
#[no_mangle]
// CHECK-LABEL: @wider_reduce_into_iter
pub fn wider_reduce_into_iter(x: Simd<u8, N>) -> u16 {
// CHECK: zext <8 x i8>
// CHECK-SAME: to <8 x i16>
// CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16>
// FIXME MIR inlining messes up LLVM optimizations.
// WOULD-CHECK: zext <8 x i8>
// WOULD-CHECK-SAME: to <8 x i16>
// WOULD-CHECK: call i16 @llvm.vector.reduce.add.v8i16(<8 x i16>
x.to_array().into_iter().map(u16::from).sum()
}
Loading