Skip to content

Commit

Permalink
Fix inline const pattern unsafety checking in THIR
Browse files Browse the repository at this point in the history
THIR unsafety checking was getting a cycle of
function unsafety checking
-> building THIR for the function
-> evaluating pattern inline constants in the function
-> building MIR for the inline constant
-> checking unsafety of functions (so that THIR can be stolen)
This is fixed by not stealing THIR when generating MIR but instead when
unsafety checking.
This leaves an issue with pattern inline constants not being unsafety
checked because they are evaluated away when generating THIR.
To fix that we now represent inline constants in THIR patterns and
visit them in THIR unsafety checking.
  • Loading branch information
matthewjasper committed Oct 12, 2023
1 parent 3ff244b commit 98b4c1e
Show file tree
Hide file tree
Showing 24 changed files with 239 additions and 50 deletions.
8 changes: 6 additions & 2 deletions compiler/rustc_interface/src/passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -775,12 +775,16 @@ fn analysis(tcx: TyCtxt<'_>, (): ()) -> Result<()> {
rustc_hir_analysis::check_crate(tcx)?;

sess.time("MIR_borrow_checking", || {
tcx.hir().par_body_owners(|def_id| tcx.ensure().mir_borrowck(def_id));
tcx.hir().par_body_owners(|def_id| {
// Run THIR unsafety check because it's responsible for stealing
// and deallocating THIR when enabled.
tcx.ensure().thir_check_unsafety(def_id);
tcx.ensure().mir_borrowck(def_id)
});
});

sess.time("MIR_effect_checking", || {
for def_id in tcx.hir().body_owners() {
tcx.ensure().thir_check_unsafety(def_id);
if !tcx.sess.opts.unstable_opts.thir_unsafeck {
rustc_mir_transform::check_unsafety::check_unsafety(tcx, def_id);
}
Expand Down
11 changes: 10 additions & 1 deletion compiler/rustc_middle/src/thir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,8 @@ impl<'tcx> Pat<'tcx> {
Wild | Range(..) | Binding { subpattern: None, .. } | Constant { .. } => {}
AscribeUserType { subpattern, .. }
| Binding { subpattern: Some(subpattern), .. }
| Deref { subpattern } => subpattern.walk_(it),
| Deref { subpattern }
| InlineConstant { subpattern, .. } => subpattern.walk_(it),
Leaf { subpatterns } | Variant { subpatterns, .. } => {
subpatterns.iter().for_each(|field| field.pattern.walk_(it))
}
Expand Down Expand Up @@ -748,6 +749,11 @@ pub enum PatKind<'tcx> {
value: mir::Const<'tcx>,
},

InlineConstant {
value: mir::UnevaluatedConst<'tcx>,
subpattern: Box<Pat<'tcx>>,
},

Range(Box<PatRange<'tcx>>),

/// Matches against a slice, checking the length and extracting elements.
Expand Down Expand Up @@ -904,6 +910,9 @@ impl<'tcx> fmt::Display for Pat<'tcx> {
write!(f, "{subpattern}")
}
PatKind::Constant { value } => write!(f, "{value}"),
PatKind::InlineConstant { value: _, ref subpattern } => {
write!(f, "{} (from inline const)", subpattern)
}
PatKind::Range(box PatRange { lo, hi, end }) => {
write!(f, "{lo}")?;
write!(f, "{end}")?;
Expand Down
7 changes: 4 additions & 3 deletions compiler/rustc_middle/src/thir/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,16 +233,17 @@ pub fn walk_pat<'a, 'tcx: 'a, V: Visitor<'a, 'tcx>>(visitor: &mut V, pat: &Pat<'
}
}
Constant { value: _ } => {}
InlineConstant { value: _, subpattern } => visitor.visit_pat(subpattern),
Range(_) => {}
Slice { prefix, slice, suffix } | Array { prefix, slice, suffix } => {
for subpattern in prefix.iter() {
visitor.visit_pat(&subpattern);
visitor.visit_pat(subpattern);
}
if let Some(pat) = slice {
visitor.visit_pat(&pat);
visitor.visit_pat(pat);
}
for subpattern in suffix.iter() {
visitor.visit_pat(&subpattern);
visitor.visit_pat(subpattern);
}
}
Or { pats } => {
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_mir_build/src/build/matches/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,10 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
self.visit_primary_bindings(subpattern, subpattern_user_ty, f)
}

PatKind::InlineConstant { ref subpattern, .. } => {
self.visit_primary_bindings(subpattern, pattern_user_ty.clone(), f)
}

PatKind::Leaf { ref subpatterns } => {
for subpattern in subpatterns {
let subpattern_user_ty = pattern_user_ty.clone().leaf(subpattern.field);
Expand Down
24 changes: 20 additions & 4 deletions compiler/rustc_mir_build/src/build/matches/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
Err(match_pair)
}

PatKind::InlineConstant { subpattern: ref pattern, value: _ } => {
candidate.match_pairs.push(MatchPair::new(match_pair.place, pattern, self));

Ok(())
}

PatKind::Range(box PatRange { lo, hi, end }) => {
let (range, bias) = match *lo.ty().kind() {
ty::Char => {
Expand All @@ -229,11 +235,21 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
// correct the comparison. This is achieved by XORing with a bias (see
// pattern/_match.rs for another pertinent example of this pattern).
//
// Also, for performance, it's important to only do the second `try_to_bits` if
// necessary.
let lo = lo.try_to_bits(sz).unwrap() ^ bias;
// Also, for performance, it's important to only do the second
// `try_eval_scalar_int` if necessary.
let lo = lo
.try_eval_scalar_int(self.tcx, self.param_env)
.unwrap()
.to_bits(sz)
.unwrap()
^ bias;
if lo <= min {
let hi = hi.try_to_bits(sz).unwrap() ^ bias;
let hi = hi
.try_eval_scalar_int(self.tcx, self.param_env)
.unwrap()
.to_bits(sz)
.unwrap()
^ bias;
if hi > max || hi == max && end == RangeEnd::Included {
// Irrefutable pattern match.
return Ok(());
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_mir_build/src/build/matches/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
PatKind::Or { .. } => bug!("or-patterns should have already been handled"),

PatKind::AscribeUserType { .. }
| PatKind::InlineConstant { .. }
| PatKind::Array { .. }
| PatKind::Wild
| PatKind::Binding { .. }
Expand Down Expand Up @@ -110,6 +111,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
| PatKind::Or { .. }
| PatKind::Binding { .. }
| PatKind::AscribeUserType { .. }
| PatKind::InlineConstant { .. }
| PatKind::Leaf { .. }
| PatKind::Deref { .. } => {
// don't know how to add these patterns to a switch
Expand Down
25 changes: 15 additions & 10 deletions compiler/rustc_mir_build/src/build/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ pub(crate) fn closure_saved_names_of_captured_variables<'tcx>(
}

/// Construct the MIR for a given `DefId`.
fn mir_build(tcx: TyCtxt<'_>, def: LocalDefId) -> Body<'_> {
// Ensure unsafeck and abstract const building is ran before we steal the THIR.
tcx.ensure_with_value()
.thir_check_unsafety(tcx.typeck_root_def_id(def.to_def_id()).expect_local());
fn mir_build<'tcx>(tcx: TyCtxt<'tcx>, def: LocalDefId) -> Body<'tcx> {
tcx.ensure_with_value().thir_abstract_const(def);
if let Err(e) = tcx.check_match(def) {
return construct_error(tcx, def, e);
Expand All @@ -65,9 +62,10 @@ fn mir_build(tcx: TyCtxt<'_>, def: LocalDefId) -> Body<'_> {
let body = match tcx.thir_body(def) {
Err(error_reported) => construct_error(tcx, def, error_reported),
Ok((thir, expr)) => {
// We ran all queries that depended on THIR at the beginning
// of `mir_build`, so now we can steal it
let thir = thir.steal();
let build_mir = |thir: &Thir<'tcx>| match thir.body_type {
thir::BodyTy::Fn(fn_sig) => construct_fn(tcx, def, thir, expr, fn_sig),
thir::BodyTy::Const(ty) => construct_const(tcx, def, thir, expr, ty),
};

tcx.ensure().check_match(def);
// this must run before MIR dump, because
Expand All @@ -76,9 +74,16 @@ fn mir_build(tcx: TyCtxt<'_>, def: LocalDefId) -> Body<'_> {
// maybe move the check to a MIR pass?
tcx.ensure().check_liveness(def);

match thir.body_type {
thir::BodyTy::Fn(fn_sig) => construct_fn(tcx, def, &thir, expr, fn_sig),
thir::BodyTy::Const(ty) => construct_const(tcx, def, &thir, expr, ty),
if tcx.sess.opts.unstable_opts.thir_unsafeck {
// Don't steal here if THIR unsafeck is being used. Instead
// steal in unsafeck. This is so that pattern inline constants
// can be evaluated as part of building the THIR of the parent
// function without a cycle.
build_mir(&thir.borrow())
} else {
// We ran all queries that depended on THIR at the beginning
// of `mir_build`, so now we can steal it
build_mir(&thir.steal())
}
}
};
Expand Down
27 changes: 24 additions & 3 deletions compiler/rustc_mir_build/src/check_unsafety.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::errors::*;
use rustc_middle::thir::visit::{self, Visitor};

use rustc_hir as hir;
use rustc_middle::mir::BorrowKind;
use rustc_middle::mir::{BorrowKind, Const};
use rustc_middle::thir::*;
use rustc_middle::ty::print::with_no_trimmed_paths;
use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt};
Expand Down Expand Up @@ -124,7 +124,8 @@ impl<'tcx> UnsafetyVisitor<'_, 'tcx> {
/// Handle closures/generators/inline-consts, which is unsafecked with their parent body.
fn visit_inner_body(&mut self, def: LocalDefId) {
if let Ok((inner_thir, expr)) = self.tcx.thir_body(def) {
let inner_thir = &inner_thir.borrow();
let _ = self.tcx.ensure_with_value().mir_built(def);
let inner_thir = &inner_thir.steal();
let hir_context = self.tcx.hir().local_def_id_to_hir_id(def);
let mut inner_visitor = UnsafetyVisitor { thir: inner_thir, hir_context, ..*self };
inner_visitor.visit_expr(&inner_thir[expr]);
Expand Down Expand Up @@ -224,6 +225,7 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
PatKind::Wild |
// these just wrap other patterns
PatKind::Or { .. } |
PatKind::InlineConstant { .. } |
PatKind::AscribeUserType { .. } => {}
}
};
Expand Down Expand Up @@ -276,6 +278,24 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
visit::walk_pat(self, pat);
self.inside_adt = old_inside_adt;
}
PatKind::Range(range) => {
if let Const::Unevaluated(c, _) = range.lo {
if let hir::def::DefKind::InlineConst = self.tcx.def_kind(c.def) {
let def_id = c.def.expect_local();
self.visit_inner_body(def_id);
}
}
if let Const::Unevaluated(c, _) = range.hi {
if let hir::def::DefKind::InlineConst = self.tcx.def_kind(c.def) {
let def_id = c.def.expect_local();
self.visit_inner_body(def_id);
}
}
}
PatKind::InlineConstant { value, .. } => {
let def_id = value.def.expect_local();
self.visit_inner_body(def_id);
}
_ => {
visit::walk_pat(self, pat);
}
Expand Down Expand Up @@ -784,7 +804,8 @@ pub fn thir_check_unsafety(tcx: TyCtxt<'_>, def: LocalDefId) {
}

let Ok((thir, expr)) = tcx.thir_body(def) else { return };
let thir = &thir.borrow();
let _ = tcx.ensure_with_value().mir_built(def);
let thir = &thir.steal();
// If `thir` is empty, a type error occurred, skip this body.
if thir.exprs.is_empty() {
return;
Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_mir_build/src/thir/pattern/deconstruct_pat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,8 @@ impl<'p, 'tcx> DeconstructedPat<'p, 'tcx> {
let ctor;
let fields;
match &pat.kind {
PatKind::AscribeUserType { subpattern, .. } => return mkpat(subpattern),
PatKind::AscribeUserType { subpattern, .. }
| PatKind::InlineConstant { subpattern, .. } => return mkpat(subpattern),
PatKind::Binding { subpattern: Some(subpat), .. } => return mkpat(subpat),
PatKind::Binding { subpattern: None, .. } | PatKind::Wild => {
ctor = Wildcard;
Expand Down
13 changes: 10 additions & 3 deletions compiler/rustc_mir_build/src/thir/pattern/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
None => Ok((None, None)),
Some(expr) => {
let (kind, ascr) = match self.lower_lit(expr) {
PatKind::InlineConstant { subpattern, value } => (
PatKind::Constant { value: Const::Unevaluated(value, subpattern.ty) },
None,
),
PatKind::AscribeUserType { ascription, subpattern: box Pat { kind, .. } } => {
(kind, Some(ascription))
}
Expand Down Expand Up @@ -636,13 +640,13 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
if let Ok(Some(valtree)) =
self.tcx.const_eval_resolve_for_typeck(self.param_env, ct, Some(span))
{
self.const_to_pat(
let subpattern = self.const_to_pat(
Const::Ty(ty::Const::new_value(self.tcx, valtree, ty)),
id,
span,
None,
)
.kind
);
PatKind::InlineConstant { subpattern, value: uneval }
} else {
// If that fails, convert it to an opaque constant pattern.
match tcx.const_eval_resolve(self.param_env, uneval, Some(span)) {
Expand Down Expand Up @@ -824,6 +828,9 @@ impl<'tcx> PatternFoldable<'tcx> for PatKind<'tcx> {
PatKind::Deref { subpattern: subpattern.fold_with(folder) }
}
PatKind::Constant { value } => PatKind::Constant { value },
PatKind::InlineConstant { value, subpattern: ref pattern } => {
PatKind::InlineConstant { value, subpattern: pattern.fold_with(folder) }
}
PatKind::Range(ref range) => PatKind::Range(range.clone()),
PatKind::Slice { ref prefix, ref slice, ref suffix } => PatKind::Slice {
prefix: prefix.fold_with(folder),
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_mir_build/src/thir/print.rs
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,13 @@ impl<'a, 'tcx> ThirPrinter<'a, 'tcx> {
print_indented!(self, format!("value: {:?}", value), depth_lvl + 2);
print_indented!(self, "}", depth_lvl + 1);
}
PatKind::InlineConstant { value, subpattern } => {
print_indented!(self, "InlineConstant {", depth_lvl + 1);
print_indented!(self, format!("value: {:?}", value), depth_lvl + 2);
print_indented!(self, "subpattern: ", depth_lvl + 2);
self.print_pat(subpattern, depth_lvl + 2);
print_indented!(self, "}", depth_lvl + 1);
}
PatKind::Range(pat_range) => {
print_indented!(self, format!("Range ( {:?} )", pat_range), depth_lvl + 1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ LL | S::f();
= note: consult the function's documentation for information on how to avoid undefined behavior

error[E0133]: call to unsafe function is unsafe and requires unsafe function or block
--> $DIR/async-unsafe-fn-call-in-safe.rs:24:5
--> $DIR/async-unsafe-fn-call-in-safe.rs:26:5
|
LL | f();
| ^^^ call to unsafe function
Expand Down
8 changes: 6 additions & 2 deletions tests/ui/async-await/async-unsafe-fn-call-in-safe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ async fn g() {
}

fn main() {
S::f(); //[mir]~ ERROR call to unsafe function is unsafe
f(); //[mir]~ ERROR call to unsafe function is unsafe
S::f();
//[mir]~^ ERROR call to unsafe function is unsafe
//[thir]~^^ ERROR call to unsafe function `S::f` is unsafe
f();
//[mir]~^ ERROR call to unsafe function is unsafe
//[thir]~^^ ERROR call to unsafe function `f` is unsafe
}
18 changes: 17 additions & 1 deletion tests/ui/async-await/async-unsafe-fn-call-in-safe.thir.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@ LL | f();
|
= note: consult the function's documentation for information on how to avoid undefined behavior

error: aborting due to 2 previous errors
error[E0133]: call to unsafe function `S::f` is unsafe and requires unsafe function or block
--> $DIR/async-unsafe-fn-call-in-safe.rs:23:5
|
LL | S::f();
| ^^^^^^ call to unsafe function
|
= note: consult the function's documentation for information on how to avoid undefined behavior

error[E0133]: call to unsafe function `f` is unsafe and requires unsafe function or block
--> $DIR/async-unsafe-fn-call-in-safe.rs:26:5
|
LL | f();
| ^^^ call to unsafe function
|
= note: consult the function's documentation for information on how to avoid undefined behavior

error: aborting due to 4 previous errors

For more information about this error, try `rustc --explain E0133`.
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ fn main() {
//[thir]~^^ call to unsafe function `foo` is unsafe and requires unsafe function or block
foo();
//[mir]~^ ERROR call to unsafe function is unsafe and requires unsafe function or block
//[thir]~^^ ERROR call to unsafe function `foo` is unsafe and requires unsafe function or block
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
error[E0133]: call to unsafe function `foo` is unsafe and requires unsafe function or block
--> $DIR/const-extern-fn-requires-unsafe.rs:12:5
|
LL | foo();
| ^^^^^ call to unsafe function
|
= note: consult the function's documentation for information on how to avoid undefined behavior

error[E0133]: call to unsafe function `foo` is unsafe and requires unsafe function or block
--> $DIR/const-extern-fn-requires-unsafe.rs:9:17
|
Expand All @@ -6,6 +14,6 @@ LL | let a: [u8; foo()];
|
= note: consult the function's documentation for information on how to avoid undefined behavior

error: aborting due to previous error
error: aborting due to 2 previous errors

For more information about this error, try `rustc --explain E0133`.
Loading

0 comments on commit 98b4c1e

Please sign in to comment.