Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FnMut::call_mut/Fn::call shim for async closures that capture references #127136

Merged
merged 2 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_ssa/src/mir/locals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let expected_ty = self.monomorphize(self.mir.local_decls[local].ty);
if expected_ty != op.layout.ty {
warn!(
"Unexpected initial operand type: expected {expected_ty:?}, found {:?}.\
"Unexpected initial operand type:\nexpected {expected_ty:?},\nfound {:?}.\n\
See <https://github.com/rust-lang/rust/issues/114858>.",
op.layout.ty
);
Expand Down
49 changes: 31 additions & 18 deletions compiler/rustc_mir_transform/src/shim.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
use rustc_hir as hir;
use rustc_hir::def_id::DefId;
use rustc_hir::lang_items::LangItem;
use rustc_index::{Idx, IndexVec};
use rustc_middle::mir::*;
use rustc_middle::query::Providers;
use rustc_middle::ty::GenericArgs;
use rustc_middle::ty::{self, CoroutineArgs, CoroutineArgsExt, EarlyBinder, Ty, TyCtxt};
use rustc_middle::{bug, span_bug};
use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT};

use rustc_index::{Idx, IndexVec};

use rustc_span::{source_map::Spanned, Span, DUMMY_SP};
use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT};
use rustc_target::spec::abi::Abi;

use std::assert_matches::assert_matches;
use std::fmt;
use std::iter;

Expand Down Expand Up @@ -1020,21 +1019,19 @@ fn build_construct_coroutine_by_move_shim<'tcx>(
receiver_by_ref: bool,
) -> Body<'tcx> {
let mut self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity();
let mut self_local: Place<'tcx> = Local::from_usize(1).into();
let ty::CoroutineClosure(_, args) = *self_ty.kind() else {
bug!();
};

// We use `&mut Self` here because we only need to emit an ABI-compatible shim body,
// rather than match the signature exactly (which might take `&self` instead).
// We use `&Self` here because we only need to emit an ABI-compatible shim body,
// rather than match the signature exactly (which might take `&mut self` instead).
//
// The self type here is a coroutine-closure, not a coroutine, and we never read from
// it because it never has any captures, because this is only true in the Fn/FnMut
// implementation, not the AsyncFn/AsyncFnMut implementation, which is implemented only
// if the coroutine-closure has no captures.
// We adjust the `self_local` to be a deref since we want to copy fields out of
// a reference to the closure.
if receiver_by_ref {
// Triple-check that there's no captures here.
assert_eq!(args.as_coroutine_closure().tupled_upvars_ty(), tcx.types.unit);
self_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, self_ty);
self_local = tcx.mk_place_deref(self_local);
self_ty = Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, self_ty);
}

let poly_sig = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
Expand Down Expand Up @@ -1067,11 +1064,27 @@ fn build_construct_coroutine_by_move_shim<'tcx>(
fields.push(Operand::Move(Local::from_usize(idx + 1).into()));
}
for (idx, ty) in args.as_coroutine_closure().upvar_tys().iter().enumerate() {
fields.push(Operand::Move(tcx.mk_place_field(
Local::from_usize(1).into(),
FieldIdx::from_usize(idx),
ty,
)));
if receiver_by_ref {
// The only situation where it's possible is when we capture immuatable references,
// since those don't need to be reborrowed with the closure's env lifetime. Since
// references are always `Copy`, just emit a copy.
assert_matches!(
ty.kind(),
ty::Ref(_, _, hir::Mutability::Not),
"field should be captured by immutable ref if we have an `Fn` instance"
);
fields.push(Operand::Copy(tcx.mk_place_field(
self_local,
FieldIdx::from_usize(idx),
ty,
)));
} else {
fields.push(Operand::Move(tcx.mk_place_field(
self_local,
FieldIdx::from_usize(idx),
ty,
)));
}
}

let source_info = SourceInfo::outermost(span);
Expand Down
10 changes: 7 additions & 3 deletions compiler/rustc_symbol_mangling/src/legacy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,13 @@ pub(super) fn mangle<'tcx>(
}
// FIXME(async_closures): This shouldn't be needed when we fix
// `Instance::ty`/`Instance::def_id`.
ty::InstanceKind::ConstructCoroutineInClosureShim { .. }
| ty::InstanceKind::CoroutineKindShim { .. } => {
printer.write_str("{{fn-once-shim}}").unwrap();
ty::InstanceKind::ConstructCoroutineInClosureShim { receiver_by_ref, .. } => {
printer
.write_str(if receiver_by_ref { "{{by-move-shim}}" } else { "{{by-ref-shim}}" })
.unwrap();
}
ty::InstanceKind::CoroutineKindShim { .. } => {
printer.write_str("{{by-move-body-shim}}").unwrap();
}
_ => {}
}
Expand Down
11 changes: 9 additions & 2 deletions compiler/rustc_symbol_mangling/src/v0.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,15 @@ pub(super) fn mangle<'tcx>(
ty::InstanceKind::ReifyShim(_, Some(ReifyReason::FnPtr)) => Some("reify_fnptr"),
ty::InstanceKind::ReifyShim(_, Some(ReifyReason::Vtable)) => Some("reify_vtable"),

