Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rollup of 4 pull requests #127965

Closed
wants to merge 10 commits into from
41 changes: 36 additions & 5 deletions compiler/rustc_const_eval/src/interpret/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,15 +401,46 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
}
(ty::Dynamic(data_a, _, ty::Dyn), ty::Dynamic(data_b, _, ty::Dyn)) => {
let val = self.read_immediate(src)?;
if data_a.principal() == data_b.principal() {
// A NOP cast that doesn't actually change anything, should be allowed even with mismatching vtables.
// (But currently mismatching vtables violate the validity invariant so UB is triggered anyway.)
return self.write_immediate(*val, dest);
}
// Take apart the old pointer, and find the dynamic type.
let (old_data, old_vptr) = val.to_scalar_pair();
let old_data = old_data.to_pointer(self)?;
let old_vptr = old_vptr.to_pointer(self)?;
let ty = self.get_ptr_vtable_ty(old_vptr, Some(data_a))?;

// Sanity-check that `supertrait_vtable_slot` in this type's vtable indeed produces
// our destination trait.
if cfg!(debug_assertions) {
let vptr_entry_idx =
self.tcx.supertrait_vtable_slot((src_pointee_ty, dest_pointee_ty));
let vtable_entries = self.vtable_entries(data_a.principal(), ty);
if let Some(entry_idx) = vptr_entry_idx {
let Some(&ty::VtblEntry::TraitVPtr(upcast_trait_ref)) =
vtable_entries.get(entry_idx)
else {
span_bug!(
self.cur_span(),
"invalid vtable entry index in {} -> {} upcast",
src_pointee_ty,
dest_pointee_ty
);
};
let erased_trait_ref = upcast_trait_ref
.map_bound(|r| ty::ExistentialTraitRef::erase_self_ty(*self.tcx, r));
assert!(
data_b
.principal()
.is_some_and(|b| self.eq_in_param_env(erased_trait_ref, b))
);
} else {
// In this case codegen would keep using the old vtable. We don't want to do
// that as it has the wrong trait. The reason codegen can do this is that
// one vtable is a prefix of the other, so we double-check that.
let vtable_entries_b = self.vtable_entries(data_b.principal(), ty);
assert!(&vtable_entries[..vtable_entries_b.len()] == vtable_entries_b);
};
}

// Get the destination trait vtable and return that.
let new_vptr = self.get_vtable_ptr(ty, data_b.principal())?;
self.write_immediate(Immediate::new_dyn_trait(old_data, new_vptr, self), dest)
}
Expand Down
30 changes: 30 additions & 0 deletions compiler/rustc_const_eval/src/interpret/eval_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ use std::cell::Cell;
use std::{fmt, mem};

use either::{Either, Left, Right};
use rustc_infer::infer::at::ToTrace;
use rustc_infer::traits::ObligationCause;
use rustc_trait_selection::traits::ObligationCtxt;
use tracing::{debug, info, info_span, instrument, trace};

use rustc_errors::DiagCtxtHandle;
use rustc_hir::{self as hir, def_id::DefId, definitions::DefPathData};
use rustc_index::IndexVec;
use rustc_infer::infer::TyCtxtInferExt;
use rustc_middle::mir;
use rustc_middle::mir::interpret::{
CtfeProvenance, ErrorHandled, InvalidMetaKind, ReportedErrorInfo,
Expand Down Expand Up @@ -640,6 +644,32 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
}
}

/// Check if the two things are equal in the current param_env, using an infctx to get proper
/// equality checks.
pub(super) fn eq_in_param_env<T>(&self, a: T, b: T) -> bool
where
T: PartialEq + TypeFoldable<TyCtxt<'tcx>> + ToTrace<'tcx>,
{
// Fast path: compare directly.
if a == b {
return true;
}
// Slow path: spin up an inference context to check if these traits are sufficiently equal.
let infcx = self.tcx.infer_ctxt().build();
let ocx = ObligationCtxt::new(&infcx);
let cause = ObligationCause::dummy_with_span(self.cur_span());
// equate the two trait refs after normalization
let a = ocx.normalize(&cause, self.param_env, a);
let b = ocx.normalize(&cause, self.param_env, b);
if ocx.eq(&cause, self.param_env, a, b).is_ok() {
if ocx.select_all_or_error().is_empty() {
// All good.
return true;
}
}
return false;
}

