Skip to content

Commit

Permalink
Auto merge of #126614 - compiler-errors:uplift-next-trait-solver, r=lcnr
Browse files Browse the repository at this point in the history
Uplift next trait solver to `rustc_next_trait_solver`

🎉

There's so many FIXMEs! Sorry! Ideally this merges with the FIXMEs and we track and squash them over the near future.

Also, this still doesn't build on anything other than rustc. I still need to fix `feature = "nightly"` in `rustc_type_ir`, and remove and fix all the nightly feature usage in the new trait solver (notably: let-chains).

Also, sorry `@lcnr` I know you asked for me to separate the commit where we `mv rustc_trait_selection/solve/... rustc_next_trait_solver/solve/...`, but I had already done all the work by that point. Luckily, `git` understands the file moves so it should still be relatively reviewable.

If this is still very difficult to review, then I can do some rebasing magic to try to separate this out. Please let me know!

r? lcnr
  • Loading branch information
bors committed Jun 18, 2024
2 parents dd104ef + 6609501 commit 8fcd4dd
Show file tree
Hide file tree
Showing 51 changed files with 2,824 additions and 1,823 deletions.
9 changes: 9 additions & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4520,7 +4520,16 @@ dependencies = [
name = "rustc_next_trait_solver"
version = "0.0.0"
dependencies = [
"bitflags 2.5.0",
"derivative",
"rustc_ast_ir",
"rustc_data_structures",
"rustc_index",
"rustc_macros",
"rustc_serialize",
"rustc_type_ir",
"rustc_type_ir_macros",
"tracing",
]

[[package]]
Expand Down
151 changes: 6 additions & 145 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub use BoundRegionConversionTime::*;
pub use RegionVariableOrigin::*;
pub use SubregionOrigin::*;

use crate::infer::relate::{Relate, RelateResult};
use crate::infer::relate::RelateResult;
use crate::traits::{self, ObligationCause, ObligationInspector, PredicateObligation, TraitEngine};
use error_reporting::TypeErrCtxt;
use free_regions::RegionRelations;
Expand Down Expand Up @@ -45,7 +45,7 @@ use rustc_middle::ty::{ConstVid, EffectVid, FloatVid, IntVid, TyVid};
use rustc_middle::ty::{GenericArg, GenericArgKind, GenericArgs, GenericArgsRef};
use rustc_middle::{bug, span_bug};
use rustc_span::symbol::Symbol;
use rustc_span::{Span, DUMMY_SP};
use rustc_span::Span;
use snapshot::undo_log::InferCtxtUndoLogs;
use std::cell::{Cell, RefCell};
use std::fmt;
Expand Down Expand Up @@ -335,149 +335,6 @@ pub struct InferCtxt<'tcx> {
pub obligation_inspector: Cell<Option<ObligationInspector<'tcx>>>,
}

impl<'tcx> ty::InferCtxtLike for InferCtxt<'tcx> {
type Interner = TyCtxt<'tcx>;

fn interner(&self) -> TyCtxt<'tcx> {
self.tcx
}

fn universe_of_ty(&self, vid: TyVid) -> Option<ty::UniverseIndex> {
// FIXME(BoxyUwU): this is kind of jank and means that printing unresolved
// ty infers will give you the universe of the var it resolved to not the universe
// it actually had. It also means that if you have a `?0.1` and infer it to `u8` then
// try to print out `?0.1` it will just print `?0`.
match self.probe_ty_var(vid) {
Err(universe) => Some(universe),
Ok(_) => None,
}
}

fn universe_of_lt(&self, lt: ty::RegionVid) -> Option<ty::UniverseIndex> {
match self.inner.borrow_mut().unwrap_region_constraints().probe_value(lt) {
Err(universe) => Some(universe),
Ok(_) => None,
}
}

fn universe_of_ct(&self, ct: ConstVid) -> Option<ty::UniverseIndex> {
// Same issue as with `universe_of_ty`
match self.probe_const_var(ct) {
Err(universe) => Some(universe),
Ok(_) => None,
}
}

fn root_ty_var(&self, var: TyVid) -> TyVid {
self.root_var(var)
}

fn root_const_var(&self, var: ConstVid) -> ConstVid {
self.root_const_var(var)
}

