Skip to content

Commit

Permalink
Auto merge of #104975 - JakobDegen:custom_mir_let, r=oli-obk
Browse files Browse the repository at this point in the history
`#![custom_mir]`: Various improvements

This PR makes a bunch of improvements to `#![custom_mir]`. Ideally this would be 4 PRs, one for each commit, but those would take forever to get merged and be a pain to juggle. Should still be reviewed one commit at a time though.

### Commit 1: Support arbitrary `let`

Before this change, all locals used in the body need to be declared at the top of the `mir!` invocation, which is rather annoying. We attempt to change that.

Unfortunately, we still have the requirement that the output of the `mir!` macro must resolve, typecheck, etc. Because of that, we can't just accept this in the THIR -> MIR parser because something like
```rust
{
    let x = 0;
    Goto(other)
}
other = {
    RET = x;
    Return()
}
```
will fail to resolve. Instead, the implementation does macro shenanigans to find the let declarations and extract them as part of the `mir!` macro. That *works*, but it is fairly complicated and degrades debuginfo by quite a bit. Specifically, the spans for any statements and declarations that are affected by this are completely wrong. My guess is that this is a net improvement though.

One way to recover some of the debuginfo would be to not support type annotations in the `let` statements, which would allow us to parse like `let $stmt:stmt`. That seems quite surprising though.

### Commit 2: Parse consts

Reuses most of the const parsing from regular Mir building for building custom mir

### Commit 3: Parse statics

Statics are slightly weird because the Mir primitive associated with them is a reference/pointer to them, so this is factored out separately.

### Commit 4: Fix some spans

A bunch of the spans were non-ideal, so we adjust them to be much more helpful.

r? `@oli-obk`
  • Loading branch information
