Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use placeholders to prevent using inferred RPITIT types to imply their own well-formedness #116072

Merged
merged 2 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 88 additions & 8 deletions compiler/rustc_hir_analysis/src/check/compare_impl_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKi
use rustc_infer::infer::{self, InferCtxt, TyCtxtInferExt};
use rustc_infer::traits::util;
use rustc_middle::ty::error::{ExpectedFound, TypeError};
use rustc_middle::ty::fold::BottomUpFolder;
use rustc_middle::ty::util::ExplicitSelf;
use rustc_middle::ty::{
self, GenericArgs, Ty, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt,
Expand Down Expand Up @@ -661,8 +662,6 @@ pub(super) fn collect_return_position_impl_trait_in_trait_tys<'tcx>(
let trait_m = tcx.opt_associated_item(impl_m.trait_item_def_id.unwrap()).unwrap();
let impl_trait_ref =
tcx.impl_trait_ref(impl_m.impl_container(tcx).unwrap()).unwrap().instantiate_identity();
let param_env = tcx.param_env(impl_m_def_id);

// First, check a few of the same things as `compare_impl_method`,
// just so we don't ICE during substitution later.
check_method_is_structurally_compatible(tcx, impl_m, trait_m, impl_trait_ref, true)?;
Expand All @@ -688,13 +687,26 @@ pub(super) fn collect_return_position_impl_trait_in_trait_tys<'tcx>(
let trait_to_placeholder_args =
impl_to_placeholder_args.rebase_onto(tcx, impl_m.container_id(tcx), trait_to_impl_args);

let hybrid_preds = tcx
.predicates_of(impl_m.container_id(tcx))
.instantiate_identity(tcx)
.into_iter()
.chain(tcx.predicates_of(trait_m.def_id).instantiate_own(tcx, trait_to_placeholder_args))
.map(|(clause, _)| clause);
let param_env = ty::ParamEnv::new(tcx.mk_clauses_from_iter(hybrid_preds), Reveal::UserFacing);
let param_env = traits::normalize_param_env_or_error(
tcx,
param_env,
ObligationCause::misc(tcx.def_span(impl_m_def_id), impl_m_def_id),
);

let infcx = &tcx.infer_ctxt().build();
let ocx = ObligationCtxt::new(infcx);

// Normalize the impl signature with fresh variables for lifetime inference.
let norm_cause = ObligationCause::misc(return_span, impl_m_def_id);
let misc_cause = ObligationCause::misc(return_span, impl_m_def_id);
let impl_sig = ocx.normalize(
&norm_cause,
&misc_cause,
param_env,
tcx.liberate_late_bound_regions(
impl_m.def_id,
Expand Down Expand Up @@ -725,12 +737,68 @@ pub(super) fn collect_return_position_impl_trait_in_trait_tys<'tcx>(
);
}

let trait_sig = ocx.normalize(&norm_cause, param_env, unnormalized_trait_sig);
let trait_sig = ocx.normalize(&misc_cause, param_env, unnormalized_trait_sig);
trait_sig.error_reported()?;
let trait_return_ty = trait_sig.output();

// RPITITs are allowed to use the implied predicates of the method that
// defines them. This is because we want code like:
// ```
// trait Foo {
// fn test<'a, T>(_: &'a T) -> impl Sized;
// }
// impl Foo for () {
// fn test<'a, T>(x: &'a T) -> &'a T { x }
// }
// ```
// .. to compile. However, since we use both the normalized and unnormalized
// inputs and outputs from the substituted trait signature, we will end up
// seeing the hidden type of an RPIT in the signature itself. Naively, this
// means that we will use the hidden type to imply the hidden type's own
// well-formedness.
//
// To avoid this, we replace the infer vars used for hidden type inference
// with placeholders, which imply nothing about outlives bounds, and then
// prove below that the hidden types are well formed.
let universe = infcx.create_next_universe();
let mut idx = 0;
let mapping: FxHashMap<_, _> = collector
.types
.iter()
.map(|(_, &(ty, _))| {
assert!(
infcx.resolve_vars_if_possible(ty) == ty && ty.is_ty_var(),
"{ty:?} should not have been constrained via normalization",
ty = infcx.resolve_vars_if_possible(ty)
);
idx += 1;
(
ty,
Ty::new_placeholder(
tcx,
ty::Placeholder {
universe,
bound: ty::BoundTy {
var: ty::BoundVar::from_usize(idx),
kind: ty::BoundTyKind::Anon,
},
},
),
)
})
.collect();
let mut type_mapper = BottomUpFolder {
tcx,
ty_op: |ty| *mapping.get(&ty).unwrap_or(&ty),
lt_op: |lt| lt,
ct_op: |ct| ct,
};
let wf_tys = FxIndexSet::from_iter(
unnormalized_trait_sig.inputs_and_output.iter().chain(trait_sig.inputs_and_output.iter()),
unnormalized_trait_sig
.inputs_and_output
.iter()
.chain(trait_sig.inputs_and_output.iter())
.map(|ty| ty.fold_with(&mut type_mapper)),
);

match ocx.eq(&cause, param_env, trait_return_ty, impl_return_ty) {
Expand Down Expand Up @@ -787,6 +855,20 @@ pub(super) fn collect_return_position_impl_trait_in_trait_tys<'tcx>(
}
}

