diff --git a/compiler/rustc_hir_analysis/src/coherence/inherent_impls.rs b/compiler/rustc_hir_analysis/src/coherence/inherent_impls.rs index bd8b43e28e540..9fce927e2c7f2 100644 --- a/compiler/rustc_hir_analysis/src/coherence/inherent_impls.rs +++ b/compiler/rustc_hir_analysis/src/coherence/inherent_impls.rs @@ -93,7 +93,8 @@ impl<'tcx> InherentCollect<'tcx> { } } - if let Some(simp) = simplify_type(self.tcx, self_ty, TreatParams::AsCandidateKey) { + if let Some(simp) = simplify_type(self.tcx, self_ty, TreatParams::InstantiateWithInfer) + { self.impls_map.incoherent_impls.entry(simp).or_default().push(impl_def_id); } else { bug!("unexpected self type: {:?}", self_ty); @@ -132,7 +133,7 @@ impl<'tcx> InherentCollect<'tcx> { } } - if let Some(simp) = simplify_type(self.tcx, ty, TreatParams::AsCandidateKey) { + if let Some(simp) = simplify_type(self.tcx, ty, TreatParams::InstantiateWithInfer) { self.impls_map.incoherent_impls.entry(simp).or_default().push(impl_def_id); } else { bug!("unexpected primitive type: {:?}", ty); diff --git a/compiler/rustc_hir_typeck/src/method/probe.rs b/compiler/rustc_hir_typeck/src/method/probe.rs index 0cf5403b3c085..e9d756b4c579c 100644 --- a/compiler/rustc_hir_typeck/src/method/probe.rs +++ b/compiler/rustc_hir_typeck/src/method/probe.rs @@ -715,7 +715,7 @@ impl<'a, 'tcx> ProbeContext<'a, 'tcx> { } fn assemble_inherent_candidates_for_incoherent_ty(&mut self, self_ty: Ty<'tcx>) { - let Some(simp) = simplify_type(self.tcx, self_ty, TreatParams::AsCandidateKey) else { + let Some(simp) = simplify_type(self.tcx, self_ty, TreatParams::InstantiateWithInfer) else { bug!("unexpected incoherent type: {:?}", self_ty) }; for &impl_def_id in self.tcx.incoherent_impls(simp).into_iter().flatten() { diff --git a/compiler/rustc_hir_typeck/src/method/suggest.rs b/compiler/rustc_hir_typeck/src/method/suggest.rs index 9ea57e4aa6168..14ad5830111b4 100644 --- a/compiler/rustc_hir_typeck/src/method/suggest.rs +++ b/compiler/rustc_hir_typeck/src/method/suggest.rs @@ -2235,8 +2235,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { let target_ty = self .autoderef(sugg_span, rcvr_ty) .find(|(rcvr_ty, _)| { - DeepRejectCtxt::new(self.tcx, TreatParams::ForLookup) - .types_may_unify(*rcvr_ty, impl_ty) + DeepRejectCtxt::relate_rigid_infer(self.tcx).types_may_unify(*rcvr_ty, impl_ty) }) .map_or(impl_ty, |(ty, _)| ty) .peel_refs(); @@ -2498,7 +2497,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { .into_iter() .any(|info| self.associated_value(info.def_id, item_name).is_some()); let found_assoc = |ty: Ty<'tcx>| { - simplify_type(tcx, ty, TreatParams::AsCandidateKey) + simplify_type(tcx, ty, TreatParams::InstantiateWithInfer) .and_then(|simp| { tcx.incoherent_impls(simp) .into_iter() @@ -3963,7 +3962,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { // cases where a positive bound implies a negative impl. (candidates, Vec::new()) } else if let Some(simp_rcvr_ty) = - simplify_type(self.tcx, rcvr_ty, TreatParams::ForLookup) + simplify_type(self.tcx, rcvr_ty, TreatParams::AsRigid) { let mut potential_candidates = Vec::new(); let mut explicitly_negative = Vec::new(); @@ -3981,7 +3980,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { .any(|header| { let imp = header.trait_ref.instantiate_identity(); let imp_simp = - simplify_type(self.tcx, imp.self_ty(), TreatParams::ForLookup); + simplify_type(self.tcx, imp.self_ty(), TreatParams::AsRigid); imp_simp.is_some_and(|s| s == simp_rcvr_ty) }) { diff --git a/compiler/rustc_metadata/src/rmeta/encoder.rs b/compiler/rustc_metadata/src/rmeta/encoder.rs index 88256c4db04ec..15234ff4b15d1 100644 --- a/compiler/rustc_metadata/src/rmeta/encoder.rs +++ b/compiler/rustc_metadata/src/rmeta/encoder.rs @@ -2033,7 +2033,7 @@ impl<'a, 'tcx> EncodeContext<'a, 'tcx> { let simplified_self_ty = fast_reject::simplify_type( self.tcx, trait_ref.self_ty(), - TreatParams::AsCandidateKey, + TreatParams::InstantiateWithInfer, ); trait_impls .entry(trait_ref.def_id) diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs index 20828067c469c..d057a7017f3d6 100644 --- a/compiler/rustc_middle/src/ty/context.rs +++ b/compiler/rustc_middle/src/ty/context.rs @@ -426,7 +426,7 @@ impl<'tcx> Interner for TyCtxt<'tcx> { let simp = ty::fast_reject::simplify_type( tcx, self_ty, - ty::fast_reject::TreatParams::ForLookup, + ty::fast_reject::TreatParams::AsRigid, ) .unwrap(); consider_impls_for_simplified_type(simp); diff --git a/compiler/rustc_middle/src/ty/fast_reject.rs b/compiler/rustc_middle/src/ty/fast_reject.rs index 91344c4e39c8b..2945a0be424f8 100644 --- a/compiler/rustc_middle/src/ty/fast_reject.rs +++ b/compiler/rustc_middle/src/ty/fast_reject.rs @@ -3,6 +3,14 @@ pub use rustc_type_ir::fast_reject::*; use super::TyCtxt; -pub type DeepRejectCtxt<'tcx> = rustc_type_ir::fast_reject::DeepRejectCtxt>; +pub type DeepRejectCtxt< + 'tcx, + const INSTANTIATE_LHS_WITH_INFER: bool, + const INSTANTIATE_RHS_WITH_INFER: bool, +> = rustc_type_ir::fast_reject::DeepRejectCtxt< + TyCtxt<'tcx>, + INSTANTIATE_LHS_WITH_INFER, + INSTANTIATE_RHS_WITH_INFER, +>; pub type SimplifiedType = rustc_type_ir::fast_reject::SimplifiedType; diff --git a/compiler/rustc_middle/src/ty/trait_def.rs b/compiler/rustc_middle/src/ty/trait_def.rs index dfb137f738f1e..82690f70e5f18 100644 --- a/compiler/rustc_middle/src/ty/trait_def.rs +++ b/compiler/rustc_middle/src/ty/trait_def.rs @@ -168,9 +168,9 @@ impl<'tcx> TyCtxt<'tcx> { // whose outer level is not a parameter or projection. Especially for things like // `T: Clone` this is incredibly useful as we would otherwise look at all the impls // of `Clone` for `Option`, `Vec`, `ConcreteType` and so on. - // Note that we're using `TreatParams::ForLookup` to query `non_blanket_impls` while using - // `TreatParams::AsCandidateKey` while actually adding them. - if let Some(simp) = fast_reject::simplify_type(self, self_ty, TreatParams::ForLookup) { + // Note that we're using `TreatParams::AsRigid` to query `non_blanket_impls` while using + // `TreatParams::InstantiateWithInfer` while actually adding them. + if let Some(simp) = fast_reject::simplify_type(self, self_ty, TreatParams::AsRigid) { if let Some(impls) = impls.non_blanket_impls.get(&simp) { for &impl_def_id in impls { f(impl_def_id); @@ -190,7 +190,9 @@ impl<'tcx> TyCtxt<'tcx> { self_ty: Ty<'tcx>, ) -> impl Iterator + 'tcx { let impls = self.trait_impls_of(trait_def_id); - if let Some(simp) = fast_reject::simplify_type(self, self_ty, TreatParams::AsCandidateKey) { + if let Some(simp) = + fast_reject::simplify_type(self, self_ty, TreatParams::InstantiateWithInfer) + { if let Some(impls) = impls.non_blanket_impls.get(&simp) { return impls.iter().copied(); } @@ -239,7 +241,7 @@ pub(super) fn trait_impls_of_provider(tcx: TyCtxt<'_>, trait_id: DefId) -> Trait let impl_self_ty = tcx.type_of(impl_def_id).instantiate_identity(); if let Some(simplified_self_ty) = - fast_reject::simplify_type(tcx, impl_self_ty, TreatParams::AsCandidateKey) + fast_reject::simplify_type(tcx, impl_self_ty, TreatParams::InstantiateWithInfer) { impls.non_blanket_impls.entry(simplified_self_ty).or_default().push(impl_def_id); } else { diff --git a/compiler/rustc_next_trait_solver/src/solve/normalizes_to/mod.rs b/compiler/rustc_next_trait_solver/src/solve/normalizes_to/mod.rs index c7a9846cd97f0..17b6ec7e2bb2c 100644 --- a/compiler/rustc_next_trait_solver/src/solve/normalizes_to/mod.rs +++ b/compiler/rustc_next_trait_solver/src/solve/normalizes_to/mod.rs @@ -3,7 +3,7 @@ mod inherent; mod opaque_types; mod weak_types; -use rustc_type_ir::fast_reject::{DeepRejectCtxt, TreatParams}; +use rustc_type_ir::fast_reject::DeepRejectCtxt; use rustc_type_ir::inherent::*; use rustc_type_ir::lang_items::TraitSolverLangItem; use rustc_type_ir::{self as ty, Interner, NormalizesTo, Upcast as _}; @@ -106,6 +106,12 @@ where if let Some(projection_pred) = assumption.as_projection_clause() { if projection_pred.projection_def_id() == goal.predicate.def_id() { let cx = ecx.cx(); + if !DeepRejectCtxt::relate_rigid_rigid(ecx.cx()).args_may_unify( + goal.predicate.alias.args, + projection_pred.skip_binder().projection_term.args, + ) { + return Err(NoSolution); + } ecx.probe_trait_candidate(source).enter(|ecx| { let assumption_projection_pred = ecx.instantiate_binder_with_infer(projection_pred); @@ -144,7 +150,7 @@ where let goal_trait_ref = goal.predicate.alias.trait_ref(cx); let impl_trait_ref = cx.impl_trait_ref(impl_def_id); - if !DeepRejectCtxt::new(ecx.cx(), TreatParams::ForLookup).args_may_unify( + if !DeepRejectCtxt::relate_rigid_infer(ecx.cx()).args_may_unify( goal.predicate.alias.trait_ref(cx).args, impl_trait_ref.skip_binder().args, ) { diff --git a/compiler/rustc_next_trait_solver/src/solve/trait_goals.rs b/compiler/rustc_next_trait_solver/src/solve/trait_goals.rs index 67b001d0cceda..683d8dab3b2a8 100644 --- a/compiler/rustc_next_trait_solver/src/solve/trait_goals.rs +++ b/compiler/rustc_next_trait_solver/src/solve/trait_goals.rs @@ -2,7 +2,7 @@ use rustc_ast_ir::Movability; use rustc_type_ir::data_structures::IndexSet; -use rustc_type_ir::fast_reject::{DeepRejectCtxt, TreatParams}; +use rustc_type_ir::fast_reject::DeepRejectCtxt; use rustc_type_ir::inherent::*; use rustc_type_ir::lang_items::TraitSolverLangItem; use rustc_type_ir::visit::TypeVisitableExt as _; @@ -47,7 +47,7 @@ where let cx = ecx.cx(); let impl_trait_ref = cx.impl_trait_ref(impl_def_id); - if !DeepRejectCtxt::new(ecx.cx(), TreatParams::ForLookup) + if !DeepRejectCtxt::relate_rigid_infer(ecx.cx()) .args_may_unify(goal.predicate.trait_ref.args, impl_trait_ref.skip_binder().args) { return Err(NoSolution); @@ -124,6 +124,13 @@ where if trait_clause.def_id() == goal.predicate.def_id() && trait_clause.polarity() == goal.predicate.polarity { + if !DeepRejectCtxt::relate_rigid_rigid(ecx.cx()).args_may_unify( + goal.predicate.trait_ref.args, + trait_clause.skip_binder().trait_ref.args, + ) { + return Err(NoSolution); + } + ecx.probe_trait_candidate(source).enter(|ecx| { let assumption_trait_pred = ecx.instantiate_binder_with_infer(trait_clause); ecx.eq( diff --git a/compiler/rustc_trait_selection/src/error_reporting/infer/note_and_explain.rs b/compiler/rustc_trait_selection/src/error_reporting/infer/note_and_explain.rs index db71331d07f07..7bfc6471dc83f 100644 --- a/compiler/rustc_trait_selection/src/error_reporting/infer/note_and_explain.rs +++ b/compiler/rustc_trait_selection/src/error_reporting/infer/note_and_explain.rs @@ -4,7 +4,7 @@ use rustc_hir as hir; use rustc_hir::def::DefKind; use rustc_middle::traits::{ObligationCause, ObligationCauseCode}; use rustc_middle::ty::error::{ExpectedFound, TypeError}; -use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams}; +use rustc_middle::ty::fast_reject::DeepRejectCtxt; use rustc_middle::ty::print::{FmtPrinter, Printer}; use rustc_middle::ty::{self, suggest_constraining_type_param, Ty}; use rustc_span::def_id::DefId; @@ -317,7 +317,7 @@ impl Trait for X { { let mut has_matching_impl = false; tcx.for_each_relevant_impl(def_id, values.found, |did| { - if DeepRejectCtxt::new(tcx, TreatParams::ForLookup) + if DeepRejectCtxt::relate_rigid_infer(tcx) .types_may_unify(values.found, tcx.type_of(did).skip_binder()) { has_matching_impl = true; @@ -338,7 +338,7 @@ impl Trait for X { { let mut has_matching_impl = false; tcx.for_each_relevant_impl(def_id, values.expected, |did| { - if DeepRejectCtxt::new(tcx, TreatParams::ForLookup) + if DeepRejectCtxt::relate_rigid_infer(tcx) .types_may_unify(values.expected, tcx.type_of(did).skip_binder()) { has_matching_impl = true; @@ -358,7 +358,7 @@ impl Trait for X { { let mut has_matching_impl = false; tcx.for_each_relevant_impl(def_id, values.found, |did| { - if DeepRejectCtxt::new(tcx, TreatParams::ForLookup) + if DeepRejectCtxt::relate_rigid_infer(tcx) .types_may_unify(values.found, tcx.type_of(did).skip_binder()) { has_matching_impl = true; diff --git a/compiler/rustc_trait_selection/src/traits/coherence.rs b/compiler/rustc_trait_selection/src/traits/coherence.rs index 8558520897b5c..7fbb9846943bb 100644 --- a/compiler/rustc_trait_selection/src/traits/coherence.rs +++ b/compiler/rustc_trait_selection/src/traits/coherence.rs @@ -15,7 +15,7 @@ use rustc_middle::bug; use rustc_middle::traits::query::NoSolution; use rustc_middle::traits::solve::{CandidateSource, Certainty, Goal}; use rustc_middle::traits::specialization_graph::OverlapMode; -use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams}; +use rustc_middle::ty::fast_reject::DeepRejectCtxt; use rustc_middle::ty::visit::{TypeSuperVisitable, TypeVisitable, TypeVisitableExt, TypeVisitor}; use rustc_middle::ty::{self, Ty, TyCtxt}; pub use rustc_next_trait_solver::coherence::*; @@ -95,7 +95,7 @@ pub fn overlapping_impls( // Before doing expensive operations like entering an inference context, do // a quick check via fast_reject to tell if the impl headers could possibly // unify. - let drcx = DeepRejectCtxt::new(tcx, TreatParams::AsCandidateKey); + let drcx = DeepRejectCtxt::relate_infer_infer(tcx); let impl1_ref = tcx.impl_trait_ref(impl1_def_id); let impl2_ref = tcx.impl_trait_ref(impl2_def_id); let may_overlap = match (impl1_ref, impl2_ref) { diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs index 4702fd866c1b6..630acc91fbedc 100644 --- a/compiler/rustc_trait_selection/src/traits/project.rs +++ b/compiler/rustc_trait_selection/src/traits/project.rs @@ -13,6 +13,7 @@ use rustc_infer::traits::ObligationCauseCode; use rustc_middle::traits::select::OverflowError; pub use rustc_middle::traits::Reveal; use rustc_middle::traits::{BuiltinImplSource, ImplSource, ImplSourceUserDefinedData}; +use rustc_middle::ty::fast_reject::DeepRejectCtxt; use rustc_middle::ty::fold::TypeFoldable; use rustc_middle::ty::visit::{MaxUniverse, TypeVisitable, TypeVisitableExt}; use rustc_middle::ty::{self, Term, Ty, TyCtxt, Upcast}; @@ -886,6 +887,7 @@ fn assemble_candidates_from_predicates<'cx, 'tcx>( potentially_unnormalized_candidates: bool, ) { let infcx = selcx.infcx; + let drcx = DeepRejectCtxt::relate_rigid_rigid(selcx.tcx()); for predicate in env_predicates { let bound_predicate = predicate.kind(); if let ty::ClauseKind::Projection(data) = predicate.kind().skip_binder() { @@ -894,6 +896,12 @@ fn assemble_candidates_from_predicates<'cx, 'tcx>( continue; } + if !drcx + .args_may_unify(obligation.predicate.args, data.skip_binder().projection_term.args) + { + continue; + } + let is_match = infcx.probe(|_| { selcx.match_projection_projections( obligation, diff --git a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs index 96faa5236b198..77efc2fc2dbfd 100644 --- a/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs +++ b/compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs @@ -13,7 +13,7 @@ use hir::LangItem; use rustc_data_structures::fx::{FxHashSet, FxIndexSet}; use rustc_hir as hir; use rustc_infer::traits::{Obligation, ObligationCause, PolyTraitObligation, SelectionError}; -use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams}; +use rustc_middle::ty::fast_reject::DeepRejectCtxt; use rustc_middle::ty::{self, ToPolyTraitRef, Ty, TypeVisitableExt}; use rustc_middle::{bug, span_bug}; use tracing::{debug, instrument, trace}; @@ -248,11 +248,17 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { .filter(|p| p.def_id() == stack.obligation.predicate.def_id()) .filter(|p| p.polarity() == stack.obligation.predicate.polarity()); + let drcx = DeepRejectCtxt::relate_rigid_rigid(self.tcx()); + let obligation_args = stack.obligation.predicate.skip_binder().trait_ref.args; // Keep only those bounds which may apply, and propagate overflow if it occurs. for bound in bounds { + let bound_trait_ref = bound.map_bound(|t| t.trait_ref); + if !drcx.args_may_unify(obligation_args, bound_trait_ref.skip_binder().args) { + continue; + } // FIXME(oli-obk): it is suspicious that we are dropping the constness and // polarity here. - let wc = self.where_clause_may_apply(stack, bound.map_bound(|t| t.trait_ref))?; + let wc = self.where_clause_may_apply(stack, bound_trait_ref)?; if wc.may_apply() { candidates.vec.push(ParamCandidate(bound)); } @@ -581,7 +587,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { return; } - let drcx = DeepRejectCtxt::new(self.tcx(), TreatParams::ForLookup); + let drcx = DeepRejectCtxt::relate_rigid_infer(self.tcx()); let obligation_args = obligation.predicate.skip_binder().trait_ref.args; self.tcx().for_each_relevant_impl( obligation.predicate.def_id(), diff --git a/compiler/rustc_trait_selection/src/traits/specialize/specialization_graph.rs b/compiler/rustc_trait_selection/src/traits/specialize/specialization_graph.rs index 0fdaf40b136d8..13620f4b8d9cc 100644 --- a/compiler/rustc_trait_selection/src/traits/specialize/specialization_graph.rs +++ b/compiler/rustc_trait_selection/src/traits/specialize/specialization_graph.rs @@ -41,7 +41,7 @@ impl<'tcx> Children { fn insert_blindly(&mut self, tcx: TyCtxt<'tcx>, impl_def_id: DefId) { let trait_ref = tcx.impl_trait_ref(impl_def_id).unwrap().skip_binder(); if let Some(st) = - fast_reject::simplify_type(tcx, trait_ref.self_ty(), TreatParams::AsCandidateKey) + fast_reject::simplify_type(tcx, trait_ref.self_ty(), TreatParams::InstantiateWithInfer) { debug!("insert_blindly: impl_def_id={:?} st={:?}", impl_def_id, st); self.non_blanket_impls.entry(st).or_default().push(impl_def_id) @@ -58,7 +58,7 @@ impl<'tcx> Children { let trait_ref = tcx.impl_trait_ref(impl_def_id).unwrap().skip_binder(); let vec: &mut Vec; if let Some(st) = - fast_reject::simplify_type(tcx, trait_ref.self_ty(), TreatParams::AsCandidateKey) + fast_reject::simplify_type(tcx, trait_ref.self_ty(), TreatParams::InstantiateWithInfer) { debug!("remove_existing: impl_def_id={:?} st={:?}", impl_def_id, st); vec = self.non_blanket_impls.get_mut(&st).unwrap(); @@ -279,7 +279,7 @@ impl<'tcx> Graph { let mut parent = trait_def_id; let mut last_lint = None; let simplified = - fast_reject::simplify_type(tcx, trait_ref.self_ty(), TreatParams::AsCandidateKey); + fast_reject::simplify_type(tcx, trait_ref.self_ty(), TreatParams::InstantiateWithInfer); // Descend the specialization tree, where `parent` is the current parent node. loop { diff --git a/compiler/rustc_type_ir/src/fast_reject.rs b/compiler/rustc_type_ir/src/fast_reject.rs index fab4a0991175d..2c8e47bcbca2a 100644 --- a/compiler/rustc_type_ir/src/fast_reject.rs +++ b/compiler/rustc_type_ir/src/fast_reject.rs @@ -74,13 +74,13 @@ impl> ToStableHashKey for SimplifiedType pub enum TreatParams { /// Treat parameters as infer vars. This is the correct mode for caching /// an impl's type for lookup. - AsCandidateKey, + InstantiateWithInfer, /// Treat parameters as placeholders in the given environment. This is the /// correct mode for *lookup*, as during candidate selection. /// /// This also treats projections with inference variables as infer vars /// since they could be further normalized. - ForLookup, + AsRigid, } /// Tries to simplify a type by only returning the outermost injective¹ layer, if one exists. @@ -140,18 +140,16 @@ pub fn simplify_type( } ty::Placeholder(..) => Some(SimplifiedType::Placeholder), ty::Param(_) => match treat_params { - TreatParams::ForLookup => Some(SimplifiedType::Placeholder), - TreatParams::AsCandidateKey => None, + TreatParams::AsRigid => Some(SimplifiedType::Placeholder), + TreatParams::InstantiateWithInfer => None, }, ty::Alias(..) => match treat_params { // When treating `ty::Param` as a placeholder, projections also // don't unify with anything else as long as they are fully normalized. // FIXME(-Znext-solver): Can remove this `if` and always simplify to `Placeholder` // when the new solver is enabled by default. - TreatParams::ForLookup if !ty.has_non_region_infer() => { - Some(SimplifiedType::Placeholder) - } - TreatParams::ForLookup | TreatParams::AsCandidateKey => None, + TreatParams::AsRigid if !ty.has_non_region_infer() => Some(SimplifiedType::Placeholder), + TreatParams::AsRigid | TreatParams::InstantiateWithInfer => None, }, ty::Foreign(def_id) => Some(SimplifiedType::Foreign(def_id)), ty::Error(_) => Some(SimplifiedType::Error), @@ -173,29 +171,49 @@ impl SimplifiedType { } } -/// Given generic arguments from an obligation and an impl, -/// could these two be unified after replacing parameters in the -/// the impl with inference variables. +/// Given generic arguments, could they be unified after +/// replacing parameters with inference variables or placeholders. +/// This behavior is toggled using the const generics. /// -/// For obligations, parameters won't be replaced by inference -/// variables and only unify with themselves. We treat them -/// the same way we treat placeholders. +/// We use this to quickly reject impl/wc candidates without needing +/// to instantiate generic arguments/having to enter a probe. /// /// We also use this function during coherence. For coherence the /// impls only have to overlap for some value, so we treat parameters -/// on both sides like inference variables. This behavior is toggled -/// using the `treat_obligation_params` field. +/// on both sides like inference variables. #[derive(Debug, Clone, Copy)] -pub struct DeepRejectCtxt { - treat_obligation_params: TreatParams, +pub struct DeepRejectCtxt< + I: Interner, + const INSTANTIATE_LHS_WITH_INFER: bool, + const INSTANTIATE_RHS_WITH_INFER: bool, +> { _interner: PhantomData, } -impl DeepRejectCtxt { - pub fn new(_interner: I, treat_obligation_params: TreatParams) -> Self { - DeepRejectCtxt { treat_obligation_params, _interner: PhantomData } +impl DeepRejectCtxt { + /// Treat parameters in both the lhs and the rhs as rigid. + pub fn relate_rigid_rigid(_interner: I) -> DeepRejectCtxt { + DeepRejectCtxt { _interner: PhantomData } + } +} + +impl DeepRejectCtxt { + /// Treat parameters in both the lhs and the rhs as infer vars. + pub fn relate_infer_infer(_interner: I) -> DeepRejectCtxt { + DeepRejectCtxt { _interner: PhantomData } } +} +impl DeepRejectCtxt { + /// Treat parameters in the lhs as rigid, and in rhs as infer vars. + pub fn relate_rigid_infer(_interner: I) -> DeepRejectCtxt { + DeepRejectCtxt { _interner: PhantomData } + } +} + +impl + DeepRejectCtxt +{ pub fn args_may_unify( self, obligation_args: I::GenericArgs, @@ -216,11 +234,18 @@ impl DeepRejectCtxt { }) } - pub fn types_may_unify(self, obligation_ty: I::Ty, impl_ty: I::Ty) -> bool { - match impl_ty.kind() { - // Start by checking whether the type in the impl may unify with + pub fn types_may_unify(self, lhs: I::Ty, rhs: I::Ty) -> bool { + match rhs.kind() { + // Start by checking whether the `rhs` type may unify with // pretty much everything. Just return `true` in that case. - ty::Param(_) | ty::Error(_) | ty::Alias(..) => return true, + ty::Param(_) => { + if INSTANTIATE_RHS_WITH_INFER { + return true; + } + } + ty::Error(_) | ty::Alias(..) | ty::Bound(..) => return true, + ty::Infer(var) => return self.var_and_ty_may_unify(var, lhs), + // These types only unify with inference variables or their own // variant. ty::Bool @@ -238,159 +263,217 @@ impl DeepRejectCtxt { | ty::Ref(..) | ty::Never | ty::Tuple(..) + | ty::FnDef(..) | ty::FnPtr(..) - | ty::Foreign(..) => debug_assert!(impl_ty.is_known_rigid()), - ty::FnDef(..) | ty::Closure(..) | ty::CoroutineClosure(..) | ty::Coroutine(..) | ty::CoroutineWitness(..) - | ty::Placeholder(..) - | ty::Bound(..) - | ty::Infer(_) => panic!("unexpected impl_ty: {impl_ty:?}"), - } + | ty::Foreign(_) + | ty::Placeholder(_) => {} + }; - let k = impl_ty.kind(); - match obligation_ty.kind() { - // Purely rigid types, use structural equivalence. - ty::Bool - | ty::Char - | ty::Int(_) - | ty::Uint(_) - | ty::Float(_) - | ty::Str - | ty::Never - | ty::Foreign(_) => obligation_ty == impl_ty, - ty::Ref(_, obl_ty, obl_mutbl) => match k { - ty::Ref(_, impl_ty, impl_mutbl) => { - obl_mutbl == impl_mutbl && self.types_may_unify(obl_ty, impl_ty) + // For purely rigid types, use structural equivalence. + match lhs.kind() { + ty::Ref(_, lhs_ty, lhs_mutbl) => match rhs.kind() { + ty::Ref(_, rhs_ty, rhs_mutbl) => { + lhs_mutbl == rhs_mutbl && self.types_may_unify(lhs_ty, rhs_ty) } _ => false, }, - ty::Adt(obl_def, obl_args) => match k { - ty::Adt(impl_def, impl_args) => { - obl_def == impl_def && self.args_may_unify(obl_args, impl_args) + + ty::Adt(lhs_def, lhs_args) => match rhs.kind() { + ty::Adt(rhs_def, rhs_args) => { + lhs_def == rhs_def && self.args_may_unify(lhs_args, rhs_args) } _ => false, }, - ty::Pat(obl_ty, _) => { - // FIXME(pattern_types): take pattern into account - matches!(k, ty::Pat(impl_ty, _) if self.types_may_unify(obl_ty, impl_ty)) + + // Depending on the value of const generics, we either treat generic parameters + // like placeholders or like inference variables. + ty::Param(lhs) => { + INSTANTIATE_LHS_WITH_INFER + || match rhs.kind() { + ty::Param(rhs) => lhs == rhs, + _ => false, + } } - ty::Slice(obl_ty) => { - matches!(k, ty::Slice(impl_ty) if self.types_may_unify(obl_ty, impl_ty)) + + // Placeholder types don't unify with anything on their own. + ty::Placeholder(lhs) => { + matches!(rhs.kind(), ty::Placeholder(rhs) if lhs == rhs) } - ty::Array(obl_ty, obl_len) => match k { - ty::Array(impl_ty, impl_len) => { - self.types_may_unify(obl_ty, impl_ty) - && self.consts_may_unify(obl_len, impl_len) + + ty::Infer(var) => self.var_and_ty_may_unify(var, rhs), + + // As we're walking the whole type, it may encounter projections + // inside of binders and what not, so we're just going to assume that + // projections can unify with other stuff. + // + // Looking forward to lazy normalization this is the safer strategy anyways. + ty::Alias(..) => true, + + ty::Int(_) + | ty::Uint(_) + | ty::Float(_) + | ty::Str + | ty::Bool + | ty::Char + | ty::Never + | ty::Foreign(_) => lhs == rhs, + + ty::Tuple(lhs) => match rhs.kind() { + ty::Tuple(rhs) => { + lhs.len() == rhs.len() + && iter::zip(lhs.iter(), rhs.iter()) + .all(|(lhs, rhs)| self.types_may_unify(lhs, rhs)) } _ => false, }, - ty::Tuple(obl) => match k { - ty::Tuple(imp) => { - obl.len() == imp.len() - && iter::zip(obl.iter(), imp.iter()) - .all(|(obl, imp)| self.types_may_unify(obl, imp)) + + ty::Array(lhs_ty, lhs_len) => match rhs.kind() { + ty::Array(rhs_ty, rhs_len) => { + self.types_may_unify(lhs_ty, rhs_ty) && self.consts_may_unify(lhs_len, rhs_len) } _ => false, }, - ty::RawPtr(obl_ty, obl_mutbl) => match k { - ty::RawPtr(imp_ty, imp_mutbl) => { - obl_mutbl == imp_mutbl && self.types_may_unify(obl_ty, imp_ty) + + ty::RawPtr(lhs_ty, lhs_mutbl) => match rhs.kind() { + ty::RawPtr(rhs_ty, rhs_mutbl) => { + lhs_mutbl == rhs_mutbl && self.types_may_unify(lhs_ty, rhs_ty) } _ => false, }, - ty::Dynamic(obl_preds, ..) => { + + ty::Slice(lhs_ty) => { + matches!(rhs.kind(), ty::Slice(rhs_ty) if self.types_may_unify(lhs_ty, rhs_ty)) + } + + ty::Dynamic(lhs_preds, ..) => { // Ideally we would walk the existential predicates here or at least // compare their length. But considering that the relevant `Relate` impl // actually sorts and deduplicates these, that doesn't work. - matches!(k, ty::Dynamic(impl_preds, ..) if - obl_preds.principal_def_id() == impl_preds.principal_def_id() + matches!(rhs.kind(), ty::Dynamic(rhs_preds, ..) if + lhs_preds.principal_def_id() == rhs_preds.principal_def_id() ) } - ty::FnPtr(obl_sig_tys, obl_hdr) => match k { - ty::FnPtr(impl_sig_tys, impl_hdr) => { - let obl_sig_tys = obl_sig_tys.skip_binder().inputs_and_output; - let impl_sig_tys = impl_sig_tys.skip_binder().inputs_and_output; - - obl_hdr == impl_hdr - && obl_sig_tys.len() == impl_sig_tys.len() - && iter::zip(obl_sig_tys.iter(), impl_sig_tys.iter()) - .all(|(obl, imp)| self.types_may_unify(obl, imp)) + + ty::FnPtr(lhs_sig_tys, lhs_hdr) => match rhs.kind() { + ty::FnPtr(rhs_sig_tys, rhs_hdr) => { + let lhs_sig_tys = lhs_sig_tys.skip_binder().inputs_and_output; + let rhs_sig_tys = rhs_sig_tys.skip_binder().inputs_and_output; + + lhs_hdr == rhs_hdr + && lhs_sig_tys.len() == rhs_sig_tys.len() + && iter::zip(lhs_sig_tys.iter(), rhs_sig_tys.iter()) + .all(|(lhs, rhs)| self.types_may_unify(lhs, rhs)) } _ => false, }, - // Impls cannot contain these types as these cannot be named directly. - ty::FnDef(..) | ty::Closure(..) | ty::CoroutineClosure(..) | ty::Coroutine(..) => false, - - // Placeholder types don't unify with anything on their own - ty::Placeholder(..) | ty::Bound(..) => false, + ty::Bound(..) => true, - // Depending on the value of `treat_obligation_params`, we either - // treat generic parameters like placeholders or like inference variables. - ty::Param(_) => match self.treat_obligation_params { - TreatParams::ForLookup => false, - TreatParams::AsCandidateKey => true, + ty::FnDef(lhs_def_id, lhs_args) => match rhs.kind() { + ty::FnDef(rhs_def_id, rhs_args) => { + lhs_def_id == rhs_def_id && self.args_may_unify(lhs_args, rhs_args) + } + _ => false, }, - ty::Infer(ty::IntVar(_)) => impl_ty.is_integral(), - - ty::Infer(ty::FloatVar(_)) => impl_ty.is_floating_point(), + ty::Closure(lhs_def_id, lhs_args) => match rhs.kind() { + ty::Closure(rhs_def_id, rhs_args) => { + lhs_def_id == rhs_def_id && self.args_may_unify(lhs_args, rhs_args) + } + _ => false, + }, - ty::Infer(_) => true, + ty::CoroutineClosure(lhs_def_id, lhs_args) => match rhs.kind() { + ty::CoroutineClosure(rhs_def_id, rhs_args) => { + lhs_def_id == rhs_def_id && self.args_may_unify(lhs_args, rhs_args) + } + _ => false, + }, - // As we're walking the whole type, it may encounter projections - // inside of binders and what not, so we're just going to assume that - // projections can unify with other stuff. - // - // Looking forward to lazy normalization this is the safer strategy anyways. - ty::Alias(..) => true, + ty::Coroutine(lhs_def_id, lhs_args) => match rhs.kind() { + ty::Coroutine(rhs_def_id, rhs_args) => { + lhs_def_id == rhs_def_id && self.args_may_unify(lhs_args, rhs_args) + } + _ => false, + }, - ty::Error(_) => true, + ty::CoroutineWitness(lhs_def_id, lhs_args) => match rhs.kind() { + ty::CoroutineWitness(rhs_def_id, rhs_args) => { + lhs_def_id == rhs_def_id && self.args_may_unify(lhs_args, rhs_args) + } + _ => false, + }, - ty::CoroutineWitness(..) => { - panic!("unexpected obligation type: {:?}", obligation_ty) + ty::Pat(lhs_ty, _) => { + // FIXME(pattern_types): take pattern into account + matches!(rhs.kind(), ty::Pat(rhs_ty, _) if self.types_may_unify(lhs_ty, rhs_ty)) } + + ty::Error(..) => true, } } - pub fn consts_may_unify(self, obligation_ct: I::Const, impl_ct: I::Const) -> bool { - let impl_val = match impl_ct.kind() { + pub fn consts_may_unify(self, lhs: I::Const, rhs: I::Const) -> bool { + match rhs.kind() { + ty::ConstKind::Param(_) => { + if INSTANTIATE_RHS_WITH_INFER { + return true; + } + } + ty::ConstKind::Expr(_) - | ty::ConstKind::Param(_) | ty::ConstKind::Unevaluated(_) - | ty::ConstKind::Error(_) => { + | ty::ConstKind::Error(_) + | ty::ConstKind::Infer(_) + | ty::ConstKind::Bound(..) => { return true; } - ty::ConstKind::Value(_, impl_val) => impl_val, - ty::ConstKind::Infer(_) | ty::ConstKind::Bound(..) | ty::ConstKind::Placeholder(_) => { - panic!("unexpected impl arg: {:?}", impl_ct) - } + + ty::ConstKind::Value(..) | ty::ConstKind::Placeholder(_) => {} }; - match obligation_ct.kind() { - ty::ConstKind::Param(_) => match self.treat_obligation_params { - TreatParams::ForLookup => false, - TreatParams::AsCandidateKey => true, + match lhs.kind() { + ty::ConstKind::Value(_, lhs_val) => match rhs.kind() { + ty::ConstKind::Value(_, rhs_val) => lhs_val == rhs_val, + _ => false, }, + ty::ConstKind::Param(lhs) => { + INSTANTIATE_LHS_WITH_INFER + || match rhs.kind() { + ty::ConstKind::Param(rhs) => lhs == rhs, + _ => false, + } + } + // Placeholder consts don't unify with anything on their own - ty::ConstKind::Placeholder(_) => false, + ty::ConstKind::Placeholder(lhs) => { + matches!(rhs.kind(), ty::ConstKind::Placeholder(rhs) if lhs == rhs) + } // As we don't necessarily eagerly evaluate constants, // they might unify with any value. ty::ConstKind::Expr(_) | ty::ConstKind::Unevaluated(_) | ty::ConstKind::Error(_) => { true } - ty::ConstKind::Value(_, obl_val) => obl_val == impl_val, - ty::ConstKind::Infer(_) => true, + ty::ConstKind::Infer(_) | ty::ConstKind::Bound(..) => true, + } + } - ty::ConstKind::Bound(..) => { - panic!("unexpected obl const: {:?}", obligation_ct) - } + fn var_and_ty_may_unify(self, var: ty::InferTy, ty: I::Ty) -> bool { + if !ty.is_known_rigid() { + return true; + } + + match var { + ty::IntVar(_) => ty.is_integral(), + ty::FloatVar(_) => ty.is_floating_point(), + _ => true, } } } diff --git a/src/librustdoc/html/render/write_shared.rs b/src/librustdoc/html/render/write_shared.rs index 60dc142b9ff91..b0af8d8c23ee5 100644 --- a/src/librustdoc/html/render/write_shared.rs +++ b/src/librustdoc/html/render/write_shared.rs @@ -29,7 +29,7 @@ use itertools::Itertools; use regex::Regex; use rustc_data_structures::flock; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; -use rustc_middle::ty::fast_reject::{DeepRejectCtxt, TreatParams}; +use rustc_middle::ty::fast_reject::DeepRejectCtxt; use rustc_middle::ty::TyCtxt; use rustc_span::def_id::DefId; use rustc_span::Symbol; @@ -832,7 +832,7 @@ impl<'cx, 'cache> DocVisitor for TypeImplCollector<'cx, 'cache> { // Be aware of `tests/rustdoc/type-alias/deeply-nested-112515.rs` which might regress. let Some(impl_did) = impl_item_id.as_def_id() else { continue }; let for_ty = self.cx.tcx().type_of(impl_did).skip_binder(); - let reject_cx = DeepRejectCtxt::new(self.cx.tcx(), TreatParams::AsCandidateKey); + let reject_cx = DeepRejectCtxt::relate_infer_infer(self.cx.tcx()); if !reject_cx.types_may_unify(aliased_ty, for_ty) { continue; }