Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use obligation ctxt instead of dyn TraitEngine #104509

Merged
merged 3 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions compiler/rustc_hir_typeck/src/coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use rustc_hir::Expr;
use rustc_hir_analysis::astconv::AstConv;
use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
use rustc_infer::infer::{Coercion, InferOk, InferResult};
use rustc_infer::traits::{Obligation, TraitEngine, TraitEngineExt};
use rustc_infer::traits::Obligation;
use rustc_middle::lint::in_external_macro;
use rustc_middle::ty::adjustment::{
Adjust, Adjustment, AllowTwoPhase, AutoBorrow, AutoBorrowMutability, PointerCast,
Expand All @@ -62,8 +62,7 @@ use rustc_span::{self, BytePos, DesugaringKind, Span};
use rustc_target::spec::abi::Abi;
use rustc_trait_selection::infer::InferCtxtExt as _;
use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt as _;
use rustc_trait_selection::traits::TraitEngineExt as _;
use rustc_trait_selection::traits::{self, ObligationCause, ObligationCauseCode};
use rustc_trait_selection::traits::{self, ObligationCause, ObligationCauseCode, ObligationCtxt};

use smallvec::{smallvec, SmallVec};
use std::ops::Deref;
Expand Down Expand Up @@ -1055,9 +1054,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
let Ok(ok) = coerce.coerce(source, target) else {
return false;
};
let mut fcx = <dyn TraitEngine<'tcx>>::new_in_snapshot(self.tcx);
fcx.register_predicate_obligations(self, ok.obligations);
fcx.select_where_possible(&self).is_empty()
let ocx = ObligationCtxt::new_in_snapshot(self);
ocx.register_obligations(ok.obligations);
ocx.select_where_possible().is_empty()
})
}

Expand Down
12 changes: 5 additions & 7 deletions compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@ pub mod suggestions;