/// Walks up the callstack from the intrinsic's callsite, searching for the first callsite in a
/// frame which is not `#[track_caller]`. This matches the `caller_location` intrinsic,
/// and is primarily intended for the panic machinery.
Expand Down
17 changes: 6 additions & 11 deletions compiler/rustc_const_eval/src/interpret/terminator.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::borrow::Cow;

use either::Either;
use rustc_middle::ty::TyCtxt;
use tracing::trace;

use rustc_middle::{
Expand Down Expand Up @@ -867,7 +866,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
};

// Obtain the underlying trait we are working on, and the adjusted receiver argument.
let (dyn_trait, dyn_ty, adjusted_recv) = if let ty::Dynamic(data, _, ty::DynStar) =
let (trait_, dyn_ty, adjusted_recv) = if let ty::Dynamic(data, _, ty::DynStar) =
receiver_place.layout.ty.kind()
{
let recv = self.unpack_dyn_star(&receiver_place, data)?;
Expand Down Expand Up @@ -898,20 +897,16 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
(receiver_trait.principal(), dyn_ty, receiver_place.ptr())
};

// Now determine the actual method to call. We can do that in two different ways and
// compare them to ensure everything fits.
let vtable_entries = if let Some(dyn_trait) = dyn_trait {
let trait_ref = dyn_trait.with_self_ty(*self.tcx, dyn_ty);
let trait_ref = self.tcx.erase_regions(trait_ref);
self.tcx.vtable_entries(trait_ref)
} else {
TyCtxt::COMMON_VTABLE_ENTRIES
};
// Now determine the actual method to call. Usually we use the easy way of just
// looking up the method at index `idx`.
let vtable_entries = self.vtable_entries(trait_, dyn_ty);
let Some(ty::VtblEntry::Method(fn_inst)) = vtable_entries.get(idx).copied() else {
// FIXME(fee1-dead) these could be variants of the UB info enum instead of this
throw_ub_custom!(fluent::const_eval_dyn_call_not_a_method);
};
trace!("Virtual call dispatches to {fn_inst:#?}");
// We can also do the lookup based on `def_id` and `dyn_ty`, and check that that
// produces the same result.
if cfg!(debug_assertions) {
let tcx = *self.tcx;

Expand Down
48 changes: 23 additions & 25 deletions compiler/rustc_const_eval/src/interpret/traits.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use rustc_infer::infer::TyCtxtInferExt;
use rustc_infer::traits::ObligationCause;
use rustc_middle::mir::interpret::{InterpResult, Pointer};
use rustc_middle::ty::layout::LayoutOf;
use rustc_middle::ty::{self, Ty};
use rustc_middle::ty::{self, Ty, TyCtxt, VtblEntry};
use rustc_target::abi::{Align, Size};
use rustc_trait_selection::traits::ObligationCtxt;
use tracing::trace;

use super::util::ensure_monomorphic_enough;
Expand Down Expand Up @@ -47,35 +44,36 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
Ok((layout.size, layout.align.abi))
}

pub(super) fn vtable_entries(
&self,
trait_: Option<ty::PolyExistentialTraitRef<'tcx>>,
dyn_ty: Ty<'tcx>,
) -> &'tcx [VtblEntry<'tcx>] {
if let Some(trait_) = trait_ {
let trait_ref = trait_.with_self_ty(*self.tcx, dyn_ty);
let trait_ref = self.tcx.erase_regions(trait_ref);
self.tcx.vtable_entries(trait_ref)
} else {
TyCtxt::COMMON_VTABLE_ENTRIES
}
}

