Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to evaluate in try unify and postpone resolution of constants that contain inference variables #95179

Merged
merged 7 commits into from
Mar 25, 2022
18 changes: 10 additions & 8 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ mod sub;
pub mod type_variable;
mod undo_log;

use crate::infer::canonical::OriginalQueryValues;
pub use rustc_middle::infer::unify_key;

#[must_use]
Expand Down Expand Up @@ -695,14 +694,19 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
) -> bool {
// Reject any attempt to unify two unevaluated constants that contain inference
// variables, since inference variables in queries lead to ICEs.
if a.substs.has_infer_types_or_consts() || b.substs.has_infer_types_or_consts() {
debug!("a or b contain infer vars in its substs -> cannot unify");
if a.substs.has_infer_types_or_consts()
|| b.substs.has_infer_types_or_consts()
|| param_env.has_infer_types_or_consts()
{
debug!("a or b or param_env contain infer vars in its substs -> cannot unify");
return false;
}

let canonical = self.canonicalize_query((a, b), &mut OriginalQueryValues::default());
let erased_args = self.tcx.erase_regions((a, b));
let erased_param_env = self.tcx.erase_regions(param_env);
b-naber marked this conversation as resolved.
Show resolved Hide resolved
debug!("after erase_regions args: {:?}, param_env: {:?}", erased_args, param_env);

self.tcx.try_unify_abstract_consts(param_env.and(canonical.value))
self.tcx.try_unify_abstract_consts(erased_param_env.and(erased_args))
}

pub fn is_in_snapshot(&self) -> bool {
Expand Down Expand Up @@ -1619,9 +1623,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
// variables
if substs.has_infer_types_or_consts() {
debug!("substs have infer types or consts: {:?}", substs);
if substs.has_infer_types_or_consts() {
return Err(ErrorHandled::TooGeneric);
}
return Err(ErrorHandled::TooGeneric);
}

let param_env_erased = self.tcx.erase_regions(param_env);
Expand Down
33 changes: 12 additions & 21 deletions compiler/rustc_trait_selection/src/traits/const_evaluatable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,13 @@ fn satisfied_from_param_env<'tcx>(
match pred.kind().skip_binder() {
ty::PredicateKind::ConstEvaluatable(uv) => {
if let Some(b_ct) = AbstractConst::new(tcx, uv)? {
let const_unify_ctxt = ConstUnifyCtxt::new(tcx, param_env);

// Try to unify with each subtree in the AbstractConst to allow for
// `N + 1` being const evaluatable even if theres only a `ConstEvaluatable`
// predicate for `(N + 1) * 2`
let result = walk_abstract_const(tcx, b_ct, |b_ct| {
match try_unify(tcx, ct, b_ct, param_env) {
match const_unify_ctxt.try_unify(ct, b_ct) {
true => ControlFlow::BREAK,
false => ControlFlow::CONTINUE,
}
Expand Down Expand Up @@ -569,18 +571,6 @@ pub(super) fn thir_abstract_const<'tcx>(
}
}

/// Tries to unify two abstract constants using structural equality.
#[instrument(skip(tcx), level = "debug")]
pub(super) fn try_unify<'tcx>(
tcx: TyCtxt<'tcx>,
a: AbstractConst<'tcx>,
b: AbstractConst<'tcx>,
param_env: ty::ParamEnv<'tcx>,
) -> bool {
let const_unify_ctxt = ConstUnifyCtxt::new(tcx, param_env);
const_unify_ctxt.try_unify_inner(a, b)
}

pub(super) fn try_unify_abstract_consts<'tcx>(
tcx: TyCtxt<'tcx>,
(a, b): (ty::Unevaluated<'tcx, ()>, ty::Unevaluated<'tcx, ()>),
Expand All @@ -589,7 +579,8 @@ pub(super) fn try_unify_abstract_consts<'tcx>(
(|| {
if let Some(a) = AbstractConst::new(tcx, a)? {
if let Some(b) = AbstractConst::new(tcx, b)? {
return Ok(try_unify(tcx, a, b, param_env));
let const_unify_ctxt = ConstUnifyCtxt::new(tcx, param_env);
return Ok(const_unify_ctxt.try_unify(a, b));
}
}

Expand Down Expand Up @@ -666,7 +657,7 @@ impl<'tcx> ConstUnifyCtxt<'tcx> {

/// Tries to unify two abstract constants using structural equality.
#[instrument(skip(self), level = "debug")]
fn try_unify_inner(&self, a: AbstractConst<'tcx>, b: AbstractConst<'tcx>) -> bool {
fn try_unify(&self, a: AbstractConst<'tcx>, b: AbstractConst<'tcx>) -> bool {
let a = if let Some(a) = self.try_replace_substs_in_root(a) {
a
} else {
Expand Down Expand Up @@ -723,23 +714,23 @@ impl<'tcx> ConstUnifyCtxt<'tcx> {
}
}
(Node::Binop(a_op, al, ar), Node::Binop(b_op, bl, br)) if a_op == b_op => {
self.try_unify_inner(a.subtree(al), b.subtree(bl))
&& self.try_unify_inner(a.subtree(ar), b.subtree(br))
self.try_unify(a.subtree(al), b.subtree(bl))
&& self.try_unify(a.subtree(ar), b.subtree(br))
}
(Node::UnaryOp(a_op, av), Node::UnaryOp(b_op, bv)) if a_op == b_op => {
self.try_unify_inner(a.subtree(av), b.subtree(bv))
self.try_unify(a.subtree(av), b.subtree(bv))
}
(Node::FunctionCall(a_f, a_args), Node::FunctionCall(b_f, b_args))
if a_args.len() == b_args.len() =>
{
self.try_unify_inner(a.subtree(a_f), b.subtree(b_f))
self.try_unify(a.subtree(a_f), b.subtree(b_f))
&& iter::zip(a_args, b_args)
.all(|(&an, &bn)| self.try_unify_inner(a.subtree(an), b.subtree(bn)))
.all(|(&an, &bn)| self.try_unify(a.subtree(an), b.subtree(bn)))
}
(Node::Cast(a_kind, a_operand, a_ty), Node::Cast(b_kind, b_operand, b_ty))
if (a_ty == b_ty) && (a_kind == b_kind) =>
{
self.try_unify_inner(a.subtree(a_operand), b.subtree(b_operand))
self.try_unify(a.subtree(a_operand), b.subtree(b_operand))
}
// use this over `_ => false` to make adding variants to `Node` less error prone
(Node::Cast(..), _)
Expand Down