use super::{
FulfillmentError, FulfillmentErrorCode, MismatchedProjectionTypes, Obligation, ObligationCause,
ObligationCauseCode, OutputTypeParameterMismatch, Overflow, PredicateObligation,
SelectionContext, SelectionError, TraitNotObjectSafe,
ObligationCauseCode, ObligationCtxt, OutputTypeParameterMismatch, Overflow,
PredicateObligation, SelectionContext, SelectionError, TraitNotObjectSafe,
};
use crate::infer::error_reporting::{TyCategory, TypeAnnotationNeeded as ErrorCode};
use crate::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
use crate::infer::{self, InferCtxt, TyCtxtInferExt};
use crate::traits::engine::TraitEngineExt as _;
use crate::traits::query::evaluate_obligation::InferCtxtExt as _;
use crate::traits::query::normalize::AtExt as _;
use crate::traits::specialize::to_pretty_impl_header;
Expand All @@ -30,7 +29,6 @@ use rustc_hir::Item;
use rustc_hir::Node;
use rustc_infer::infer::error_reporting::TypeErrCtxt;
use rustc_infer::infer::TypeTrace;
use rustc_infer::traits::TraitEngine;
use rustc_middle::traits::select::OverflowError;
use rustc_middle::ty::abstract_const::NotConstEvaluatable;
use rustc_middle::ty::error::ExpectedFound;
Expand Down Expand Up @@ -354,9 +352,9 @@ impl<'tcx> InferCtxtExt<'tcx> for InferCtxt<'tcx> {
param_env,
ty.rebind(ty::TraitPredicate { trait_ref, constness, polarity }),
);
let mut fulfill_cx = <dyn TraitEngine<'tcx>>::new_in_snapshot(self.tcx);
fulfill_cx.register_predicate_obligation(self, obligation);
if fulfill_cx.select_all_or_error(self).is_empty() {
let ocx = ObligationCtxt::new_in_snapshot(self);
ocx.register_obligation(obligation);
if ocx.select_all_or_error().is_empty() {
return Ok((
ty::ClosureKind::from_def_id(self.tcx, trait_def_id)
.expect("expected to map DefId to ClosureKind"),
Expand Down
24 changes: 14 additions & 10 deletions compiler/rustc_trait_selection/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ use rustc_errors::ErrorGuaranteed;
use rustc_hir as hir;
use rustc_hir::def_id::DefId;
use rustc_hir::lang_items::LangItem;
use rustc_infer::traits::TraitEngineExt as _;
use rustc_middle::ty::fold::TypeFoldable;
use rustc_middle::ty::visit::TypeVisitable;
use rustc_middle::ty::{
Expand Down Expand Up @@ -403,9 +402,7 @@ pub fn fully_solve_obligation<'tcx>(
infcx: &InferCtxt<'tcx>,
obligation: PredicateObligation<'tcx>,
) -> Vec<FulfillmentError<'tcx>> {
let mut engine = <dyn TraitEngine<'tcx>>::new(infcx.tcx);
engine.register_predicate_obligation(infcx, obligation);
engine.select_all_or_error(infcx)
fully_solve_obligations(infcx, [obligation])
}

/// Process a set of obligations (and any nested obligations that come from them)
Expand All @@ -414,9 +411,9 @@ pub fn fully_solve_obligations<'tcx>(
infcx: &InferCtxt<'tcx>,
obligations: impl IntoIterator<Item = PredicateObligation<'tcx>>,
) -> Vec<FulfillmentError<'tcx>> {
let mut engine = <dyn TraitEngine<'tcx>>::new(infcx.tcx);
engine.register_predicate_obligations(infcx, obligations);
engine.select_all_or_error(infcx)
let ocx = ObligationCtxt::new(infcx);
ocx.register_obligations(obligations);
ocx.select_all_or_error()
}

/// Process a bound (and any nested obligations that come from it) to completion.
Expand All @@ -429,9 +426,16 @@ pub fn fully_solve_bound<'tcx>(
ty: Ty<'tcx>,
bound: DefId,
) -> Vec<FulfillmentError<'tcx>> {
let mut engine = <dyn TraitEngine<'tcx>>::new(infcx.tcx);
engine.register_bound(infcx, param_env, ty, bound, cause);
engine.select_all_or_error(infcx)
let tcx = infcx.tcx;
let trait_ref = ty::TraitRef { def_id: bound, substs: tcx.mk_substs_trait(ty, []) };
let obligation = Obligation {
cause,
recursion_depth: 0,
param_env,
predicate: ty::Binder::dummy(trait_ref).without_const().to_predicate(tcx),
};

fully_solve_obligation(infcx, obligation)
}

/// Normalizes the predicates and checks whether they hold in an empty environment. If this
Expand Down
12 changes: 6 additions & 6 deletions compiler/rustc_trait_selection/src/traits/specialize/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
//! [rustc dev guide]: https://rustc-dev-guide.rust-lang.org/traits/specialization.html

pub mod specialization_graph;
use rustc_infer::traits::{TraitEngine, TraitEngineExt as _};
use specialization_graph::GraphExt;

use crate::errors::NegativePositiveConflict;
use crate::infer::{InferCtxt, InferOk, TyCtxtInferExt};
use crate::traits::engine::TraitEngineExt as _;
use crate::traits::select::IntercrateAmbiguityCause;
use crate::traits::{self, coherence, FutureCompatOverlapErrorKind, ObligationCause};
use crate::traits::{
self, coherence, FutureCompatOverlapErrorKind, ObligationCause, ObligationCtxt,
};
use rustc_data_structures::fx::FxIndexSet;
use rustc_errors::{error_code, DelayDm, Diagnostic};
use rustc_hir::def_id::{DefId, LocalDefId};
Expand Down Expand Up @@ -204,12 +204,12 @@ fn fulfill_implication<'tcx>(

