Skip to content

Commit

Permalink
Rollup merge of rust-lang#121059 - compiler-errors:extension, r=david…
Browse files Browse the repository at this point in the history
…twco,Nilstrieb

Add and use a simple extension trait derive macro in the compiler

Adds `#[extension]` to `rustc_macros` for implementing an extension trait. This expands an impl (with an optional visibility) into two parallel trait + impl definitions.

before:
```rust
pub trait Extension {
  fn a();
}
impl Extension for () {
  fn a() {}
}
```

to:
```rust
#[extension]
pub impl Extension for () {
  fn a() {}
}
```

Opted to just implement it by hand because I couldn't figure if there was a "canonical" choice of extension trait macro in the ecosystem. It's really lightweight anyways, and can always be changed.

I'm interested in adding this because I'd like to later split up the large `TypeErrCtxtExt` traits into several different files. This should make it one step easier.
  • Loading branch information
Nadrieril authored Feb 17, 2024
2 parents c5ffa54 + f624d55 commit 9ccabc8
Show file tree
Hide file tree
Showing 31 changed files with 291 additions and 977 deletions.
13 changes: 3 additions & 10 deletions compiler/rustc_ast_lowering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ use rustc_hir::def::{DefKind, LifetimeRes, Namespace, PartialRes, PerNS, Res};
use rustc_hir::def_id::{LocalDefId, LocalDefIdMap, CRATE_DEF_ID, LOCAL_CRATE};
use rustc_hir::{ConstArg, GenericArg, ItemLocalMap, ParamName, TraitCandidate};
use rustc_index::{Idx, IndexSlice, IndexVec};
use rustc_macros::extension;
use rustc_middle::span_bug;
use rustc_middle::ty::{ResolverAstLowering, TyCtxt};
use rustc_session::parse::{add_feature_diagnostics, feature_err};
Expand Down Expand Up @@ -190,16 +191,8 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
}
}

trait ResolverAstLoweringExt {
fn legacy_const_generic_args(&self, expr: &Expr) -> Option<Vec<usize>>;
fn get_partial_res(&self, id: NodeId) -> Option<PartialRes>;
fn get_import_res(&self, id: NodeId) -> PerNS<Option<Res<NodeId>>>;
fn get_label_res(&self, id: NodeId) -> Option<NodeId>;
fn get_lifetime_res(&self, id: NodeId) -> Option<LifetimeRes>;
fn take_extra_lifetime_params(&mut self, id: NodeId) -> Vec<(Ident, NodeId, LifetimeRes)>;
}

