Skip to content

Commit

Permalink
Auto merge of #118847 - eholk:for-await, r=compiler-errors
Browse files Browse the repository at this point in the history
Add support for `for await` loops

This adds support for `for await` loops. This includes parsing, desugaring in AST->HIR lowering, and adding some support functions to the library.

Given a loop like:
```rust
for await i in iter {
    ...
}
```
this is desugared to something like:
```rust
let mut iter = iter.into_async_iter();
while let Some(i) = loop {
    match core::pin::Pin::new(&mut iter).poll_next(cx) {
        Poll::Ready(i) => break i,
        Poll::Pending => yield,
    }
} {
    ...
}
```

This PR also adds a basic `IntoAsyncIterator` trait. This is partly for symmetry with the way `Iterator` and `IntoIterator` work. The other reason is that for async iterators it's helpful to have a place apart from the data structure being iterated over to store state. `IntoAsyncIterator` gives us a good place to do this.

I've gated this feature behind `async_for_loop` and opened #118898 as the feature tracking issue.

r? `@compiler-errors`
  • Loading branch information
bors committed Dec 22, 2023
2 parents c1fc1d1 + 397f4a1 commit 208dd20
Show file tree
Hide file tree
Showing 34 changed files with 367 additions and 79 deletions.
13 changes: 10 additions & 3 deletions compiler/rustc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1249,7 +1249,7 @@ impl Expr {
ExprKind::Let(..) => ExprPrecedence::Let,
ExprKind::If(..) => ExprPrecedence::If,
ExprKind::While(..) => ExprPrecedence::While,
ExprKind::ForLoop(..) => ExprPrecedence::ForLoop,
ExprKind::ForLoop { .. } => ExprPrecedence::ForLoop,
ExprKind::Loop(..) => ExprPrecedence::Loop,
ExprKind::Match(..) => ExprPrecedence::Match,
ExprKind::Closure(..) => ExprPrecedence::Closure,
Expand Down Expand Up @@ -1411,10 +1411,10 @@ pub enum ExprKind {
While(P<Expr>, P<Block>, Option<Label>),
/// A `for` loop, with an optional label.
///
/// `'label: for pat in expr { block }`
/// `'label: for await? pat in iter { block }`
///
/// This is desugared to a combination of `loop` and `match` expressions.
ForLoop(P<Pat>, P<Expr>, P<Block>, Option<Label>),
ForLoop { pat: P<Pat>, iter: P<Expr>, body: P<Block>, label: Option<Label>, kind: ForLoopKind },
/// Conditionless loop (can be exited with `break`, `continue`, or `return`).
///
/// `'label: loop { block }`
Expand Down Expand Up @@ -1517,6 +1517,13 @@ pub enum ExprKind {
Err,
}

/// Used to differentiate between `for` loops and `for await` loops.
#[derive(Clone, Copy, Encodable, Decodable, Debug, PartialEq, Eq)]
pub enum ForLoopKind {
For,
ForAwait,
}

/// Used to differentiate between `async {}` blocks and `gen {}` blocks.
#[derive(Clone, Encodable, Decodable, Debug, PartialEq, Eq)]
pub enum GenBlockKind {
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_ast/src/mut_visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1389,7 +1389,7 @@ pub fn noop_visit_expr<T: MutVisitor>(
vis.visit_block(body);
visit_opt(label, |label| vis.visit_label(label));
}
ExprKind::ForLoop(pat, iter, body, label) => {
ExprKind::ForLoop { pat, iter, body, label, kind: _ } => {
vis.visit_pat(pat);
vis.visit_expr(iter);
vis.visit_block(body);
Expand Down
14 changes: 11 additions & 3 deletions compiler/rustc_ast/src/util/classify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub fn expr_requires_semi_to_be_stmt(e: &ast::Expr) -> bool {
| ast::ExprKind::Block(..)
| ast::ExprKind::While(..)
| ast::ExprKind::Loop(..)
| ast::ExprKind::ForLoop(..)
| ast::ExprKind::ForLoop { .. }
| ast::ExprKind::TryBlock(..)
| ast::ExprKind::ConstBlock(..)
)
Expand Down Expand Up @@ -48,8 +48,16 @@ pub fn expr_trailing_brace(mut expr: &ast::Expr) -> Option<&ast::Expr> {
Closure(closure) => {
expr = &closure.body;
}
Gen(..) | Block(..) | ForLoop(..) | If(..) | Loop(..) | Match(..) | Struct(..)
| TryBlock(..) | While(..) | ConstBlock(_) => break Some(expr),
Gen(..)
| Block(..)
| ForLoop { .. }
| If(..)
| Loop(..)
| Match(..)
| Struct(..)
| TryBlock(..)
| While(..)
| ConstBlock(_) => break Some(expr),

// FIXME: These can end in `}`, but changing these would break stable code.
InlineAsm(_) | OffsetOf(_, _) | MacCall(_) | IncludedBytes(_) | FormatArgs(_) => {
Expand Down
10 changes: 5 additions & 5 deletions compiler/rustc_ast/src/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -844,11 +844,11 @@ pub fn walk_expr<'a, V: Visitor<'a>>(visitor: &mut V, expression: &'a Expr) {
visitor.visit_expr(subexpression);
visitor.visit_block(block);
}
ExprKind::ForLoop(pattern, subexpression, block, opt_label) => {
walk_list!(visitor, visit_label, opt_label);
visitor.visit_pat(pattern);
visitor.visit_expr(subexpression);
visitor.visit_block(block);
ExprKind::ForLoop { pat, iter, body, label, kind: _ } => {
walk_list!(visitor, visit_label, label);
visitor.visit_pat(pat);
visitor.visit_expr(iter);
visitor.visit_block(body);
}
ExprKind::Loop(block, opt_label, _) => {
walk_list!(visitor, visit_label, opt_label);
Expand Down
125 changes: 95 additions & 30 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ impl<'hir> LoweringContext<'_, 'hir> {
return ex;
}
// Desugar `ExprForLoop`
// from: `[opt_ident]: for <pat> in <head> <body>`
// from: `[opt_ident]: for await? <pat> in <iter> <body>`
//
// This also needs special handling because the HirId of the returned `hir::Expr` will not
// correspond to the `e.id`, so `lower_expr_for` handles attribute lowering itself.
ExprKind::ForLoop(pat, head, body, opt_label) => {
return self.lower_expr_for(e, pat, head, body, *opt_label);
ExprKind::ForLoop { pat, iter, body, label, kind } => {
return self.lower_expr_for(e, pat, iter, body, *label, *kind);
}
_ => (),
}
Expand Down Expand Up @@ -337,7 +337,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
),
ExprKind::Try(sub_expr) => self.lower_expr_try(e.span, sub_expr),

ExprKind::Paren(_) | ExprKind::ForLoop(..) => {
ExprKind::Paren(_) | ExprKind::ForLoop { .. } => {
unreachable!("already handled")
}

Expand Down Expand Up @@ -874,6 +874,17 @@ impl<'hir> LoweringContext<'_, 'hir> {
/// }
/// ```
fn lower_expr_await(&mut self, await_kw_span: Span, expr: &Expr) -> hir::ExprKind<'hir> {
let expr = self.arena.alloc(self.lower_expr_mut(expr));
self.make_lowered_await(await_kw_span, expr, FutureKind::Future)
}

/// Takes an expr that has already been lowered and generates a desugared await loop around it
fn make_lowered_await(
&mut self,
await_kw_span: Span,
expr: &'hir hir::Expr<'hir>,
await_kind: FutureKind,
) -> hir::ExprKind<'hir> {
let full_span = expr.span.to(await_kw_span);

let is_async_gen = match self.coroutine_kind {
Expand All @@ -887,13 +898,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
}
};

let span = self.mark_span_with_reason(DesugaringKind::Await, await_kw_span, None);
let features = match await_kind {
FutureKind::Future => None,
FutureKind::AsyncIterator => Some(self.allow_for_await.clone()),
};
let span = self.mark_span_with_reason(DesugaringKind::Await, await_kw_span, features);
let gen_future_span = self.mark_span_with_reason(
DesugaringKind::Await,
full_span,
Some(self.allow_gen_future.clone()),
);
let expr = self.lower_expr_mut(expr);
let expr_hir_id = expr.hir_id;

// Note that the name of this binding must not be changed to something else because
Expand Down Expand Up @@ -934,11 +948,18 @@ impl<'hir> LoweringContext<'_, 'hir> {
hir::LangItem::GetContext,
arena_vec![self; task_context],
);
let call = self.expr_call_lang_item_fn(
span,
hir::LangItem::FuturePoll,
arena_vec![self; new_unchecked, get_context],
);
let call = match await_kind {
FutureKind::Future => self.expr_call_lang_item_fn(
span,
hir::LangItem::FuturePoll,
arena_vec![self; new_unchecked, get_context],
),
FutureKind::AsyncIterator => self.expr_call_lang_item_fn(
span,
hir::LangItem::AsyncIteratorPollNext,
arena_vec![self; new_unchecked, get_context],
),
};
self.arena.alloc(self.expr_unsafe(call))
};

Expand Down Expand Up @@ -1020,11 +1041,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
let awaitee_arm = self.arm(awaitee_pat, loop_expr);

// `match ::std::future::IntoFuture::into_future(<expr>) { ... }`
let into_future_expr = self.expr_call_lang_item_fn(
span,
hir::LangItem::IntoFutureIntoFuture,
arena_vec![self; expr],
);
let into_future_expr = match await_kind {
FutureKind::Future => self.expr_call_lang_item_fn(
span,
hir::LangItem::IntoFutureIntoFuture,
arena_vec![self; *expr],
),
// Not needed for `for await` because we expect to have already called
// `IntoAsyncIterator::into_async_iter` on it.
FutureKind::AsyncIterator => expr,
};

// match <into_future_expr> {
// mut __awaitee => loop { .. }
Expand Down Expand Up @@ -1685,6 +1711,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
head: &Expr,
body: &Block,
opt_label: Option<Label>,
loop_kind: ForLoopKind,
) -> hir::Expr<'hir> {
let head = self.lower_expr_mut(head);
let pat = self.lower_pat(pat);
Expand Down Expand Up @@ -1713,17 +1740,41 @@ impl<'hir> LoweringContext<'_, 'hir> {
let (iter_pat, iter_pat_nid) =
self.pat_ident_binding_mode(head_span, iter, hir::BindingAnnotation::MUT);

// `match Iterator::next(&mut iter) { ... }`
let match_expr = {
let iter = self.expr_ident(head_span, iter, iter_pat_nid);
let ref_mut_iter = self.expr_mut_addr_of(head_span, iter);
let next_expr = self.expr_call_lang_item_fn(
head_span,
hir::LangItem::IteratorNext,
arena_vec![self; ref_mut_iter],
);
let next_expr = match loop_kind {
ForLoopKind::For => {
// `Iterator::next(&mut iter)`
let ref_mut_iter = self.expr_mut_addr_of(head_span, iter);
self.expr_call_lang_item_fn(
head_span,
hir::LangItem::IteratorNext,
arena_vec![self; ref_mut_iter],
)
}
ForLoopKind::ForAwait => {
// we'll generate `unsafe { Pin::new_unchecked(&mut iter) })` and then pass this
// to make_lowered_await with `FutureKind::AsyncIterator` which will generator
// calls to `poll_next`. In user code, this would probably be a call to
// `Pin::as_mut` but here it's easy enough to do `new_unchecked`.

// `&mut iter`
let iter = self.expr_mut_addr_of(head_span, iter);
// `Pin::new_unchecked(...)`
let iter = self.arena.alloc(self.expr_call_lang_item_fn_mut(
head_span,
hir::LangItem::PinNewUnchecked,
arena_vec![self; iter],
));
// `unsafe { ... }`
let iter = self.arena.alloc(self.expr_unsafe(iter));
let kind = self.make_lowered_await(head_span, iter, FutureKind::AsyncIterator);
self.arena.alloc(hir::Expr { hir_id: self.next_id(), kind, span: head_span })
}
};
let arms = arena_vec![self; none_arm, some_arm];

// `match $next_expr { ... }`
self.expr_match(head_span, next_expr, arms, hir::MatchSource::ForLoopDesugar)
};
let match_stmt = self.stmt_expr(for_span, match_expr);
Expand All @@ -1743,13 +1794,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
// `mut iter => { ... }`
let iter_arm = self.arm(iter_pat, loop_expr);

// `match ::std::iter::IntoIterator::into_iter(<head>) { ... }`
let into_iter_expr = {
self.expr_call_lang_item_fn(
head_span,
hir::LangItem::IntoIterIntoIter,
arena_vec![self; head],
)
let into_iter_expr = match loop_kind {
ForLoopKind::For => {
// `::std::iter::IntoIterator::into_iter(<head>)`
self.expr_call_lang_item_fn(
head_span,
hir::LangItem::IntoIterIntoIter,
arena_vec![self; head],
)
}
ForLoopKind::ForAwait => self.arena.alloc(head),
};

let match_expr = self.arena.alloc(self.expr_match(
Expand Down Expand Up @@ -2152,3 +2206,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
}
}
}

/// Used by [`LoweringContext::make_lowered_await`] to customize the desugaring based on what kind
/// of future we are awaiting.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
enum FutureKind {
/// We are awaiting a normal future
Future,
/// We are awaiting something that's known to be an AsyncIterator (i.e. we are in the header of
/// a `for await` loop)
AsyncIterator,
}
2 changes: 2 additions & 0 deletions compiler/rustc_ast_lowering/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ struct LoweringContext<'a, 'hir> {
allow_try_trait: Lrc<[Symbol]>,
allow_gen_future: Lrc<[Symbol]>,
allow_async_iterator: Lrc<[Symbol]>,
allow_for_await: Lrc<[Symbol]>,

/// Mapping from generics `def_id`s to TAIT generics `def_id`s.
/// For each captured lifetime (e.g., 'a), we create a new lifetime parameter that is a generic
Expand Down Expand Up @@ -174,6 +175,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
} else {
[sym::gen_future].into()
},
allow_for_await: [sym::async_iterator].into(),
// FIXME(gen_blocks): how does `closure_track_caller`/`async_fn_track_caller`
// interact with `gen`/`async gen` blocks
allow_async_iterator: [sym::gen_future, sym::async_iterator].into(),
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_ast_passes/src/feature_gate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ pub fn check_crate(krate: &ast::Crate, sess: &Session, features: &Features) {
"async closures are unstable",
"to use an async block, remove the `||`: `async {`"
);
gate_all!(async_for_loop, "`for await` loops are experimental");
gate_all!(
closure_lifetime_binder,
"`for<...>` binders for closures are experimental",
Expand Down
10 changes: 7 additions & 3 deletions compiler/rustc_ast_pretty/src/pprust/state/expr.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::pp::Breaks::Inconsistent;
use crate::pprust::state::{AnnNode, PrintState, State, INDENT_UNIT};
use ast::ForLoopKind;
use itertools::{Itertools, Position};
use rustc_ast::ptr::P;
use rustc_ast::token;
Expand Down Expand Up @@ -418,20 +419,23 @@ impl<'a> State<'a> {
self.space();
self.print_block_with_attrs(blk, attrs);
}
ast::ExprKind::ForLoop(pat, iter, blk, opt_label) => {
if let Some(label) = opt_label {
ast::ExprKind::ForLoop { pat, iter, body, label, kind } => {
if let Some(label) = label {
self.print_ident(label.ident);
self.word_space(":");
}
self.cbox(0);
self.ibox(0);
self.word_nbsp("for");
if kind == &ForLoopKind::ForAwait {
self.word_nbsp("await");
}
self.print_pat(pat);
self.space();
self.word_space("in");
self.print_expr_as_cond(iter);
self.space();
self.print_block_with_attrs(blk, attrs);
self.print_block_with_attrs(body, attrs);
}
ast::ExprKind::Loop(blk, opt_label, _) => {
if let Some(label) = opt_label {
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_builtin_macros/src/assert/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ impl<'cx, 'a> Context<'cx, 'a> {
| ExprKind::Continue(_)
| ExprKind::Err
| ExprKind::Field(_, _)
| ExprKind::ForLoop(_, _, _, _)
| ExprKind::ForLoop { .. }
| ExprKind::FormatArgs(_)
| ExprKind::IncludedBytes(..)
| ExprKind::InlineAsm(_)
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_feature/src/unstable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ declare_features! (
(unstable, async_closure, "1.37.0", Some(62290)),
/// Allows `#[track_caller]` on async functions.
(unstable, async_fn_track_caller, "1.73.0", Some(110011)),
/// Allows `for await` loops.
(unstable, async_for_loop, "CURRENT_RUSTC_VERSION", Some(118898)),
/// Allows builtin # foo() syntax
(unstable, builtin_syntax, "1.71.0", Some(110680)),
/// Treat `extern "C"` function as nounwind.
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_hir/src/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ language_item_table! {
Context, sym::Context, context, Target::Struct, GenericRequirement::None;
FuturePoll, sym::poll, future_poll_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;

AsyncIteratorPollNext, sym::async_iterator_poll_next, async_iterator_poll_next, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::Exact(0);

Option, sym::Option, option_type, Target::Enum, GenericRequirement::None;
OptionSome, sym::Some, option_some_variant, Target::Variant, GenericRequirement::None;
OptionNone, sym::None, option_none_variant, Target::Variant, GenericRequirement::None;
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_lint/src/unused.rs
Original file line number Diff line number Diff line change
Expand Up @@ -852,8 +852,8 @@ trait UnusedDelimLint {
(cond, UnusedDelimsCtx::WhileCond, true, Some(left), Some(right), true)
}

ForLoop(_, ref cond, ref block, ..) => {
(cond, UnusedDelimsCtx::ForIterExpr, true, None, Some(block.span.lo()), true)
ForLoop { ref iter, ref body, .. } => {
(iter, UnusedDelimsCtx::ForIterExpr, true, None, Some(body.span.lo()), true)
}

Match(ref head, _) if Self::LINT_EXPR_IN_PATTERN_MATCHING_CTX => {
Expand Down Expand Up @@ -1085,7 +1085,7 @@ impl EarlyLintPass for UnusedParens {
}

match e.kind {
ExprKind::Let(ref pat, _, _, _) | ExprKind::ForLoop(ref pat, ..) => {
ExprKind::Let(ref pat, _, _, _) | ExprKind::ForLoop { ref pat, .. } => {
self.check_unused_parens_pat(cx, pat, false, false, (true, true));
}
// We ignore parens in cases like `if (((let Some(0) = Some(1))))` because we already
Expand Down
Loading

0 comments on commit 208dd20

Please sign in to comment.