// Needs to be `in_snapshot` because this function is used to rebase
// substitutions, which may happen inside of a select within a probe.
let mut engine = <dyn TraitEngine<'tcx>>::new_in_snapshot(infcx.tcx);
let ocx = ObligationCtxt::new_in_snapshot(infcx);
// attempt to prove all of the predicates for impl2 given those for impl1
// (which are packed up in penv)
engine.register_predicate_obligations(infcx, obligations.chain(more_obligations));
ocx.register_obligations(obligations.chain(more_obligations));

let errors = engine.select_all_or_error(infcx);
let errors = ocx.select_all_or_error();
if !errors.is_empty() {
// no dice!
debug!(
Expand Down
23 changes: 5 additions & 18 deletions compiler/rustc_trait_selection/src/traits/structural_match.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::infer::{InferCtxt, TyCtxtInferExt};
use crate::traits::ObligationCause;
use crate::traits::{TraitEngine, TraitEngineExt};
use crate::traits::{ObligationCause, ObligationCtxt};

use rustc_data_structures::fx::FxHashSet;
use rustc_hir as hir;
Expand Down Expand Up @@ -72,28 +71,16 @@ fn type_marked_structural<'tcx>(
adt_ty: Ty<'tcx>,
cause: ObligationCause<'tcx>,
) -> bool {
let mut fulfillment_cx = <dyn TraitEngine<'tcx>>::new(infcx.tcx);
let ocx = ObligationCtxt::new(infcx);
// require `#[derive(PartialEq)]`
let structural_peq_def_id =
infcx.tcx.require_lang_item(LangItem::StructuralPeq, Some(cause.span));
fulfillment_cx.register_bound(
infcx,
ty::ParamEnv::empty(),
adt_ty,
structural_peq_def_id,
cause.clone(),
);
ocx.register_bound(cause.clone(), ty::ParamEnv::empty(), adt_ty, structural_peq_def_id);
// for now, require `#[derive(Eq)]`. (Doing so is a hack to work around
// the type `for<'a> fn(&'a ())` failing to implement `Eq` itself.)
let structural_teq_def_id =
infcx.tcx.require_lang_item(LangItem::StructuralTeq, Some(cause.span));
fulfillment_cx.register_bound(
infcx,
ty::ParamEnv::empty(),
adt_ty,
structural_teq_def_id,
cause,
);
ocx.register_bound(cause, ty::ParamEnv::empty(), adt_ty, structural_teq_def_id);

// We deliberately skip *reporting* fulfillment errors (via
// `report_fulfillment_errors`), for two reasons:
Expand All @@ -104,7 +91,7 @@ fn type_marked_structural<'tcx>(
//
// 2. We are sometimes doing future-incompatibility lints for
// now, so we do not want unconditional errors here.
fulfillment_cx.select_all_or_error(infcx).is_empty()
ocx.select_all_or_error().is_empty()
}

/// This implements the traversal over the structure of a given type to try to
Expand Down
29 changes: 13 additions & 16 deletions compiler/rustc_traits/src/implied_outlives_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@
use rustc_hir as hir;
use rustc_infer::infer::canonical::{self, Canonical};
use rustc_infer::infer::outlives::components::{push_outlives_components, Component};
use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
use rustc_infer::infer::TyCtxtInferExt;
use rustc_infer::traits::query::OutlivesBound;
use rustc_infer::traits::TraitEngineExt as _;
use rustc_middle::ty::query::Providers;
use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitable};
use rustc_span::source_map::DUMMY_SP;
use rustc_trait_selection::infer::InferCtxtBuilderExt;
use rustc_trait_selection::traits::query::{CanonicalTyGoal, Fallible, NoSolution};
use rustc_trait_selection::traits::wf;
use rustc_trait_selection::traits::{TraitEngine, TraitEngineExt};
use rustc_trait_selection::traits::ObligationCtxt;
use smallvec::{smallvec, SmallVec};