impl ResolverAstLoweringExt for ResolverAstLowering {
#[extension(trait ResolverAstLoweringExt)]
impl ResolverAstLowering {
fn legacy_const_generic_args(&self, expr: &Expr) -> Option<Vec<usize>> {
if let ExprKind::Path(None, path) = &expr.kind {
// Don't perform legacy const generics rewriting if the path already
Expand Down
15 changes: 3 additions & 12 deletions compiler/rustc_borrowck/src/facts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::location::{LocationIndex, LocationTable};
use crate::BorrowIndex;
use polonius_engine::AllFacts as PoloniusFacts;
use polonius_engine::Atom;
use rustc_macros::extension;
use rustc_middle::mir::Local;
use rustc_middle::ty::{RegionVid, TyCtxt};
use rustc_mir_dataflow::move_paths::MovePathIndex;
Expand All @@ -24,20 +25,10 @@ impl polonius_engine::FactTypes for RustcFacts {

pub type AllFacts = PoloniusFacts<RustcFacts>;

pub(crate) trait AllFactsExt {
#[extension(pub(crate) trait AllFactsExt)]
impl AllFacts {
/// Returns `true` if there is a need to gather `AllFacts` given the
/// current `-Z` flags.
fn enabled(tcx: TyCtxt<'_>) -> bool;

fn write_to_dir(
&self,
dir: impl AsRef<Path>,
location_table: &LocationTable,
) -> Result<(), Box<dyn Error>>;
}

impl AllFactsExt for AllFacts {
/// Return
fn enabled(tcx: TyCtxt<'_>) -> bool {
tcx.sess.opts.unstable_opts.nll_facts
|| tcx.sess.opts.unstable_opts.polonius.is_legacy_enabled()
Expand Down
14 changes: 3 additions & 11 deletions compiler/rustc_borrowck/src/place_ext.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
use crate::borrow_set::LocalsStateAtExit;
use rustc_hir as hir;
use rustc_macros::extension;
use rustc_middle::mir::ProjectionElem;
use rustc_middle::mir::{Body, Mutability, Place};
use rustc_middle::ty::{self, TyCtxt};

/// Extension methods for the `Place` type.
pub trait PlaceExt<'tcx> {
#[extension(pub trait PlaceExt<'tcx>)]
impl<'tcx> Place<'tcx> {
/// Returns `true` if we can safely ignore borrows of this place.
/// This is true whenever there is no action that the user can do
/// to the place `self` that would invalidate the borrow. This is true
/// for borrows of raw pointer dereferents as well as shared references.
fn ignore_borrow(
&self,
tcx: TyCtxt<'tcx>,
body: &Body<'tcx>,
locals_state_at_exit: &LocalsStateAtExit,
) -> bool;
}

impl<'tcx> PlaceExt<'tcx> for Place<'tcx> {
fn ignore_borrow(
&self,
tcx: TyCtxt<'tcx>,
Expand Down
12 changes: 3 additions & 9 deletions compiler/rustc_borrowck/src/region_infer/opaque_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use rustc_hir::OpaqueTyOrigin;
use rustc_infer::infer::InferCtxt;
use rustc_infer::infer::TyCtxtInferExt as _;
use rustc_infer::traits::{Obligation, ObligationCause};
use rustc_macros::extension;
use rustc_middle::traits::DefiningAnchor;
use rustc_middle::ty::visit::TypeVisitableExt;
use rustc_middle::ty::{self, OpaqueHiddenType, OpaqueTypeKey, Ty, TyCtxt, TypeFoldable};
Expand Down Expand Up @@ -225,15 +226,8 @@ impl<'tcx> RegionInferenceContext<'tcx> {
}
}

pub trait InferCtxtExt<'tcx> {
fn infer_opaque_definition_from_instantiation(
&self,
opaque_type_key: OpaqueTypeKey<'tcx>,
instantiated_ty: OpaqueHiddenType<'tcx>,
) -> Ty<'tcx>;
}

impl<'tcx> InferCtxtExt<'tcx> for InferCtxt<'tcx> {
#[extension(pub trait InferCtxtExt<'tcx>)]
impl<'tcx> InferCtxt<'tcx> {
/// Given the fully resolved, instantiated type for an opaque
/// type, i.e., the value of an inference variable like C1 or C2
/// (*), computes the "definition type" for an opaque type
Expand Down
24 changes: 3 additions & 21 deletions compiler/rustc_borrowck/src/universal_regions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use rustc_hir::lang_items::LangItem;
use rustc_hir::BodyOwnerKind;
use rustc_index::IndexVec;
use rustc_infer::infer::NllRegionVariableOrigin;
use rustc_macros::extension;
use rustc_middle::ty::fold::TypeFoldable;
use rustc_middle::ty::print::with_no_trimmed_paths;
use rustc_middle::ty::{self, InlineConstArgs, InlineConstArgsParts, RegionVid, Ty, TyCtxt};
Expand Down Expand Up @@ -793,27 +794,8 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> {
}
}

trait InferCtxtExt<'tcx> {
fn replace_free_regions_with_nll_infer_vars<T>(
&self,
origin: NllRegionVariableOrigin,
value: T,
) -> T
where
T: TypeFoldable<TyCtxt<'tcx>>;

fn replace_bound_regions_with_nll_infer_vars<T>(
&self,
origin: NllRegionVariableOrigin,
all_outlive_scope: LocalDefId,
value: ty::Binder<'tcx, T>,
indices: &mut UniversalRegionIndices<'tcx>,
) -> T
where
T: TypeFoldable<TyCtxt<'tcx>>;
}

impl<'cx, 'tcx> InferCtxtExt<'tcx> for BorrowckInferCtxt<'cx, 'tcx> {
#[extension(trait InferCtxtExt<'tcx>)]
impl<'cx, 'tcx> BorrowckInferCtxt<'cx, 'tcx> {
#[instrument(skip(self), level = "debug")]
fn replace_free_regions_with_nll_infer_vars<T>(
&self,
Expand Down
14 changes: 3 additions & 11 deletions compiler/rustc_hir_analysis/src/collect/resolve_bound_vars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use rustc_hir::def::{DefKind, Res};
use rustc_hir::def_id::LocalDefId;
use rustc_hir::intravisit::{self, Visitor};
use rustc_hir::{GenericArg, GenericParam, GenericParamKind, HirIdMap, LifetimeName, Node};
use rustc_macros::extension;
use rustc_middle::bug;
use rustc_middle::hir::nested_filter;
use rustc_middle::middle::resolve_bound_vars::*;
Expand All @@ -27,17 +28,8 @@ use std::fmt;

use crate::errors;

trait RegionExt {
fn early(param: &GenericParam<'_>) -> (LocalDefId, ResolvedArg);

fn late(index: u32, param: &GenericParam<'_>) -> (LocalDefId, ResolvedArg);

fn id(&self) -> Option<DefId>;

fn shifted(self, amount: u32) -> ResolvedArg;
}

impl RegionExt for ResolvedArg {
#[extension(trait RegionExt)]
impl ResolvedArg {
fn early(param: &GenericParam<'_>) -> (LocalDefId, ResolvedArg) {
debug!("ResolvedArg::early: def_id={:?}", param.def_id);
(param.def_id, ResolvedArg::EarlyBound(param.def_id.to_def_id()))
Expand Down
26 changes: 6 additions & 20 deletions compiler/rustc_infer/src/infer/canonical/instantiate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,23 @@ use rustc_middle::ty::{self, TyCtxt};

/// FIXME(-Znext-solver): This or public because it is shared with the
/// new trait solver implementation. We should deduplicate canonicalization.
pub trait CanonicalExt<'tcx, V> {
#[extension(pub trait CanonicalExt<'tcx, V>)]
impl<'tcx, V> Canonical<'tcx, V> {
/// Instantiate the wrapped value, replacing each canonical value
/// with the value given in `var_values`.
fn instantiate(&self, tcx: TyCtxt<'tcx>, var_values: &CanonicalVarValues<'tcx>) -> V
where
V: TypeFoldable<TyCtxt<'tcx>>;
V: TypeFoldable<TyCtxt<'tcx>>,
{
self.instantiate_projected(tcx, var_values, |value| value.clone())
}

/// Allows one to apply a instantiation to some subset of
/// `self.value`. Invoke `projection_fn` with `self.value` to get
/// a value V that is expressed in terms of the same canonical
/// variables bound in `self` (usually this extracts from subset
/// of `self`). Apply the instantiation `var_values` to this value
/// V, replacing each of the canonical variables.
fn instantiate_projected<T>(
&self,
tcx: TyCtxt<'tcx>,
var_values: &CanonicalVarValues<'tcx>,
projection_fn: impl FnOnce(&V) -> T,
) -> T
where
T: TypeFoldable<TyCtxt<'tcx>>;
}

impl<'tcx, V> CanonicalExt<'tcx, V> for Canonical<'tcx, V> {
fn instantiate(&self, tcx: TyCtxt<'tcx>, var_values: &CanonicalVarValues<'tcx>) -> V
where
V: TypeFoldable<TyCtxt<'tcx>>,
{
self.instantiate_projected(tcx, var_values, |value| value.clone())
}

fn instantiate_projected<T>(
&self,
tcx: TyCtxt<'tcx>,
Expand Down
15 changes: 2 additions & 13 deletions compiler/rustc_infer/src/infer/error_reporting/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2786,19 +2786,8 @@ pub enum FailureCode {
Error0644,
}

pub trait ObligationCauseExt<'tcx> {
fn as_failure_code(&self, terr: TypeError<'tcx>) -> FailureCode;

fn as_failure_code_diag(
&self,
terr: TypeError<'tcx>,
span: Span,
subdiags: Vec<TypeErrorAdditionalDiags>,
) -> ObligationCauseFailureCode;
fn as_requirement_str(&self) -> &'static str;
}

impl<'tcx> ObligationCauseExt<'tcx> for ObligationCause<'tcx> {
#[extension(pub trait ObligationCauseExt<'tcx>)]
impl<'tcx> ObligationCause<'tcx> {
fn as_failure_code(&self, terr: TypeError<'tcx>) -> FailureCode {
use self::FailureCode::*;
use crate::traits::ObligationCauseCode::*;
Expand Down
7 changes: 2 additions & 5 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -626,11 +626,8 @@ pub struct InferCtxtBuilder<'tcx> {
next_trait_solver: bool,
}

pub trait TyCtxtInferExt<'tcx> {
fn infer_ctxt(self) -> InferCtxtBuilder<'tcx>;
}

impl<'tcx> TyCtxtInferExt<'tcx> for TyCtxt<'tcx> {
#[extension(pub trait TyCtxtInferExt<'tcx>)]
impl<'tcx> TyCtxt<'tcx> {
fn infer_ctxt(self) -> InferCtxtBuilder<'tcx> {
InferCtxtBuilder {
tcx: self,
Expand Down
15 changes: 3 additions & 12 deletions compiler/rustc_infer/src/traits/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,8 @@ pub trait TraitEngine<'tcx>: 'tcx {
) -> Vec<PredicateObligation<'tcx>>;
}

pub trait TraitEngineExt<'tcx> {
fn register_predicate_obligations(
&mut self,
infcx: &InferCtxt<'tcx>,
obligations: impl IntoIterator<Item = PredicateObligation<'tcx>>,
);

#[must_use]
fn select_all_or_error(&mut self, infcx: &InferCtxt<'tcx>) -> Vec<FulfillmentError<'tcx>>;
}

impl<'tcx, T: ?Sized + TraitEngine<'tcx>> TraitEngineExt<'tcx> for T {
#[extension(pub trait TraitEngineExt<'tcx>)]
impl<'tcx, T: ?Sized + TraitEngine<'tcx>> T {
fn register_predicate_obligations(
&mut self,
infcx: &InferCtxt<'tcx>,
Expand All @@ -74,6 +64,7 @@ impl<'tcx, T: ?Sized + TraitEngine<'tcx>> TraitEngineExt<'tcx> for T {
}
}

#[must_use]
fn select_all_or_error(&mut self, infcx: &InferCtxt<'tcx>) -> Vec<FulfillmentError<'tcx>> {
let errors = self.select_where_possible(infcx);
if !errors.is_empty() {
Expand Down
Loading

0 comments on commit 9ccabc8

Please sign in to comment.