Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add way to differentiate argument locals from other locals in Stable MIR #117095

Merged
merged 9 commits into from
Oct 26, 2023
11 changes: 5 additions & 6 deletions compiler/rustc_smir/src/rustc_smir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,8 @@ impl<'tcx> Stable<'tcx> for mir::Body<'tcx> {
type T = stable_mir::mir::Body;

fn stable(&self, tables: &mut Tables<'tcx>) -> Self::T {
stable_mir::mir::Body {
blocks: self
.basic_blocks
stable_mir::mir::Body::new(
self.basic_blocks
.iter()
.map(|block| stable_mir::mir::BasicBlock {
terminator: block.terminator().stable(tables),
Expand All @@ -300,15 +299,15 @@ impl<'tcx> Stable<'tcx> for mir::Body<'tcx> {
.collect(),
})
.collect(),
locals: self
.local_decls
self.local_decls
.iter()
.map(|decl| stable_mir::mir::LocalDecl {
ty: decl.ty.stable(tables),
span: decl.source_info.span.stable(tables),
})
.collect(),
}
self.arg_count,
)
}
}

Expand Down
56 changes: 53 additions & 3 deletions compiler/stable_mir/src/mir/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,60 @@ use crate::ty::{AdtDef, ClosureDef, Const, CoroutineDef, GenericArgs, Movability
use crate::Opaque;
use crate::{ty::Ty, Span};

/// The SMIR representation of a single function.
#[derive(Clone, Debug)]
pub struct Body {
pub blocks: Vec<BasicBlock>,
pub locals: LocalDecls,

// Declarations of locals within the function.
//
// The first local is the return value pointer, followed by `arg_count`
// locals for the function arguments, followed by any user-declared
// variables and temporaries.
locals: LocalDecls,

// The number of arguments this function takes.
arg_count: usize,
}

impl Body {
/// Constructs a `Body`.
///
/// A constructor is required to build a `Body` from outside the crate
/// because the `arg_count` and `locals` fields are private.
pub fn new(blocks: Vec<BasicBlock>, locals: LocalDecls, arg_count: usize) -> Self {
// If locals doesn't contain enough entries, it can lead to panics in
// `ret_local`, `arg_locals`, and `inner_locals`.
assert!(
locals.len() > arg_count,
"A Body must contain at least a local for the return value and each of the function's arguments"
);
Self { blocks, locals, arg_count }
}

/// Return local that holds this function's return value.
pub fn ret_local(&self) -> &LocalDecl {
&self.locals[0]
}

/// Locals in `self` that correspond to this function's arguments.
pub fn arg_locals(&self) -> &[LocalDecl] {
&self.locals[1..self.arg_count + 1]
oli-obk marked this conversation as resolved.
Show resolved Hide resolved
}

/// Inner locals for this function. These are the locals that are
/// neither the return local nor the argument locals.
pub fn inner_locals(&self) -> &[LocalDecl] {
&self.locals[self.arg_count + 1..]
}

/// Convenience function to get all the locals in this function.
///
/// Locals are typically accessed via the more specific methods `ret_local`,
/// `arg_locals`, and `inner_locals`.
pub fn locals(&self) -> &[LocalDecl] {
&self.locals
}
}

type LocalDecls = Vec<LocalDecl>;
Expand Down Expand Up @@ -467,7 +517,7 @@ pub enum NullOp {
}

impl Operand {
pub fn ty(&self, locals: &LocalDecls) -> Ty {
pub fn ty(&self, locals: &[LocalDecl]) -> Ty {
match self {
Operand::Copy(place) | Operand::Move(place) => place.ty(locals),
Operand::Constant(c) => c.ty(),
Expand All @@ -482,7 +532,7 @@ impl Constant {
}

impl Place {
pub fn ty(&self, locals: &LocalDecls) -> Ty {
pub fn ty(&self, locals: &[LocalDecl]) -> Ty {
let _start_ty = locals[self.local].ty;
todo!("Implement projection")
}
Expand Down
2 changes: 1 addition & 1 deletion tests/ui-fulldeps/stable-mir/check_instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ fn test_body(body: mir::Body) {
for term in body.blocks.iter().map(|bb| &bb.terminator) {
match &term.kind {
Call { func, .. } => {
let TyKind::RigidTy(ty) = func.ty(&body.locals).kind() else { unreachable!() };
let TyKind::RigidTy(ty) = func.ty(body.locals()).kind() else { unreachable!() };
let RigidTy::FnDef(def, args) = ty else { unreachable!() };
let result = Instance::resolve(def, &args);
assert!(result.is_ok());
Expand Down
53 changes: 42 additions & 11 deletions tests/ui-fulldeps/stable-mir/crate-info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fn test_stable_mir(_tcx: TyCtxt<'_>) -> ControlFlow<()> {

let bar = get_item(&items, (DefKind::Fn, "bar")).unwrap();
let body = bar.body();
assert_eq!(body.locals.len(), 2);
assert_eq!(body.locals().len(), 2);
assert_eq!(body.blocks.len(), 1);
let block = &body.blocks[0];
assert_eq!(block.statements.len(), 1);
Expand All @@ -62,7 +62,7 @@ fn test_stable_mir(_tcx: TyCtxt<'_>) -> ControlFlow<()> {

let foo_bar = get_item(&items, (DefKind::Fn, "foo_bar")).unwrap();
let body = foo_bar.body();
assert_eq!(body.locals.len(), 5);
assert_eq!(body.locals().len(), 5);
assert_eq!(body.blocks.len(), 4);
let block = &body.blocks[0];
match &block.terminator.kind {
Expand All @@ -72,29 +72,29 @@ fn test_stable_mir(_tcx: TyCtxt<'_>) -> ControlFlow<()> {

let types = get_item(&items, (DefKind::Fn, "types")).unwrap();
let body = types.body();
assert_eq!(body.locals.len(), 6);
assert_eq!(body.locals().len(), 6);
assert_matches!(
body.locals[0].ty.kind(),
body.locals()[0].ty.kind(),
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Bool)
);
assert_matches!(
body.locals[1].ty.kind(),
body.locals()[1].ty.kind(),
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Bool)
);
assert_matches!(
body.locals[2].ty.kind(),
body.locals()[2].ty.kind(),
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Char)
);
assert_matches!(
body.locals[3].ty.kind(),
body.locals()[3].ty.kind(),
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Int(stable_mir::ty::IntTy::I32))
);
assert_matches!(
body.locals[4].ty.kind(),
body.locals()[4].ty.kind(),
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Uint(stable_mir::ty::UintTy::U64))
);
assert_matches!(
body.locals[5].ty.kind(),
body.locals()[5].ty.kind(),
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Float(
stable_mir::ty::FloatTy::F64
))
Expand Down Expand Up @@ -123,10 +123,10 @@ fn test_stable_mir(_tcx: TyCtxt<'_>) -> ControlFlow<()> {
for block in instance.body().blocks {
match &block.terminator.kind {
stable_mir::mir::TerminatorKind::Call { func, .. } => {
let TyKind::RigidTy(ty) = func.ty(&body.locals).kind() else { unreachable!() };
let TyKind::RigidTy(ty) = func.ty(&body.locals()).kind() else { unreachable!() };
let RigidTy::FnDef(def, args) = ty else { unreachable!() };
let next_func = Instance::resolve(def, &args).unwrap();
match next_func.body().locals[1].ty.kind() {
match next_func.body().locals()[1].ty.kind() {
TyKind::RigidTy(RigidTy::Uint(_)) | TyKind::RigidTy(RigidTy::Tuple(_)) => {}
other => panic!("{other:?}"),
}
Expand All @@ -140,6 +140,29 @@ fn test_stable_mir(_tcx: TyCtxt<'_>) -> ControlFlow<()> {
// Ensure we don't panic trying to get the body of a constant.
foo_const.body();

let locals_fn = get_item(&items, (DefKind::Fn, "locals")).unwrap();
let body = locals_fn.body();
assert_eq!(body.locals().len(), 4);
assert_matches!(
body.ret_local().ty.kind(),
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Char)
);
assert_eq!(body.arg_locals().len(), 2);
assert_matches!(
body.arg_locals()[0].ty.kind(),
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Int(stable_mir::ty::IntTy::I32))
);
assert_matches!(
body.arg_locals()[1].ty.kind(),
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Uint(stable_mir::ty::UintTy::U64))
);
assert_eq!(body.inner_locals().len(), 1);
// If conditions have an extra inner local to hold their results
assert_matches!(
body.inner_locals()[0].ty.kind(),
stable_mir::ty::TyKind::RigidTy(stable_mir::ty::RigidTy::Bool)
);

ControlFlow::Continue(())
}

Expand Down Expand Up @@ -211,6 +234,14 @@ fn generate_input(path: &str) -> std::io::Result<()> {

pub fn assert(x: i32) -> i32 {{
x + 1
}}

pub fn locals(a: i32, _: u64) -> char {{
if a > 5 {{
'a'
}} else {{
'b'
}}
}}"#
)?;
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion tests/ui-fulldeps/stable-mir/smir_internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ const CRATE_NAME: &str = "input";
fn test_translation(_tcx: TyCtxt<'_>) -> ControlFlow<()> {
let main_fn = stable_mir::entry_fn().unwrap();
let body = main_fn.body();
let orig_ty = body.locals[0].ty;
let orig_ty = body.locals()[0].ty;
let rustc_ty = rustc_internal::internal(&orig_ty);
assert!(rustc_ty.is_unit());
ControlFlow::Continue(())
Expand Down
Loading