/// Check that the given vtable trait is valid for a pointer/reference/place with the given
/// expected trait type.
pub(super) fn check_vtable_for_type(
&self,
vtable_trait: Option<ty::PolyExistentialTraitRef<'tcx>>,
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
) -> InterpResult<'tcx> {
// Fast path: if they are equal, it's all fine.
if expected_trait.principal() == vtable_trait {
return Ok(());
}
if let (Some(expected_trait), Some(vtable_trait)) =
(expected_trait.principal(), vtable_trait)
{
// Slow path: spin up an inference context to check if these traits are sufficiently equal.
let infcx = self.tcx.infer_ctxt().build();
let ocx = ObligationCtxt::new(&infcx);
let cause = ObligationCause::dummy_with_span(self.cur_span());
// equate the two trait refs after normalization
let expected_trait = ocx.normalize(&cause, self.param_env, expected_trait);
let vtable_trait = ocx.normalize(&cause, self.param_env, vtable_trait);
if ocx.eq(&cause, self.param_env, expected_trait, vtable_trait).is_ok() {
if ocx.select_all_or_error().is_empty() {
// All good.
return Ok(());
}
}
let eq = match (expected_trait.principal(), vtable_trait) {
(Some(a), Some(b)) => self.eq_in_param_env(a, b),
(None, None) => true,
_ => false,
};
if !eq {
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
}
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
Ok(())
}

/// Turn a place with a `dyn Trait` type into a place with the actual dynamic type.
Expand Down
20 changes: 12 additions & 8 deletions compiler/rustc_trait_selection/src/traits/vtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,9 @@ pub(crate) fn first_method_vtable_slot<'tcx>(tcx: TyCtxt<'tcx>, key: ty::TraitRe
}

