From ad094cdcebde8336d4f63d5323db5843e27def40 Mon Sep 17 00:00:00 2001 From: Santiago Pastorino Date: Wed, 16 Nov 2022 15:58:48 -0300 Subject: [PATCH 1/3] Use ObligationCtxt intead of dyn TraitEngine --- compiler/rustc_hir_typeck/src/coercion.rs | 11 ++++----- .../src/traits/error_reporting/mod.rs | 12 ++++------ .../rustc_trait_selection/src/traits/mod.rs | 19 ++++++++------- .../src/traits/specialize/mod.rs | 12 +++++----- .../src/traits/structural_match.rs | 23 ++++--------------- 5 files changed, 30 insertions(+), 47 deletions(-) diff --git a/compiler/rustc_hir_typeck/src/coercion.rs b/compiler/rustc_hir_typeck/src/coercion.rs index 43c7127b0d4c5..3f0d0a76027f4 100644 --- a/compiler/rustc_hir_typeck/src/coercion.rs +++ b/compiler/rustc_hir_typeck/src/coercion.rs @@ -46,7 +46,7 @@ use rustc_hir::Expr; use rustc_hir_analysis::astconv::AstConv; use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind}; use rustc_infer::infer::{Coercion, InferOk, InferResult}; -use rustc_infer::traits::{Obligation, TraitEngine, TraitEngineExt}; +use rustc_infer::traits::Obligation; use rustc_middle::lint::in_external_macro; use rustc_middle::ty::adjustment::{ Adjust, Adjustment, AllowTwoPhase, AutoBorrow, AutoBorrowMutability, PointerCast, @@ -62,8 +62,7 @@ use rustc_span::{self, BytePos, DesugaringKind, Span}; use rustc_target::spec::abi::Abi; use rustc_trait_selection::infer::InferCtxtExt as _; use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt as _; -use rustc_trait_selection::traits::TraitEngineExt as _; -use rustc_trait_selection::traits::{self, ObligationCause, ObligationCauseCode}; +use rustc_trait_selection::traits::{self, ObligationCause, ObligationCauseCode, ObligationCtxt}; use smallvec::{smallvec, SmallVec}; use std::ops::Deref; @@ -1055,9 +1054,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { let Ok(ok) = coerce.coerce(source, target) else { return false; }; - let mut fcx = >::new_in_snapshot(self.tcx); - fcx.register_predicate_obligations(self, ok.obligations); - fcx.select_where_possible(&self).is_empty() + let ocx = ObligationCtxt::new_in_snapshot(self); + ocx.register_obligations(ok.obligations); + ocx.select_where_possible().is_empty() }) } diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs index ef3d300020a39..de31eb1aa5719 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs @@ -4,13 +4,12 @@ pub mod suggestions; use super::{ FulfillmentError, FulfillmentErrorCode, MismatchedProjectionTypes, Obligation, ObligationCause, - ObligationCauseCode, OutputTypeParameterMismatch, Overflow, PredicateObligation, - SelectionContext, SelectionError, TraitNotObjectSafe, + ObligationCauseCode, ObligationCtxt, OutputTypeParameterMismatch, Overflow, + PredicateObligation, SelectionContext, SelectionError, TraitNotObjectSafe, }; use crate::infer::error_reporting::{TyCategory, TypeAnnotationNeeded as ErrorCode}; use crate::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind}; use crate::infer::{self, InferCtxt, TyCtxtInferExt}; -use crate::traits::engine::TraitEngineExt as _; use crate::traits::query::evaluate_obligation::InferCtxtExt as _; use crate::traits::query::normalize::AtExt as _; use crate::traits::specialize::to_pretty_impl_header; @@ -30,7 +29,6 @@ use rustc_hir::Item; use rustc_hir::Node; use rustc_infer::infer::error_reporting::TypeErrCtxt; use rustc_infer::infer::TypeTrace; -use rustc_infer::traits::TraitEngine; use rustc_middle::traits::select::OverflowError; use rustc_middle::ty::abstract_const::NotConstEvaluatable; use rustc_middle::ty::error::ExpectedFound; @@ -354,9 +352,9 @@ impl<'tcx> InferCtxtExt<'tcx> for InferCtxt<'tcx> { param_env, ty.rebind(ty::TraitPredicate { trait_ref, constness, polarity }), ); - let mut fulfill_cx = >::new_in_snapshot(self.tcx); - fulfill_cx.register_predicate_obligation(self, obligation); - if fulfill_cx.select_all_or_error(self).is_empty() { + let ocx = ObligationCtxt::new_in_snapshot(self); + ocx.register_obligation(obligation); + if ocx.select_all_or_error().is_empty() { return Ok(( ty::ClosureKind::from_def_id(self.tcx, trait_def_id) .expect("expected to map DefId to ClosureKind"), diff --git a/compiler/rustc_trait_selection/src/traits/mod.rs b/compiler/rustc_trait_selection/src/traits/mod.rs index ff18aa1f9e909..8a42bf4113a45 100644 --- a/compiler/rustc_trait_selection/src/traits/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/mod.rs @@ -31,7 +31,6 @@ use rustc_errors::ErrorGuaranteed; use rustc_hir as hir; use rustc_hir::def_id::DefId; use rustc_hir::lang_items::LangItem; -use rustc_infer::traits::TraitEngineExt as _; use rustc_middle::ty::fold::TypeFoldable; use rustc_middle::ty::visit::TypeVisitable; use rustc_middle::ty::{ @@ -403,9 +402,9 @@ pub fn fully_solve_obligation<'tcx>( infcx: &InferCtxt<'tcx>, obligation: PredicateObligation<'tcx>, ) -> Vec> { - let mut engine = >::new(infcx.tcx); - engine.register_predicate_obligation(infcx, obligation); - engine.select_all_or_error(infcx) + let ocx = ObligationCtxt::new(infcx); + ocx.register_obligation(obligation); + ocx.select_all_or_error() } /// Process a set of obligations (and any nested obligations that come from them) @@ -414,9 +413,9 @@ pub fn fully_solve_obligations<'tcx>( infcx: &InferCtxt<'tcx>, obligations: impl IntoIterator>, ) -> Vec> { - let mut engine = >::new(infcx.tcx); - engine.register_predicate_obligations(infcx, obligations); - engine.select_all_or_error(infcx) + let ocx = ObligationCtxt::new(infcx); + ocx.register_obligations(obligations); + ocx.select_all_or_error() } /// Process a bound (and any nested obligations that come from it) to completion. @@ -429,9 +428,9 @@ pub fn fully_solve_bound<'tcx>( ty: Ty<'tcx>, bound: DefId, ) -> Vec> { - let mut engine = >::new(infcx.tcx); - engine.register_bound(infcx, param_env, ty, bound, cause); - engine.select_all_or_error(infcx) + let ocx = ObligationCtxt::new(infcx); + ocx.register_bound(cause, param_env, ty, bound); + ocx.select_all_or_error() } /// Normalizes the predicates and checks whether they hold in an empty environment. If this diff --git a/compiler/rustc_trait_selection/src/traits/specialize/mod.rs b/compiler/rustc_trait_selection/src/traits/specialize/mod.rs index 7cc12eff20e8b..9a3c0707c7ce9 100644 --- a/compiler/rustc_trait_selection/src/traits/specialize/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/specialize/mod.rs @@ -10,14 +10,14 @@ //! [rustc dev guide]: https://rustc-dev-guide.rust-lang.org/traits/specialization.html pub mod specialization_graph; -use rustc_infer::traits::{TraitEngine, TraitEngineExt as _}; use specialization_graph::GraphExt; use crate::errors::NegativePositiveConflict; use crate::infer::{InferCtxt, InferOk, TyCtxtInferExt}; -use crate::traits::engine::TraitEngineExt as _; use crate::traits::select::IntercrateAmbiguityCause; -use crate::traits::{self, coherence, FutureCompatOverlapErrorKind, ObligationCause}; +use crate::traits::{ + self, coherence, FutureCompatOverlapErrorKind, ObligationCause, ObligationCtxt, +}; use rustc_data_structures::fx::FxIndexSet; use rustc_errors::{error_code, DelayDm, Diagnostic}; use rustc_hir::def_id::{DefId, LocalDefId}; @@ -204,12 +204,12 @@ fn fulfill_implication<'tcx>( // Needs to be `in_snapshot` because this function is used to rebase // substitutions, which may happen inside of a select within a probe. - let mut engine = >::new_in_snapshot(infcx.tcx); + let ocx = ObligationCtxt::new_in_snapshot(infcx); // attempt to prove all of the predicates for impl2 given those for impl1 // (which are packed up in penv) - engine.register_predicate_obligations(infcx, obligations.chain(more_obligations)); + ocx.register_obligations(obligations.chain(more_obligations)); - let errors = engine.select_all_or_error(infcx); + let errors = ocx.select_all_or_error(); if !errors.is_empty() { // no dice! debug!( diff --git a/compiler/rustc_trait_selection/src/traits/structural_match.rs b/compiler/rustc_trait_selection/src/traits/structural_match.rs index 932dbbb81e5cc..40dbe0b3ff063 100644 --- a/compiler/rustc_trait_selection/src/traits/structural_match.rs +++ b/compiler/rustc_trait_selection/src/traits/structural_match.rs @@ -1,6 +1,5 @@ use crate::infer::{InferCtxt, TyCtxtInferExt}; -use crate::traits::ObligationCause; -use crate::traits::{TraitEngine, TraitEngineExt}; +use crate::traits::{ObligationCause, ObligationCtxt}; use rustc_data_structures::fx::FxHashSet; use rustc_hir as hir; @@ -72,28 +71,16 @@ fn type_marked_structural<'tcx>( adt_ty: Ty<'tcx>, cause: ObligationCause<'tcx>, ) -> bool { - let mut fulfillment_cx = >::new(infcx.tcx); + let ocx = ObligationCtxt::new(infcx); // require `#[derive(PartialEq)]` let structural_peq_def_id = infcx.tcx.require_lang_item(LangItem::StructuralPeq, Some(cause.span)); - fulfillment_cx.register_bound( - infcx, - ty::ParamEnv::empty(), - adt_ty, - structural_peq_def_id, - cause.clone(), - ); + ocx.register_bound(cause.clone(), ty::ParamEnv::empty(), adt_ty, structural_peq_def_id); // for now, require `#[derive(Eq)]`. (Doing so is a hack to work around // the type `for<'a> fn(&'a ())` failing to implement `Eq` itself.) let structural_teq_def_id = infcx.tcx.require_lang_item(LangItem::StructuralTeq, Some(cause.span)); - fulfillment_cx.register_bound( - infcx, - ty::ParamEnv::empty(), - adt_ty, - structural_teq_def_id, - cause, - ); + ocx.register_bound(cause, ty::ParamEnv::empty(), adt_ty, structural_teq_def_id); // We deliberately skip *reporting* fulfillment errors (via // `report_fulfillment_errors`), for two reasons: @@ -104,7 +91,7 @@ fn type_marked_structural<'tcx>( // // 2. We are sometimes doing future-incompatibility lints for // now, so we do not want unconditional errors here. - fulfillment_cx.select_all_or_error(infcx).is_empty() + ocx.select_all_or_error().is_empty() } /// This implements the traversal over the structure of a given type to try to From 5b3a06a3c2584d303cb40637a50a4bc3f0d8cedf Mon Sep 17 00:00:00 2001 From: Santiago Pastorino Date: Thu, 17 Nov 2022 11:44:24 -0300 Subject: [PATCH 2/3] Call fully_solve_obligations instead of repeating code --- .../rustc_trait_selection/src/traits/mod.rs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/compiler/rustc_trait_selection/src/traits/mod.rs b/compiler/rustc_trait_selection/src/traits/mod.rs index 8a42bf4113a45..548ca1c1d7faa 100644 --- a/compiler/rustc_trait_selection/src/traits/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/mod.rs @@ -402,9 +402,7 @@ pub fn fully_solve_obligation<'tcx>( infcx: &InferCtxt<'tcx>, obligation: PredicateObligation<'tcx>, ) -> Vec> { - let ocx = ObligationCtxt::new(infcx); - ocx.register_obligation(obligation); - ocx.select_all_or_error() + fully_solve_obligations(infcx, [obligation]) } /// Process a set of obligations (and any nested obligations that come from them) @@ -428,9 +426,16 @@ pub fn fully_solve_bound<'tcx>( ty: Ty<'tcx>, bound: DefId, ) -> Vec> { - let ocx = ObligationCtxt::new(infcx); - ocx.register_bound(cause, param_env, ty, bound); - ocx.select_all_or_error() + let tcx = infcx.tcx; + let trait_ref = ty::TraitRef { def_id: bound, substs: tcx.mk_substs_trait(ty, []) }; + let obligation = Obligation { + cause, + recursion_depth: 0, + param_env, + predicate: ty::Binder::dummy(trait_ref).without_const().to_predicate(tcx), + }; + + fully_solve_obligation(infcx, obligation) } /// Normalizes the predicates and checks whether they hold in an empty environment. If this From 859b147d4f6683b15f309bf3b997efdefca7767d Mon Sep 17 00:00:00 2001 From: Santiago Pastorino Date: Wed, 16 Nov 2022 19:40:55 -0300 Subject: [PATCH 3/3] Pass ObligationCtxt from enter_canonical_trait_query and use ObligationCtxt API --- .../src/implied_outlives_bounds.rs | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/compiler/rustc_traits/src/implied_outlives_bounds.rs b/compiler/rustc_traits/src/implied_outlives_bounds.rs index 2d1a386992617..3ab353c963802 100644 --- a/compiler/rustc_traits/src/implied_outlives_bounds.rs +++ b/compiler/rustc_traits/src/implied_outlives_bounds.rs @@ -5,16 +5,15 @@ use rustc_hir as hir; use rustc_infer::infer::canonical::{self, Canonical}; use rustc_infer::infer::outlives::components::{push_outlives_components, Component}; -use rustc_infer::infer::{InferCtxt, TyCtxtInferExt}; +use rustc_infer::infer::TyCtxtInferExt; use rustc_infer::traits::query::OutlivesBound; -use rustc_infer::traits::TraitEngineExt as _; use rustc_middle::ty::query::Providers; use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitable}; use rustc_span::source_map::DUMMY_SP; use rustc_trait_selection::infer::InferCtxtBuilderExt; use rustc_trait_selection::traits::query::{CanonicalTyGoal, Fallible, NoSolution}; use rustc_trait_selection::traits::wf; -use rustc_trait_selection::traits::{TraitEngine, TraitEngineExt}; +use rustc_trait_selection::traits::ObligationCtxt; use smallvec::{smallvec, SmallVec}; pub(crate) fn provide(p: &mut Providers) { @@ -30,16 +29,16 @@ fn implied_outlives_bounds<'tcx>( > { tcx.infer_ctxt().enter_canonical_trait_query(&goal, |ocx, key| { let (param_env, ty) = key.into_parts(); - compute_implied_outlives_bounds(&ocx.infcx, param_env, ty) + compute_implied_outlives_bounds(ocx, param_env, ty) }) } fn compute_implied_outlives_bounds<'tcx>( - infcx: &InferCtxt<'tcx>, + ocx: &ObligationCtxt<'_, 'tcx>, param_env: ty::ParamEnv<'tcx>, ty: Ty<'tcx>, ) -> Fallible>> { - let tcx = infcx.tcx; + let tcx = ocx.infcx.tcx; // Sometimes when we ask what it takes for T: WF, we get back that // U: WF is required; in that case, we push U onto this stack and @@ -52,8 +51,6 @@ fn compute_implied_outlives_bounds<'tcx>( let mut outlives_bounds: Vec, ty::Region<'tcx>>> = vec![]; - let mut fulfill_cx = >::new(tcx); - while let Some(arg) = wf_args.pop() { if !checked_wf_args.insert(arg) { continue; @@ -70,15 +67,15 @@ fn compute_implied_outlives_bounds<'tcx>( // FIXME(@lcnr): It's not really "always fine", having fewer implied // bounds can be backward incompatible, e.g. #101951 was caused by // us not dealing with inference vars in `TypeOutlives` predicates. - let obligations = wf::obligations(infcx, param_env, hir::CRATE_HIR_ID, 0, arg, DUMMY_SP) - .unwrap_or_default(); + let obligations = + wf::obligations(ocx.infcx, param_env, hir::CRATE_HIR_ID, 0, arg, DUMMY_SP) + .unwrap_or_default(); // While these predicates should all be implied by other parts of // the program, they are still relevant as they may constrain // inference variables, which is necessary to add the correct // implied bounds in some cases, mostly when dealing with projections. - fulfill_cx.register_predicate_obligations( - infcx, + ocx.register_obligations( obligations.iter().filter(|o| o.predicate.has_non_region_infer()).cloned(), ); @@ -116,9 +113,9 @@ fn compute_implied_outlives_bounds<'tcx>( })); } - // Ensure that those obligations that we had to solve - // get solved *here*. - match fulfill_cx.select_all_or_error(infcx).as_slice() { + // This call to `select_all_or_error` is necessary to constrain inference variables, which we + // use further down when computing the implied bounds. + match ocx.select_all_or_error().as_slice() { [] => (), _ => return Err(NoSolution), } @@ -130,7 +127,7 @@ fn compute_implied_outlives_bounds<'tcx>( .flat_map(|ty::OutlivesPredicate(a, r_b)| match a.unpack() { ty::GenericArgKind::Lifetime(r_a) => vec![OutlivesBound::RegionSubRegion(r_b, r_a)], ty::GenericArgKind::Type(ty_a) => { - let ty_a = infcx.resolve_vars_if_possible(ty_a); + let ty_a = ocx.infcx.resolve_vars_if_possible(ty_a); let mut components = smallvec![]; push_outlives_components(tcx, ty_a, &mut components); implied_bounds_from_components(r_b, components)