From a21ba348964e6d72a3369daa4bc22af507eb7778 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Sat, 29 Jun 2024 23:01:24 -0400 Subject: [PATCH 1/2] add TyCtxt::as_lang_item, use in new solver --- compiler/rustc_hir/src/lang_items.rs | 13 ++- .../rustc_middle/src/middle/lang_items.rs | 4 + compiler/rustc_middle/src/ty/context.rs | 95 +++++++++++-------- .../src/solve/assembly/mod.rs | 87 +++++++++-------- compiler/rustc_type_ir/src/interner.rs | 2 + 5 files changed, 126 insertions(+), 75 deletions(-) diff --git a/compiler/rustc_hir/src/lang_items.rs b/compiler/rustc_hir/src/lang_items.rs index 3c44acb16577a..30c0e40206aaf 100644 --- a/compiler/rustc_hir/src/lang_items.rs +++ b/compiler/rustc_hir/src/lang_items.rs @@ -11,6 +11,7 @@ use crate::def_id::DefId; use crate::{MethodKind, Target}; use rustc_ast as ast; +use rustc_data_structures::fx::FxIndexMap; use rustc_data_structures::stable_hasher::{HashStable, StableHasher}; use rustc_macros::{Decodable, Encodable, HashStable_Generic}; use rustc_span::symbol::{kw, sym, Symbol}; @@ -23,6 +24,7 @@ pub struct LanguageItems { /// Mappings from lang items to their possibly found [`DefId`]s. /// The index corresponds to the order in [`LangItem`]. items: [Option; std::mem::variant_count::()], + reverse_items: FxIndexMap, /// Lang items that were not found during collection. pub missing: Vec, } @@ -30,7 +32,11 @@ pub struct LanguageItems { impl LanguageItems { /// Construct an empty collection of lang items and no missing ones. pub fn new() -> Self { - Self { items: [None; std::mem::variant_count::()], missing: Vec::new() } + Self { + items: [None; std::mem::variant_count::()], + reverse_items: FxIndexMap::default(), + missing: Vec::new(), + } } pub fn get(&self, item: LangItem) -> Option { @@ -39,6 +45,11 @@ impl LanguageItems { pub fn set(&mut self, item: LangItem, def_id: DefId) { self.items[item as usize] = Some(def_id); + self.reverse_items.insert(def_id, item); + } + + pub fn from_def_id(&self, def_id: DefId) -> Option { + self.reverse_items.get(&def_id).copied() } pub fn iter(&self) -> impl Iterator + '_ { diff --git a/compiler/rustc_middle/src/middle/lang_items.rs b/compiler/rustc_middle/src/middle/lang_items.rs index e76d7af6e4aba..a0c9af436e207 100644 --- a/compiler/rustc_middle/src/middle/lang_items.rs +++ b/compiler/rustc_middle/src/middle/lang_items.rs @@ -27,6 +27,10 @@ impl<'tcx> TyCtxt<'tcx> { self.lang_items().get(lang_item) == Some(def_id) } + pub fn as_lang_item(self, def_id: DefId) -> Option { + self.lang_items().from_def_id(def_id) + } + /// Given a [`DefId`] of one of the [`Fn`], [`FnMut`] or [`FnOnce`] traits, /// returns a corresponding [`ty::ClosureKind`]. /// For any other [`DefId`] return `None`. diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs index 9225ae6300fc6..155a76713cab6 100644 --- a/compiler/rustc_middle/src/ty/context.rs +++ b/compiler/rustc_middle/src/ty/context.rs @@ -366,6 +366,10 @@ impl<'tcx> Interner for TyCtxt<'tcx> { self.is_lang_item(def_id, trait_lang_item_to_lang_item(lang_item)) } + fn as_lang_item(self, def_id: DefId) -> Option { + lang_item_to_trait_lang_item(self.lang_items().from_def_id(def_id)?) + } + fn associated_type_def_ids(self, def_id: DefId) -> impl IntoIterator { self.associated_items(def_id) .in_definition_order() @@ -584,46 +588,63 @@ impl<'tcx> Interner for TyCtxt<'tcx> { } } -fn trait_lang_item_to_lang_item(lang_item: TraitSolverLangItem) -> LangItem { - match lang_item { - TraitSolverLangItem::AsyncDestruct => LangItem::AsyncDestruct, - TraitSolverLangItem::AsyncFnKindHelper => LangItem::AsyncFnKindHelper, - TraitSolverLangItem::AsyncFnKindUpvars => LangItem::AsyncFnKindUpvars, - TraitSolverLangItem::AsyncFnOnceOutput => LangItem::AsyncFnOnceOutput, - TraitSolverLangItem::AsyncIterator => LangItem::AsyncIterator, - TraitSolverLangItem::CallOnceFuture => LangItem::CallOnceFuture, - TraitSolverLangItem::CallRefFuture => LangItem::CallRefFuture, - TraitSolverLangItem::Clone => LangItem::Clone, - TraitSolverLangItem::Copy => LangItem::Copy, - TraitSolverLangItem::Coroutine => LangItem::Coroutine, - TraitSolverLangItem::CoroutineReturn => LangItem::CoroutineReturn, - TraitSolverLangItem::CoroutineYield => LangItem::CoroutineYield, - TraitSolverLangItem::Destruct => LangItem::Destruct, - TraitSolverLangItem::DiscriminantKind => LangItem::DiscriminantKind, - TraitSolverLangItem::DynMetadata => LangItem::DynMetadata, - TraitSolverLangItem::EffectsMaybe => LangItem::EffectsMaybe, - TraitSolverLangItem::EffectsIntersection => LangItem::EffectsIntersection, - TraitSolverLangItem::EffectsIntersectionOutput => LangItem::EffectsIntersectionOutput, - TraitSolverLangItem::EffectsNoRuntime => LangItem::EffectsNoRuntime, - TraitSolverLangItem::EffectsRuntime => LangItem::EffectsRuntime, - TraitSolverLangItem::FnPtrTrait => LangItem::FnPtrTrait, - TraitSolverLangItem::FusedIterator => LangItem::FusedIterator, - TraitSolverLangItem::Future => LangItem::Future, - TraitSolverLangItem::FutureOutput => LangItem::FutureOutput, - TraitSolverLangItem::Iterator => LangItem::Iterator, - TraitSolverLangItem::Metadata => LangItem::Metadata, - TraitSolverLangItem::Option => LangItem::Option, - TraitSolverLangItem::PointeeTrait => LangItem::PointeeTrait, - TraitSolverLangItem::PointerLike => LangItem::PointerLike, - TraitSolverLangItem::Poll => LangItem::Poll, - TraitSolverLangItem::Sized => LangItem::Sized, - TraitSolverLangItem::TransmuteTrait => LangItem::TransmuteTrait, - TraitSolverLangItem::Tuple => LangItem::Tuple, - TraitSolverLangItem::Unpin => LangItem::Unpin, - TraitSolverLangItem::Unsize => LangItem::Unsize, +macro_rules! bidirectional_lang_item_map { + ($($name:ident),+ $(,)?) => { + fn trait_lang_item_to_lang_item(lang_item: TraitSolverLangItem) -> LangItem { + match lang_item { + $(TraitSolverLangItem::$name => LangItem::$name,)+ + } + } + + fn lang_item_to_trait_lang_item(lang_item: LangItem) -> Option { + Some(match lang_item { + $(LangItem::$name => TraitSolverLangItem::$name,)+ + _ => return None, + }) + } } } +bidirectional_lang_item_map! { +// tidy-alphabetical-start + AsyncDestruct, + AsyncFnKindHelper, + AsyncFnKindUpvars, + AsyncFnOnceOutput, + AsyncIterator, + CallOnceFuture, + CallRefFuture, + Clone, + Copy, + Coroutine, + CoroutineReturn, + CoroutineYield, + Destruct, + DiscriminantKind, + DynMetadata, + EffectsMaybe, + EffectsIntersection, + EffectsIntersectionOutput, + EffectsNoRuntime, + EffectsRuntime, + FnPtrTrait, + FusedIterator, + Future, + FutureOutput, + Iterator, + Metadata, + Option, + PointeeTrait, + PointerLike, + Poll, + Sized, + TransmuteTrait, + Tuple, + Unpin, + Unsize, +// tidy-alphabetical-end +} + impl<'tcx> rustc_type_ir::inherent::DefId> for DefId { fn as_local(self) -> Option { self.as_local() diff --git a/compiler/rustc_next_trait_solver/src/solve/assembly/mod.rs b/compiler/rustc_next_trait_solver/src/solve/assembly/mod.rs index 6ee684605ac60..20a0da16c7521 100644 --- a/compiler/rustc_next_trait_solver/src/solve/assembly/mod.rs +++ b/compiler/rustc_next_trait_solver/src/solve/assembly/mod.rs @@ -387,48 +387,61 @@ where G::consider_auto_trait_candidate(self, goal) } else if cx.trait_is_alias(trait_def_id) { G::consider_trait_alias_candidate(self, goal) - } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::Sized) { - G::consider_builtin_sized_candidate(self, goal) - } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::Copy) - || cx.is_lang_item(trait_def_id, TraitSolverLangItem::Clone) - { - G::consider_builtin_copy_clone_candidate(self, goal) - } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::PointerLike) { - G::consider_builtin_pointer_like_candidate(self, goal) - } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::FnPtrTrait) { - G::consider_builtin_fn_ptr_trait_candidate(self, goal) } else if let Some(kind) = self.cx().fn_trait_kind_from_def_id(trait_def_id) { G::consider_builtin_fn_trait_candidates(self, goal, kind) } else if let Some(kind) = self.cx().async_fn_trait_kind_from_def_id(trait_def_id) { G::consider_builtin_async_fn_trait_candidates(self, goal, kind) - } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::AsyncFnKindHelper) { - G::consider_builtin_async_fn_kind_helper_candidate(self, goal) - } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::Tuple) { - G::consider_builtin_tuple_candidate(self, goal) - } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::PointeeTrait) { - G::consider_builtin_pointee_candidate(self, goal) - } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::Future) { - G::consider_builtin_future_candidate(self, goal) - } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::Iterator) { - G::consider_builtin_iterator_candidate(self, goal) - } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::FusedIterator) { - G::consider_builtin_fused_iterator_candidate(self, goal) - } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::AsyncIterator) { - G::consider_builtin_async_iterator_candidate(self, goal) - } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::Coroutine) { - G::consider_builtin_coroutine_candidate(self, goal) - } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::DiscriminantKind) { - G::consider_builtin_discriminant_kind_candidate(self, goal) - } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::AsyncDestruct) { - G::consider_builtin_async_destruct_candidate(self, goal) - } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::Destruct) { - G::consider_builtin_destruct_candidate(self, goal) - } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::TransmuteTrait) { - G::consider_builtin_transmute_candidate(self, goal) - } else if cx.is_lang_item(trait_def_id, TraitSolverLangItem::EffectsIntersection) { - G::consider_builtin_effects_intersection_candidate(self, goal) } else { - Err(NoSolution) + match cx.as_lang_item(trait_def_id) { + Some(TraitSolverLangItem::Sized) => G::consider_builtin_sized_candidate(self, goal), + Some(TraitSolverLangItem::Copy | TraitSolverLangItem::Clone) => { + G::consider_builtin_copy_clone_candidate(self, goal) + } + Some(TraitSolverLangItem::PointerLike) => { + G::consider_builtin_pointer_like_candidate(self, goal) + } + Some(TraitSolverLangItem::FnPtrTrait) => { + G::consider_builtin_fn_ptr_trait_candidate(self, goal) + } + Some(TraitSolverLangItem::AsyncFnKindHelper) => { + G::consider_builtin_async_fn_kind_helper_candidate(self, goal) + } + Some(TraitSolverLangItem::Tuple) => G::consider_builtin_tuple_candidate(self, goal), + Some(TraitSolverLangItem::PointeeTrait) => { + G::consider_builtin_pointee_candidate(self, goal) + } + Some(TraitSolverLangItem::Future) => { + G::consider_builtin_future_candidate(self, goal) + } + Some(TraitSolverLangItem::Iterator) => { + G::consider_builtin_iterator_candidate(self, goal) + } + Some(TraitSolverLangItem::FusedIterator) => { + G::consider_builtin_fused_iterator_candidate(self, goal) + } + Some(TraitSolverLangItem::AsyncIterator) => { + G::consider_builtin_async_iterator_candidate(self, goal) + } + Some(TraitSolverLangItem::Coroutine) => { + G::consider_builtin_coroutine_candidate(self, goal) + } + Some(TraitSolverLangItem::DiscriminantKind) => { + G::consider_builtin_discriminant_kind_candidate(self, goal) + } + Some(TraitSolverLangItem::AsyncDestruct) => { + G::consider_builtin_async_destruct_candidate(self, goal) + } + Some(TraitSolverLangItem::Destruct) => { + G::consider_builtin_destruct_candidate(self, goal) + } + Some(TraitSolverLangItem::TransmuteTrait) => { + G::consider_builtin_transmute_candidate(self, goal) + } + Some(TraitSolverLangItem::EffectsIntersection) => { + G::consider_builtin_effects_intersection_candidate(self, goal) + } + _ => Err(NoSolution), + } }; candidates.extend(result); diff --git a/compiler/rustc_type_ir/src/interner.rs b/compiler/rustc_type_ir/src/interner.rs index 6665158c7cd34..84c7e6e8f2003 100644 --- a/compiler/rustc_type_ir/src/interner.rs +++ b/compiler/rustc_type_ir/src/interner.rs @@ -220,6 +220,8 @@ pub trait Interner: fn is_lang_item(self, def_id: Self::DefId, lang_item: TraitSolverLangItem) -> bool; + fn as_lang_item(self, def_id: Self::DefId) -> Option; + fn associated_type_def_ids(self, def_id: Self::DefId) -> impl IntoIterator; // FIXME: move `fast_reject` into `rustc_type_ir`. From 5a837515f28eed106dd031b0243d6476eba58dd9 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Sat, 29 Jun 2024 23:05:35 -0400 Subject: [PATCH 2/2] Make fn traits into first-class TraitSolverLangItems to avoid needing fn_trait_kind_from_def_id --- compiler/rustc_middle/src/ty/context.rs | 16 +++++----- .../src/solve/assembly/mod.rs | 30 ++++++++++++++++--- compiler/rustc_type_ir/src/interner.rs | 4 --- compiler/rustc_type_ir/src/lang_items.rs | 6 ++++ 4 files changed, 39 insertions(+), 17 deletions(-) diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs index 155a76713cab6..7c3d2df72faa2 100644 --- a/compiler/rustc_middle/src/ty/context.rs +++ b/compiler/rustc_middle/src/ty/context.rs @@ -537,14 +537,6 @@ impl<'tcx> Interner for TyCtxt<'tcx> { self.trait_def(trait_def_id).implement_via_object } - fn fn_trait_kind_from_def_id(self, trait_def_id: DefId) -> Option { - self.fn_trait_kind_from_def_id(trait_def_id) - } - - fn async_fn_trait_kind_from_def_id(self, trait_def_id: DefId) -> Option { - self.async_fn_trait_kind_from_def_id(trait_def_id) - } - fn supertrait_def_ids(self, trait_def_id: DefId) -> impl IntoIterator { self.supertrait_def_ids(trait_def_id) } @@ -608,8 +600,11 @@ macro_rules! bidirectional_lang_item_map { bidirectional_lang_item_map! { // tidy-alphabetical-start AsyncDestruct, + AsyncFn, AsyncFnKindHelper, AsyncFnKindUpvars, + AsyncFnMut, + AsyncFnOnce, AsyncFnOnceOutput, AsyncIterator, CallOnceFuture, @@ -622,11 +617,14 @@ bidirectional_lang_item_map! { Destruct, DiscriminantKind, DynMetadata, - EffectsMaybe, EffectsIntersection, EffectsIntersectionOutput, + EffectsMaybe, EffectsNoRuntime, EffectsRuntime, + Fn, + FnMut, + FnOnce, FnPtrTrait, FusedIterator, Future, diff --git a/compiler/rustc_next_trait_solver/src/solve/assembly/mod.rs b/compiler/rustc_next_trait_solver/src/solve/assembly/mod.rs index 20a0da16c7521..38a4f7dfe25cb 100644 --- a/compiler/rustc_next_trait_solver/src/solve/assembly/mod.rs +++ b/compiler/rustc_next_trait_solver/src/solve/assembly/mod.rs @@ -387,16 +387,38 @@ where G::consider_auto_trait_candidate(self, goal) } else if cx.trait_is_alias(trait_def_id) { G::consider_trait_alias_candidate(self, goal) - } else if let Some(kind) = self.cx().fn_trait_kind_from_def_id(trait_def_id) { - G::consider_builtin_fn_trait_candidates(self, goal, kind) - } else if let Some(kind) = self.cx().async_fn_trait_kind_from_def_id(trait_def_id) { - G::consider_builtin_async_fn_trait_candidates(self, goal, kind) } else { match cx.as_lang_item(trait_def_id) { Some(TraitSolverLangItem::Sized) => G::consider_builtin_sized_candidate(self, goal), Some(TraitSolverLangItem::Copy | TraitSolverLangItem::Clone) => { G::consider_builtin_copy_clone_candidate(self, goal) } + Some(TraitSolverLangItem::Fn) => { + G::consider_builtin_fn_trait_candidates(self, goal, ty::ClosureKind::Fn) + } + Some(TraitSolverLangItem::FnMut) => { + G::consider_builtin_fn_trait_candidates(self, goal, ty::ClosureKind::FnMut) + } + Some(TraitSolverLangItem::FnOnce) => { + G::consider_builtin_fn_trait_candidates(self, goal, ty::ClosureKind::FnOnce) + } + Some(TraitSolverLangItem::AsyncFn) => { + G::consider_builtin_async_fn_trait_candidates(self, goal, ty::ClosureKind::Fn) + } + Some(TraitSolverLangItem::AsyncFnMut) => { + G::consider_builtin_async_fn_trait_candidates( + self, + goal, + ty::ClosureKind::FnMut, + ) + } + Some(TraitSolverLangItem::AsyncFnOnce) => { + G::consider_builtin_async_fn_trait_candidates( + self, + goal, + ty::ClosureKind::FnOnce, + ) + } Some(TraitSolverLangItem::PointerLike) => { G::consider_builtin_pointer_like_candidate(self, goal) } diff --git a/compiler/rustc_type_ir/src/interner.rs b/compiler/rustc_type_ir/src/interner.rs index 84c7e6e8f2003..f76a2f55c482e 100644 --- a/compiler/rustc_type_ir/src/interner.rs +++ b/compiler/rustc_type_ir/src/interner.rs @@ -254,10 +254,6 @@ pub trait Interner: fn trait_may_be_implemented_via_object(self, trait_def_id: Self::DefId) -> bool; - fn fn_trait_kind_from_def_id(self, trait_def_id: Self::DefId) -> Option; - - fn async_fn_trait_kind_from_def_id(self, trait_def_id: Self::DefId) -> Option; - fn supertrait_def_ids(self, trait_def_id: Self::DefId) -> impl IntoIterator; diff --git a/compiler/rustc_type_ir/src/lang_items.rs b/compiler/rustc_type_ir/src/lang_items.rs index cf00c37caa287..265a411882735 100644 --- a/compiler/rustc_type_ir/src/lang_items.rs +++ b/compiler/rustc_type_ir/src/lang_items.rs @@ -3,8 +3,11 @@ pub enum TraitSolverLangItem { // tidy-alphabetical-start AsyncDestruct, + AsyncFn, AsyncFnKindHelper, AsyncFnKindUpvars, + AsyncFnMut, + AsyncFnOnce, AsyncFnOnceOutput, AsyncIterator, CallOnceFuture, @@ -22,6 +25,9 @@ pub enum TraitSolverLangItem { EffectsMaybe, EffectsNoRuntime, EffectsRuntime, + Fn, + FnMut, + FnOnce, FnPtrTrait, FusedIterator, Future,