pub(crate) fn provide(p: &mut Providers) {
Expand All @@ -30,16 +29,16 @@ fn implied_outlives_bounds<'tcx>(
> {
tcx.infer_ctxt().enter_canonical_trait_query(&goal, |ocx, key| {
let (param_env, ty) = key.into_parts();
compute_implied_outlives_bounds(&ocx.infcx, param_env, ty)
compute_implied_outlives_bounds(ocx, param_env, ty)
})
}

fn compute_implied_outlives_bounds<'tcx>(
infcx: &InferCtxt<'tcx>,
ocx: &ObligationCtxt<'_, 'tcx>,
param_env: ty::ParamEnv<'tcx>,
ty: Ty<'tcx>,
) -> Fallible<Vec<OutlivesBound<'tcx>>> {
let tcx = infcx.tcx;
let tcx = ocx.infcx.tcx;

// Sometimes when we ask what it takes for T: WF, we get back that
// U: WF is required; in that case, we push U onto this stack and
Expand All @@ -52,8 +51,6 @@ fn compute_implied_outlives_bounds<'tcx>(
let mut outlives_bounds: Vec<ty::OutlivesPredicate<ty::GenericArg<'tcx>, ty::Region<'tcx>>> =
vec![];

let mut fulfill_cx = <dyn TraitEngine<'tcx>>::new(tcx);

while let Some(arg) = wf_args.pop() {
if !checked_wf_args.insert(arg) {
continue;
Expand All @@ -70,15 +67,15 @@ fn compute_implied_outlives_bounds<'tcx>(
// FIXME(@lcnr): It's not really "always fine", having fewer implied
// bounds can be backward incompatible, e.g. #101951 was caused by
// us not dealing with inference vars in `TypeOutlives` predicates.
let obligations = wf::obligations(infcx, param_env, hir::CRATE_HIR_ID, 0, arg, DUMMY_SP)
.unwrap_or_default();
let obligations =
wf::obligations(ocx.infcx, param_env, hir::CRATE_HIR_ID, 0, arg, DUMMY_SP)
.unwrap_or_default();

// While these predicates should all be implied by other parts of
// the program, they are still relevant as they may constrain
// inference variables, which is necessary to add the correct
// implied bounds in some cases, mostly when dealing with projections.
fulfill_cx.register_predicate_obligations(
infcx,
ocx.register_obligations(
obligations.iter().filter(|o| o.predicate.has_non_region_infer()).cloned(),
);

Expand Down Expand Up @@ -116,9 +113,9 @@ fn compute_implied_outlives_bounds<'tcx>(
}));
}

// Ensure that those obligations that we had to solve
// get solved *here*.
match fulfill_cx.select_all_or_error(infcx).as_slice() {
// This call to `select_all_or_error` is necessary to constrain inference variables, which we
// use further down when computing the implied bounds.
match ocx.select_all_or_error().as_slice() {
[] => (),
_ => return Err(NoSolution),
}
Comment on lines +118 to 121
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does calling select_all_or_error here do to the select_* calls that happen in make_query_response, since we're inside of a canonical query here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Maybe it does nothing -- just curious, since we're passing down the ObligationCtxt from the canonical query instead of creating a trait engine inline, it might change behavior -- but maybe not!)

Copy link
Contributor

@lcnr lcnr Nov 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is necessary to constrain inference variables which we use further down when computing the implied bounds. Can you add a comment mentioning that here?

Expand All @@ -130,7 +127,7 @@ fn compute_implied_outlives_bounds<'tcx>(
.flat_map(|ty::OutlivesPredicate(a, r_b)| match a.unpack() {
ty::GenericArgKind::Lifetime(r_a) => vec![OutlivesBound::RegionSubRegion(r_b, r_a)],
ty::GenericArgKind::Type(ty_a) => {
let ty_a = infcx.resolve_vars_if_possible(ty_a);
let ty_a = ocx.infcx.resolve_vars_if_possible(ty_a);
let mut components = smallvec![];
push_outlives_components(tcx, ty_a, &mut components);
implied_bounds_from_components(r_b, components)
Expand Down