Skip to content

Commit

Permalink
apply micro optimizations to DeepRejectCtxt
Browse files Browse the repository at this point in the history
  • Loading branch information
Bryanskiy committed Aug 16, 2024
1 parent 6738c1e commit b4eaef1
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 118 deletions.
10 changes: 3 additions & 7 deletions compiler/rustc_hir_typeck/src/method/suggest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use rustc_hir::lang_items::LangItem;
use rustc_hir::{self as hir, ExprKind, HirId, Node, PathSegment, QPath};
use rustc_infer::infer::{self, RegionVariableOrigin};
use rustc_middle::bug;
use rustc_middle::ty::fast_reject::{simplify_type, DeepRejectCtxt, TreatParams};
use rustc_middle::ty::fast_reject::{new_reject_ctxt, simplify_type, DeepRejectCtxt, TreatParams};
use rustc_middle::ty::print::{
with_crate_prefix, with_forced_trimmed_paths, PrintTraitRefExt as _,
};
Expand Down Expand Up @@ -2234,12 +2234,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
let target_ty = self
.autoderef(sugg_span, rcvr_ty)
.find(|(rcvr_ty, _)| {
DeepRejectCtxt::new(
self.tcx,
TreatParams::AsRigid,
TreatParams::InstantiateWithInfer,
)
.types_may_unify(*rcvr_ty, impl_ty)
new_reject_ctxt!(self.tcx, AsRigid, InstantiateWithInfer)
.types_may_unify(*rcvr_ty, impl_ty)
})
.map_or(impl_ty, |(ty, _)| ty)
.peel_refs();
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_middle/src/ty/fast_reject.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use rustc_hir::def_id::DefId;
pub use rustc_type_ir::fast_reject::*;
pub use rustc_type_ir::new_reject_ctxt;

use super::TyCtxt;

pub type DeepRejectCtxt<'tcx> = rustc_type_ir::fast_reject::DeepRejectCtxt<TyCtxt<'tcx>>;
pub type DeepRejectCtxt<I, const LHS: bool, const RHS: bool> =
rustc_type_ir::fast_reject::DeepRejectCtxt<I, LHS, RHS>;

pub type SimplifiedType = rustc_type_ir::fast_reject::SimplifiedType<DefId>;
12 changes: 5 additions & 7 deletions compiler/rustc_next_trait_solver/src/solve/normalizes_to/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod weak_types;
use rustc_type_ir::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_type_ir::inherent::*;
use rustc_type_ir::lang_items::TraitSolverLangItem;
use rustc_type_ir::{self as ty, Interner, NormalizesTo, Upcast as _};
use rustc_type_ir::{self as ty, new_reject_ctxt, Interner, NormalizesTo, Upcast as _};
use tracing::instrument;

use crate::delegate::SolverDelegate;
Expand Down Expand Up @@ -144,12 +144,10 @@ where

