Skip to content

Commit

Permalink
relocate upvars to Unresumed state and make coroutine prefix trivial
Browse files Browse the repository at this point in the history
Co-authored-by: Dario Nieuwenhuis <dirbaio@dirbaio.net>
  • Loading branch information
dingxiangfei2009 and Dirbaio committed Oct 21, 2024
1 parent 31e102c commit e917475
Show file tree
Hide file tree
Showing 56 changed files with 1,016 additions and 442 deletions.
23 changes: 19 additions & 4 deletions compiler/rustc_borrowck/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use std::ops::Deref;
use consumers::{BodyWithBorrowckFacts, ConsumerOptions};
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
use rustc_data_structures::graph::dominators::Dominators;
use rustc_data_structures::unord::UnordMap;
use rustc_errors::Diag;
use rustc_hir as hir;
use rustc_hir::def_id::LocalDefId;
Expand Down Expand Up @@ -287,6 +288,7 @@ fn do_mir_borrowck<'tcx>(
regioncx: &regioncx,
used_mut: Default::default(),
used_mut_upvars: SmallVec::new(),
local_from_upvars: UnordMap::default(),
borrow_set: &borrow_set,
upvars: &[],
local_names: IndexVec::from_elem(None, &promoted_body.local_decls),
Expand All @@ -313,6 +315,11 @@ fn do_mir_borrowck<'tcx>(
}
}