fn opportunistic_resolve_ty_var(&self, vid: TyVid) -> Ty<'tcx> {
match self.probe_ty_var(vid) {
Ok(ty) => ty,
Err(_) => Ty::new_var(self.tcx, self.root_var(vid)),
}
}

fn opportunistic_resolve_int_var(&self, vid: IntVid) -> Ty<'tcx> {
self.opportunistic_resolve_int_var(vid)
}

fn opportunistic_resolve_float_var(&self, vid: FloatVid) -> Ty<'tcx> {
self.opportunistic_resolve_float_var(vid)
}

fn opportunistic_resolve_ct_var(&self, vid: ConstVid) -> ty::Const<'tcx> {
match self.probe_const_var(vid) {
Ok(ct) => ct,
Err(_) => ty::Const::new_var(self.tcx, self.root_const_var(vid)),
}
}

fn opportunistic_resolve_effect_var(&self, vid: EffectVid) -> ty::Const<'tcx> {
match self.probe_effect_var(vid) {
Some(ct) => ct,
None => {
ty::Const::new_infer(self.tcx, InferConst::EffectVar(self.root_effect_var(vid)))
}
}
}

fn opportunistic_resolve_lt_var(&self, vid: ty::RegionVid) -> ty::Region<'tcx> {
self.inner.borrow_mut().unwrap_region_constraints().opportunistic_resolve_var(self.tcx, vid)
}

fn defining_opaque_types(&self) -> &'tcx ty::List<LocalDefId> {
self.defining_opaque_types
}

fn next_ty_infer(&self) -> Ty<'tcx> {
self.next_ty_var(DUMMY_SP)
}

fn next_const_infer(&self) -> ty::Const<'tcx> {
self.next_const_var(DUMMY_SP)
}

fn fresh_args_for_item(&self, def_id: DefId) -> ty::GenericArgsRef<'tcx> {
self.fresh_args_for_item(DUMMY_SP, def_id)
}

fn instantiate_binder_with_infer<T: TypeFoldable<Self::Interner> + Copy>(
&self,
value: ty::Binder<'tcx, T>,
) -> T {
self.instantiate_binder_with_fresh_vars(
DUMMY_SP,
BoundRegionConversionTime::HigherRankedType,
value,
)
}

fn enter_forall<T: TypeFoldable<TyCtxt<'tcx>> + Copy, U>(
&self,
value: ty::Binder<'tcx, T>,
f: impl FnOnce(T) -> U,
) -> U {
self.enter_forall(value, f)
}

fn relate<T: Relate<TyCtxt<'tcx>>>(
&self,
param_env: ty::ParamEnv<'tcx>,
lhs: T,
variance: ty::Variance,
rhs: T,
) -> Result<Vec<Goal<'tcx, ty::Predicate<'tcx>>>, NoSolution> {
self.at(&ObligationCause::dummy(), param_env).relate_no_trace(lhs, variance, rhs)
}

fn eq_structurally_relating_aliases<T: Relate<TyCtxt<'tcx>>>(
&self,
param_env: ty::ParamEnv<'tcx>,
lhs: T,
rhs: T,
) -> Result<Vec<Goal<'tcx, ty::Predicate<'tcx>>>, NoSolution> {
self.at(&ObligationCause::dummy(), param_env)
.eq_structurally_relating_aliases_no_trace(lhs, rhs)
}

fn resolve_vars_if_possible<T>(&self, value: T) -> T
where
T: TypeFoldable<TyCtxt<'tcx>>,
{
self.resolve_vars_if_possible(value)
}

fn probe<T>(&self, probe: impl FnOnce() -> T) -> T {
self.probe(|_| probe())
}
}

/// See the `error_reporting` module for more details.
#[derive(Clone, Copy, Debug, PartialEq, Eq, TypeFoldable, TypeVisitable)]
pub enum ValuePairs<'tcx> {
Expand Down Expand Up @@ -831,6 +688,10 @@ impl<'tcx> InferCtxt<'tcx> {
self.tcx.dcx()
}

pub fn defining_opaque_types(&self) -> &'tcx ty::List<LocalDefId> {
self.defining_opaque_types
}