// FIXME: This has the same issue as #108544, but since this isn't breaking
// existing code, I'm not particularly inclined to do the same hack as above
// where we process wf obligations manually. This can be fixed in a forward-
// compatible way later.
let collected_types = collector.types;
for (_, &(ty, _)) in &collected_types {
ocx.register_obligation(traits::Obligation::new(
tcx,
misc_cause.clone(),
param_env,
ty::ClauseKind::WellFormed(ty.into()),
));
}

// Check that all obligations are satisfied by the implementation's
// RPITs.
let errors = ocx.select_all_or_error();
Expand All @@ -795,8 +877,6 @@ pub(super) fn collect_return_position_impl_trait_in_trait_tys<'tcx>(
return Err(reported);
}

let collected_types = collector.types;

// Finally, resolve all regions. This catches wily misuses of
// lifetime parameters.
let outlives_env = OutlivesEnvironment::with_bounds(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#![feature(return_position_impl_trait_in_trait)]

trait Extend {
fn extend<'a: 'a>(_: &'a str) -> (impl Sized + 'a, &'static str);
}

impl Extend for () {
fn extend<'a: 'a>(s: &'a str) -> (Option<&'static &'a ()>, &'static str)
//~^ ERROR in type `&'static &'a ()`, reference has a longer lifetime than the data it references
where
'a: 'static,
{
(None, s)
}
}

// This indirection is not necessary for reproduction,
// but it makes this test future-proof against #114936.
fn extend<T: Extend>(s: &str) -> &'static str {
<T as Extend>::extend(s).1
}

fn main() {
let use_after_free = extend::<()>(&String::from("temporary"));
println!("{}", use_after_free);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
error[E0491]: in type `&'static &'a ()`, reference has a longer lifetime than the data it references
--> $DIR/rpitit-hidden-types-self-implied-wf-via-param.rs:8:38
|
LL | fn extend<'a: 'a>(s: &'a str) -> (Option<&'static &'a ()>, &'static str)
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
= note: the pointer is valid for the static lifetime
note: but the referenced data is only valid for the lifetime `'a` as defined here
--> $DIR/rpitit-hidden-types-self-implied-wf-via-param.rs:8:15
|
LL | fn extend<'a: 'a>(s: &'a str) -> (Option<&'static &'a ()>, &'static str)
| ^^

error: aborting due to previous error

For more information about this error, try `rustc --explain E0491`.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#![feature(return_position_impl_trait_in_trait)]

trait Extend {
fn extend(_: &str) -> (impl Sized + '_, &'static str);
}

impl Extend for () {
fn extend(s: &str) -> (Option<&'static &'_ ()>, &'static str) {
//~^ ERROR in type `&'static &()`, reference has a longer lifetime than the data it references
(None, s)
}
}

// This indirection is not necessary for reproduction,
// but it makes this test future-proof against #114936.
fn extend<T: Extend>(s: &str) -> &'static str {
<T as Extend>::extend(s).1
}

fn main() {
let use_after_free = extend::<()>(&String::from("temporary"));
println!("{}", use_after_free);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
error[E0491]: in type `&'static &()`, reference has a longer lifetime than the data it references
--> $DIR/rpitit-hidden-types-self-implied-wf.rs:8:27
|
LL | fn extend(s: &str) -> (Option<&'static &'_ ()>, &'static str) {
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
= note: the pointer is valid for the static lifetime
note: but the referenced data is only valid for the anonymous lifetime defined here
--> $DIR/rpitit-hidden-types-self-implied-wf.rs:8:18
|
LL | fn extend(s: &str) -> (Option<&'static &'_ ()>, &'static str) {
| ^^^^

error: aborting due to previous error

For more information about this error, try `rustc --explain E0491`.