let mut local_from_upvars = UnordMap::default();
for (field, &local) in body.local_upvar_map.iter_enumerated() {
let Some(local) = local else { continue };
local_from_upvars.insert(local, field);
}
let mut mbcx = MirBorrowckCtxt {
infcx: &infcx,
param_env,
Expand All @@ -328,6 +335,7 @@ fn do_mir_borrowck<'tcx>(
regioncx: &regioncx,
used_mut: Default::default(),
used_mut_upvars: SmallVec::new(),
local_from_upvars,
borrow_set: &borrow_set,
upvars: tcx.closure_captures(def),
local_names,
Expand Down Expand Up @@ -563,6 +571,9 @@ struct MirBorrowckCtxt<'a, 'infcx, 'tcx> {
/// If the function we're checking is a closure, then we'll need to report back the list of
/// mutable upvars that have been used. This field keeps track of them.
used_mut_upvars: SmallVec<[FieldIdx; 8]>,
/// Since upvars are moved to real locals, we need to map mutations to the locals back to
/// the upvars, so that used_mut_upvars is up-to-date.
local_from_upvars: UnordMap<Local, FieldIdx>,
/// Region inference context. This contains the results from region inference and lets us e.g.
/// find out which CFG points are contained in each borrow region.
regioncx: &'a RegionInferenceContext<'tcx>,
Expand Down Expand Up @@ -2218,10 +2229,12 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, '_, 'tcx> {
// If the local may have been initialized, and it is now currently being
// mutated, then it is justified to be annotated with the `mut`
// keyword, since the mutation may be a possible reassignment.
if is_local_mutation_allowed != LocalMutationIsAllowed::Yes
&& self.is_local_ever_initialized(local, state).is_some()
{
self.used_mut.insert(local);
if !matches!(is_local_mutation_allowed, LocalMutationIsAllowed::Yes) {
if self.is_local_ever_initialized(local, state).is_some() {
self.used_mut.insert(local);
} else if let Some(&field) = self.local_from_upvars.get(&local) {
self.used_mut_upvars.push(field);
}
}
}
RootPlace {
Expand All @@ -2239,6 +2252,8 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, '_, 'tcx> {
projection: place_projection,
}) {
self.used_mut_upvars.push(field);
} else if let Some(&field) = self.local_from_upvars.get(&place_local) {
self.used_mut_upvars.push(field);
}
}
}
Expand Down
24 changes: 12 additions & 12 deletions compiler/rustc_borrowck/src/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -812,15 +812,15 @@ impl<'a, 'b, 'tcx> TypeVerifier<'a, 'b, 'tcx> {
}),
};
}
ty::Coroutine(_, args) => {
ty::Coroutine(_def_id, args) => {
// Only prefix fields (upvars and current state) are
// accessible without a variant index.
return match args.as_coroutine().prefix_tys().get(field.index()) {
Some(ty) => Ok(*ty),
None => Err(FieldAccessError::OutOfRange {
field_count: args.as_coroutine().prefix_tys().len(),
}),
};
let upvar_tys = args.as_coroutine().upvar_tys();
if let Some(ty) = upvar_tys.get(field.index()) {
return Ok(*ty);
} else {
return Err(FieldAccessError::OutOfRange { field_count: upvar_tys.len() });
}
}
ty::Tuple(tys) => {
return match tys.get(field.index()) {
Expand Down Expand Up @@ -1837,11 +1837,11 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
// It doesn't make sense to look at a field beyond the prefix;
// these require a variant index, and are not initialized in
// aggregate rvalues.
match args.as_coroutine().prefix_tys().get(field_index.as_usize()) {
Some(ty) => Ok(*ty),
None => Err(FieldAccessError::OutOfRange {
field_count: args.as_coroutine().prefix_tys().len(),
}),
let upvar_tys = &args.as_coroutine().upvar_tys();
if let Some(ty) = upvar_tys.get(field_index.as_usize()) {
Ok(*ty)
} else {
Err(FieldAccessError::OutOfRange { field_count: upvar_tys.len() })
}
}
AggregateKind::CoroutineClosure(_, args) => {
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_codegen_cranelift/src/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,9 @@ fn codegen_stmt<'tcx>(
let variant_dest = lval.downcast_variant(fx, variant_index);
(variant_index, variant_dest, active_field_index)
}
mir::AggregateKind::Coroutine(_def_id, _args) => {
(FIRST_VARIANT, lval.downcast_variant(fx, FIRST_VARIANT), None)
}
_ => (FIRST_VARIANT, lval, None),
};
if active_field_index.is_some() {
Expand Down
5 changes: 2 additions & 3 deletions compiler/rustc_codegen_llvm/src/debuginfo/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ use rustc_hir::def_id::{DefId, LOCAL_CRATE};
use rustc_middle::bug;
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
use rustc_middle::ty::{
self, AdtKind, CoroutineArgsExt, Instance, ParamEnv, PolyExistentialTraitRef, Ty, TyCtxt,
Visibility,
self, AdtKind, Instance, ParamEnv, PolyExistentialTraitRef, Ty, TyCtxt, Visibility,
};
use rustc_session::config::{self, DebugInfo, Lto};
use rustc_span::symbol::Symbol;
Expand Down Expand Up @@ -1124,7 +1123,7 @@ fn build_upvar_field_di_nodes<'ll, 'tcx>(
closure_or_coroutine_di_node: &'ll DIType,
) -> SmallVec<&'ll DIType> {
let (&def_id, up_var_tys) = match closure_or_coroutine_ty.kind() {
ty::Coroutine(def_id, args) => (def_id, args.as_coroutine().prefix_tys()),
ty::Coroutine(def_id, args) => (def_id, args.as_coroutine().upvar_tys()),
ty::Closure(def_id, args) => (def_id, args.as_closure().upvar_tys()),
ty::CoroutineClosure(def_id, args) => (def_id, args.as_coroutine_closure().upvar_tys()),
_ => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,6 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>(
let coroutine_layout =
cx.tcx.coroutine_layout(coroutine_def_id, coroutine_args.kind_ty()).unwrap();

let common_upvar_names = cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);
let variant_range = coroutine_args.variant_range(coroutine_def_id, cx.tcx);
let variant_count = (variant_range.start.as_u32()..variant_range.end.as_u32()).len();

Expand Down Expand Up @@ -707,7 +706,6 @@ fn build_union_fields_for_direct_tag_coroutine<'ll, 'tcx>(
coroutine_type_and_layout,
coroutine_type_di_node,
coroutine_layout,
common_upvar_names,
);

let span = coroutine_layout.variant_source_info[variant_index].span;
Expand Down
33 changes: 2 additions & 31 deletions compiler/rustc_codegen_llvm/src/debuginfo/metadata/enums/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ use std::borrow::Cow;
use rustc_codegen_ssa::debuginfo::type_names::{compute_debuginfo_type_name, cpp_like_debuginfo};
use rustc_codegen_ssa::debuginfo::{tag_base_type, wants_c_like_enum_debuginfo};
use rustc_hir::def::CtorKind;
use rustc_index::IndexSlice;
use rustc_middle::bug;
use rustc_middle::mir::CoroutineLayout;
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
use rustc_middle::ty::{self, AdtDef, CoroutineArgs, CoroutineArgsExt, Ty, VariantDef};
use rustc_span::Symbol;
use rustc_target::abi::{FieldIdx, TagEncoding, VariantIdx, Variants};

use super::type_map::{DINodeCreationResult, UniqueTypeId};
Expand Down Expand Up @@ -263,7 +261,6 @@ fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(
coroutine_type_and_layout: TyAndLayout<'tcx>,
coroutine_type_di_node: &'ll DIType,
coroutine_layout: &CoroutineLayout<'tcx>,
common_upvar_names: &IndexSlice<FieldIdx, Symbol>,
) -> &'ll DIType {
let variant_name = CoroutineArgs::variant_name(variant_index);
let unique_type_id = UniqueTypeId::for_enum_variant_struct_type(
Expand All @@ -274,11 +271,6 @@ fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(

let variant_layout = coroutine_type_and_layout.for_variant(cx, variant_index);

let coroutine_args = match coroutine_type_and_layout.ty.kind() {
ty::Coroutine(_, args) => args.as_coroutine(),
_ => unreachable!(),
};

type_map::build_type_with_children(
cx,
type_map::stub(
Expand All @@ -292,7 +284,7 @@ fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(
),
|cx, variant_struct_type_di_node| {
// Fields that just belong to this variant/state
let state_specific_fields: SmallVec<_> = (0..variant_layout.fields.count())
(0..variant_layout.fields.count())
.map(|field_index| {
let coroutine_saved_local = coroutine_layout.variant_fields[variant_index]
[FieldIdx::from_usize(field_index)];
Expand All @@ -314,28 +306,7 @@ fn build_coroutine_variant_struct_type_di_node<'ll, 'tcx>(
type_di_node(cx, field_type),
)
})
.collect();

// Fields that are common to all states
let common_fields: SmallVec<_> = coroutine_args
.prefix_tys()
.iter()
.zip(common_upvar_names)
.enumerate()
.map(|(index, (upvar_ty, upvar_name))| {
build_field_di_node(
cx,
variant_struct_type_di_node,
upvar_name.as_str(),
cx.size_and_align_of(upvar_ty),
coroutine_type_and_layout.fields.offset(index),
DIFlags::FlagZero,
type_di_node(cx, upvar_ty),
)
})
.collect();

state_specific_fields.into_iter().chain(common_fields).collect()
.collect()
},
|cx| build_generic_type_param_di_nodes(cx, coroutine_type_and_layout.ty),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,6 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
)
};

let common_upvar_names =
cx.tcx.closure_saved_names_of_captured_variables(coroutine_def_id);

// Build variant struct types
let variant_struct_type_di_nodes: SmallVec<_> = variants
.indices()
Expand Down Expand Up @@ -190,7 +187,6 @@ pub(super) fn build_coroutine_di_node<'ll, 'tcx>(
coroutine_type_and_layout,
coroutine_type_di_node,
coroutine_layout,
common_upvar_names,
),
source_info,
}
Expand Down
11 changes: 4 additions & 7 deletions compiler/rustc_codegen_ssa/src/mir/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use rustc_middle::mir::interpret::{Pointer, Scalar, alloc_range};
use rustc_middle::mir::{self, ConstValue};
use rustc_middle::ty::Ty;
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
use tracing::debug;
use tracing::{debug, instrument};

use super::place::{PlaceRef, PlaceValue};
use super::{FunctionCx, LocalRef};
Expand Down Expand Up @@ -551,13 +551,12 @@ impl<'a, 'tcx, V: CodegenObject> OperandValue<V> {
}

impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
#[instrument(level = "debug", skip(self, bx), ret)]
fn maybe_codegen_consume_direct(
&mut self,
bx: &mut Bx,
place_ref: mir::PlaceRef<'tcx>,
) -> Option<OperandRef<'tcx, Bx::Value>> {
debug!("maybe_codegen_consume_direct(place_ref={:?})", place_ref);

match self.locals[place_ref.local] {
LocalRef::Operand(mut o) => {
// Moves out of scalar and scalar pair fields are trivial.
Expand Down Expand Up @@ -600,13 +599,12 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
}
}

#[instrument(level = "debug", skip(self, bx), ret)]
pub fn codegen_consume(
&mut self,
bx: &mut Bx,
place_ref: mir::PlaceRef<'tcx>,
) -> OperandRef<'tcx, Bx::Value> {
debug!("codegen_consume(place_ref={:?})", place_ref);

let ty = self.monomorphized_place_ty(place_ref);
let layout = bx.cx().layout_of(ty);

Expand All @@ -625,13 +623,12 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
bx.load_operand(place)
}

#[instrument(level = "debug", skip(self, bx), ret)]
pub fn codegen_operand(
&mut self,
bx: &mut Bx,
operand: &mir::Operand<'tcx>,
) -> OperandRef<'tcx, Bx::Value> {
debug!("codegen_operand(operand={:?})", operand);

match *operand {
mir::Operand::Copy(ref place) | mir::Operand::Move(ref place) => {
self.codegen_consume(bx, place.as_ref())
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_codegen_ssa/src/mir/rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
let variant_dest = dest.project_downcast(bx, variant_index);
(variant_index, variant_dest, active_field_index)
}
mir::AggregateKind::Coroutine(_, _) => {
(FIRST_VARIANT, dest.project_downcast(bx, FIRST_VARIANT), None)
}
_ => (FIRST_VARIANT, dest, None),
};
if active_field_index.is_some() {
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_const_eval/src/interpret/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,9 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
let variant_dest = self.project_downcast(dest, variant_index)?;
(variant_index, variant_dest, active_field_index)
}
mir::AggregateKind::Coroutine(_def_id, _args) => {
(FIRST_VARIANT, self.project_downcast(dest, FIRST_VARIANT)?, None)
}
mir::AggregateKind::RawPtr(..) => {
// Pointers don't have "fields" in the normal sense, so the
// projection-based code below would either fail in projection
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_middle/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,9 @@ pub struct Body<'tcx> {
/// If `-Cinstrument-coverage` is not active, or if an individual function
/// is not eligible for coverage, then this should always be `None`.
pub function_coverage_info: Option<Box<coverage::FunctionCoverageInfo>>,

/// Coroutine local-upvar map
pub local_upvar_map: IndexVec<FieldIdx, Option<Local>>,
}

impl<'tcx> Body<'tcx> {
Expand Down Expand Up @@ -411,6 +414,7 @@ impl<'tcx> Body<'tcx> {
tainted_by_errors,
coverage_info_hi: None,
function_coverage_info: None,
local_upvar_map: IndexVec::new(),
};
body.is_polymorphic = body.has_non_region_param();
body
Expand Down Expand Up @@ -442,6 +446,7 @@ impl<'tcx> Body<'tcx> {
tainted_by_errors: None,
coverage_info_hi: None,
function_coverage_info: None,
local_upvar_map: IndexVec::new(),
};
body.is_polymorphic = body.has_non_region_param();
body
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_middle/src/mir/patch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,14 @@ impl<'tcx> MirPatch<'tcx> {
ty: Ty<'tcx>,
span: Span,
local_info: LocalInfo<'tcx>,
immutable: bool,
) -> Local {
let index = self.next_local;
self.next_local += 1;
let mut new_decl = LocalDecl::new(ty, span);
if immutable {
new_decl = new_decl.immutable();
}
**new_decl.local_info.as_mut().assert_crate_local() = local_info;
self.new_locals.push(new_decl);
Local::new(index)
Expand Down
5 changes: 3 additions & 2 deletions compiler/rustc_middle/src/ty/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -891,9 +891,10 @@ where
),
Variants::Multiple { tag, tag_field, .. } => {
if i == tag_field {
return TyMaybeWithLayout::TyAndLayout(tag_layout(tag));
TyMaybeWithLayout::TyAndLayout(tag_layout(tag))
} else {
TyMaybeWithLayout::Ty(args.as_coroutine().upvar_tys()[i])
}
TyMaybeWithLayout::Ty(args.as_coroutine().prefix_tys()[i])
}
},

Expand Down
7 changes: 0 additions & 7 deletions compiler/rustc_middle/src/ty/sty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,6 @@ impl<'tcx> ty::CoroutineArgs<TyCtxt<'tcx>> {
})
})
}

/// This is the types of the fields of a coroutine which are not stored in a
/// variant.
#[inline]
fn prefix_tys(self) -> &'tcx List<Ty<'tcx>> {
self.upvar_tys()
}
}

#[derive(Debug, Copy, Clone, HashStable, TypeFoldable, TypeVisitable)]
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_mir_build/src/build/custom/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pub(super) fn build_custom_mir<'tcx>(
pass_count: 0,
coverage_info_hi: None,
function_coverage_info: None,
local_upvar_map: IndexVec::new(),
};

body.local_decls.push(LocalDecl::new(return_ty, return_ty_span));
Expand Down
Loading

0 comments on commit e917475

Please sign in to comment.