ty::InstanceKind::ConstructCoroutineInClosureShim { .. }
| ty::InstanceKind::CoroutineKindShim { .. } => Some("fn_once"),
// FIXME(async_closures): This shouldn't be needed when we fix
// `Instance::ty`/`Instance::def_id`.
ty::InstanceKind::ConstructCoroutineInClosureShim { receiver_by_ref: true, .. } => {
Some("by_move")
}
ty::InstanceKind::ConstructCoroutineInClosureShim { receiver_by_ref: false, .. } => {
Some("by_ref")
}
ty::InstanceKind::CoroutineKindShim { .. } => Some("by_move_body"),

_ => None,
};
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_ty_utils/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ fn fn_sig_for_fn_abi<'tcx>(
coroutine_kind = ty::ClosureKind::FnOnce;

// Implementations of `FnMut` and `Fn` for coroutine-closures
// still take their receiver by (mut) ref.
// still take their receiver by ref.
if receiver_by_ref {
Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty)
Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty)
} else {
coroutine_ty
}
Expand Down
21 changes: 14 additions & 7 deletions src/tools/miri/tests/pass/async-closure.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#![feature(async_closure, noop_waker, async_fn_traits)]
#![allow(unused)]

use std::future::Future;
use std::ops::{AsyncFnMut, AsyncFnOnce};
use std::ops::{AsyncFn, AsyncFnMut, AsyncFnOnce};
use std::pin::pin;
use std::task::*;

Expand All @@ -17,6 +18,10 @@ pub fn block_on<T>(fut: impl Future<Output = T>) -> T {
}
}

async fn call(f: &mut impl AsyncFn(i32)) {
f(0).await;
}

async fn call_mut(f: &mut impl AsyncFnMut(i32)) {
f(0).await;
}
Expand All @@ -26,10 +31,10 @@ async fn call_once(f: impl AsyncFnOnce(i32)) {
}

async fn call_normal<F: Future<Output = ()>>(f: &impl Fn(i32) -> F) {
f(0).await;
f(1).await;
}

async fn call_normal_once<F: Future<Output = ()>>(f: impl FnOnce(i32) -> F) {
async fn call_normal_mut<F: Future<Output = ()>>(f: &mut impl FnMut(i32) -> F) {
f(1).await;
}

Expand All @@ -39,14 +44,16 @@ pub fn main() {
let mut async_closure = async move |a: i32| {
println!("{a} {b}");
};
call(&mut async_closure).await;
call_mut(&mut async_closure).await;
call_once(async_closure).await;

// No-capture closures implement `Fn`.
let async_closure = async move |a: i32| {
println!("{a}");
let b = 2i32;
let mut async_closure = async |a: i32| {
println!("{a} {b}");
};
call_normal(&async_closure).await;
call_normal_once(async_closure).await;
call_normal_mut(&mut async_closure).await;
call_once(async_closure).await;
});
}
6 changes: 4 additions & 2 deletions src/tools/miri/tests/pass/async-closure.stdout
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
0 2
0 2
1 2
1 2
1 2
1 2
0
1
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// MIR for `main::{closure#0}::{closure#0}::{closure#0}` 0 coroutine_by_move

fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10}, _2: ResumeTy) -> ()
fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:53:53: 56:10}, _2: ResumeTy) -> ()
yields ()
{
debug _task_context => _2;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// MIR for `main::{closure#0}::{closure#0}::{closure#0}` 0 coroutine_by_move

fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10}, _2: ResumeTy) -> ()
fn main::{closure#0}::{closure#0}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:53:53: 56:10}, _2: ResumeTy) -> ()
yields ()
{
debug _task_context => _2;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// MIR for `main::{closure#0}::{closure#0}` 0 coroutine_closure_by_move

fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:42:33: 42:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10};
fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:53:33: 53:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:53:53: 56:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:53:53: 56:10};

