From 004450506e1302239fa9ea09d50d9dc5a3261a32 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Thu, 21 Dec 2023 18:49:20 +0000 Subject: [PATCH] Split coroutine desugaring kind from source --- compiler/rustc_ast_lowering/src/expr.rs | 35 ++++++--- .../src/diagnostics/conflict_errors.rs | 25 +++--- .../src/diagnostics/region_errors.rs | 5 +- .../src/diagnostics/region_name.rs | 65 +++++++++------- .../src/debuginfo/type_names.rs | 36 ++++++--- .../src/transform/check_consts/check.rs | 8 +- .../src/transform/check_consts/ops.rs | 12 ++- compiler/rustc_hir/src/hir.rs | 78 +++++++++++-------- compiler/rustc_hir_typeck/src/callee.rs | 6 +- compiler/rustc_hir_typeck/src/check.rs | 7 +- compiler/rustc_hir_typeck/src/closure.rs | 12 ++- .../src/fn_ctxt/suggestions.rs | 9 ++- compiler/rustc_metadata/src/rmeta/table.rs | 20 ++--- compiler/rustc_middle/src/mir/mod.rs | 2 +- compiler/rustc_middle/src/mir/terminator.rs | 38 ++++++--- compiler/rustc_middle/src/ty/context.rs | 15 +++- compiler/rustc_middle/src/ty/util.rs | 22 ++++-- compiler/rustc_mir_transform/src/coroutine.rs | 44 +++++++---- .../rustc_smir/src/rustc_smir/convert/mod.rs | 23 ++++-- .../src/traits/error_reporting/suggestions.rs | 77 ++++++++++++------ .../error_reporting/type_err_ctxt_ext.rs | 47 ++++++++--- compiler/rustc_ty_utils/src/abi.rs | 12 +-- compiler/stable_mir/src/mir/body.rs | 55 ++++++++----- .../clippy_lints/src/async_yields_async.rs | 5 +- .../clippy_lints/src/await_holding_invalid.rs | 5 +- .../clippy_lints/src/manual_async_fn.rs | 4 +- .../src/needless_question_mark.rs | 4 +- .../clippy_lints/src/redundant_async_block.rs | 4 +- .../src/redundant_closure_call.rs | 4 +- .../clippy/clippy_lints/src/unused_async.rs | 8 +- 30 files changed, 448 insertions(+), 239 deletions(-) diff --git a/compiler/rustc_ast_lowering/src/expr.rs b/compiler/rustc_ast_lowering/src/expr.rs index 2d61f3bceec73..5fa04dda8be5e 100644 --- a/compiler/rustc_ast_lowering/src/expr.rs +++ b/compiler/rustc_ast_lowering/src/expr.rs @@ -670,7 +670,10 @@ impl<'hir> LoweringContext<'_, 'hir> { let params = arena_vec![self; param]; let body = self.lower_body(move |this| { - this.coroutine_kind = Some(hir::CoroutineKind::Async(async_coroutine_source)); + this.coroutine_kind = Some(hir::CoroutineKind::Desugared( + hir::CoroutineDesugaring::Async, + async_coroutine_source, + )); let old_ctx = this.task_context; this.task_context = Some(task_context_hid); @@ -724,7 +727,10 @@ impl<'hir> LoweringContext<'_, 'hir> { }); let body = self.lower_body(move |this| { - this.coroutine_kind = Some(hir::CoroutineKind::Gen(coroutine_source)); + this.coroutine_kind = Some(hir::CoroutineKind::Desugared( + hir::CoroutineDesugaring::Gen, + coroutine_source, + )); let res = body(this); (&[], res) @@ -802,7 +808,10 @@ impl<'hir> LoweringContext<'_, 'hir> { let params = arena_vec![self; param]; let body = self.lower_body(move |this| { - this.coroutine_kind = Some(hir::CoroutineKind::AsyncGen(async_coroutine_source)); + this.coroutine_kind = Some(hir::CoroutineKind::Desugared( + hir::CoroutineDesugaring::AsyncGen, + async_coroutine_source, + )); let old_ctx = this.task_context; this.task_context = Some(task_context_hid); @@ -888,9 +897,11 @@ impl<'hir> LoweringContext<'_, 'hir> { let full_span = expr.span.to(await_kw_span); let is_async_gen = match self.coroutine_kind { - Some(hir::CoroutineKind::Async(_)) => false, - Some(hir::CoroutineKind::AsyncGen(_)) => true, - Some(hir::CoroutineKind::Coroutine) | Some(hir::CoroutineKind::Gen(_)) | None => { + Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _)) => false, + Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _)) => true, + Some(hir::CoroutineKind::Coroutine) + | Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _)) + | None => { return hir::ExprKind::Err(self.tcx.sess.emit_err(AwaitOnlyInAsyncFnAndBlocks { await_kw_span, item_span: self.current_item, @@ -1123,9 +1134,9 @@ impl<'hir> LoweringContext<'_, 'hir> { Some(movability) } Some( - hir::CoroutineKind::Gen(_) - | hir::CoroutineKind::Async(_) - | hir::CoroutineKind::AsyncGen(_), + hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _) + | hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _) + | hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _), ) => { panic!("non-`async`/`gen` closure body turned `async`/`gen` during lowering"); } @@ -1638,9 +1649,9 @@ impl<'hir> LoweringContext<'_, 'hir> { fn lower_expr_yield(&mut self, span: Span, opt_expr: Option<&Expr>) -> hir::ExprKind<'hir> { let is_async_gen = match self.coroutine_kind { - Some(hir::CoroutineKind::Gen(_)) => false, - Some(hir::CoroutineKind::AsyncGen(_)) => true, - Some(hir::CoroutineKind::Async(_)) => { + Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _)) => false, + Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _)) => true, + Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _)) => { return hir::ExprKind::Err( self.tcx.sess.emit_err(AsyncCoroutinesNotSupported { span }), ); diff --git a/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs b/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs index db0f4559a6be0..e1f8ce2aff585 100644 --- a/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs +++ b/compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs @@ -1,5 +1,4 @@ use either::Either; -use hir::PatField; use rustc_data_structures::captures::Captures; use rustc_data_structures::fx::FxIndexSet; use rustc_errors::{ @@ -8,6 +7,7 @@ use rustc_errors::{ use rustc_hir as hir; use rustc_hir::def::{DefKind, Res}; use rustc_hir::intravisit::{walk_block, walk_expr, Visitor}; +use rustc_hir::{CoroutineDesugaring, PatField}; use rustc_hir::{CoroutineKind, CoroutineSource, LangItem}; use rustc_infer::traits::ObligationCause; use rustc_middle::hir::nested_filter::OnlyBodies; @@ -2516,27 +2516,29 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> { }; let kind = match use_span.coroutine_kind() { Some(coroutine_kind) => match coroutine_kind { - CoroutineKind::Gen(kind) => match kind { + CoroutineKind::Desugared(CoroutineDesugaring::Gen, kind) => match kind { CoroutineSource::Block => "gen block", CoroutineSource::Closure => "gen closure", CoroutineSource::Fn => { bug!("gen block/closure expected, but gen function found.") } }, - CoroutineKind::AsyncGen(kind) => match kind { + CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, kind) => match kind { CoroutineSource::Block => "async gen block", CoroutineSource::Closure => "async gen closure", CoroutineSource::Fn => { bug!("gen block/closure expected, but gen function found.") } }, - CoroutineKind::Async(async_kind) => match async_kind { - CoroutineSource::Block => "async block", - CoroutineSource::Closure => "async closure", - CoroutineSource::Fn => { - bug!("async block/closure expected, but async function found.") + CoroutineKind::Desugared(CoroutineDesugaring::Async, async_kind) => { + match async_kind { + CoroutineSource::Block => "async block", + CoroutineSource::Closure => "async closure", + CoroutineSource::Fn => { + bug!("async block/closure expected, but async function found.") + } } - }, + } CoroutineKind::Coroutine => "coroutine", }, None => "closure", @@ -2566,7 +2568,10 @@ impl<'cx, 'tcx> MirBorrowckCtxt<'cx, 'tcx> { } ConstraintCategory::CallArgument(_) => { fr_name.highlight_region_name(&mut err); - if matches!(use_span.coroutine_kind(), Some(CoroutineKind::Async(_))) { + if matches!( + use_span.coroutine_kind(), + Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, _)) + ) { err.note( "async blocks are not executed immediately and must either take a \ reference or ownership of outside variables they use", diff --git a/compiler/rustc_borrowck/src/diagnostics/region_errors.rs b/compiler/rustc_borrowck/src/diagnostics/region_errors.rs index 759f5e910f70a..348c07f2470ab 100644 --- a/compiler/rustc_borrowck/src/diagnostics/region_errors.rs +++ b/compiler/rustc_borrowck/src/diagnostics/region_errors.rs @@ -1049,7 +1049,10 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, 'tcx> { .. }) => { let body = map.body(*body); - if !matches!(body.coroutine_kind, Some(hir::CoroutineKind::Async(..))) { + if !matches!( + body.coroutine_kind, + Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _)) + ) { closure_span = Some(expr.span.shrink_to_lo()); } } diff --git a/compiler/rustc_borrowck/src/diagnostics/region_name.rs b/compiler/rustc_borrowck/src/diagnostics/region_name.rs index 8441dfaa7df28..78d84f468e026 100644 --- a/compiler/rustc_borrowck/src/diagnostics/region_name.rs +++ b/compiler/rustc_borrowck/src/diagnostics/region_name.rs @@ -684,39 +684,46 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> { hir::FnRetTy::Return(hir_ty) => (fn_decl.output.span(), Some(hir_ty)), }; let mir_description = match hir.body(body).coroutine_kind { - Some(hir::CoroutineKind::Async(src)) => match src { - hir::CoroutineSource::Block => " of async block", - hir::CoroutineSource::Closure => " of async closure", - hir::CoroutineSource::Fn => { - let parent_item = - tcx.hir_node_by_def_id(hir.get_parent_item(mir_hir_id).def_id); - let output = &parent_item - .fn_decl() - .expect("coroutine lowered from async fn should be in fn") - .output; - span = output.span(); - if let hir::FnRetTy::Return(ret) = output { - hir_ty = Some(self.get_future_inner_return_ty(*ret)); + Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, src)) => { + match src { + hir::CoroutineSource::Block => " of async block", + hir::CoroutineSource::Closure => " of async closure", + hir::CoroutineSource::Fn => { + let parent_item = + tcx.hir_node_by_def_id(hir.get_parent_item(mir_hir_id).def_id); + let output = &parent_item + .fn_decl() + .expect("coroutine lowered from async fn should be in fn") + .output; + span = output.span(); + if let hir::FnRetTy::Return(ret) = output { + hir_ty = Some(self.get_future_inner_return_ty(*ret)); + } + " of async function" } - " of async function" } - }, - Some(hir::CoroutineKind::Gen(src)) => match src { - hir::CoroutineSource::Block => " of gen block", - hir::CoroutineSource::Closure => " of gen closure", - hir::CoroutineSource::Fn => { - let parent_item = - tcx.hir_node_by_def_id(hir.get_parent_item(mir_hir_id).def_id); - let output = &parent_item - .fn_decl() - .expect("coroutine lowered from gen fn should be in fn") - .output; - span = output.span(); - " of gen function" + } + Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, src)) => { + match src { + hir::CoroutineSource::Block => " of gen block", + hir::CoroutineSource::Closure => " of gen closure", + hir::CoroutineSource::Fn => { + let parent_item = + tcx.hir_node_by_def_id(hir.get_parent_item(mir_hir_id).def_id); + let output = &parent_item + .fn_decl() + .expect("coroutine lowered from gen fn should be in fn") + .output; + span = output.span(); + " of gen function" + } } - }, + } - Some(hir::CoroutineKind::AsyncGen(src)) => match src { + Some(hir::CoroutineKind::Desugared( + hir::CoroutineDesugaring::AsyncGen, + src, + )) => match src { hir::CoroutineSource::Block => " of async gen block", hir::CoroutineSource::Closure => " of async gen closure", hir::CoroutineSource::Fn => { diff --git a/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs b/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs index dda30046bfbad..2ecc5ad4fe42c 100644 --- a/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs +++ b/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs @@ -15,7 +15,7 @@ use rustc_data_structures::fx::FxHashSet; use rustc_data_structures::stable_hasher::{Hash64, HashStable, StableHasher}; use rustc_hir::def_id::DefId; use rustc_hir::definitions::{DefPathData, DefPathDataName, DisambiguatedDefPathData}; -use rustc_hir::{CoroutineKind, CoroutineSource, Mutability}; +use rustc_hir::{CoroutineDesugaring, CoroutineKind, CoroutineSource, Mutability}; use rustc_middle::ty::layout::{IntegerExt, TyAndLayout}; use rustc_middle::ty::{self, ExistentialProjection, ParamEnv, Ty, TyCtxt}; use rustc_middle::ty::{GenericArgKind, GenericArgsRef}; @@ -560,15 +560,31 @@ pub fn push_item_name(tcx: TyCtxt<'_>, def_id: DefId, qualified: bool, output: & fn coroutine_kind_label(coroutine_kind: Option) -> &'static str { match coroutine_kind { - Some(CoroutineKind::Gen(CoroutineSource::Block)) => "gen_block", - Some(CoroutineKind::Gen(CoroutineSource::Closure)) => "gen_closure", - Some(CoroutineKind::Gen(CoroutineSource::Fn)) => "gen_fn", - Some(CoroutineKind::Async(CoroutineSource::Block)) => "async_block", - Some(CoroutineKind::Async(CoroutineSource::Closure)) => "async_closure", - Some(CoroutineKind::Async(CoroutineSource::Fn)) => "async_fn", - Some(CoroutineKind::AsyncGen(CoroutineSource::Block)) => "async_gen_block", - Some(CoroutineKind::AsyncGen(CoroutineSource::Closure)) => "async_gen_closure", - Some(CoroutineKind::AsyncGen(CoroutineSource::Fn)) => "async_gen_fn", + Some(CoroutineKind::Desugared(CoroutineDesugaring::Gen, CoroutineSource::Block)) => { + "gen_block" + } + Some(CoroutineKind::Desugared(CoroutineDesugaring::Gen, CoroutineSource::Closure)) => { + "gen_closure" + } + Some(CoroutineKind::Desugared(CoroutineDesugaring::Gen, CoroutineSource::Fn)) => "gen_fn", + Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, CoroutineSource::Block)) => { + "async_block" + } + Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, CoroutineSource::Closure)) => { + "async_closure" + } + Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, CoroutineSource::Fn)) => { + "async_fn" + } + Some(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, CoroutineSource::Block)) => { + "async_gen_block" + } + Some(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, CoroutineSource::Closure)) => { + "async_gen_closure" + } + Some(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, CoroutineSource::Fn)) => { + "async_gen_fn" + } Some(CoroutineKind::Coroutine) => "coroutine", None => "closure", } diff --git a/compiler/rustc_const_eval/src/transform/check_consts/check.rs b/compiler/rustc_const_eval/src/transform/check_consts/check.rs index 949606ed6c9f9..8a1e5356a159d 100644 --- a/compiler/rustc_const_eval/src/transform/check_consts/check.rs +++ b/compiler/rustc_const_eval/src/transform/check_consts/check.rs @@ -464,8 +464,12 @@ impl<'tcx> Visitor<'tcx> for Checker<'_, 'tcx> { Rvalue::Aggregate(kind, ..) => { if let AggregateKind::Coroutine(def_id, ..) = kind.as_ref() - && let Some(coroutine_kind @ hir::CoroutineKind::Async(..)) = - self.tcx.coroutine_kind(def_id) + && let Some( + coroutine_kind @ hir::CoroutineKind::Desugared( + hir::CoroutineDesugaring::Async, + _, + ), + ) = self.tcx.coroutine_kind(def_id) { self.check_op(ops::Coroutine(coroutine_kind)); } diff --git a/compiler/rustc_const_eval/src/transform/check_consts/ops.rs b/compiler/rustc_const_eval/src/transform/check_consts/ops.rs index 2de6362b9fe01..23d8a563d910a 100644 --- a/compiler/rustc_const_eval/src/transform/check_consts/ops.rs +++ b/compiler/rustc_const_eval/src/transform/check_consts/ops.rs @@ -359,7 +359,11 @@ impl<'tcx> NonConstOp<'tcx> for FnCallUnstable { pub struct Coroutine(pub hir::CoroutineKind); impl<'tcx> NonConstOp<'tcx> for Coroutine { fn status_in_item(&self, _: &ConstCx<'_, 'tcx>) -> Status { - if let hir::CoroutineKind::Async(hir::CoroutineSource::Block) = self.0 { + if let hir::CoroutineKind::Desugared( + hir::CoroutineDesugaring::Async, + hir::CoroutineSource::Block, + ) = self.0 + { Status::Unstable(sym::const_async_blocks) } else { Status::Forbidden @@ -372,7 +376,11 @@ impl<'tcx> NonConstOp<'tcx> for Coroutine { span: Span, ) -> DiagnosticBuilder<'tcx, ErrorGuaranteed> { let msg = format!("{:#}s are not allowed in {}s", self.0, ccx.const_kind()); - if let hir::CoroutineKind::Async(hir::CoroutineSource::Block) = self.0 { + if let hir::CoroutineKind::Desugared( + hir::CoroutineDesugaring::Async, + hir::CoroutineSource::Block, + ) = self.0 + { ccx.tcx.sess.create_feature_err( errors::UnallowedOpInConstContext { span, msg }, sym::const_async_blocks, diff --git a/compiler/rustc_hir/src/hir.rs b/compiler/rustc_hir/src/hir.rs index 26430dcf965fe..452f5d0b7ace9 100644 --- a/compiler/rustc_hir/src/hir.rs +++ b/compiler/rustc_hir/src/hir.rs @@ -1351,15 +1351,8 @@ impl<'hir> Body<'hir> { /// The type of source expression that caused this coroutine to be created. #[derive(Clone, PartialEq, Eq, Debug, Copy, Hash, HashStable_Generic, Encodable, Decodable)] pub enum CoroutineKind { - /// An explicit `async` block or the body of an `async` function. - Async(CoroutineSource), - - /// An explicit `gen` block or the body of a `gen` function. - Gen(CoroutineSource), - - /// An explicit `async gen` block or the body of an `async gen` function, - /// which is able to both `yield` and `.await`. - AsyncGen(CoroutineSource), + /// A coroutine that comes from a desugaring. + Desugared(CoroutineDesugaring, CoroutineSource), /// A coroutine literal created via a `yield` inside a closure. Coroutine, @@ -1368,31 +1361,11 @@ pub enum CoroutineKind { impl fmt::Display for CoroutineKind { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - CoroutineKind::Async(k) => { - if f.alternate() { - f.write_str("`async` ")?; - } else { - f.write_str("async ")? - } + CoroutineKind::Desugared(d, k) => { + d.fmt(f)?; k.fmt(f) } CoroutineKind::Coroutine => f.write_str("coroutine"), - CoroutineKind::Gen(k) => { - if f.alternate() { - f.write_str("`gen` ")?; - } else { - f.write_str("gen ")? - } - k.fmt(f) - } - CoroutineKind::AsyncGen(k) => { - if f.alternate() { - f.write_str("`async gen` ")?; - } else { - f.write_str("async gen ")? - } - k.fmt(f) - } } } } @@ -1425,6 +1398,49 @@ impl fmt::Display for CoroutineSource { } } +#[derive(Clone, PartialEq, Eq, Debug, Copy, Hash, HashStable_Generic, Encodable, Decodable)] +pub enum CoroutineDesugaring { + /// An explicit `async` block or the body of an `async` function. + Async, + + /// An explicit `gen` block or the body of a `gen` function. + Gen, + + /// An explicit `async gen` block or the body of an `async gen` function, + /// which is able to both `yield` and `.await`. + AsyncGen, +} + +impl fmt::Display for CoroutineDesugaring { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CoroutineDesugaring::Async => { + if f.alternate() { + f.write_str("`async` ")?; + } else { + f.write_str("async ")? + } + } + CoroutineDesugaring::Gen => { + if f.alternate() { + f.write_str("`gen` ")?; + } else { + f.write_str("gen ")? + } + } + CoroutineDesugaring::AsyncGen => { + if f.alternate() { + f.write_str("`async gen` ")?; + } else { + f.write_str("async gen ")? + } + } + } + + Ok(()) + } +} + #[derive(Copy, Clone, Debug)] pub enum BodyOwnerKind { /// Functions and methods. diff --git a/compiler/rustc_hir_typeck/src/callee.rs b/compiler/rustc_hir_typeck/src/callee.rs index 5e6b54950b302..2146effd84f06 100644 --- a/compiler/rustc_hir_typeck/src/callee.rs +++ b/compiler/rustc_hir_typeck/src/callee.rs @@ -305,8 +305,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { ) = (parent_node, callee_node) { let fn_decl_span = if hir.body(body).coroutine_kind - == Some(hir::CoroutineKind::Async(hir::CoroutineSource::Closure)) - { + == Some(hir::CoroutineKind::Desugared( + hir::CoroutineDesugaring::Async, + hir::CoroutineSource::Closure, + )) { // Actually need to unwrap one more layer of HIR to get to // the _real_ closure... let async_closure = hir.parent_id(parent_hir_id); diff --git a/compiler/rustc_hir_typeck/src/check.rs b/compiler/rustc_hir_typeck/src/check.rs index 2855cea80b212..8e2af40291879 100644 --- a/compiler/rustc_hir_typeck/src/check.rs +++ b/compiler/rustc_hir_typeck/src/check.rs @@ -59,7 +59,8 @@ pub(super) fn check_fn<'a, 'tcx>( && can_be_coroutine.is_some() { let yield_ty = match kind { - hir::CoroutineKind::Gen(..) | hir::CoroutineKind::Coroutine => { + hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _) + | hir::CoroutineKind::Coroutine => { let yield_ty = fcx.next_ty_var(TypeVariableOrigin { kind: TypeVariableOriginKind::TypeInference, span, @@ -71,7 +72,7 @@ pub(super) fn check_fn<'a, 'tcx>( // guide inference on the yield type so that we can handle `AsyncIterator` // in this block in projection correctly. In the new trait solver, it is // not a problem. - hir::CoroutineKind::AsyncGen(..) => { + hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _) => { let yield_ty = fcx.next_ty_var(TypeVariableOrigin { kind: TypeVariableOriginKind::TypeInference, span, @@ -89,7 +90,7 @@ pub(super) fn check_fn<'a, 'tcx>( .into()]), ) } - hir::CoroutineKind::Async(..) => Ty::new_unit(tcx), + hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _) => Ty::new_unit(tcx), }; // Resume type defaults to `()` if the coroutine has no argument. diff --git a/compiler/rustc_hir_typeck/src/closure.rs b/compiler/rustc_hir_typeck/src/closure.rs index d19d304128a1a..cd42be28e6f04 100644 --- a/compiler/rustc_hir_typeck/src/closure.rs +++ b/compiler/rustc_hir_typeck/src/closure.rs @@ -634,7 +634,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { // In the case of the async block that we create for a function body, // we expect the return type of the block to match that of the enclosing // function. - Some(hir::CoroutineKind::Async(hir::CoroutineSource::Fn)) => { + Some(hir::CoroutineKind::Desugared( + hir::CoroutineDesugaring::Async, + hir::CoroutineSource::Fn, + )) => { debug!("closure is async fn body"); let def_id = self.tcx.hir().body_owner_def_id(body.id()); self.deduce_future_output_from_obligations(expr_def_id, def_id).unwrap_or_else( @@ -651,9 +654,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { ) } // All `gen {}` and `async gen {}` must return unit. - Some(hir::CoroutineKind::Gen(_) | hir::CoroutineKind::AsyncGen(_)) => { - self.tcx.types.unit - } + Some( + hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _) + | hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _), + ) => self.tcx.types.unit, _ => astconv.ty_infer(None, decl.output.span()), }, diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs index 668e547571f69..d2917b25c54bd 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/suggestions.rs @@ -17,8 +17,8 @@ use rustc_hir::def::Res; use rustc_hir::def::{CtorKind, CtorOf, DefKind}; use rustc_hir::lang_items::LangItem; use rustc_hir::{ - CoroutineKind, CoroutineSource, Expr, ExprKind, GenericBound, HirId, Node, Path, QPath, Stmt, - StmtKind, TyKind, WherePredicate, + CoroutineDesugaring, CoroutineKind, CoroutineSource, Expr, ExprKind, GenericBound, HirId, Node, + Path, QPath, Stmt, StmtKind, TyKind, WherePredicate, }; use rustc_hir_analysis::astconv::AstConv; use rustc_infer::traits::{self, StatementAsExpression}; @@ -549,7 +549,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { ty::Coroutine(def_id, ..) if matches!( self.tcx.coroutine_kind(def_id), - Some(CoroutineKind::Async(CoroutineSource::Closure)) + Some(CoroutineKind::Desugared( + CoroutineDesugaring::Async, + CoroutineSource::Closure + )) ) => { errors::SuggestBoxing::AsyncBody diff --git a/compiler/rustc_metadata/src/rmeta/table.rs b/compiler/rustc_metadata/src/rmeta/table.rs index 0f5d4d9476d63..667fc30199187 100644 --- a/compiler/rustc_metadata/src/rmeta/table.rs +++ b/compiler/rustc_metadata/src/rmeta/table.rs @@ -206,16 +206,16 @@ fixed_size_enum! { fixed_size_enum! { hir::CoroutineKind { - ( Coroutine ) - ( Gen(hir::CoroutineSource::Block) ) - ( Gen(hir::CoroutineSource::Fn) ) - ( Gen(hir::CoroutineSource::Closure) ) - ( Async(hir::CoroutineSource::Block) ) - ( Async(hir::CoroutineSource::Fn) ) - ( Async(hir::CoroutineSource::Closure) ) - ( AsyncGen(hir::CoroutineSource::Block) ) - ( AsyncGen(hir::CoroutineSource::Fn) ) - ( AsyncGen(hir::CoroutineSource::Closure) ) + ( Coroutine ) + ( Desugared(hir::CoroutineDesugaring::Gen, hir::CoroutineSource::Block) ) + ( Desugared(hir::CoroutineDesugaring::Gen, hir::CoroutineSource::Fn) ) + ( Desugared(hir::CoroutineDesugaring::Gen, hir::CoroutineSource::Closure) ) + ( Desugared(hir::CoroutineDesugaring::Async, hir::CoroutineSource::Block) ) + ( Desugared(hir::CoroutineDesugaring::Async, hir::CoroutineSource::Fn) ) + ( Desugared(hir::CoroutineDesugaring::Async, hir::CoroutineSource::Closure) ) + ( Desugared(hir::CoroutineDesugaring::AsyncGen, hir::CoroutineSource::Block) ) + ( Desugared(hir::CoroutineDesugaring::AsyncGen, hir::CoroutineSource::Fn) ) + ( Desugared(hir::CoroutineDesugaring::AsyncGen, hir::CoroutineSource::Closure) ) } } diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs index 1e5a7401c6f94..5c425fef27ebc 100644 --- a/compiler/rustc_middle/src/mir/mod.rs +++ b/compiler/rustc_middle/src/mir/mod.rs @@ -17,7 +17,7 @@ use rustc_data_structures::captures::Captures; use rustc_errors::{DiagnosticArgValue, DiagnosticMessage, ErrorGuaranteed, IntoDiagnosticArg}; use rustc_hir::def::{CtorKind, Namespace}; use rustc_hir::def_id::{DefId, CRATE_DEF_ID}; -use rustc_hir::{self, CoroutineKind, ImplicitSelfKind}; +use rustc_hir::{self, CoroutineDesugaring, CoroutineKind, ImplicitSelfKind}; use rustc_hir::{self as hir, HirId}; use rustc_session::Session; use rustc_target::abi::{FieldIdx, VariantIdx}; diff --git a/compiler/rustc_middle/src/mir/terminator.rs b/compiler/rustc_middle/src/mir/terminator.rs index 98e3a1f604e6e..e0c9def037948 100644 --- a/compiler/rustc_middle/src/mir/terminator.rs +++ b/compiler/rustc_middle/src/mir/terminator.rs @@ -148,19 +148,23 @@ impl AssertKind { DivisionByZero(_) => "attempt to divide by zero", RemainderByZero(_) => "attempt to calculate the remainder with a divisor of zero", ResumedAfterReturn(CoroutineKind::Coroutine) => "coroutine resumed after completion", - ResumedAfterReturn(CoroutineKind::Async(_)) => "`async fn` resumed after completion", - ResumedAfterReturn(CoroutineKind::AsyncGen(_)) => { + ResumedAfterReturn(CoroutineKind::Desugared(CoroutineDesugaring::Async, _)) => { + "`async fn` resumed after completion" + } + ResumedAfterReturn(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)) => { "`async gen fn` resumed after completion" } - ResumedAfterReturn(CoroutineKind::Gen(_)) => { + ResumedAfterReturn(CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) => { "`gen fn` should just keep returning `None` after completion" } ResumedAfterPanic(CoroutineKind::Coroutine) => "coroutine resumed after panicking", - ResumedAfterPanic(CoroutineKind::Async(_)) => "`async fn` resumed after panicking", - ResumedAfterPanic(CoroutineKind::AsyncGen(_)) => { + ResumedAfterPanic(CoroutineKind::Desugared(CoroutineDesugaring::Async, _)) => { + "`async fn` resumed after panicking" + } + ResumedAfterPanic(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)) => { "`async gen fn` resumed after panicking" } - ResumedAfterPanic(CoroutineKind::Gen(_)) => { + ResumedAfterPanic(CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) => { "`gen fn` should just keep returning `None` after panicking" } @@ -249,17 +253,27 @@ impl AssertKind { OverflowNeg(_) => middle_assert_overflow_neg, DivisionByZero(_) => middle_assert_divide_by_zero, RemainderByZero(_) => middle_assert_remainder_by_zero, - ResumedAfterReturn(CoroutineKind::Async(_)) => middle_assert_async_resume_after_return, - ResumedAfterReturn(CoroutineKind::AsyncGen(_)) => todo!(), - ResumedAfterReturn(CoroutineKind::Gen(_)) => { + ResumedAfterReturn(CoroutineKind::Desugared(CoroutineDesugaring::Async, _)) => { + middle_assert_async_resume_after_return + } + ResumedAfterReturn(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)) => { + todo!() + } + ResumedAfterReturn(CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) => { bug!("gen blocks can be resumed after they return and will keep returning `None`") } ResumedAfterReturn(CoroutineKind::Coroutine) => { middle_assert_coroutine_resume_after_return } - ResumedAfterPanic(CoroutineKind::Async(_)) => middle_assert_async_resume_after_panic, - ResumedAfterPanic(CoroutineKind::AsyncGen(_)) => todo!(), - ResumedAfterPanic(CoroutineKind::Gen(_)) => middle_assert_gen_resume_after_panic, + ResumedAfterPanic(CoroutineKind::Desugared(CoroutineDesugaring::Async, _)) => { + middle_assert_async_resume_after_panic + } + ResumedAfterPanic(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)) => { + todo!() + } + ResumedAfterPanic(CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) => { + middle_assert_gen_resume_after_panic + } ResumedAfterPanic(CoroutineKind::Coroutine) => { middle_assert_coroutine_resume_after_panic } diff --git a/compiler/rustc_middle/src/ty/context.rs b/compiler/rustc_middle/src/ty/context.rs index b5ca700c2cd5a..655dde1d9c987 100644 --- a/compiler/rustc_middle/src/ty/context.rs +++ b/compiler/rustc_middle/src/ty/context.rs @@ -849,7 +849,10 @@ impl<'tcx> TyCtxt<'tcx> { /// Returns `true` if the node pointed to by `def_id` is a coroutine for an async construct. pub fn coroutine_is_async(self, def_id: DefId) -> bool { - matches!(self.coroutine_kind(def_id), Some(hir::CoroutineKind::Async(_))) + matches!( + self.coroutine_kind(def_id), + Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _)) + ) } /// Returns `true` if the node pointed to by `def_id` is a general coroutine that implements `Coroutine`. @@ -860,12 +863,18 @@ impl<'tcx> TyCtxt<'tcx> { /// Returns `true` if the node pointed to by `def_id` is a coroutine for a `gen` construct. pub fn coroutine_is_gen(self, def_id: DefId) -> bool { - matches!(self.coroutine_kind(def_id), Some(hir::CoroutineKind::Gen(_))) + matches!( + self.coroutine_kind(def_id), + Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _)) + ) } /// Returns `true` if the node pointed to by `def_id` is a coroutine for a `async gen` construct. pub fn coroutine_is_async_gen(self, def_id: DefId) -> bool { - matches!(self.coroutine_kind(def_id), Some(hir::CoroutineKind::AsyncGen(_))) + matches!( + self.coroutine_kind(def_id), + Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _)) + ) } pub fn stability(self) -> &'tcx stability::Index { diff --git a/compiler/rustc_middle/src/ty/util.rs b/compiler/rustc_middle/src/ty/util.rs index 8b2b76764e646..ab2488f2e9103 100644 --- a/compiler/rustc_middle/src/ty/util.rs +++ b/compiler/rustc_middle/src/ty/util.rs @@ -728,10 +728,16 @@ impl<'tcx> TyCtxt<'tcx> { DefKind::AssocFn if self.associated_item(def_id).fn_has_self_parameter => "method", DefKind::Closure if let Some(coroutine_kind) = self.coroutine_kind(def_id) => { match coroutine_kind { - rustc_hir::CoroutineKind::Async(..) => "async closure", - rustc_hir::CoroutineKind::AsyncGen(..) => "async gen closure", - rustc_hir::CoroutineKind::Coroutine => "coroutine", - rustc_hir::CoroutineKind::Gen(..) => "gen closure", + hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _) => { + "async closure" + } + hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _) => { + "async gen closure" + } + hir::CoroutineKind::Coroutine => "coroutine", + hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _) => { + "gen closure" + } } } _ => def_kind.descr(def_id), @@ -749,10 +755,10 @@ impl<'tcx> TyCtxt<'tcx> { DefKind::AssocFn if self.associated_item(def_id).fn_has_self_parameter => "a", DefKind::Closure if let Some(coroutine_kind) = self.coroutine_kind(def_id) => { match coroutine_kind { - rustc_hir::CoroutineKind::Async(..) => "an", - rustc_hir::CoroutineKind::AsyncGen(..) => "an", - rustc_hir::CoroutineKind::Coroutine => "a", - rustc_hir::CoroutineKind::Gen(..) => "a", + hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, ..) => "an", + hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, ..) => "an", + hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, ..) => "a", + hir::CoroutineKind::Coroutine => "a", } } _ => def_kind.article(), diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index d7dd44af7d251..6da102dcb1c61 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -59,7 +59,7 @@ use rustc_data_structures::fx::{FxHashMap, FxHashSet}; use rustc_errors::pluralize; use rustc_hir as hir; use rustc_hir::lang_items::LangItem; -use rustc_hir::CoroutineKind; +use rustc_hir::{CoroutineDesugaring, CoroutineKind}; use rustc_index::bit_set::{BitMatrix, BitSet, GrowableBitSet}; use rustc_index::{Idx, IndexVec}; use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; @@ -254,10 +254,12 @@ impl<'tcx> TransformVisitor<'tcx> { let source_info = SourceInfo::outermost(body.span); let none_value = match self.coroutine_kind { - CoroutineKind::Async(_) => span_bug!(body.span, "`Future`s are not fused inherently"), + CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => { + span_bug!(body.span, "`Future`s are not fused inherently") + } CoroutineKind::Coroutine => span_bug!(body.span, "`Coroutine`s cannot be fused"), // `gen` continues return `None` - CoroutineKind::Gen(_) => { + CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => { let option_def_id = self.tcx.require_lang_item(LangItem::Option, None); Rvalue::Aggregate( Box::new(AggregateKind::Adt( @@ -271,7 +273,7 @@ impl<'tcx> TransformVisitor<'tcx> { ) } // `async gen` continues to return `Poll::Ready(None)` - CoroutineKind::AsyncGen(_) => { + CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => { let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() }; let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() }; let yield_ty = args.type_at(0); @@ -316,7 +318,7 @@ impl<'tcx> TransformVisitor<'tcx> { statements: &mut Vec>, ) { let rvalue = match self.coroutine_kind { - CoroutineKind::Async(_) => { + CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => { let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, None); let args = self.tcx.mk_args(&[self.old_ret_ty.into()]); if is_return { @@ -345,7 +347,7 @@ impl<'tcx> TransformVisitor<'tcx> { ) } } - CoroutineKind::Gen(_) => { + CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => { let option_def_id = self.tcx.require_lang_item(LangItem::Option, None); let args = self.tcx.mk_args(&[self.old_yield_ty.into()]); if is_return { @@ -374,7 +376,7 @@ impl<'tcx> TransformVisitor<'tcx> { ) } } - CoroutineKind::AsyncGen(_) => { + CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => { if is_return { let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() }; let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() }; @@ -1426,10 +1428,11 @@ fn create_coroutine_resume_function<'tcx>( if can_return { let block = match coroutine_kind { - CoroutineKind::Async(_) | CoroutineKind::Coroutine => { + CoroutineKind::Desugared(CoroutineDesugaring::Async, _) | CoroutineKind::Coroutine => { insert_panic_block(tcx, body, ResumedAfterReturn(coroutine_kind)) } - CoroutineKind::AsyncGen(_) | CoroutineKind::Gen(_) => { + CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) + | CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => { transform.insert_none_ret_block(body) } }; @@ -1443,7 +1446,7 @@ fn create_coroutine_resume_function<'tcx>( match coroutine_kind { // Iterator::next doesn't accept a pinned argument, // unlike for all other coroutine kinds. - CoroutineKind::Gen(_) => {} + CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {} _ => { make_coroutine_state_argument_pinned(tcx, body); } @@ -1609,25 +1612,34 @@ impl<'tcx> MirPass<'tcx> for StateTransform { } }; - let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_))); - let is_async_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::AsyncGen(_))); - let is_gen_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Gen(_))); + let is_async_kind = matches!( + body.coroutine_kind(), + Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, _)) + ); + let is_async_gen_kind = matches!( + body.coroutine_kind(), + Some(CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)) + ); + let is_gen_kind = matches!( + body.coroutine_kind(), + Some(CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) + ); let new_ret_ty = match body.coroutine_kind().unwrap() { - CoroutineKind::Async(_) => { + CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => { // Compute Poll let poll_did = tcx.require_lang_item(LangItem::Poll, None); let poll_adt_ref = tcx.adt_def(poll_did); let poll_args = tcx.mk_args(&[old_ret_ty.into()]); Ty::new_adt(tcx, poll_adt_ref, poll_args) } - CoroutineKind::Gen(_) => { + CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => { // Compute Option let option_did = tcx.require_lang_item(LangItem::Option, None); let option_adt_ref = tcx.adt_def(option_did); let option_args = tcx.mk_args(&[old_yield_ty.into()]); Ty::new_adt(tcx, option_adt_ref, option_args) } - CoroutineKind::AsyncGen(_) => { + CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => { // The yield ty is already `Poll>` old_yield_ty } diff --git a/compiler/rustc_smir/src/rustc_smir/convert/mod.rs b/compiler/rustc_smir/src/rustc_smir/convert/mod.rs index 8b7b26f969c96..5f505ac181cad 100644 --- a/compiler/rustc_smir/src/rustc_smir/convert/mod.rs +++ b/compiler/rustc_smir/src/rustc_smir/convert/mod.rs @@ -41,17 +41,26 @@ impl<'tcx> Stable<'tcx> for rustc_hir::CoroutineSource { impl<'tcx> Stable<'tcx> for rustc_hir::CoroutineKind { type T = stable_mir::mir::CoroutineKind; fn stable(&self, tables: &mut Tables<'tcx>) -> Self::T { - use rustc_hir::CoroutineKind; + use rustc_hir::{CoroutineDesugaring, CoroutineKind}; match self { - CoroutineKind::Async(source) => { - stable_mir::mir::CoroutineKind::Async(source.stable(tables)) + CoroutineKind::Desugared(CoroutineDesugaring::Async, source) => { + stable_mir::mir::CoroutineKind::Desugared( + stable_mir::mir::CoroutineDesugaring::Async, + source.stable(tables), + ) } - CoroutineKind::Gen(source) => { - stable_mir::mir::CoroutineKind::Gen(source.stable(tables)) + CoroutineKind::Desugared(CoroutineDesugaring::Gen, source) => { + stable_mir::mir::CoroutineKind::Desugared( + stable_mir::mir::CoroutineDesugaring::Gen, + source.stable(tables), + ) } CoroutineKind::Coroutine => stable_mir::mir::CoroutineKind::Coroutine, - CoroutineKind::AsyncGen(source) => { - stable_mir::mir::CoroutineKind::AsyncGen(source.stable(tables)) + CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, source) => { + stable_mir::mir::CoroutineKind::Desugared( + stable_mir::mir::CoroutineDesugaring::AsyncGen, + source.stable(tables), + ) } } } diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs index a1b896d225125..066510b45e92d 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/suggestions.rs @@ -22,7 +22,7 @@ use rustc_hir::def_id::DefId; use rustc_hir::intravisit::Visitor; use rustc_hir::is_range_literal; use rustc_hir::lang_items::LangItem; -use rustc_hir::{CoroutineKind, CoroutineSource, Node}; +use rustc_hir::{CoroutineDesugaring, CoroutineKind, CoroutineSource, Node}; use rustc_hir::{Expr, HirId}; use rustc_infer::infer::error_reporting::TypeErrCtxt; use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind}; @@ -2578,7 +2578,10 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> { .and_then(|coroutine_did| { Some(match self.tcx.coroutine_kind(coroutine_did).unwrap() { CoroutineKind::Coroutine => format!("coroutine is not {trait_name}"), - CoroutineKind::Async(CoroutineSource::Fn) => self + CoroutineKind::Desugared( + CoroutineDesugaring::Async, + CoroutineSource::Fn, + ) => self .tcx .parent(coroutine_did) .as_local() @@ -2587,13 +2590,22 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> { .map(|name| { format!("future returned by `{name}` is not {trait_name}") })?, - CoroutineKind::Async(CoroutineSource::Block) => { + CoroutineKind::Desugared( + CoroutineDesugaring::Async, + CoroutineSource::Block, + ) => { format!("future created by async block is not {trait_name}") } - CoroutineKind::Async(CoroutineSource::Closure) => { + CoroutineKind::Desugared( + CoroutineDesugaring::Async, + CoroutineSource::Closure, + ) => { format!("future created by async closure is not {trait_name}") } - CoroutineKind::AsyncGen(CoroutineSource::Fn) => self + CoroutineKind::Desugared( + CoroutineDesugaring::AsyncGen, + CoroutineSource::Fn, + ) => self .tcx .parent(coroutine_did) .as_local() @@ -2602,27 +2614,40 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> { .map(|name| { format!("async iterator returned by `{name}` is not {trait_name}") })?, - CoroutineKind::AsyncGen(CoroutineSource::Block) => { + CoroutineKind::Desugared( + CoroutineDesugaring::AsyncGen, + CoroutineSource::Block, + ) => { format!("async iterator created by async gen block is not {trait_name}") } - CoroutineKind::AsyncGen(CoroutineSource::Closure) => { + CoroutineKind::Desugared( + CoroutineDesugaring::AsyncGen, + CoroutineSource::Closure, + ) => { format!( "async iterator created by async gen closure is not {trait_name}" ) } - CoroutineKind::Gen(CoroutineSource::Fn) => self - .tcx - .parent(coroutine_did) - .as_local() - .map(|parent_did| self.tcx.local_def_id_to_hir_id(parent_did)) - .and_then(|parent_hir_id| hir.opt_name(parent_hir_id)) - .map(|name| { - format!("iterator returned by `{name}` is not {trait_name}") - })?, - CoroutineKind::Gen(CoroutineSource::Block) => { + CoroutineKind::Desugared(CoroutineDesugaring::Gen, CoroutineSource::Fn) => { + self.tcx + .parent(coroutine_did) + .as_local() + .map(|parent_did| self.tcx.local_def_id_to_hir_id(parent_did)) + .and_then(|parent_hir_id| hir.opt_name(parent_hir_id)) + .map(|name| { + format!("iterator returned by `{name}` is not {trait_name}") + })? + } + CoroutineKind::Desugared( + CoroutineDesugaring::Gen, + CoroutineSource::Block, + ) => { format!("iterator created by gen block is not {trait_name}") } - CoroutineKind::Gen(CoroutineSource::Closure) => { + CoroutineKind::Desugared( + CoroutineDesugaring::Gen, + CoroutineSource::Closure, + ) => { format!("iterator created by gen closure is not {trait_name}") } }) @@ -3145,9 +3170,15 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> { let what = match self.tcx.coroutine_kind(coroutine_def_id) { None | Some(hir::CoroutineKind::Coroutine) - | Some(hir::CoroutineKind::Gen(_)) => "yield", - Some(hir::CoroutineKind::Async(..)) => "await", - Some(hir::CoroutineKind::AsyncGen(_)) => "yield`/`await", + | Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _)) => { + "yield" + } + Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _)) => { + "await" + } + Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _)) => { + "yield`/`await" + } }; err.note(format!( "all values live across `{what}` must have a statically known size" @@ -3535,7 +3566,9 @@ impl<'tcx> TypeErrCtxtExt<'tcx> for TypeErrCtxt<'_, 'tcx> { ) { if let Some(body_id) = self.tcx.hir().maybe_body_owned_by(obligation.cause.body_id) { let body = self.tcx.hir().body(body_id); - if let Some(hir::CoroutineKind::Async(_)) = body.coroutine_kind { + if let Some(hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _)) = + body.coroutine_kind + { let future_trait = self.tcx.require_lang_item(LangItem::Future, None); let self_ty = self.resolve_vars_if_possible(trait_pred.self_ty()); diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs index 9ee091bbd1e6d..998cb7539808a 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/type_err_ctxt_ext.rs @@ -1,3 +1,5 @@ +// ignore-tidy-filelength :( + use super::on_unimplemented::{AppendConstMessage, OnUnimplementedNote, TypeErrCtxtExt as _}; use super::suggestions::{get_explanation_based_on_obligation, TypeErrCtxtExt as _}; use crate::errors::{ClosureFnMutLabel, ClosureFnOnceLabel, ClosureKindMismatch}; @@ -1926,15 +1928,42 @@ impl<'tcx> InferCtxtPrivExt<'tcx> for TypeErrCtxt<'_, 'tcx> { fn describe_coroutine(&self, body_id: hir::BodyId) -> Option<&'static str> { self.tcx.hir().body(body_id).coroutine_kind.map(|coroutine_source| match coroutine_source { hir::CoroutineKind::Coroutine => "a coroutine", - hir::CoroutineKind::Async(hir::CoroutineSource::Block) => "an async block", - hir::CoroutineKind::Async(hir::CoroutineSource::Fn) => "an async function", - hir::CoroutineKind::Async(hir::CoroutineSource::Closure) => "an async closure", - hir::CoroutineKind::AsyncGen(hir::CoroutineSource::Block) => "an async gen block", - hir::CoroutineKind::AsyncGen(hir::CoroutineSource::Fn) => "an async gen function", - hir::CoroutineKind::AsyncGen(hir::CoroutineSource::Closure) => "an async gen closure", - hir::CoroutineKind::Gen(hir::CoroutineSource::Block) => "a gen block", - hir::CoroutineKind::Gen(hir::CoroutineSource::Fn) => "a gen function", - hir::CoroutineKind::Gen(hir::CoroutineSource::Closure) => "a gen closure", + hir::CoroutineKind::Desugared( + hir::CoroutineDesugaring::Async, + hir::CoroutineSource::Block, + ) => "an async block", + hir::CoroutineKind::Desugared( + hir::CoroutineDesugaring::Async, + hir::CoroutineSource::Fn, + ) => "an async function", + hir::CoroutineKind::Desugared( + hir::CoroutineDesugaring::Async, + hir::CoroutineSource::Closure, + ) => "an async closure", + hir::CoroutineKind::Desugared( + hir::CoroutineDesugaring::AsyncGen, + hir::CoroutineSource::Block, + ) => "an async gen block", + hir::CoroutineKind::Desugared( + hir::CoroutineDesugaring::AsyncGen, + hir::CoroutineSource::Fn, + ) => "an async gen function", + hir::CoroutineKind::Desugared( + hir::CoroutineDesugaring::AsyncGen, + hir::CoroutineSource::Closure, + ) => "an async gen closure", + hir::CoroutineKind::Desugared( + hir::CoroutineDesugaring::Gen, + hir::CoroutineSource::Block, + ) => "a gen block", + hir::CoroutineKind::Desugared( + hir::CoroutineDesugaring::Gen, + hir::CoroutineSource::Fn, + ) => "a gen function", + hir::CoroutineKind::Desugared( + hir::CoroutineDesugaring::Gen, + hir::CoroutineSource::Closure, + ) => "a gen closure", }) } diff --git a/compiler/rustc_ty_utils/src/abi.rs b/compiler/rustc_ty_utils/src/abi.rs index a5f11ca23e124..86501b5a72d15 100644 --- a/compiler/rustc_ty_utils/src/abi.rs +++ b/compiler/rustc_ty_utils/src/abi.rs @@ -114,13 +114,13 @@ fn fn_sig_for_fn_abi<'tcx>( let pin_adt_ref = tcx.adt_def(pin_did); let pin_args = tcx.mk_args(&[env_ty.into()]); let env_ty = match coroutine_kind { - hir::CoroutineKind::Gen(_) => { + hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _) => { // Iterator::next doesn't accept a pinned argument, // unlike for all other coroutine kinds. env_ty } - hir::CoroutineKind::Async(_) - | hir::CoroutineKind::AsyncGen(_) + hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _) + | hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _) | hir::CoroutineKind::Coroutine => Ty::new_adt(tcx, pin_adt_ref, pin_args), }; @@ -131,7 +131,7 @@ fn fn_sig_for_fn_abi<'tcx>( // or the `Iterator::next(...) -> Option` function in case this is a // special coroutine backing a gen construct. let (resume_ty, ret_ty) = match coroutine_kind { - hir::CoroutineKind::Async(_) => { + hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, _) => { // The signature should be `Future::poll(_, &mut Context<'_>) -> Poll` assert_eq!(sig.yield_ty, tcx.types.unit); @@ -156,7 +156,7 @@ fn fn_sig_for_fn_abi<'tcx>( (Some(context_mut_ref), ret_ty) } - hir::CoroutineKind::Gen(_) => { + hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, _) => { // The signature should be `Iterator::next(_) -> Option` let option_did = tcx.require_lang_item(LangItem::Option, None); let option_adt_ref = tcx.adt_def(option_did); @@ -168,7 +168,7 @@ fn fn_sig_for_fn_abi<'tcx>( (None, ret_ty) } - hir::CoroutineKind::AsyncGen(_) => { + hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::AsyncGen, _) => { // The signature should be // `AsyncIterator::poll_next(_, &mut Context<'_>) -> Poll>` assert_eq!(sig.return_ty, tcx.types.unit); diff --git a/compiler/stable_mir/src/mir/body.rs b/compiler/stable_mir/src/mir/body.rs index b8fd9370aa618..89d75569ce3db 100644 --- a/compiler/stable_mir/src/mir/body.rs +++ b/compiler/stable_mir/src/mir/body.rs @@ -288,27 +288,33 @@ impl AssertMessage { AssertMessage::ResumedAfterReturn(CoroutineKind::Coroutine) => { Ok("coroutine resumed after completion") } - AssertMessage::ResumedAfterReturn(CoroutineKind::Async(_)) => { - Ok("`async fn` resumed after completion") - } - AssertMessage::ResumedAfterReturn(CoroutineKind::Gen(_)) => { - Ok("`async gen fn` resumed after completion") - } - AssertMessage::ResumedAfterReturn(CoroutineKind::AsyncGen(_)) => { - Ok("`gen fn` should just keep returning `AssertMessage::None` after completion") - } + AssertMessage::ResumedAfterReturn(CoroutineKind::Desugared( + CoroutineDesugaring::Async, + _, + )) => Ok("`async fn` resumed after completion"), + AssertMessage::ResumedAfterReturn(CoroutineKind::Desugared( + CoroutineDesugaring::Gen, + _, + )) => Ok("`async gen fn` resumed after completion"), + AssertMessage::ResumedAfterReturn(CoroutineKind::Desugared( + CoroutineDesugaring::AsyncGen, + _, + )) => Ok("`gen fn` should just keep returning `AssertMessage::None` after completion"), AssertMessage::ResumedAfterPanic(CoroutineKind::Coroutine) => { Ok("coroutine resumed after panicking") } - AssertMessage::ResumedAfterPanic(CoroutineKind::Async(_)) => { - Ok("`async fn` resumed after panicking") - } - AssertMessage::ResumedAfterPanic(CoroutineKind::Gen(_)) => { - Ok("`async gen fn` resumed after panicking") - } - AssertMessage::ResumedAfterPanic(CoroutineKind::AsyncGen(_)) => { - Ok("`gen fn` should just keep returning `AssertMessage::None` after panicking") - } + AssertMessage::ResumedAfterPanic(CoroutineKind::Desugared( + CoroutineDesugaring::Async, + _, + )) => Ok("`async fn` resumed after panicking"), + AssertMessage::ResumedAfterPanic(CoroutineKind::Desugared( + CoroutineDesugaring::Gen, + _, + )) => Ok("`async gen fn` resumed after panicking"), + AssertMessage::ResumedAfterPanic(CoroutineKind::Desugared( + CoroutineDesugaring::AsyncGen, + _, + )) => Ok("`gen fn` should just keep returning `AssertMessage::None` after panicking"), AssertMessage::BoundsCheck { .. } => Ok("index out of bounds"), AssertMessage::MisalignedPointerDereference { .. } => { @@ -392,10 +398,8 @@ pub enum UnOp { #[derive(Clone, Debug, Eq, PartialEq)] pub enum CoroutineKind { - Async(CoroutineSource), + Desugared(CoroutineDesugaring, CoroutineSource), Coroutine, - Gen(CoroutineSource), - AsyncGen(CoroutineSource), } #[derive(Copy, Clone, Debug, Eq, PartialEq)] @@ -405,6 +409,15 @@ pub enum CoroutineSource { Fn, } +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum CoroutineDesugaring { + Async, + + Gen, + + AsyncGen, +} + pub(crate) type LocalDefId = Opaque; /// The rustc coverage data structures are heavily tied to internal details of the /// coverage implementation that are likely to change, and are unlikely to be diff --git a/src/tools/clippy/clippy_lints/src/async_yields_async.rs b/src/tools/clippy/clippy_lints/src/async_yields_async.rs index 3e5a01c45df1d..28e6614f03fb6 100644 --- a/src/tools/clippy/clippy_lints/src/async_yields_async.rs +++ b/src/tools/clippy/clippy_lints/src/async_yields_async.rs @@ -2,7 +2,7 @@ use clippy_utils::diagnostics::span_lint_hir_and_then; use clippy_utils::source::snippet; use clippy_utils::ty::implements_trait; use rustc_errors::Applicability; -use rustc_hir::{Body, BodyId, CoroutineKind, CoroutineSource, ExprKind, QPath}; +use rustc_hir::{Body, BodyId, CoroutineKind, CoroutineSource, CoroutineDesugaring, ExprKind, QPath}; use rustc_lint::{LateContext, LateLintPass}; use rustc_session::declare_lint_pass; @@ -45,10 +45,9 @@ declare_lint_pass!(AsyncYieldsAsync => [ASYNC_YIELDS_ASYNC]); impl<'tcx> LateLintPass<'tcx> for AsyncYieldsAsync { fn check_body(&mut self, cx: &LateContext<'tcx>, body: &'tcx Body<'_>) { - use CoroutineSource::{Block, Closure}; // For functions, with explicitly defined types, don't warn. // XXXkhuey maybe we should? - if let Some(CoroutineKind::Async(Block | Closure)) = body.coroutine_kind { + if let Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, CoroutineSource::Block | CoroutineSource::Closure)) = body.coroutine_kind { if let Some(future_trait_def_id) = cx.tcx.lang_items().future_trait() { let body_id = BodyId { hir_id: body.value.hir_id, diff --git a/src/tools/clippy/clippy_lints/src/await_holding_invalid.rs b/src/tools/clippy/clippy_lints/src/await_holding_invalid.rs index 9894a1639618d..dff6e884fa118 100644 --- a/src/tools/clippy/clippy_lints/src/await_holding_invalid.rs +++ b/src/tools/clippy/clippy_lints/src/await_holding_invalid.rs @@ -3,7 +3,7 @@ use clippy_utils::diagnostics::span_lint_and_then; use clippy_utils::{match_def_path, paths}; use rustc_data_structures::fx::FxHashMap; use rustc_hir::def_id::DefId; -use rustc_hir::{Body, CoroutineKind, CoroutineSource}; +use rustc_hir::{Body, CoroutineKind, CoroutineDesugaring}; use rustc_lint::{LateContext, LateLintPass}; use rustc_middle::mir::CoroutineLayout; use rustc_session::impl_lint_pass; @@ -194,8 +194,7 @@ impl LateLintPass<'_> for AwaitHolding { } fn check_body(&mut self, cx: &LateContext<'_>, body: &'_ Body<'_>) { - use CoroutineSource::{Block, Closure, Fn}; - if let Some(CoroutineKind::Async(Block | Closure | Fn)) = body.coroutine_kind { + if let Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, _)) = body.coroutine_kind { let def_id = cx.tcx.hir().body_owner_def_id(body.id()); if let Some(coroutine_layout) = cx.tcx.mir_coroutine_witnesses(def_id) { self.check_interior_types(cx, coroutine_layout); diff --git a/src/tools/clippy/clippy_lints/src/manual_async_fn.rs b/src/tools/clippy/clippy_lints/src/manual_async_fn.rs index eaaaea0be9f89..8982ce5e196ec 100644 --- a/src/tools/clippy/clippy_lints/src/manual_async_fn.rs +++ b/src/tools/clippy/clippy_lints/src/manual_async_fn.rs @@ -3,7 +3,7 @@ use clippy_utils::source::{position_before_rarrow, snippet_block, snippet_opt}; use rustc_errors::Applicability; use rustc_hir::intravisit::FnKind; use rustc_hir::{ - Block, Body, Closure, CoroutineKind, CoroutineSource, Expr, ExprKind, FnDecl, FnRetTy, GenericArg, GenericBound, + Block, Body, Closure, CoroutineKind, CoroutineSource, CoroutineDesugaring, Expr, ExprKind, FnDecl, FnRetTy, GenericArg, GenericBound, ImplItem, Item, ItemKind, LifetimeName, Node, Term, TraitRef, Ty, TyKind, TypeBindingKind, }; use rustc_lint::{LateContext, LateLintPass}; @@ -178,7 +178,7 @@ fn desugared_async_block<'tcx>(cx: &LateContext<'tcx>, block: &'tcx Block<'tcx>) .. } = block_expr && let closure_body = cx.tcx.hir().body(body) - && closure_body.coroutine_kind == Some(CoroutineKind::Async(CoroutineSource::Block)) + && closure_body.coroutine_kind == Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, CoroutineSource::Block)) { return Some(closure_body); } diff --git a/src/tools/clippy/clippy_lints/src/needless_question_mark.rs b/src/tools/clippy/clippy_lints/src/needless_question_mark.rs index a4d3aaf0de988..350707d3a1361 100644 --- a/src/tools/clippy/clippy_lints/src/needless_question_mark.rs +++ b/src/tools/clippy/clippy_lints/src/needless_question_mark.rs @@ -3,7 +3,7 @@ use clippy_utils::path_res; use clippy_utils::source::snippet; use rustc_errors::Applicability; use rustc_hir::def::{DefKind, Res}; -use rustc_hir::{Block, Body, CoroutineKind, CoroutineSource, Expr, ExprKind, LangItem, MatchSource, QPath}; +use rustc_hir::{Block, Body, CoroutineKind, CoroutineSource, CoroutineDesugaring, Expr, ExprKind, LangItem, MatchSource, QPath}; use rustc_lint::{LateContext, LateLintPass}; use rustc_session::declare_lint_pass; @@ -86,7 +86,7 @@ impl LateLintPass<'_> for NeedlessQuestionMark { } fn check_body(&mut self, cx: &LateContext<'_>, body: &'_ Body<'_>) { - if let Some(CoroutineKind::Async(CoroutineSource::Fn)) = body.coroutine_kind { + if let Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, CoroutineSource::Fn)) = body.coroutine_kind { if let ExprKind::Block( Block { expr: diff --git a/src/tools/clippy/clippy_lints/src/redundant_async_block.rs b/src/tools/clippy/clippy_lints/src/redundant_async_block.rs index 19d9d64b31e3c..4b3fe9c0bb556 100644 --- a/src/tools/clippy/clippy_lints/src/redundant_async_block.rs +++ b/src/tools/clippy/clippy_lints/src/redundant_async_block.rs @@ -5,7 +5,7 @@ use clippy_utils::peel_blocks; use clippy_utils::source::{snippet, walk_span_to_context}; use clippy_utils::visitors::for_each_expr; use rustc_errors::Applicability; -use rustc_hir::{Closure, CoroutineKind, CoroutineSource, Expr, ExprKind, MatchSource}; +use rustc_hir::{Closure, CoroutineKind, CoroutineSource, CoroutineDesugaring, Expr, ExprKind, MatchSource}; use rustc_lint::{LateContext, LateLintPass}; use rustc_middle::lint::in_external_macro; use rustc_middle::ty::UpvarCapture; @@ -71,7 +71,7 @@ impl<'tcx> LateLintPass<'tcx> for RedundantAsyncBlock { fn desugar_async_block<'tcx>(cx: &LateContext<'tcx>, expr: &'tcx Expr<'_>) -> Option<&'tcx Expr<'tcx>> { if let ExprKind::Closure(Closure { body, def_id, .. }) = expr.kind && let body = cx.tcx.hir().body(*body) - && matches!(body.coroutine_kind, Some(CoroutineKind::Async(CoroutineSource::Block))) + && matches!(body.coroutine_kind, Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, CoroutineSource::Block))) { cx.typeck_results() .closure_min_captures diff --git a/src/tools/clippy/clippy_lints/src/redundant_closure_call.rs b/src/tools/clippy/clippy_lints/src/redundant_closure_call.rs index 8bac2e40e0128..9312a9c89b789 100644 --- a/src/tools/clippy/clippy_lints/src/redundant_closure_call.rs +++ b/src/tools/clippy/clippy_lints/src/redundant_closure_call.rs @@ -5,7 +5,7 @@ use clippy_utils::sugg::Sugg; use rustc_errors::Applicability; use rustc_hir as hir; use rustc_hir::intravisit::{Visitor as HirVisitor, Visitor}; -use rustc_hir::{intravisit as hir_visit, CoroutineKind, CoroutineSource, Node}; +use rustc_hir::{intravisit as hir_visit, CoroutineKind, CoroutineSource, CoroutineDesugaring, Node}; use rustc_lint::{LateContext, LateLintPass}; use rustc_middle::hir::nested_filter; use rustc_middle::lint::in_external_macro; @@ -67,7 +67,7 @@ fn is_async_closure(cx: &LateContext<'_>, body: &hir::Body<'_>) -> bool { if let hir::ExprKind::Closure(innermost_closure_generated_by_desugar) = body.value.kind && let desugared_inner_closure_body = cx.tcx.hir().body(innermost_closure_generated_by_desugar.body) // checks whether it is `async || whatever_expression` - && let Some(CoroutineKind::Async(CoroutineSource::Closure)) = desugared_inner_closure_body.coroutine_kind + && let Some(CoroutineKind::Desugared(CoroutineDesugaring::Async, CoroutineSource::Closure)) = desugared_inner_closure_body.coroutine_kind { true } else { diff --git a/src/tools/clippy/clippy_lints/src/unused_async.rs b/src/tools/clippy/clippy_lints/src/unused_async.rs index 9c8c44c0a16df..f71fe4e1e92e1 100644 --- a/src/tools/clippy/clippy_lints/src/unused_async.rs +++ b/src/tools/clippy/clippy_lints/src/unused_async.rs @@ -86,7 +86,13 @@ impl<'a, 'tcx> Visitor<'tcx> for AsyncFnVisitor<'a, 'tcx> { } fn visit_body(&mut self, b: &'tcx Body<'tcx>) { - let is_async_block = matches!(b.coroutine_kind, Some(rustc_hir::CoroutineKind::Async(_))); + let is_async_block = matches!( + b.coroutine_kind, + Some(rustc_hir::CoroutineKind::Desugared( + rustc_hir::CoroutineDesugaring::Async, + _ + )) + ); if is_async_block { self.async_depth += 1;