pub fn next_trait_solver(&self) -> bool {
self.next_trait_solver
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ macro_rules! arena_types {
rustc_middle::ty::EarlyBinder<'tcx, rustc_middle::ty::Ty<'tcx>>
>,
[] external_constraints: rustc_middle::traits::solve::ExternalConstraintsData<rustc_middle::ty::TyCtxt<'tcx>>,
[] predefined_opaques_in_body: rustc_middle::traits::solve::PredefinedOpaquesData<'tcx>,
[] predefined_opaques_in_body: rustc_middle::traits::solve::PredefinedOpaquesData<rustc_middle::ty::TyCtxt<'tcx>>,
[decode] doc_link_resolutions: rustc_hir::def::DocLinkResMap,
[] stripped_cfg_items: rustc_ast::expand::StrippedCfgItem,
[] mod_child: rustc_middle::metadata::ModChild,
Expand Down
14 changes: 4 additions & 10 deletions compiler/rustc_middle/src/traits/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ use rustc_type_ir as ir;
pub use rustc_type_ir::solve::*;

use crate::ty::{
self, FallibleTypeFolder, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeVisitable, TypeVisitor,
self, FallibleTypeFolder, TyCtxt, TypeFoldable, TypeFolder, TypeVisitable, TypeVisitor,
};

mod cache;

pub use cache::{CacheData, EvaluationCache};
pub use cache::EvaluationCache;

pub type Goal<'tcx, P> = ir::solve::Goal<TyCtxt<'tcx>, P>;
pub type QueryInput<'tcx, P> = ir::solve::QueryInput<TyCtxt<'tcx>, P>;
Expand All @@ -19,17 +19,11 @@ pub type CandidateSource<'tcx> = ir::solve::CandidateSource<TyCtxt<'tcx>>;
pub type CanonicalInput<'tcx, P = ty::Predicate<'tcx>> = ir::solve::CanonicalInput<TyCtxt<'tcx>, P>;
pub type CanonicalResponse<'tcx> = ir::solve::CanonicalResponse<TyCtxt<'tcx>>;

/// Additional constraints returned on success.
#[derive(Debug, PartialEq, Eq, Clone, Hash, HashStable, Default)]
pub struct PredefinedOpaquesData<'tcx> {
pub opaque_types: Vec<(ty::OpaqueTypeKey<'tcx>, Ty<'tcx>)>,
}

#[derive(Debug, PartialEq, Eq, Copy, Clone, Hash, HashStable)]
pub struct PredefinedOpaques<'tcx>(pub(crate) Interned<'tcx, PredefinedOpaquesData<'tcx>>);
pub struct PredefinedOpaques<'tcx>(pub(crate) Interned<'tcx, PredefinedOpaquesData<TyCtxt<'tcx>>>);

impl<'tcx> std::ops::Deref for PredefinedOpaques<'tcx> {
type Target = PredefinedOpaquesData<'tcx>;
type Target = PredefinedOpaquesData<TyCtxt<'tcx>>;

fn deref(&self) -> &Self::Target {
&self.0
Expand Down
28 changes: 11 additions & 17 deletions compiler/rustc_middle/src/traits/solve/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use rustc_data_structures::sync::Lock;
use rustc_query_system::cache::WithDepNode;
use rustc_query_system::dep_graph::DepNodeIndex;
use rustc_session::Limit;
use rustc_type_ir::solve::CacheData;

/// The trait solver cache used by `-Znext-solver`.
///
/// FIXME(@lcnr): link to some official documentation of how
Expand All @@ -14,17 +16,9 @@ pub struct EvaluationCache<'tcx> {
map: Lock<FxHashMap<CanonicalInput<'tcx>, CacheEntry<'tcx>>>,
}

#[derive(Debug, PartialEq, Eq)]
pub struct CacheData<'tcx> {
pub result: QueryResult<'tcx>,
pub proof_tree: Option<&'tcx inspect::CanonicalGoalEvaluationStep<TyCtxt<'tcx>>>,
pub additional_depth: usize,
pub encountered_overflow: bool,
}

impl<'tcx> EvaluationCache<'tcx> {
impl<'tcx> rustc_type_ir::inherent::EvaluationCache<TyCtxt<'tcx>> for &'tcx EvaluationCache<'tcx> {
/// Insert a final result into the global cache.
pub fn insert(
fn insert(
&self,
tcx: TyCtxt<'tcx>,
key: CanonicalInput<'tcx>,
Expand All @@ -48,7 +42,7 @@ impl<'tcx> EvaluationCache<'tcx> {
if cfg!(debug_assertions) {
drop(map);
let expected = CacheData { result, proof_tree, additional_depth, encountered_overflow };
let actual = self.get(tcx, key, [], Limit(additional_depth));
let actual = self.get(tcx, key, [], additional_depth);
if !actual.as_ref().is_some_and(|actual| expected == *actual) {
bug!("failed to lookup inserted element for {key:?}: {expected:?} != {actual:?}");
}
Expand All @@ -59,13 +53,13 @@ impl<'tcx> EvaluationCache<'tcx> {
/// and handling root goals of coinductive cycles.
///
/// If this returns `Some` the cache result can be used.
pub fn get(
fn get(
&self,
tcx: TyCtxt<'tcx>,
key: CanonicalInput<'tcx>,
stack_entries: impl IntoIterator<Item = CanonicalInput<'tcx>>,
available_depth: Limit,
) -> Option<CacheData<'tcx>> {
available_depth: usize,
) -> Option<CacheData<TyCtxt<'tcx>>> {
let map = self.map.borrow();
let entry = map.get(&key)?;

Expand All @@ -76,7 +70,7 @@ impl<'tcx> EvaluationCache<'tcx> {
}

if let Some(ref success) = entry.success {
if available_depth.value_within_limit(success.additional_depth) {
if Limit(available_depth).value_within_limit(success.additional_depth) {
let QueryData { result, proof_tree } = success.data.get(tcx);
return Some(CacheData {
result,
Expand All @@ -87,12 +81,12 @@ impl<'tcx> EvaluationCache<'tcx> {
}
}

entry.with_overflow.get(&available_depth.0).map(|e| {
entry.with_overflow.get(&available_depth).map(|e| {
let QueryData { result, proof_tree } = e.get(tcx);
CacheData {
result,
proof_tree,
additional_depth: available_depth.0,
additional_depth: available_depth,
encountered_overflow: true,
}
})
Expand Down
10 changes: 9 additions & 1 deletion compiler/rustc_middle/src/ty/adt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,22 @@ impl<'tcx> rustc_type_ir::inherent::AdtDef<TyCtxt<'tcx>> for AdtDef<'tcx> {
self.did()
}

fn is_struct(self) -> bool {
self.is_struct()
}

fn struct_tail_ty(self, interner: TyCtxt<'tcx>) -> Option<ty::EarlyBinder<'tcx, Ty<'tcx>>> {
Some(interner.type_of(self.non_enum_variant().tail_opt()?.did))
}

fn is_phantom_data(self) -> bool {
self.is_phantom_data()
}

fn all_field_tys(
self,
tcx: TyCtxt<'tcx>,
) -> ty::EarlyBinder<'tcx, impl Iterator<Item = Ty<'tcx>>> {
) -> ty::EarlyBinder<'tcx, impl IntoIterator<Item = Ty<'tcx>>> {
ty::EarlyBinder::bind(
self.all_fields().map(move |field| tcx.type_of(field.did).skip_binder()),
)
Expand Down
6 changes: 5 additions & 1 deletion compiler/rustc_middle/src/ty/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ mod valtree;

pub use int::*;
pub use kind::*;
use rustc_span::Span;
use rustc_span::DUMMY_SP;
use rustc_span::{ErrorGuaranteed, Span};
pub use valtree::*;

pub type ConstKind<'tcx> = ir::ConstKind<TyCtxt<'tcx>>;
Expand Down Expand Up @@ -176,6 +176,10 @@ impl<'tcx> rustc_type_ir::inherent::Const<TyCtxt<'tcx>> for Const<'tcx> {
fn new_expr(interner: TyCtxt<'tcx>, expr: ty::Expr<'tcx>) -> Self {
Const::new_expr(interner, expr)
}

fn new_error(interner: TyCtxt<'tcx>, guar: ErrorGuaranteed) -> Self {
Const::new_error(interner, guar)
}
}

impl<'tcx> Const<'tcx> {
Expand Down
Loading

0 comments on commit 8fcd4dd

Please sign in to comment.