let goal_trait_ref = goal.predicate.alias.trait_ref(cx);
let impl_trait_ref = cx.impl_trait_ref(impl_def_id);
if !DeepRejectCtxt::new(ecx.cx(), TreatParams::AsRigid, TreatParams::InstantiateWithInfer)
.args_may_unify(
goal.predicate.alias.trait_ref(cx).args,
impl_trait_ref.skip_binder().args,
)
{
if !new_reject_ctxt!(ecx.cx(), AsRigid, InstantiateWithInfer).args_may_unify(
goal.predicate.alias.trait_ref(cx).args,
impl_trait_ref.skip_binder().args,
) {
return Err(NoSolution);
}

Expand Down
6 changes: 4 additions & 2 deletions compiler/rustc_next_trait_solver/src/solve/trait_goals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ use rustc_type_ir::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_type_ir::inherent::*;
use rustc_type_ir::lang_items::TraitSolverLangItem;
use rustc_type_ir::visit::TypeVisitableExt as _;
use rustc_type_ir::{self as ty, elaborate, Interner, TraitPredicate, Upcast as _};
use rustc_type_ir::{
self as ty, elaborate, new_reject_ctxt, Interner, TraitPredicate, Upcast as _,
};
use tracing::{instrument, trace};

use crate::delegate::SolverDelegate;
Expand Down Expand Up @@ -47,7 +49,7 @@ where
let cx = ecx.cx();

let impl_trait_ref = cx.impl_trait_ref(impl_def_id);
if !DeepRejectCtxt::new(ecx.cx(), TreatParams::AsRigid, TreatParams::InstantiateWithInfer)
if !new_reject_ctxt!(ecx.cx(), AsRigid, InstantiateWithInfer)
.args_may_unify(goal.predicate.trait_ref.args, impl_trait_ref.skip_binder().args)
{
return Err(NoSolution);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rustc_hir as hir;
use rustc_hir::def::DefKind;
use rustc_middle::traits::{ObligationCause, ObligationCauseCode};
use rustc_middle::ty::error::{ExpectedFound, TypeError};
use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_middle::ty::fast_reject::{new_reject_ctxt, DeepRejectCtxt, TreatParams};
use rustc_middle::ty::print::{FmtPrinter, Printer};
use rustc_middle::ty::{self, suggest_constraining_type_param, Ty};
use rustc_span::def_id::DefId;
Expand Down Expand Up @@ -316,12 +316,8 @@ impl<T> Trait<T> for X {
{
let mut has_matching_impl = false;
tcx.for_each_relevant_impl(def_id, values.found, |did| {
if DeepRejectCtxt::new(
tcx,
TreatParams::AsRigid,
TreatParams::InstantiateWithInfer,
)
.types_may_unify(values.found, tcx.type_of(did).skip_binder())
if new_reject_ctxt!(tcx, AsRigid, InstantiateWithInfer)
.types_may_unify(values.found, tcx.type_of(did).skip_binder())
{
has_matching_impl = true;
}
Expand All @@ -341,12 +337,8 @@ impl<T> Trait<T> for X {
{
let mut has_matching_impl = false;
tcx.for_each_relevant_impl(def_id, values.expected, |did| {
if DeepRejectCtxt::new(
tcx,
TreatParams::AsRigid,
TreatParams::InstantiateWithInfer,
)
.types_may_unify(values.expected, tcx.type_of(did).skip_binder())
if new_reject_ctxt!(tcx, AsRigid, InstantiateWithInfer)
.types_may_unify(values.expected, tcx.type_of(did).skip_binder())
{
has_matching_impl = true;
}
Expand All @@ -365,12 +357,8 @@ impl<T> Trait<T> for X {
{
let mut has_matching_impl = false;
tcx.for_each_relevant_impl(def_id, values.found, |did| {
if DeepRejectCtxt::new(
tcx,
TreatParams::AsRigid,
TreatParams::InstantiateWithInfer,
)
.types_may_unify(values.found, tcx.type_of(did).skip_binder())
if new_reject_ctxt!(tcx, AsRigid, InstantiateWithInfer)
.types_may_unify(values.found, tcx.type_of(did).skip_binder())
{
has_matching_impl = true;
}
Expand Down
8 changes: 2 additions & 6 deletions compiler/rustc_trait_selection/src/traits/coherence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use rustc_middle::bug;
use rustc_middle::traits::query::NoSolution;
use rustc_middle::traits::solve::{CandidateSource, Certainty, Goal};
use rustc_middle::traits::specialization_graph::OverlapMode;
use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_middle::ty::fast_reject::{new_reject_ctxt, DeepRejectCtxt, TreatParams};
use rustc_middle::ty::visit::{TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor};
use rustc_middle::ty::{self, Ty, TyCtxt};
pub use rustc_next_trait_solver::coherence::*;
Expand Down Expand Up @@ -94,11 +94,7 @@ pub fn overlapping_impls(
// Before doing expensive operations like entering an inference context, do
// a quick check via fast_reject to tell if the impl headers could possibly
// unify.
let drcx = DeepRejectCtxt::new(
tcx,
TreatParams::InstantiateWithInfer,
TreatParams::InstantiateWithInfer,
);
let drcx = new_reject_ctxt!(tcx, InstantiateWithInfer, InstantiateWithInfer);
let impl1_ref = tcx.impl_trait_ref(impl1_def_id);
let impl2_ref = tcx.impl_trait_ref(impl2_def_id);
let may_overlap = match (impl1_ref, impl2_ref) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use hir::LangItem;
use rustc_data_structures::fx::{FxHashSet, FxIndexSet};
use rustc_hir as hir;
use rustc_infer::traits::{Obligation, ObligationCause, PolyTraitObligation, SelectionError};
use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_middle::ty::fast_reject::{new_reject_ctxt, DeepRejectCtxt, TreatParams};
use rustc_middle::ty::{self, ToPolyTraitRef, Ty, TypeVisitableExt};
use rustc_middle::{bug, span_bug};

Expand Down Expand Up @@ -580,11 +580,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
return;
}

let drcx = DeepRejectCtxt::new(
self.tcx(),
TreatParams::AsRigid,
TreatParams::InstantiateWithInfer,
);
let drcx = new_reject_ctxt!(self.tcx(), AsRigid, InstantiateWithInfer);
let obligation_args = obligation.predicate.skip_binder().trait_ref.args;
self.tcx().for_each_relevant_impl(
obligation.predicate.def_id(),
Expand Down
130 changes: 68 additions & 62 deletions compiler/rustc_type_ir/src/fast_reject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ impl<DefId> SimplifiedType<DefId> {

/// Given generic arguments, could they be unified after
/// replacing parameters with inference variables or placeholders.
/// This behavior is toggled using the `TreatParams` fields.
/// This behavior is toggled using the const generics.
///
/// We use this to quickly reject impl/wc candidates without needing
/// to instantiate generic arguments/having to enter a probe.
Expand All @@ -182,15 +182,31 @@ impl<DefId> SimplifiedType<DefId> {
/// impls only have to overlap for some value, so we treat parameters
/// on both sides like inference variables.
#[derive(Debug, Clone, Copy)]
pub struct DeepRejectCtxt<I: Interner> {
treat_lhs_params: TreatParams,
treat_rhs_params: TreatParams,
pub struct DeepRejectCtxt<I: Interner, const TREAT_LHS_PARAMS: bool, const TREAT_RHS_PARAMS: bool> {
_interner: PhantomData<I>,
}

impl<I: Interner> DeepRejectCtxt<I> {
pub fn new(_interner: I, treat_lhs_params: TreatParams, treat_rhs_params: TreatParams) -> Self {
DeepRejectCtxt { treat_lhs_params, treat_rhs_params, _interner: PhantomData }
impl TreatParams {
pub const fn into_bool(&self) -> bool {
match *self {
TreatParams::InstantiateWithInfer => true,
TreatParams::AsRigid => false,
}
}
}

#[macro_export]
macro_rules! new_reject_ctxt {
($interner:expr, $lhs:ident, $rhs:ident) => {
DeepRejectCtxt::<_, {TreatParams::$lhs.into_bool()}, {TreatParams::$rhs.into_bool()}>::new($interner)
}
}

impl<I: Interner, const TREAT_LHS_PARAMS: bool, const TREAT_RHS_PARAMS: bool>
DeepRejectCtxt<I, TREAT_LHS_PARAMS, TREAT_RHS_PARAMS>
{
pub fn new(_interner: I) -> Self {
DeepRejectCtxt { _interner: PhantomData }
}

pub fn args_may_unify(
Expand All @@ -215,45 +231,6 @@ impl<I: Interner> DeepRejectCtxt<I> {

pub fn types_may_unify(self, lhs: I::Ty, rhs: I::Ty) -> bool {
match (lhs.kind(), rhs.kind()) {
(ty::Error(_), _) | (_, ty::Error(_)) => true,

// As we're walking the whole type, it may encounter projections
// inside of binders and what not, so we're just going to assume that
// projections can unify with other stuff.
//
// Looking forward to lazy normalization this is the safer strategy anyways.
(ty::Alias(..), _) | (_, ty::Alias(..)) => true,

// Bound type variables may unify with rigid types e.g. when using
// non-lifetime binders.
(ty::Bound(..), _) | (_, ty::Bound(..)) => true,

(ty::Infer(var), _) => self.var_and_ty_may_unify(var, rhs),
(_, ty::Infer(var)) => self.var_and_ty_may_unify(var, lhs),

(ty::Param(lhs), ty::Param(rhs)) => {
match (self.treat_lhs_params, self.treat_rhs_params) {
(TreatParams::AsRigid, TreatParams::AsRigid) => lhs == rhs,
(TreatParams::InstantiateWithInfer, _)
| (_, TreatParams::InstantiateWithInfer) => true,
}
}
(ty::Param(_), _) => self.treat_lhs_params == TreatParams::InstantiateWithInfer,
(_, ty::Param(_)) => self.treat_rhs_params == TreatParams::InstantiateWithInfer,

// Placeholder types don't unify with anything on their own.
(ty::Placeholder(lhs), ty::Placeholder(rhs)) => lhs == rhs,

// Purely rigid types, use structural equivalence.
(ty::Bool, ty::Bool)
| (ty::Char, ty::Char)
| (ty::Int(_), ty::Int(_))
| (ty::Uint(_), ty::Uint(_))
| (ty::Float(_), ty::Float(_))
| (ty::Str, ty::Str)
| (ty::Never, ty::Never)
| (ty::Foreign(_), ty::Foreign(_)) => lhs == rhs,

(ty::Ref(_, lhs_ty, lhs_mutbl), ty::Ref(_, rhs_ty, rhs_mutbl)) => {
lhs_mutbl == rhs_mutbl && self.types_may_unify(lhs_ty, rhs_ty)
}
Expand All @@ -262,37 +239,63 @@ impl<I: Interner> DeepRejectCtxt<I> {
lhs_def == rhs_def && self.args_may_unify(lhs_args, rhs_args)
}

(ty::Pat(lhs_ty, _), ty::Pat(rhs_ty, _)) => {
// FIXME(pattern_types): take pattern into account
self.types_may_unify(lhs_ty, rhs_ty)
}
(ty::Infer(var), _) => self.var_and_ty_may_unify(var, rhs),
(_, ty::Infer(var)) => self.var_and_ty_may_unify(var, lhs),

(ty::Slice(lhs_ty), ty::Slice(rhs_ty)) => self.types_may_unify(lhs_ty, rhs_ty),
(ty::Int(_), ty::Int(_)) | (ty::Uint(_), ty::Uint(_)) => lhs == rhs,

(ty::Array(lhs_ty, lhs_len), ty::Array(rhs_ty, rhs_len)) => {
self.types_may_unify(lhs_ty, rhs_ty) && self.consts_may_unify(lhs_len, rhs_len)
}
(ty::Param(lhs), ty::Param(rhs)) => match (TREAT_LHS_PARAMS, TREAT_RHS_PARAMS) {
(false, false) => lhs == rhs,
(true, _) | (_, true) => true,
},

// As we're walking the whole type, it may encounter projections
// inside of binders and what not, so we're just going to assume that
// projections can unify with other stuff.
//
// Looking forward to lazy normalization this is the safer strategy anyways.
(ty::Alias(..), _) | (_, ty::Alias(..)) => true,

(ty::Bound(..), _) | (_, ty::Bound(..)) => true,

(ty::Param(_), _) => TREAT_LHS_PARAMS,
(_, ty::Param(_)) => TREAT_RHS_PARAMS,

(ty::Tuple(lhs), ty::Tuple(rhs)) => {
lhs.len() == rhs.len()
&& iter::zip(lhs.iter(), rhs.iter())
.all(|(lhs, rhs)| self.types_may_unify(lhs, rhs))
}

(ty::Array(lhs_ty, lhs_len), ty::Array(rhs_ty, rhs_len)) => {
self.types_may_unify(lhs_ty, rhs_ty) && self.consts_may_unify(lhs_len, rhs_len)
}

(ty::RawPtr(lhs_ty, lhs_mutbl), ty::RawPtr(rhs_ty, rhs_mutbl)) => {
lhs_mutbl == rhs_mutbl && self.types_may_unify(lhs_ty, rhs_ty)
}

(ty::Slice(lhs_ty), ty::Slice(rhs_ty)) => self.types_may_unify(lhs_ty, rhs_ty),

(ty::Float(_), ty::Float(_))
| (ty::Str, ty::Str)
| (ty::Bool, ty::Bool)
| (ty::Char, ty::Char)
| (ty::Never, ty::Never)
| (ty::Foreign(_), ty::Foreign(_)) => lhs == rhs,

(ty::Dynamic(lhs_preds, ..), ty::Dynamic(rhs_preds, ..)) => {
// Ideally we would walk the existential predicates here or at least
// compare their length. But considering that the relevant `Relate` impl
// actually sorts and deduplicates these, that doesn't work.
lhs_preds.principal_def_id() == rhs_preds.principal_def_id()
}

// Placeholder types don't unify with anything on their own.
(ty::Placeholder(lhs), ty::Placeholder(rhs)) => lhs == rhs,

(ty::FnPtr(lhs_sig_tys, lhs_hdr), ty::FnPtr(rhs_sig_tys, rhs_hdr)) => {
let lhs_sig_tys = lhs_sig_tys.skip_binder().inputs_and_output;

let rhs_sig_tys = rhs_sig_tys.skip_binder().inputs_and_output;

lhs_hdr == rhs_hdr
Expand All @@ -313,7 +316,14 @@ impl<I: Interner> DeepRejectCtxt<I> {
ty::CoroutineWitness(rhs_def_id, rhs_args),
) => lhs_def_id == rhs_def_id && self.args_may_unify(lhs_args, rhs_args),

(ty::Placeholder(_), _)
(ty::Pat(lhs_ty, _), ty::Pat(rhs_ty, _)) => {
// FIXME(pattern_types): take pattern into account
self.types_may_unify(lhs_ty, rhs_ty)
}

(ty::Error(..), _)
| (_, ty::Error(..))
| (ty::Placeholder(_), _)
| (_, ty::Placeholder(_))
| (ty::Bool, _)
| (_, ty::Bool)
Expand Down Expand Up @@ -371,12 +381,8 @@ impl<I: Interner> DeepRejectCtxt<I> {
(ty::ConstKind::Value(..), ty::ConstKind::Placeholder(_))
| (ty::ConstKind::Placeholder(_), ty::ConstKind::Value(..)) => false,

(ty::ConstKind::Param(_), ty::ConstKind::Value(..)) => {
self.treat_lhs_params == TreatParams::InstantiateWithInfer
}
(ty::ConstKind::Value(..), ty::ConstKind::Param(_)) => {
self.treat_rhs_params == TreatParams::InstantiateWithInfer
}
(ty::ConstKind::Param(_), ty::ConstKind::Value(..)) => TREAT_LHS_PARAMS,
(ty::ConstKind::Value(..), ty::ConstKind::Param(_)) => TREAT_RHS_PARAMS,

_ => true,
}
Expand Down
9 changes: 3 additions & 6 deletions src/librustdoc/html/render/write_shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use indexmap::IndexMap;
use itertools::Itertools;
use rustc_data_structures::flock;
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_middle::ty::fast_reject::{new_reject_ctxt, DeepRejectCtxt};
use rustc_span::def_id::DefId;
use rustc_span::Symbol;
use serde::ser::SerializeSeq;
Expand Down Expand Up @@ -507,11 +507,8 @@ else if (window.initSearch) window.initSearch(searchIndex);
// Be aware of `tests/rustdoc/type-alias/deeply-nested-112515.rs` which might regress.
let Some(impl_did) = impl_item_id.as_def_id() else { continue };
let for_ty = self.cx.tcx().type_of(impl_did).skip_binder();
let reject_cx = DeepRejectCtxt::new(
self.cx.tcx(),
TreatParams::InstantiateWithInfer,
TreatParams::InstantiateWithInfer,
);
let reject_cx =
new_reject_ctxt!(self.cx.tcx(), InstantiateWithInfer, InstantiateWithInfer);
if !reject_cx.types_may_unify(aliased_ty, for_ty) {
continue;
}
Expand Down

0 comments on commit b4eaef1

Please sign in to comment.