bb0: {
_0 = {coroutine@$DIR/async_closure_shims.rs:42:53: 45:10 (#0)} { a: move _2, b: move (_1.0: i32) };
_0 = {coroutine@$DIR/async_closure_shims.rs:53:53: 56:10 (#0)} { a: move _2, b: move (_1.0: i32) };
return;
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// MIR for `main::{closure#0}::{closure#0}` 0 coroutine_closure_by_move

fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:42:33: 42:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:42:53: 45:10};
fn main::{closure#0}::{closure#0}(_1: {async closure@$DIR/async_closure_shims.rs:53:33: 53:52}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:53:53: 56:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:53:53: 56:10};

bb0: {
_0 = {coroutine@$DIR/async_closure_shims.rs:42:53: 45:10 (#0)} { a: move _2, b: move (_1.0: i32) };
_0 = {coroutine@$DIR/async_closure_shims.rs:53:53: 56:10 (#0)} { a: move _2, b: move (_1.0: i32) };
return;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// MIR for `main::{closure#0}::{closure#1}::{closure#0}` 0 coroutine_by_move

fn main::{closure#0}::{closure#1}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10}, _2: ResumeTy) -> ()
yields ()
{
debug _task_context => _2;
debug a => (_1.0: i32);
debug b => (*(_1.1: &i32));
let mut _0: ();
let _3: i32;
scope 1 {
debug a => _3;
let _4: &i32;
scope 2 {
debug a => _4;
let _5: &i32;
scope 3 {
debug b => _5;
}
}
}

bb0: {
StorageLive(_3);
_3 = (_1.0: i32);
FakeRead(ForLet(None), _3);
StorageLive(_4);
_4 = &_3;
FakeRead(ForLet(None), _4);
StorageLive(_5);
_5 = &(*(_1.1: &i32));
FakeRead(ForLet(None), _5);
_0 = const ();
StorageDead(_5);
StorageDead(_4);
StorageDead(_3);
drop(_1) -> [return: bb1, unwind: bb2];
}

bb1: {
return;
}

bb2 (cleanup): {
resume;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// MIR for `main::{closure#0}::{closure#1}::{closure#0}` 0 coroutine_by_move

fn main::{closure#0}::{closure#1}::{closure#0}(_1: {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10}, _2: ResumeTy) -> ()
yields ()
{
debug _task_context => _2;
debug a => (_1.0: i32);
debug b => (*(_1.1: &i32));
let mut _0: ();
let _3: i32;
scope 1 {
debug a => _3;
let _4: &i32;
scope 2 {
debug a => _4;
let _5: &i32;
scope 3 {
debug b => _5;
}
}
}

bb0: {
StorageLive(_3);
_3 = (_1.0: i32);
FakeRead(ForLet(None), _3);
StorageLive(_4);
_4 = &_3;
FakeRead(ForLet(None), _4);
StorageLive(_5);
_5 = &(*(_1.1: &i32));
FakeRead(ForLet(None), _5);
_0 = const ();
StorageDead(_5);
StorageDead(_4);
StorageDead(_3);
drop(_1) -> [return: bb1, unwind: bb2];
}

bb1: {
return;
}

bb2 (cleanup): {
resume;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// MIR for `main::{closure#0}::{closure#1}` 0 coroutine_closure_by_move

fn main::{closure#0}::{closure#1}(_1: {async closure@$DIR/async_closure_shims.rs:62:33: 62:47}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10};

bb0: {
_0 = {coroutine@$DIR/async_closure_shims.rs:62:48: 65:10 (#0)} { a: move _2, b: move (_1.0: &i32) };
return;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// MIR for `main::{closure#0}::{closure#1}` 0 coroutine_closure_by_move

fn main::{closure#0}::{closure#1}(_1: {async closure@$DIR/async_closure_shims.rs:62:33: 62:47}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10};

bb0: {
_0 = {coroutine@$DIR/async_closure_shims.rs:62:48: 65:10 (#0)} { a: move _2, b: move (_1.0: &i32) };
return;
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// MIR for `main::{closure#0}::{closure#1}` 0 coroutine_closure_by_ref

fn main::{closure#0}::{closure#1}(_1: &mut {async closure@$DIR/async_closure_shims.rs:49:29: 49:48}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10};
fn main::{closure#0}::{closure#1}(_1: &{async closure@$DIR/async_closure_shims.rs:62:33: 62:47}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10};

bb0: {
_0 = {coroutine@$DIR/async_closure_shims.rs:49:49: 51:10 (#0)} { a: move _2 };
_0 = {coroutine@$DIR/async_closure_shims.rs:62:48: 65:10 (#0)} { a: move _2, b: ((*_1).0: &i32) };
return;
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// MIR for `main::{closure#0}::{closure#1}` 0 coroutine_closure_by_ref

fn main::{closure#0}::{closure#1}(_1: &mut {async closure@$DIR/async_closure_shims.rs:49:29: 49:48}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:49:49: 51:10};
fn main::{closure#0}::{closure#1}(_1: &{async closure@$DIR/async_closure_shims.rs:62:33: 62:47}, _2: i32) -> {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10} {
let mut _0: {async closure body@$DIR/async_closure_shims.rs:62:48: 65:10};

bb0: {
_0 = {coroutine@$DIR/async_closure_shims.rs:49:49: 51:10 (#0)} { a: move _2 };
_0 = {coroutine@$DIR/async_closure_shims.rs:62:48: 65:10 (#0)} { a: move _2, b: ((*_1).0: &i32) };
return;
}
}
Loading
Loading