/// Given a `dyn Subtrait` and `dyn Supertrait` trait object, find the slot of
/// // the trait vptr in the subtrait's vtable.
/// the trait vptr in the subtrait's vtable.
///
/// A return value of `None` means that the original vtable can be reused.
pub(crate) fn supertrait_vtable_slot<'tcx>(
tcx: TyCtxt<'tcx>,
key: (
Expand All @@ -373,20 +375,22 @@ pub(crate) fn supertrait_vtable_slot<'tcx>(
),
) -> Option<usize> {
debug_assert!(!key.has_non_region_infer() && !key.has_non_region_param());

let (source, target) = key;
let ty::Dynamic(source, _, _) = *source.kind() else {

// If the target principal is `None`, we can just return `None`.
let ty::Dynamic(target, _, _) = *target.kind() else {
bug!();
};
let source_principal = tcx
.normalize_erasing_regions(ty::ParamEnv::reveal_all(), source.principal().unwrap())
let target_principal = tcx
.normalize_erasing_regions(ty::ParamEnv::reveal_all(), target.principal()?)
.with_self_ty(tcx, tcx.types.trait_object_dummy_self);

let ty::Dynamic(target, _, _) = *target.kind() else {
// Given that we have a target principal, it is a bug for there not to be a source principal.
let ty::Dynamic(source, _, _) = *source.kind() else {
bug!();
};
let target_principal = tcx
.normalize_erasing_regions(ty::ParamEnv::reveal_all(), target.principal().unwrap())
let source_principal = tcx
.normalize_erasing_regions(ty::ParamEnv::reveal_all(), source.principal().unwrap())
.with_self_ty(tcx, tcx.types.trait_object_dummy_self);

let vtable_segment_callback = {
Expand Down
1 change: 0 additions & 1 deletion library/core/src/num/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#![stable(feature = "rust1", since = "1.0.0")]

use crate::ascii;
use crate::hint;
use crate::intrinsics;
use crate::mem;
use crate::str::FromStr;
Expand Down
60 changes: 54 additions & 6 deletions library/core/src/num/nonzero.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use crate::cmp::Ordering;
use crate::fmt;
use crate::hash::{Hash, Hasher};
use crate::hint;
use crate::intrinsics;
use crate::marker::{Freeze, StructuralPartialEq};
use crate::ops::{BitOr, BitOrAssign, Div, DivAssign, Neg, Rem, RemAssign};
Expand Down Expand Up @@ -604,7 +605,6 @@ macro_rules! nonzero_integer {
}

nonzero_integer_signedness_dependent_methods! {
Self = $Ty,
Primitive = $signedness $Int,
UnsignedPrimitive = $Uint,
}
Expand Down Expand Up @@ -823,7 +823,7 @@ macro_rules! nonzero_integer {
}
}

nonzero_integer_signedness_dependent_impls!($Ty $signedness $Int);
nonzero_integer_signedness_dependent_impls!($signedness $Int);
};

(Self = $Ty:ident, Primitive = unsigned $Int:ident $(,)?) => {
Expand All @@ -849,7 +849,7 @@ macro_rules! nonzero_integer {

macro_rules! nonzero_integer_signedness_dependent_impls {
// Impls for unsigned nonzero types only.
($Ty:ident unsigned $Int:ty) => {
(unsigned $Int:ty) => {
#[stable(feature = "nonzero_div", since = "1.51.0")]
impl Div<NonZero<$Int>> for $Int {
type Output = $Int;
Expand Down Expand Up @@ -897,7 +897,7 @@ macro_rules! nonzero_integer_signedness_dependent_impls {
}
};
// Impls for signed nonzero types only.
($Ty:ident signed $Int:ty) => {
(signed $Int:ty) => {
#[stable(feature = "signed_nonzero_neg", since = "1.71.0")]
impl Neg for NonZero<$Int> {
type Output = Self;
Expand All @@ -918,7 +918,6 @@ macro_rules! nonzero_integer_signedness_dependent_impls {
macro_rules! nonzero_integer_signedness_dependent_methods {
// Associated items for unsigned nonzero types only.
(
Self = $Ty:ident,
Primitive = unsigned $Int:ident,
UnsignedPrimitive = $Uint:ty,
) => {
Expand Down Expand Up @@ -1224,11 +1223,60 @@ macro_rules! nonzero_integer_signedness_dependent_methods {

intrinsics::ctpop(self.get()) < 2
}

/// Returns the square root of the number, rounded down.
///
/// # Examples
///
/// Basic usage:
/// ```
/// #![feature(isqrt)]
/// # use std::num::NonZero;
/// #
/// # fn main() { test().unwrap(); }
/// # fn test() -> Option<()> {
#[doc = concat!("let ten = NonZero::new(10", stringify!($Int), ")?;")]
#[doc = concat!("let three = NonZero::new(3", stringify!($Int), ")?;")]
///
/// assert_eq!(ten.isqrt(), three);
/// # Some(())
/// # }
#[unstable(feature = "isqrt", issue = "116226")]
#[rustc_const_unstable(feature = "isqrt", issue = "116226")]
#[must_use = "this returns the result of the operation, \
without modifying the original"]
#[inline]
pub const fn isqrt(self) -> Self {
// The algorithm is based on the one presented in
// <https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)>
// which cites as source the following C code:
// <https://web.archive.org/web/20120306040058/http://medialab.freaknet.org/martin/src/sqrt/sqrt.c>.

let mut op = self.get();
let mut res = 0;
let mut one = 1 << (self.ilog2() & !1);

while one != 0 {
if op >= res + one {
op -= res + one;
res = (res >> 1) + one;
} else {
res >>= 1;
}
one >>= 2;
}

// SAFETY: The result fits in an integer with half as many bits.
// Inform the optimizer about it.
unsafe { hint::assert_unchecked(res < 1 << (Self::BITS / 2)) };

// SAFETY: The square root of an integer >= 1 is always >= 1.
unsafe { Self::new_unchecked(res) }
}
};

// Associated items for signed nonzero types only.
(
Self = $Ty:ident,
Primitive = signed $Int:ident,
UnsignedPrimitive = $Uint:ty,
) => {
Expand Down
Loading
Loading