From 7ebe20315f05041f68f7f99cd8491480ea5c0906 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Sun, 19 May 2024 13:10:05 +0800 Subject: [PATCH] typeintersect: followup cleanup for the nothrow path of type instantiation (#54514) (cherry picked from commit af545b90c5657e0e210aa638542052e16ca90fef) --- src/jltypes.c | 109 ++++++++++++++++++++++++++----------------- src/julia_internal.h | 2 +- src/subtype.c | 12 +++-- test/subtype.jl | 21 ++++++++- 4 files changed, 94 insertions(+), 50 deletions(-) diff --git a/src/jltypes.c b/src/jltypes.c index b39997ada2dd8..dd26d1df06625 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -1468,11 +1468,11 @@ jl_unionall_t *jl_rename_unionall(jl_unionall_t *u) return (jl_unionall_t*)t; } -jl_value_t *jl_substitute_var_nothrow(jl_value_t *t, jl_tvar_t *var, jl_value_t *val) +jl_value_t *jl_substitute_var_nothrow(jl_value_t *t, jl_tvar_t *var, jl_value_t *val, int nothrow) { if (val == (jl_value_t*)var) return t; - int nothrow = jl_is_typevar(val) ? 0 : 1; + nothrow = jl_is_typevar(val) ? 0 : nothrow; jl_typeenv_t env = { var, val, NULL }; return inst_type_w_(t, &env, NULL, 1, nothrow); } @@ -1694,7 +1694,7 @@ void jl_precompute_memoized_dt(jl_datatype_t *dt, int cacheable) dt->hash = typekey_hash(dt->name, jl_svec_data(dt->parameters), l, cacheable); } -static void check_datatype_parameters(jl_typename_t *tn, jl_value_t **params, size_t np) +static int check_datatype_parameters(jl_typename_t *tn, jl_value_t **params, size_t np, int nothrow) { jl_value_t *wrapper = tn->wrapper; jl_value_t **bounds; @@ -1712,6 +1712,10 @@ static void check_datatype_parameters(jl_typename_t *tn, jl_value_t **params, si assert(jl_is_unionall(wrapper)); jl_tvar_t *tv = ((jl_unionall_t*)wrapper)->var; if (!within_typevar(params[i], bounds[2*i], bounds[2*i+1])) { + if (nothrow) { + JL_GC_POP(); + return 1; + } if (tv->lb != bounds[2*i] || tv->ub != bounds[2*i+1]) // pass a new version of `tv` containing the instantiated bounds tv = jl_new_typevar(tv->name, bounds[2*i], bounds[2*i+1]); @@ -1721,12 +1725,26 @@ static void check_datatype_parameters(jl_typename_t *tn, jl_value_t **params, si int j; for (j = 2*i + 2; j < 2*np; j++) { jl_value_t *bj = bounds[j]; - if (bj != (jl_value_t*)jl_any_type && bj != jl_bottom_type) - bounds[j] = jl_substitute_var(bj, tv, params[i]); + if (bj != (jl_value_t*)jl_any_type && bj != jl_bottom_type) { + int isub = j & 1; + // use different nothrow level for lb and ub substitution. + // TODO: This assuming the top instantiation could only start with + // `nothrow == 2` or `nothrow == 0`. If `nothrow` is initially set to 1 + // then we might miss some inner error, perhaps the normal path should + // also follow this rule? + jl_value_t *nb = jl_substitute_var_nothrow(bj, tv, params[i], nothrow ? (isub ? 2 : 1) : 0 ); + if (nb == NULL) { + assert(nothrow); + JL_GC_POP(); + return 1; + } + bounds[j] = nb; + } } wrapper = ((jl_unionall_t*)wrapper)->body; } JL_GC_POP(); + return 0; } jl_value_t *extract_wrapper(jl_value_t *t JL_PROPAGATES_ROOT) JL_GLOBALLY_ROOTED @@ -1943,13 +1961,8 @@ static jl_value_t *inst_datatype_inner(jl_datatype_t *dt, jl_svec_t *p, jl_value // for whether this is even valid if (check && !istuple) { assert(ntp > 0); - JL_TRY { - check_datatype_parameters(tn, iparams, ntp); - } - JL_CATCH { - if (!nothrow) jl_rethrow(); + if (check_datatype_parameters(tn, iparams, ntp, nothrow)) return NULL; - } } else if (ntp == 0 && jl_emptytuple_type != NULL) { // empty tuple type case @@ -2301,7 +2314,8 @@ static jl_value_t *inst_tuple_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_ jl_value_t *elt = jl_svecref(tp, i); jl_value_t *pi = inst_type_w_(elt, env, stack, check, nothrow); if (pi == NULL) { - if (i == ntp-1 && jl_is_vararg(elt)) { + assert(nothrow); + if (nothrow == 1 || (i == ntp-1 && jl_is_vararg(elt))) { t = NULL; break; } @@ -2320,6 +2334,10 @@ static jl_value_t *inst_tuple_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_ return t; } +// `nothrow` means that when type checking fails, the type instantiation should +// return `NULL` instead of immediately throwing an error. If `nothrow` == 2 then +// we further assume that the imprecise instantiation for non invariant parameters +// is acceptable, and inner error (`NULL`) would be ignored. static jl_value_t *inst_type_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_t *stack, int check, int nothrow) { size_t i; @@ -2340,11 +2358,10 @@ static jl_value_t *inst_type_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_t jl_value_t *var = NULL; jl_value_t *newbody = NULL; JL_GC_PUSH3(&lb, &var, &newbody); - JL_TRY { - lb = inst_type_w_(ua->var->lb, env, stack, check, 0); - } - JL_CATCH { - if (!nothrow) jl_rethrow(); + // set nothrow <= 1 to ensure lb's accuracy. + lb = inst_type_w_(ua->var->lb, env, stack, check, nothrow ? 1 : 0); + if (lb == NULL) { + assert(nothrow); t = NULL; } if (t != NULL) { @@ -2368,11 +2385,9 @@ static jl_value_t *inst_type_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_t if (newbody == NULL) { t = NULL; } - else if (newbody == (jl_value_t*)jl_emptytuple_type) { - // NTuple{0} => Tuple{} can make a typevar disappear - t = (jl_value_t*)jl_emptytuple_type; - } - else if (nothrow && !jl_has_typevar(newbody, (jl_tvar_t *)var)) { + else if (!jl_has_typevar(newbody, (jl_tvar_t *)var)) { + // inner instantiation might make a typevar disappear, e.g. + // NTuple{0,T} => Tuple{} t = newbody; } else if (newbody != ua->body || var != (jl_value_t*)ua->var) { @@ -2389,16 +2404,21 @@ static jl_value_t *inst_type_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_t jl_value_t *b = NULL; JL_GC_PUSH2(&a, &b); b = inst_type_w_(u->b, env, stack, check, nothrow); + if (nothrow) { + // ensure jl_type_union nothrow. + if (a && !(jl_is_typevar(a) || jl_is_type(a))) + a = NULL; + if (b && !(jl_is_typevar(b) || jl_is_type(b))) + b = NULL; + } if (a != u->a || b != u->b) { if (!check) { // fast path for `jl_rename_unionall`. t = jl_new_struct(jl_uniontype_type, a, b); } - else if (nothrow && a == NULL) { - t = b; - } - else if (nothrow && b == NULL) { - t = a; + else if (a == NULL || b == NULL) { + assert(nothrow); + t = nothrow == 1 ? NULL : a == NULL ? b : a; } else { assert(a != NULL && b != NULL); @@ -2416,15 +2436,21 @@ static jl_value_t *inst_type_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_t JL_GC_PUSH2(&T, &N); if (v->T) { T = inst_type_w_(v->T, env, stack, check, nothrow); - if (T == NULL) - T = jl_bottom_type; - if (v->N) // This branch should never throw. - N = inst_type_w_(v->N, env, stack, check, 0); + if (T == NULL) { + if (nothrow == 2) + T = jl_bottom_type; + else + t = NULL; + } + if (t && v->N) { + // set nothrow <= 1 to ensure invariant parameter's accuracy. + N = inst_type_w_(v->N, env, stack, check, nothrow ? 1 : 0); + if (N == NULL) + t = NULL; + } } - if (T != v->T || N != v->N) { - // `Vararg` is special, we'd better handle inner error at Tuple level. + if (t && (T != v->T || N != v->N)) t = (jl_value_t*)jl_wrap_vararg(T, N, check, nothrow); - } JL_GC_POP(); return t; } @@ -2443,16 +2469,15 @@ static jl_value_t *inst_type_w_(jl_value_t *t, jl_typeenv_t *env, jl_typestack_t int bound = 0; for (i = 0; i < ntp; i++) { jl_value_t *elt = jl_svecref(tp, i); - JL_TRY { - jl_value_t *pi = inst_type_w_(elt, env, stack, check, 0); - iparams[i] = pi; - bound |= (pi != elt); - } - JL_CATCH { - if (!nothrow) jl_rethrow(); + // set nothrow <= 1 to ensure invariant parameter's accuracy. + jl_value_t *pi = inst_type_w_(elt, env, stack, check, nothrow ? 1 : 0); + if (pi == NULL) { + assert(nothrow); t = NULL; + break; } - if (t == NULL) break; + iparams[i] = pi; + bound |= (pi != elt); } // if t's parameters are not bound in the environment, return it uncopied (#9378) if (t != NULL && bound) diff --git a/src/julia_internal.h b/src/julia_internal.h index 5bd101d35d20c..ca8038c3f3f20 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -711,7 +711,7 @@ JL_DLLEXPORT int jl_type_morespecific_no_subtype(jl_value_t *a, jl_value_t *b); jl_value_t *jl_instantiate_type_with(jl_value_t *t, jl_value_t **env, size_t n); JL_DLLEXPORT jl_value_t *jl_instantiate_type_in_env(jl_value_t *ty, jl_unionall_t *env, jl_value_t **vals); jl_value_t *jl_substitute_var(jl_value_t *t, jl_tvar_t *var, jl_value_t *val); -jl_value_t *jl_substitute_var_nothrow(jl_value_t *t, jl_tvar_t *var, jl_value_t *val); +jl_value_t *jl_substitute_var_nothrow(jl_value_t *t, jl_tvar_t *var, jl_value_t *val, int nothrow); jl_unionall_t *jl_rename_unionall(jl_unionall_t *u); JL_DLLEXPORT jl_value_t *jl_unwrap_unionall(jl_value_t *v JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT; JL_DLLEXPORT jl_value_t *jl_rewrap_unionall(jl_value_t *t, jl_value_t *u); diff --git a/src/subtype.c b/src/subtype.c index c9ad92fff94d7..8a1ea03fdd6fd 100644 --- a/src/subtype.c +++ b/src/subtype.c @@ -2770,7 +2770,7 @@ static jl_value_t *omit_bad_union(jl_value_t *u, jl_tvar_t *t) res = jl_bottom_type; } else if (obviously_egal(var->lb, ub)) { - res = jl_substitute_var_nothrow(body, var, ub); + res = jl_substitute_var_nothrow(body, var, ub, 2); if (res == NULL) res = jl_bottom_type; } @@ -2961,9 +2961,11 @@ static jl_value_t *finish_unionall(jl_value_t *res JL_MAYBE_UNROOTED, jl_varbind } if (varval) { if (ub_has_dep) { // inner substitution has been handled - btemp->ub = jl_substitute_var_nothrow(btemp->ub, vb->var, varval); - if (btemp->ub == NULL) + jl_value_t *bub = jl_substitute_var_nothrow(btemp->ub, vb->var, varval, 2); + if (bub == NULL) res = jl_bottom_type; + else + btemp->ub = bub; } } else if (btemp->ub == (jl_value_t*)vb->var) { @@ -2998,12 +3000,12 @@ static jl_value_t *finish_unionall(jl_value_t *res JL_MAYBE_UNROOTED, jl_varbind if (varval) { // you can construct `T{x} where x` even if T's parameter is actually // limited. in that case we might get an invalid instantiation here. - res = jl_substitute_var_nothrow(res, vb->var, varval); + res = jl_substitute_var_nothrow(res, vb->var, varval, 2); // simplify chains of UnionAlls where bounds become equal while (res != NULL && jl_is_unionall(res) && obviously_egal(((jl_unionall_t*)res)->var->lb, ((jl_unionall_t*)res)->var->ub)) { jl_unionall_t * ures = (jl_unionall_t *)res; - res = jl_substitute_var_nothrow(ures->body, ures->var, ures->var->lb); + res = jl_substitute_var_nothrow(ures->body, ures->var, ures->var->lb, 2); } if (res == NULL) res = jl_bottom_type; diff --git a/test/subtype.jl b/test/subtype.jl index d222a35d2e6c9..d3f1dc217f699 100644 --- a/test/subtype.jl +++ b/test/subtype.jl @@ -2608,13 +2608,30 @@ end #issue 54356 abstract type A54356{T<:Real} end struct B54356{T} <: A54356{T} end -let S = Tuple{Val, Val{T}} where {T}, R = Tuple{Val{Val{T}}, Val{T}} where {T} - # general parameters check +struct C54356{S,T<:Union{S,Complex{S}}} end +struct D54356{S<:Real,T} end +let S = Tuple{Val, Val{T}} where {T}, R = Tuple{Val{Val{T}}, Val{T}} where {T}, + SS = Tuple{Val, Val{T}, Val{T}} where {T}, RR = Tuple{Val{Val{T}}, Val{T}, Val{T}} where {T} + # parameters check for self @testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, Complex{B}}}, S{1}, R{1}) + # parameters check for supertype (B54356 -> A54356) @testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, B54356{B}}}, S{1}, R{1}) + # enure unused TypeVar skips the `UnionAll` wrapping + @testintersect(Tuple{Val{A}, A} where {B, A<:(Union{Val{B}, D54356{B,C}} where {C})}, S{1}, R{1}) + # invariant parameter should not get narrowed + @testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, Val{Union{Int,Complex{B}}}}}, S{1}, R{1}) + # bit value could not be `Union` element + @testintersect(Tuple{Val{A}, A, Val{B}} where {B, A<:Union{B, Val{B}}}, SS{1}, RR{1}) + @testintersect(Tuple{Val{A}, A, Val{B}} where {B, A<:Union{B, Complex{B}}}, SS{1}, Union{}) + # `check_datatype_parameters` should ignore bad `Union` elements in constraint's ub + T = Tuple{Val{Union{Val{Nothing}, Val{C54356{V,V}}}}, Val{Nothing}} where {Nothing<:V<:Nothing} + @test T <: S{Nothing} + @test T <: Tuple{Val{A}, A} where {B, C, A<:Union{Val{B}, Val{C54356{B,C}}}} + @test T <: typeintersect(Tuple{Val{A}, A} where {B, C, A<:Union{Val{B}, Val{C54356{B,C}}}}, S{Nothing}) # extra check for Vararg @testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, NTuple{B,Any}}}, S{-1}, R{-1}) @testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, Tuple{Any,Vararg{Any,B}}}}, S{-1}, R{-1}) + @testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, Tuple{Vararg{Int,Union{Int,Complex{B}}}}}}, S{1}, R{1}) # extra check for NamedTuple @testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, NamedTuple{B,Tuple{Int}}}}, S{1}, R{1}) @testintersect(Tuple{Val{A}, A} where {B, A<:Union{Val{B}, NamedTuple{B,Tuple{Int}}}}, S{(1,)}, R{(1,)})