bors committed Dec 1, 2022
2 parents d6c4de0 + 5a34dbf commit 9c0bc30
Show file tree
Hide file tree
Showing 15 changed files with 454 additions and 104 deletions.
4 changes: 2 additions & 2 deletions compiler/rustc_mir_build/src/build/custom/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pub(super) fn build_custom_mir<'tcx>(
let mut pctxt = ParseCtxt {
tcx,
thir,
source_info: SourceInfo { span, scope: OUTERMOST_SOURCE_SCOPE },
source_scope: OUTERMOST_SOURCE_SCOPE,
body: &mut body,
local_map: FxHashMap::default(),
block_map: FxHashMap::default(),
Expand Down Expand Up @@ -128,7 +128,7 @@ fn parse_attribute(attr: &Attribute) -> MirPhase {
struct ParseCtxt<'tcx, 'body> {
tcx: TyCtxt<'tcx>,
thir: &'body Thir<'tcx>,
source_info: SourceInfo,
source_scope: SourceScope,

body: &'body mut Body<'tcx>,
local_map: FxHashMap<LocalVarId, Local>,
Expand Down
25 changes: 18 additions & 7 deletions compiler/rustc_mir_build/src/build/custom/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ macro_rules! parse_by_kind {
(
$self:ident,
$expr_id:expr,
$expr_name:pat,
$expected:literal,
$(
@call($name:literal, $args:ident) => $call_expr:expr,
Expand All @@ -33,6 +34,8 @@ macro_rules! parse_by_kind {
) => {{
let expr_id = $self.preparse($expr_id);
let expr = &$self.thir[expr_id];
debug!("Trying to parse {:?} as {}", expr.kind, $expected);
let $expr_name = expr;
match &expr.kind {
$(
ExprKind::Call { ty, fun: _, args: $args, .. } if {
Expand Down Expand Up @@ -137,26 +140,26 @@ impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
/// This allows us to easily parse the basic blocks declarations, local declarations, and
/// basic block definitions in order.
pub fn parse_body(&mut self, expr_id: ExprId) -> PResult<()> {
let body = parse_by_kind!(self, expr_id, "whole body",
let body = parse_by_kind!(self, expr_id, _, "whole body",
ExprKind::Block { block } => self.thir[*block].expr.unwrap(),
);
let (block_decls, rest) = parse_by_kind!(self, body, "body with block decls",
let (block_decls, rest) = parse_by_kind!(self, body, _, "body with block decls",
ExprKind::Block { block } => {
let block = &self.thir[*block];
(&block.stmts, block.expr.unwrap())
},
);
self.parse_block_decls(block_decls.iter().copied())?;

let (local_decls, rest) = parse_by_kind!(self, rest, "body with local decls",
let (local_decls, rest) = parse_by_kind!(self, rest, _, "body with local decls",
ExprKind::Block { block } => {
let block = &self.thir[*block];
(&block.stmts, block.expr.unwrap())
},
);
self.parse_local_decls(local_decls.iter().copied())?;

let block_defs = parse_by_kind!(self, rest, "body with block defs",
let block_defs = parse_by_kind!(self, rest, _, "body with block defs",
ExprKind::Block { block } => &self.thir[*block].stmts,
);
for (i, block_def) in block_defs.iter().enumerate() {
Expand Down Expand Up @@ -223,22 +226,30 @@ impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
}

fn parse_block_def(&self, expr_id: ExprId) -> PResult<BasicBlockData<'tcx>> {
let block = parse_by_kind!(self, expr_id, "basic block",
let block = parse_by_kind!(self, expr_id, _, "basic block",
ExprKind::Block { block } => &self.thir[*block],
);

let mut data = BasicBlockData::new(None);
for stmt_id in &*block.stmts {
let stmt = self.statement_as_expr(*stmt_id)?;
let span = self.thir[stmt].span;
let statement = self.parse_statement(stmt)?;
data.statements.push(Statement { source_info: self.source_info, kind: statement });
data.statements.push(Statement {
source_info: SourceInfo { span, scope: self.source_scope },
kind: statement,
});
}

let Some(trailing) = block.expr else {
return Err(self.expr_error(expr_id, "terminator"))
};
let span = self.thir[trailing].span;
let terminator = self.parse_terminator(trailing)?;
data.terminator = Some(Terminator { source_info: self.source_info, kind: terminator });
data.terminator = Some(Terminator {
source_info: SourceInfo { span, scope: self.source_scope },
kind: terminator,
});

Ok(data)
}
Expand Down
47 changes: 40 additions & 7 deletions compiler/rustc_mir_build/src/build/custom/parse/instruction.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use rustc_middle::mir::interpret::{ConstValue, Scalar};
use rustc_middle::{mir::*, thir::*, ty};

use super::{parse_by_kind, PResult, ParseCtxt};

impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
pub fn parse_statement(&self, expr_id: ExprId) -> PResult<StatementKind<'tcx>> {
parse_by_kind!(self, expr_id, "statement",
parse_by_kind!(self, expr_id, _, "statement",
@call("mir_retag", args) => {
Ok(StatementKind::Retag(RetagKind::Default, Box::new(self.parse_place(args[0])?)))
},
Expand All @@ -20,7 +21,7 @@ impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
}

pub fn parse_terminator(&self, expr_id: ExprId) -> PResult<TerminatorKind<'tcx>> {
parse_by_kind!(self, expr_id, "terminator",
parse_by_kind!(self, expr_id, _, "terminator",
@call("mir_return", _args) => {
Ok(TerminatorKind::Return)
},
Expand All @@ -31,7 +32,7 @@ impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
}

fn parse_rvalue(&self, expr_id: ExprId) -> PResult<Rvalue<'tcx>> {
parse_by_kind!(self, expr_id, "rvalue",
parse_by_kind!(self, expr_id, _, "rvalue",
ExprKind::Borrow { borrow_kind, arg } => Ok(
Rvalue::Ref(self.tcx.lifetimes.re_erased, *borrow_kind, self.parse_place(*arg)?)
),
Expand All @@ -43,14 +44,26 @@ impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
}

fn parse_operand(&self, expr_id: ExprId) -> PResult<Operand<'tcx>> {
parse_by_kind!(self, expr_id, "operand",
parse_by_kind!(self, expr_id, expr, "operand",
@call("mir_move", args) => self.parse_place(args[0]).map(Operand::Move),
@call("mir_static", args) => self.parse_static(args[0]),
@call("mir_static_mut", args) => self.parse_static(args[0]),
ExprKind::Literal { .. }
| ExprKind::NamedConst { .. }
| ExprKind::NonHirLiteral { .. }
| ExprKind::ZstLiteral { .. }
| ExprKind::ConstParam { .. }
| ExprKind::ConstBlock { .. } => {
Ok(Operand::Constant(Box::new(
crate::build::expr::as_constant::as_constant_inner(expr, |_| None, self.tcx)
)))
},
_ => self.parse_place(expr_id).map(Operand::Copy),
)
}

fn parse_place(&self, expr_id: ExprId) -> PResult<Place<'tcx>> {
parse_by_kind!(self, expr_id, "place",
parse_by_kind!(self, expr_id, _, "place",
ExprKind::Deref { arg } => Ok(
self.parse_place(*arg)?.project_deeper(&[PlaceElem::Deref], self.tcx)
),
Expand All @@ -59,14 +72,34 @@ impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
}

fn parse_local(&self, expr_id: ExprId) -> PResult<Local> {
parse_by_kind!(self, expr_id, "local",
parse_by_kind!(self, expr_id, _, "local",
ExprKind::VarRef { id } => Ok(self.local_map[id]),
)
}

fn parse_block(&self, expr_id: ExprId) -> PResult<BasicBlock> {
parse_by_kind!(self, expr_id, "basic block",
parse_by_kind!(self, expr_id, _, "basic block",
ExprKind::VarRef { id } => Ok(self.block_map[id]),
)
}

fn parse_static(&self, expr_id: ExprId) -> PResult<Operand<'tcx>> {
let expr_id = parse_by_kind!(self, expr_id, _, "static",
ExprKind::Deref { arg } => *arg,
);

parse_by_kind!(self, expr_id, expr, "static",
ExprKind::StaticRef { alloc_id, ty, .. } => {
let const_val =
ConstValue::Scalar(Scalar::from_pointer((*alloc_id).into(), &self.tcx));
let literal = ConstantKind::Val(const_val, *ty);

Ok(Operand::Constant(Box::new(Constant {
span: expr.span,
user_ty: None,
literal
})))
},
)
}
}
137 changes: 71 additions & 66 deletions compiler/rustc_mir_build/src/build/expr/as_constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ use rustc_middle::mir::interpret::{
};
use rustc_middle::mir::*;
use rustc_middle::thir::*;
use rustc_middle::ty::{self, CanonicalUserTypeAnnotation, TyCtxt};
use rustc_middle::ty::{
self, CanonicalUserType, CanonicalUserTypeAnnotation, TyCtxt, UserTypeAnnotationIndex,
};
use rustc_span::DUMMY_SP;
use rustc_target::abi::Size;

Expand All @@ -19,84 +21,87 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
let this = self;
let tcx = this.tcx;
let Expr { ty, temp_lifetime: _, span, ref kind } = *expr;
match *kind {
match kind {
ExprKind::Scope { region_scope: _, lint_level: _, value } => {
this.as_constant(&this.thir[value])
}
ExprKind::Literal { lit, neg } => {
let literal =
match lit_to_mir_constant(tcx, LitToConstInput { lit: &lit.node, ty, neg }) {
Ok(c) => c,
Err(LitToConstError::Reported(guar)) => {
ConstantKind::Ty(tcx.const_error_with_guaranteed(ty, guar))
}
Err(LitToConstError::TypeError) => {
bug!("encountered type error in `lit_to_mir_constant")
}
};

Constant { span, user_ty: None, literal }
this.as_constant(&this.thir[*value])
}
ExprKind::NonHirLiteral { lit, ref user_ty } => {
let user_ty = user_ty.as_ref().map(|user_ty| {
this.canonical_user_type_annotations.push(CanonicalUserTypeAnnotation {
_ => as_constant_inner(
expr,
|user_ty| {
Some(this.canonical_user_type_annotations.push(CanonicalUserTypeAnnotation {
span,
user_ty: user_ty.clone(),
inferred_ty: ty,
})
});
let literal = ConstantKind::Val(ConstValue::Scalar(Scalar::Int(lit)), ty);
}))
},
tcx,
),
}
}
}

Constant { span, user_ty: user_ty, literal }
}
ExprKind::ZstLiteral { ref user_ty } => {
let user_ty = user_ty.as_ref().map(|user_ty| {
this.canonical_user_type_annotations.push(CanonicalUserTypeAnnotation {
span,
user_ty: user_ty.clone(),
inferred_ty: ty,
})
});
let literal = ConstantKind::Val(ConstValue::ZeroSized, ty);
pub fn as_constant_inner<'tcx>(
expr: &Expr<'tcx>,
push_cuta: impl FnMut(&Box<CanonicalUserType<'tcx>>) -> Option<UserTypeAnnotationIndex>,
tcx: TyCtxt<'tcx>,
) -> Constant<'tcx> {
let Expr { ty, temp_lifetime: _, span, ref kind } = *expr;
match *kind {
ExprKind::Literal { lit, neg } => {
let literal =
match lit_to_mir_constant(tcx, LitToConstInput { lit: &lit.node, ty, neg }) {
Ok(c) => c,
Err(LitToConstError::Reported(guar)) => {
ConstantKind::Ty(tcx.const_error_with_guaranteed(ty, guar))
}
Err(LitToConstError::TypeError) => {
bug!("encountered type error in `lit_to_mir_constant")
}
};

Constant { span, user_ty: user_ty, literal }
}
ExprKind::NamedConst { def_id, substs, ref user_ty } => {
let user_ty = user_ty.as_ref().map(|user_ty| {
this.canonical_user_type_annotations.push(CanonicalUserTypeAnnotation {
span,
user_ty: user_ty.clone(),
inferred_ty: ty,
})
});
Constant { span, user_ty: None, literal }
}
ExprKind::NonHirLiteral { lit, ref user_ty } => {
let user_ty = user_ty.as_ref().map(push_cuta).flatten();

let uneval =
mir::UnevaluatedConst::new(ty::WithOptConstParam::unknown(def_id), substs);
let literal = ConstantKind::Unevaluated(uneval, ty);
let literal = ConstantKind::Val(ConstValue::Scalar(Scalar::Int(lit)), ty);

Constant { user_ty, span, literal }
}
ExprKind::ConstParam { param, def_id: _ } => {
let const_param = tcx.mk_const(param, expr.ty);
let literal = ConstantKind::Ty(const_param);
Constant { span, user_ty: user_ty, literal }
}
ExprKind::ZstLiteral { ref user_ty } => {
let user_ty = user_ty.as_ref().map(push_cuta).flatten();

Constant { user_ty: None, span, literal }
}
ExprKind::ConstBlock { did: def_id, substs } => {
let uneval =
mir::UnevaluatedConst::new(ty::WithOptConstParam::unknown(def_id), substs);
let literal = ConstantKind::Unevaluated(uneval, ty);
let literal = ConstantKind::Val(ConstValue::ZeroSized, ty);

Constant { user_ty: None, span, literal }
}
ExprKind::StaticRef { alloc_id, ty, .. } => {
let const_val = ConstValue::Scalar(Scalar::from_pointer(alloc_id.into(), &tcx));
let literal = ConstantKind::Val(const_val, ty);
Constant { span, user_ty: user_ty, literal }
}
ExprKind::NamedConst { def_id, substs, ref user_ty } => {
let user_ty = user_ty.as_ref().map(push_cuta).flatten();

Constant { span, user_ty: None, literal }
}
_ => span_bug!(span, "expression is not a valid constant {:?}", kind),
let uneval = mir::UnevaluatedConst::new(ty::WithOptConstParam::unknown(def_id), substs);
let literal = ConstantKind::Unevaluated(uneval, ty);

Constant { user_ty, span, literal }
}
ExprKind::ConstParam { param, def_id: _ } => {
let const_param = tcx.mk_const(ty::ConstKind::Param(param), expr.ty);
let literal = ConstantKind::Ty(const_param);

Constant { user_ty: None, span, literal }
}
ExprKind::ConstBlock { did: def_id, substs } => {
let uneval = mir::UnevaluatedConst::new(ty::WithOptConstParam::unknown(def_id), substs);
let literal = ConstantKind::Unevaluated(uneval, ty);

Constant { user_ty: None, span, literal }
}
ExprKind::StaticRef { alloc_id, ty, .. } => {
let const_val = ConstValue::Scalar(Scalar::from_pointer(alloc_id.into(), &tcx));
let literal = ConstantKind::Val(const_val, ty);

Constant { span, user_ty: None, literal }
}
_ => span_bug!(span, "expression is not a valid constant {:?}", kind),
}
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_mir_build/src/build/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ fn construct_fn<'tcx>(
arguments,
return_ty,
return_ty_span,
span,
span_with_body,
custom_mir_attr,
);
}
Expand Down
Loading

0 comments on commit 9c0bc30

Please sign in to comment.