From c3f9c4f4d4bbc83c7de79a09c7ec0e7fda8efc5e Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Tue, 20 Aug 2024 11:35:05 -0400 Subject: [PATCH 1/2] Use equality when relating formal and expected type in arg checking --- .../rustc_hir_typeck/src/fn_ctxt/checks.rs | 9 ++++----- .../coercion/constrain-expectation-in-arg.rs | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+), 5 deletions(-) create mode 100644 tests/ui/coercion/constrain-expectation-in-arg.rs diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs index eebb0217990df..16d65726128c3 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs @@ -292,21 +292,20 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { let coerce_error = self.coerce(provided_arg, checked_ty, coerced_ty, AllowTwoPhase::Yes, None).err(); - if coerce_error.is_some() { return Compatibility::Incompatible(coerce_error); } - // 3. Check if the formal type is a supertype of the checked one - // and register any such obligations for future type checks - let supertype_error = self.at(&self.misc(provided_arg.span), self.param_env).sup( + // 3. Check if the formal type is actually equal to the checked one + // and register any such obligations for future type checks. + let formal_ty_error = self.at(&self.misc(provided_arg.span), self.param_env).eq( DefineOpaqueTypes::Yes, formal_input_ty, coerced_ty, ); // If neither check failed, the types are compatible - match supertype_error { + match formal_ty_error { Ok(InferOk { obligations, value: () }) => { self.register_predicates(obligations); Compatibility::Compatible diff --git a/tests/ui/coercion/constrain-expectation-in-arg.rs b/tests/ui/coercion/constrain-expectation-in-arg.rs new file mode 100644 index 0000000000000..858c3a0bdb572 --- /dev/null +++ b/tests/ui/coercion/constrain-expectation-in-arg.rs @@ -0,0 +1,19 @@ +//@ check-pass + +trait Trait { + type Item; +} + +struct Struct, B> { + pub field: A, +} + +fn identity(x: T) -> T { + x +} + +fn test, B>(x: &Struct) { + let x: &Struct<_, _> = identity(x); +} + +fn main() {} From 95b9ecd6d671637e9e3db55ed31d06882d3cad4d Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Sun, 25 Aug 2024 12:45:58 -0400 Subject: [PATCH 2/2] Inline expected_inputs_for_expected_output into check_argument_types/check_expr_struct_fields --- compiler/rustc_hir_typeck/src/callee.rs | 21 ++----- compiler/rustc_hir_typeck/src/expr.rs | 25 +++++--- .../rustc_hir_typeck/src/fn_ctxt/_impl.rs | 39 +------------ .../rustc_hir_typeck/src/fn_ctxt/checks.rs | 58 ++++++++++++++----- .../coercion/constrain-expectation-in-arg.rs | 5 ++ 5 files changed, 71 insertions(+), 77 deletions(-) diff --git a/compiler/rustc_hir_typeck/src/callee.rs b/compiler/rustc_hir_typeck/src/callee.rs index a4eec5f05a8ff..9863d0364498e 100644 --- a/compiler/rustc_hir_typeck/src/callee.rs +++ b/compiler/rustc_hir_typeck/src/callee.rs @@ -503,18 +503,12 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { let fn_sig = self.instantiate_binder_with_fresh_vars(call_expr.span, infer::FnCall, fn_sig); let fn_sig = self.normalize(call_expr.span, fn_sig); - // Call the generic checker. - let expected_arg_tys = self.expected_inputs_for_expected_output( - call_expr.span, - expected, - fn_sig.output(), - fn_sig.inputs(), - ); self.check_argument_types( call_expr.span, call_expr, fn_sig.inputs(), - expected_arg_tys, + fn_sig.output(), + expected, arg_exprs, fn_sig.c_variadic, TupleArgumentsFlag::DontTupleArguments, @@ -866,19 +860,12 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { // don't know the full details yet (`Fn` vs `FnMut` etc), but we // do know the types expected for each argument and the return // type. - - let expected_arg_tys = self.expected_inputs_for_expected_output( - call_expr.span, - expected, - fn_sig.output(), - fn_sig.inputs(), - ); - self.check_argument_types( call_expr.span, call_expr, fn_sig.inputs(), - expected_arg_tys, + fn_sig.output(), + expected, arg_exprs, fn_sig.c_variadic, TupleArgumentsFlag::TupleArguments, diff --git a/compiler/rustc_hir_typeck/src/expr.rs b/compiler/rustc_hir_typeck/src/expr.rs index 1362d3626efd4..f0d47e584ac28 100644 --- a/compiler/rustc_hir_typeck/src/expr.rs +++ b/compiler/rustc_hir_typeck/src/expr.rs @@ -1673,15 +1673,22 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { ) { let tcx = self.tcx; - let expected_inputs = - self.expected_inputs_for_expected_output(span, expected, adt_ty, &[adt_ty]); - let adt_ty_hint = if let Some(expected_inputs) = expected_inputs { - expected_inputs.get(0).cloned().unwrap_or(adt_ty) - } else { - adt_ty - }; - // re-link the regions that EIfEO can erase. - self.demand_eqtype(span, adt_ty_hint, adt_ty); + let adt_ty = self.resolve_vars_with_obligations(adt_ty); + let adt_ty_hint = expected.only_has_type(self).and_then(|expected| { + self.fudge_inference_if_ok(|| { + let ocx = ObligationCtxt::new(self); + ocx.sup(&self.misc(span), self.param_env, expected, adt_ty)?; + if !ocx.select_where_possible().is_empty() { + return Err(TypeError::Mismatch); + } + Ok(self.resolve_vars_if_possible(adt_ty)) + }) + .ok() + }); + if let Some(adt_ty_hint) = adt_ty_hint { + // re-link the variables that the fudging above can create. + self.demand_eqtype(span, adt_ty_hint, adt_ty); + } let ty::Adt(adt, args) = adt_ty.kind() else { span_bug!(span, "non-ADT passed to check_expr_struct_fields"); diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs index 97c27680959f0..19f7950287f93 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs @@ -20,7 +20,6 @@ use rustc_infer::infer::canonical::{Canonical, OriginalQueryValues, QueryRespons use rustc_infer::infer::{DefineOpaqueTypes, InferResult}; use rustc_lint::builtin::SELF_CONSTRUCTOR_FROM_OUTER_ITEM; use rustc_middle::ty::adjustment::{Adjust, Adjustment, AutoBorrow, AutoBorrowMutability}; -use rustc_middle::ty::error::TypeError; use rustc_middle::ty::fold::TypeFoldable; use rustc_middle::ty::visit::{TypeVisitable, TypeVisitableExt}; use rustc_middle::ty::{ @@ -36,7 +35,7 @@ use rustc_span::Span; use rustc_target::abi::FieldIdx; use rustc_trait_selection::error_reporting::infer::need_type_info::TypeAnnotationNeeded; use rustc_trait_selection::traits::{ - self, NormalizeExt, ObligationCauseCode, ObligationCtxt, StructurallyNormalizeExt, + self, NormalizeExt, ObligationCauseCode, StructurallyNormalizeExt, }; use tracing::{debug, instrument}; @@ -689,42 +688,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { vec![ty_error; len] } - /// Unifies the output type with the expected type early, for more coercions - /// and forward type information on the input expressions. - #[instrument(skip(self, call_span), level = "debug")] - pub(crate) fn expected_inputs_for_expected_output( - &self, - call_span: Span, - expected_ret: Expectation<'tcx>, - formal_ret: Ty<'tcx>, - formal_args: &[Ty<'tcx>], - ) -> Option>> { - let formal_ret = self.resolve_vars_with_obligations(formal_ret); - let ret_ty = expected_ret.only_has_type(self)?; - - let expect_args = self - .fudge_inference_if_ok(|| { - let ocx = ObligationCtxt::new(self); - - // Attempt to apply a subtyping relationship between the formal - // return type (likely containing type variables if the function - // is polymorphic) and the expected return type. - // No argument expectations are produced if unification fails. - let origin = self.misc(call_span); - ocx.sup(&origin, self.param_env, ret_ty, formal_ret)?; - if !ocx.select_where_possible().is_empty() { - return Err(TypeError::Mismatch); - } - - // Record all the argument types, with the args - // produced from the above subtyping unification. - Ok(Some(formal_args.iter().map(|&ty| self.resolve_vars_if_possible(ty)).collect())) - }) - .unwrap_or_default(); - debug!(?formal_args, ?formal_ret, ?expect_args, ?expected_ret); - expect_args - } - pub(crate) fn resolve_lang_item_path( &self, lang_item: hir::LangItem, diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs index 16d65726128c3..bdf84f332166d 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs @@ -17,6 +17,7 @@ use rustc_hir_analysis::hir_ty_lowering::HirTyLowerer; use rustc_index::IndexVec; use rustc_infer::infer::{DefineOpaqueTypes, InferOk, TypeTrace}; use rustc_middle::ty::adjustment::AllowTwoPhase; +use rustc_middle::ty::error::TypeError; use rustc_middle::ty::visit::TypeVisitableExt; use rustc_middle::ty::{self, IsSuggestable, Ty, TyCtxt}; use rustc_middle::{bug, span_bug}; @@ -25,7 +26,7 @@ use rustc_span::symbol::{kw, Ident}; use rustc_span::{sym, Span, DUMMY_SP}; use rustc_trait_selection::error_reporting::infer::{FailureCode, ObligationCauseExt}; use rustc_trait_selection::infer::InferCtxtExt; -use rustc_trait_selection::traits::{self, ObligationCauseCode, SelectionContext}; +use rustc_trait_selection::traits::{self, ObligationCauseCode, ObligationCtxt, SelectionContext}; use tracing::debug; use {rustc_ast as ast, rustc_hir as hir}; @@ -124,6 +125,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { }; if let Err(guar) = has_error { let err_inputs = self.err_args(args_no_rcvr.len(), guar); + let err_output = Ty::new_error(self.tcx, guar); let err_inputs = match tuple_arguments { DontTupleArguments => err_inputs, @@ -134,28 +136,23 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { sp, expr, &err_inputs, - None, + err_output, + NoExpectation, args_no_rcvr, false, tuple_arguments, method.ok().map(|method| method.def_id), ); - return Ty::new_error(self.tcx, guar); + return err_output; } let method = method.unwrap(); - // HACK(eddyb) ignore self in the definition (see above). - let expected_input_tys = self.expected_inputs_for_expected_output( - sp, - expected, - method.sig.output(), - &method.sig.inputs()[1..], - ); self.check_argument_types( sp, expr, &method.sig.inputs()[1..], - expected_input_tys, + method.sig.output(), + expected, args_no_rcvr, method.sig.c_variadic, tuple_arguments, @@ -175,8 +172,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { call_expr: &'tcx hir::Expr<'tcx>, // Types (as defined in the *signature* of the target function) formal_input_tys: &[Ty<'tcx>], - // More specific expected types, after unifying with caller output types - expected_input_tys: Option>>, + formal_output: Ty<'tcx>, + // Expected output from the parent expression or statement + expectation: Expectation<'tcx>, // The expressions for each provided argument provided_args: &'tcx [hir::Expr<'tcx>], // Whether the function is variadic, for example when imported from C @@ -210,6 +208,40 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { ); } + // First, let's unify the formal method signature with the expectation eagerly. + // We use this to guide coercion inference; it's output is "fudged" which means + // any remaining type variables are assigned to new, unrelated variables. This + // is because the inference guidance here is only speculative. + let formal_output = self.resolve_vars_with_obligations(formal_output); + let expected_input_tys: Option> = expectation + .only_has_type(self) + .and_then(|expected_output| { + self.fudge_inference_if_ok(|| { + let ocx = ObligationCtxt::new(self); + + // Attempt to apply a subtyping relationship between the formal + // return type (likely containing type variables if the function + // is polymorphic) and the expected return type. + // No argument expectations are produced if unification fails. + let origin = self.misc(call_span); + ocx.sup(&origin, self.param_env, expected_output, formal_output)?; + if !ocx.select_where_possible().is_empty() { + return Err(TypeError::Mismatch); + } + + // Record all the argument types, with the args + // produced from the above subtyping unification. + Ok(Some( + formal_input_tys + .iter() + .map(|&ty| self.resolve_vars_if_possible(ty)) + .collect(), + )) + }) + .ok() + }) + .unwrap_or_default(); + let mut err_code = E0061; // If the arguments should be wrapped in a tuple (ex: closures), unwrap them here diff --git a/tests/ui/coercion/constrain-expectation-in-arg.rs b/tests/ui/coercion/constrain-expectation-in-arg.rs index 858c3a0bdb572..c515dedc4bb4d 100644 --- a/tests/ui/coercion/constrain-expectation-in-arg.rs +++ b/tests/ui/coercion/constrain-expectation-in-arg.rs @@ -1,5 +1,10 @@ //@ check-pass +// Regression test for for #129286. +// Makes sure that we don't have unconstrained type variables that come from +// bivariant type parameters due to the way that we construct expectation types +// when checking call expressions in HIR typeck. + trait Trait { type Item; }