Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix inline const pattern unsafety checking in THIR #116482

Merged
merged 2 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions compiler/rustc_interface/src/passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -775,12 +775,16 @@ fn analysis(tcx: TyCtxt<'_>, (): ()) -> Result<()> {
rustc_hir_analysis::check_crate(tcx)?;

sess.time("MIR_borrow_checking", || {
tcx.hir().par_body_owners(|def_id| tcx.ensure().mir_borrowck(def_id));
tcx.hir().par_body_owners(|def_id| {
// Run THIR unsafety check because it's responsible for stealing
// and deallocating THIR when enabled.
tcx.ensure().thir_check_unsafety(def_id);
tcx.ensure().mir_borrowck(def_id)
});
});

sess.time("MIR_effect_checking", || {
for def_id in tcx.hir().body_owners() {
tcx.ensure().thir_check_unsafety(def_id);
if !tcx.sess.opts.unstable_opts.thir_unsafeck {
rustc_mir_transform::check_unsafety::check_unsafety(tcx, def_id);
}
Expand Down
22 changes: 21 additions & 1 deletion compiler/rustc_middle/src/thir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,8 @@ impl<'tcx> Pat<'tcx> {
Wild | Range(..) | Binding { subpattern: None, .. } | Constant { .. } | Error(_) => {}
AscribeUserType { subpattern, .. }
| Binding { subpattern: Some(subpattern), .. }
| Deref { subpattern } => subpattern.walk_(it),
| Deref { subpattern }
| InlineConstant { subpattern, .. } => subpattern.walk_(it),
Leaf { subpatterns } | Variant { subpatterns, .. } => {
subpatterns.iter().for_each(|field| field.pattern.walk_(it))
}
Expand Down Expand Up @@ -764,6 +765,22 @@ pub enum PatKind<'tcx> {
value: mir::Const<'tcx>,
},

/// Inline constant found while lowering a pattern.
InlineConstant {
/// [LocalDefId] of the constant, we need this so that we have a
/// reference that can be used by unsafety checking to visit nested
/// unevaluated constants.
def: LocalDefId,
/// If the inline constant is used in a range pattern, this subpattern
/// represents the range (if both ends are inline constants, there will
/// be multiple InlineConstant wrappers).
///
/// Otherwise, the actual pattern that the constant lowered to. As with
/// other constants, inline constants are matched structurally where
/// possible.
subpattern: Box<Pat<'tcx>>,
},

Range(Box<PatRange<'tcx>>),

/// Matches against a slice, checking the length and extracting elements.
Expand Down Expand Up @@ -924,6 +941,9 @@ impl<'tcx> fmt::Display for Pat<'tcx> {
write!(f, "{subpattern}")
}
PatKind::Constant { value } => write!(f, "{value}"),
PatKind::InlineConstant { def: _, ref subpattern } => {
write!(f, "{} (from inline const)", subpattern)
}
PatKind::Range(box PatRange { lo, hi, end }) => {
write!(f, "{lo}")?;
write!(f, "{end}")?;
Expand Down
7 changes: 4 additions & 3 deletions compiler/rustc_middle/src/thir/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,16 +233,17 @@ pub fn walk_pat<'a, 'tcx: 'a, V: Visitor<'a, 'tcx>>(visitor: &mut V, pat: &Pat<'
}
}
Constant { value: _ } => {}
InlineConstant { def: _, subpattern } => visitor.visit_pat(subpattern),
Range(_) => {}
Slice { prefix, slice, suffix } | Array { prefix, slice, suffix } => {
for subpattern in prefix.iter() {
visitor.visit_pat(&subpattern);
visitor.visit_pat(subpattern);
}
if let Some(pat) = slice {
visitor.visit_pat(&pat);
visitor.visit_pat(pat);
}
for subpattern in suffix.iter() {
visitor.visit_pat(&subpattern);
visitor.visit_pat(subpattern);
}
}
Or { pats } => {
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_mir_build/src/build/matches/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,10 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
self.visit_primary_bindings(subpattern, subpattern_user_ty, f)
}

PatKind::InlineConstant { ref subpattern, .. } => {
self.visit_primary_bindings(subpattern, pattern_user_ty.clone(), f)
}

PatKind::Leaf { ref subpatterns } => {
for subpattern in subpatterns {
let subpattern_user_ty = pattern_user_ty.clone().leaf(subpattern.field);
Expand Down
10 changes: 8 additions & 2 deletions compiler/rustc_mir_build/src/build/matches/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
Err(match_pair)
}

PatKind::InlineConstant { subpattern: ref pattern, def: _ } => {
candidate.match_pairs.push(MatchPair::new(match_pair.place, pattern, self));

Ok(())
}

PatKind::Range(box PatRange { lo, hi, end }) => {
let (range, bias) = match *lo.ty().kind() {
ty::Char => {
Expand All @@ -229,8 +235,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
// correct the comparison. This is achieved by XORing with a bias (see
// pattern/_match.rs for another pertinent example of this pattern).
//
// Also, for performance, it's important to only do the second `try_to_bits` if
// necessary.
// Also, for performance, it's important to only do the second
// `try_to_bits` if necessary.
let lo = lo.try_to_bits(sz).unwrap() ^ bias;
if lo <= min {
let hi = hi.try_to_bits(sz).unwrap() ^ bias;
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_mir_build/src/build/matches/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
PatKind::Or { .. } => bug!("or-patterns should have already been handled"),

PatKind::AscribeUserType { .. }
| PatKind::InlineConstant { .. }
| PatKind::Array { .. }
| PatKind::Wild
| PatKind::Binding { .. }
Expand Down Expand Up @@ -111,6 +112,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
| PatKind::Or { .. }
| PatKind::Binding { .. }
| PatKind::AscribeUserType { .. }
| PatKind::InlineConstant { .. }
| PatKind::Leaf { .. }
| PatKind::Deref { .. }
| PatKind::Error(_) => {
Expand Down
26 changes: 15 additions & 11 deletions compiler/rustc_mir_build/src/build/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ pub(crate) fn closure_saved_names_of_captured_variables<'tcx>(
}

/// Construct the MIR for a given `DefId`.
fn mir_build(tcx: TyCtxt<'_>, def: LocalDefId) -> Body<'_> {
// Ensure unsafeck and abstract const building is ran before we steal the THIR.
tcx.ensure_with_value()
.thir_check_unsafety(tcx.typeck_root_def_id(def.to_def_id()).expect_local());
fn mir_build<'tcx>(tcx: TyCtxt<'tcx>, def: LocalDefId) -> Body<'tcx> {
tcx.ensure_with_value().thir_abstract_const(def);
if let Err(e) = tcx.check_match(def) {
return construct_error(tcx, def, e);
Expand All @@ -65,20 +62,27 @@ fn mir_build(tcx: TyCtxt<'_>, def: LocalDefId) -> Body<'_> {
let body = match tcx.thir_body(def) {
Err(error_reported) => construct_error(tcx, def, error_reported),
Ok((thir, expr)) => {
// We ran all queries that depended on THIR at the beginning
// of `mir_build`, so now we can steal it
let thir = thir.steal();
let build_mir = |thir: &Thir<'tcx>| match thir.body_type {
thir::BodyTy::Fn(fn_sig) => construct_fn(tcx, def, thir, expr, fn_sig),
thir::BodyTy::Const(ty) => construct_const(tcx, def, thir, expr, ty),
};

tcx.ensure().check_match(def);
// this must run before MIR dump, because
// "not all control paths return a value" is reported here.
//
// maybe move the check to a MIR pass?
tcx.ensure().check_liveness(def);

match thir.body_type {
thir::BodyTy::Fn(fn_sig) => construct_fn(tcx, def, &thir, expr, fn_sig),
thir::BodyTy::Const(ty) => construct_const(tcx, def, &thir, expr, ty),
if tcx.sess.opts.unstable_opts.thir_unsafeck {
// Don't steal here if THIR unsafeck is being used. Instead
// steal in unsafeck. This is so that pattern inline constants
// can be evaluated as part of building the THIR of the parent
// function without a cycle.
build_mir(&thir.borrow())
} else {
// We ran all queries that depended on THIR at the beginning
// of `mir_build`, so now we can steal it
build_mir(&thir.steal())
}
}
};
Expand Down
12 changes: 10 additions & 2 deletions compiler/rustc_mir_build/src/check_unsafety.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ impl<'tcx> UnsafetyVisitor<'_, 'tcx> {
/// Handle closures/generators/inline-consts, which is unsafecked with their parent body.
fn visit_inner_body(&mut self, def: LocalDefId) {
if let Ok((inner_thir, expr)) = self.tcx.thir_body(def) {
let inner_thir = &inner_thir.borrow();
// Runs all other queries that depend on THIR.
self.tcx.ensure_with_value().mir_built(def);
let inner_thir = &inner_thir.steal();
let hir_context = self.tcx.hir().local_def_id_to_hir_id(def);
let mut inner_visitor = UnsafetyVisitor { thir: inner_thir, hir_context, ..*self };
inner_visitor.visit_expr(&inner_thir[expr]);
Expand Down Expand Up @@ -224,6 +226,7 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
PatKind::Wild |
// these just wrap other patterns
PatKind::Or { .. } |
PatKind::InlineConstant { .. } |
PatKind::AscribeUserType { .. } |
PatKind::Error(_) => {}
}
Expand Down Expand Up @@ -277,6 +280,9 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
visit::walk_pat(self, pat);
self.inside_adt = old_inside_adt;
}
PatKind::InlineConstant { def, .. } => {
self.visit_inner_body(*def);
}
_ => {
visit::walk_pat(self, pat);
}
Expand Down Expand Up @@ -788,7 +794,9 @@ pub fn thir_check_unsafety(tcx: TyCtxt<'_>, def: LocalDefId) {
}

let Ok((thir, expr)) = tcx.thir_body(def) else { return };
let thir = &thir.borrow();
// Runs all other queries that depend on THIR.
tcx.ensure_with_value().mir_built(def);
let thir = &thir.steal();
// If `thir` is empty, a type error occurred, skip this body.
if thir.exprs.is_empty() {
return;
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_mir_build/src/thir/pattern/deconstruct_pat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1356,7 +1356,8 @@ impl<'p, 'tcx> DeconstructedPat<'p, 'tcx> {
let ctor;
let fields;
match &pat.kind {
PatKind::AscribeUserType { subpattern, .. } => return mkpat(subpattern),
PatKind::AscribeUserType { subpattern, .. }
| PatKind::InlineConstant { subpattern, .. } => return mkpat(subpattern),
PatKind::Binding { subpattern: Some(subpat), .. } => return mkpat(subpat),
PatKind::Binding { subpattern: None, .. } | PatKind::Wild => {
ctor = Wildcard;
Expand Down
46 changes: 30 additions & 16 deletions compiler/rustc_mir_build/src/thir/pattern/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use rustc_middle::ty::{
self, AdtDef, CanonicalUserTypeAnnotation, GenericArg, GenericArgsRef, Region, Ty, TyCtxt,
TypeVisitableExt, UserType,
};
use rustc_span::def_id::LocalDefId;
use rustc_span::{ErrorGuaranteed, Span, Symbol};
use rustc_target::abi::{FieldIdx, Integer};

Expand Down Expand Up @@ -88,15 +89,21 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
fn lower_pattern_range_endpoint(
&mut self,
expr: Option<&'tcx hir::Expr<'tcx>>,
) -> Result<(Option<mir::Const<'tcx>>, Option<Ascription<'tcx>>), ErrorGuaranteed> {
) -> Result<
(Option<mir::Const<'tcx>>, Option<Ascription<'tcx>>, Option<LocalDefId>),
ErrorGuaranteed,
> {
match expr {
None => Ok((None, None)),
None => Ok((None, None, None)),
Some(expr) => {
let (kind, ascr) = match self.lower_lit(expr) {
let (kind, ascr, inline_const) = match self.lower_lit(expr) {
PatKind::InlineConstant { subpattern, def } => {
(subpattern.kind, None, Some(def))
}
PatKind::AscribeUserType { ascription, subpattern: box Pat { kind, .. } } => {
(kind, Some(ascription))
(kind, Some(ascription), None)
}
kind => (kind, None),
kind => (kind, None, None),
};
let value = if let PatKind::Constant { value } = kind {
value
Expand All @@ -106,7 +113,7 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
);
return Err(self.tcx.sess.delay_span_bug(expr.span, msg));
};
Ok((Some(value), ascr))
Ok((Some(value), ascr, inline_const))
}
}
}
Expand Down Expand Up @@ -177,8 +184,8 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
return Err(self.tcx.sess.delay_span_bug(span, msg));
}

let (lo, lo_ascr) = self.lower_pattern_range_endpoint(lo_expr)?;
let (hi, hi_ascr) = self.lower_pattern_range_endpoint(hi_expr)?;
let (lo, lo_ascr, lo_inline) = self.lower_pattern_range_endpoint(lo_expr)?;
let (hi, hi_ascr, hi_inline) = self.lower_pattern_range_endpoint(hi_expr)?;

let lo = lo.unwrap_or_else(|| {
// Unwrap is ok because the type is known to be numeric.
Expand Down Expand Up @@ -237,6 +244,12 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
};
}
}
for inline_const in [lo_inline, hi_inline] {
if let Some(def) = inline_const {
kind =
PatKind::InlineConstant { def, subpattern: Box::new(Pat { span, ty, kind }) };
Copy link
Contributor

@b-naber b-naber Oct 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's that intuitive to hide the range pattern behind this nested PatKind::InlineConstant construction. Having an explicit range pattern makes this clearer imo. To me this seems to add complexity, but I don't know off the top of my head how these range patterns with inline constants are currently handled in later stages. Does this make things easier in later stages in a way that justifies introducing the complexity here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes things simpler for the consuming side: unsafety checking only cares that there are some inline constants around somewhere, pattern checking only cares that there's a range pattern whose extremities can be evaluated to bits. Carrying the unevaluated constant around felt invisible and brittle, my gut said danger.

But I am biased. I have an open PR (#116692) that reworks PatRange a lot, and I'm hoping I can eventually evaluate the consts eagerly and store bits in PatRange.

Copy link
Contributor

@b-naber b-naber Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Carrying the unevaluated constant around felt invisible and brittle, my gut said danger.

I don't understand this. Looking at const_to_pat, we always error if we find an unevaluatable constant. When we lower inline consts we currently always go through const_to_pat in the non-error case. Why would we carry unevaluated constants around?

Can't we introduce a new enum, that is used in Patkind::Constant and includes variants for InlineConst(mir::Const, LocalDefId) and mir::Const instead of modelling inline constants at the pattern level? The visitor in thir unsafety checking could still detect inline constants, pattern checking should be able to work with this afaict and we could continue to encode pattern ranges with inline constants as PatKind::Range directly, which seems less hacky than the nested PatKind::InlineConst imo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said above

Inline constants, like other constants, can be lowered to actual patterns in THIR. We could keep them as constants here and lower to patterns in both MIR building and exhaustiveness checking, but I don't think that's any better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we carry unevaluated constants around?

Unsafety checking needs to look at the body of the inline constant, so if we want to use Const to carry it around, we must not evaluate it. Pretty sure an evaluated const doesn't know about the body it came from. This is what was done in ranges before I asked for this special InlineConstant behavior.

Can't we introduce a new enum, that is used in Patkind::Constant

We can't if we have an inline constant that evaluates to something that const_to_pat will deconstruct, e.g.:

fn main() {
    match Some(0) {
        const { Some(0) } => println!("got 0!"),
        Some(0) => unreachable!(),
        _ => println!("something else"),
    }
}

Here the constant needs to be deconstructed into a Pat that represents Some(0) so that we can do proper exhaustiveness/unreachable checking with it. The resulting Pat might not have any PatKind::Constant left at the end. Hence why we attach the inline constant in the middle of the pattern.

}
}
Ok(kind)
}

Expand Down Expand Up @@ -599,11 +612,9 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
// const eval path below.
// FIXME: investigate the performance impact of removing this.
let lit_input = match expr.kind {
hir::ExprKind::Lit(ref lit) => Some(LitToConstInput { lit: &lit.node, ty, neg: false }),
hir::ExprKind::Unary(hir::UnOp::Neg, ref expr) => match expr.kind {
hir::ExprKind::Lit(ref lit) => {
Some(LitToConstInput { lit: &lit.node, ty, neg: true })
}
hir::ExprKind::Lit(lit) => Some(LitToConstInput { lit: &lit.node, ty, neg: false }),
hir::ExprKind::Unary(hir::UnOp::Neg, expr) => match expr.kind {
hir::ExprKind::Lit(lit) => Some(LitToConstInput { lit: &lit.node, ty, neg: true }),
_ => None,
},
_ => None,
Expand Down Expand Up @@ -633,13 +644,13 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
if let Ok(Some(valtree)) =
self.tcx.const_eval_resolve_for_typeck(self.param_env, ct, Some(span))
{
self.const_to_pat(
let subpattern = self.const_to_pat(
Const::Ty(ty::Const::new_value(self.tcx, valtree, ty)),
id,
span,
None,
)
.kind
);
PatKind::InlineConstant { subpattern, def: def_id }
} else {
// If that fails, convert it to an opaque constant pattern.
match tcx.const_eval_resolve(self.param_env, uneval, Some(span)) {
Expand Down Expand Up @@ -822,6 +833,9 @@ impl<'tcx> PatternFoldable<'tcx> for PatKind<'tcx> {
PatKind::Deref { subpattern: subpattern.fold_with(folder) }
}
PatKind::Constant { value } => PatKind::Constant { value },
PatKind::InlineConstant { def, subpattern: ref pattern } => {
PatKind::InlineConstant { def, subpattern: pattern.fold_with(folder) }
}
PatKind::Range(ref range) => PatKind::Range(range.clone()),
PatKind::Slice { ref prefix, ref slice, ref suffix } => PatKind::Slice {
prefix: prefix.fold_with(folder),
Expand Down
9 changes: 8 additions & 1 deletion compiler/rustc_mir_build/src/thir/print.rs
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ impl<'a, 'tcx> ThirPrinter<'a, 'tcx> {
}
PatKind::Deref { subpattern } => {
print_indented!(self, "Deref { ", depth_lvl + 1);
print_indented!(self, "subpattern: ", depth_lvl + 2);
print_indented!(self, "subpattern:", depth_lvl + 2);
self.print_pat(subpattern, depth_lvl + 2);
print_indented!(self, "}", depth_lvl + 1);
}
Expand All @@ -701,6 +701,13 @@ impl<'a, 'tcx> ThirPrinter<'a, 'tcx> {
print_indented!(self, format!("value: {:?}", value), depth_lvl + 2);
print_indented!(self, "}", depth_lvl + 1);
}
PatKind::InlineConstant { def, subpattern } => {
print_indented!(self, "InlineConstant {", depth_lvl + 1);
print_indented!(self, format!("def: {:?}", def), depth_lvl + 2);
print_indented!(self, "subpattern:", depth_lvl + 2);
self.print_pat(subpattern, depth_lvl + 2);
print_indented!(self, "}", depth_lvl + 1);
}
PatKind::Range(pat_range) => {
print_indented!(self, format!("Range ( {:?} )", pat_range), depth_lvl + 1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ LL | S::f();
= note: consult the function's documentation for information on how to avoid undefined behavior

error[E0133]: call to unsafe function is unsafe and requires unsafe function or block
--> $DIR/async-unsafe-fn-call-in-safe.rs:24:5
--> $DIR/async-unsafe-fn-call-in-safe.rs:26:5
|
LL | f();
| ^^^ call to unsafe function
Expand Down
8 changes: 6 additions & 2 deletions tests/ui/async-await/async-unsafe-fn-call-in-safe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ async fn g() {
}

fn main() {
S::f(); //[mir]~ ERROR call to unsafe function is unsafe
f(); //[mir]~ ERROR call to unsafe function is unsafe
S::f();
//[mir]~^ ERROR call to unsafe function is unsafe
//[thir]~^^ ERROR call to unsafe function `S::f` is unsafe
f();
//[mir]~^ ERROR call to unsafe function is unsafe
//[thir]~^^ ERROR call to unsafe function `f` is unsafe
}
Loading