diff --git a/compiler/rustc_const_eval/src/interpret/terminator.rs b/compiler/rustc_const_eval/src/interpret/terminator.rs index b2ebcceceb328..33c808649562c 100644 --- a/compiler/rustc_const_eval/src/interpret/terminator.rs +++ b/compiler/rustc_const_eval/src/interpret/terminator.rs @@ -255,6 +255,16 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { caller_abi: &ArgAbi<'tcx, Ty<'tcx>>, callee_abi: &ArgAbi<'tcx, Ty<'tcx>>, ) -> bool { + let primitive_abi_compat = |a1: abi::Primitive, a2: abi::Primitive| -> bool { + match (a1, a2) { + // For integers, ignore the sign. + (abi::Primitive::Int(int_ty1, _sign1), abi::Primitive::Int(int_ty2, _sign2)) => { + int_ty1 == int_ty2 + } + // For everything else we require full equality. + _ => a1 == a2, + } + }; // Heuristic for type comparison. let layout_compat = || { if caller_abi.layout.ty == callee_abi.layout.ty { @@ -267,28 +277,24 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { // then who knows what happens. return false; } - if caller_abi.layout.size != callee_abi.layout.size - || caller_abi.layout.align.abi != callee_abi.layout.align.abi - { - // This cannot go well... - return false; - } - // The rest *should* be okay, but we are extra conservative. + // This is tricky. Some ABIs split aggregates up into multiple registers etc, so we have + // to be super careful here. For the scalar ABIs we conveniently already have all the + // newtypes unwrapped etc, so in those cases we can just compare the scalar components. + // Everything else we just reject for now. match (caller_abi.layout.abi, callee_abi.layout.abi) { - // Different valid ranges are okay (once we enforce validity, - // that will take care to make it UB to leave the range, just - // like for transmute). + // Different valid ranges are okay (the validity check will complain if this leads + // to invalid transmutes). (abi::Abi::Scalar(caller), abi::Abi::Scalar(callee)) => { - caller.primitive() == callee.primitive() + primitive_abi_compat(caller.primitive(), callee.primitive()) } ( abi::Abi::ScalarPair(caller1, caller2), abi::Abi::ScalarPair(callee1, callee2), ) => { - caller1.primitive() == callee1.primitive() - && caller2.primitive() == callee2.primitive() + primitive_abi_compat(caller1.primitive(), callee1.primitive()) + && primitive_abi_compat(caller2.primitive(), callee2.primitive()) } - // Be conservative + // Be conservative. _ => false, } }; @@ -309,7 +315,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { return true; }; let mode_compat = || match (&caller_abi.mode, &callee_abi.mode) { - (PassMode::Ignore, PassMode::Ignore) => true, + (PassMode::Ignore, PassMode::Ignore) => true, // can still be reached for the return type (PassMode::Direct(a1), PassMode::Direct(a2)) => arg_attr_compat(a1, a2), (PassMode::Pair(a1, b1), PassMode::Pair(a2, b2)) => { arg_attr_compat(a1, a2) && arg_attr_compat(b1, b2) @@ -326,7 +332,15 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> { _ => false, }; + // We have to check both. `layout_compat` is needed to reject e.g. `i32` vs `f32`, + // which is not reflected in `PassMode`. `mode_compat` is needed to reject `u8` vs `bool`, + // which have the same `abi::Primitive` but different `arg_ext`. if layout_compat() && mode_compat() { + // Something went very wrong if our checks don't even imply that the layout is the same. + assert!( + caller_abi.layout.size == callee_abi.layout.size + && caller_abi.layout.align.abi == callee_abi.layout.align.abi + ); return true; } trace!( diff --git a/src/tools/miri/tests/pass/function_calls/abi_compat.rs b/src/tools/miri/tests/pass/function_calls/abi_compat.rs new file mode 100644 index 0000000000000..67b87b46bd9bc --- /dev/null +++ b/src/tools/miri/tests/pass/function_calls/abi_compat.rs @@ -0,0 +1,27 @@ +use std::num; +use std::mem; + +fn test_abi_compat(t: T, u: U) { + fn id(x: T) -> T { x } + + // This checks ABI compatibility both for arguments and return values, + // in both directions. + let f: fn(T) -> T = id; + let f: fn(U) -> U = unsafe { std::mem::transmute(f) }; + drop(f(u)); + + let f: fn(U) -> U = id; + let f: fn(T) -> T = unsafe { std::mem::transmute(f) }; + drop(f(t)); +} + +fn main() { + test_abi_compat(0u32, 'x'); + test_abi_compat(&0u32, &([true; 4], [0u32; 0])); + test_abi_compat(0u32, mem::MaybeUninit::new(0u32)); + test_abi_compat(42u32, num::NonZeroU32::new(1).unwrap()); + test_abi_compat(0u32, Some(num::NonZeroU32::new(1).unwrap())); + test_abi_compat(0u32, 0i32); + // Note that `bool` and `u8` are *not* compatible! + // One of them has `arg_ext: Zext`, the other does not. +}