From 0c240733b9144ce19013b259c3c8ca9492b32682 Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Fri, 7 Apr 2023 14:26:54 -0400 Subject: [PATCH] slightly optimize has_free_typevars (#49278) Manually convert these to tail-recursive form, so the stack can be unwound directly as soon as it finds an answer in many common cases (DataType with many simple UnionAll wrappers). --- src/jltypes.c | 451 +++++++++++++++++++++++++++++--------------------- 1 file changed, 260 insertions(+), 191 deletions(-) diff --git a/src/jltypes.c b/src/jltypes.c index 918ac3b8292dd..dc58cb57bfde6 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -37,85 +37,119 @@ static int typeenv_has(jl_typeenv_t *env, jl_tvar_t *v) JL_NOTSAFEPOINT return 0; } -static int layout_uses_free_typevars(jl_value_t *v, jl_typeenv_t *env) +static int typeenv_has_ne(jl_typeenv_t *env, jl_tvar_t *v) JL_NOTSAFEPOINT { - if (jl_typeis(v, jl_tvar_type)) - return !typeenv_has(env, (jl_tvar_t*)v); - if (jl_is_uniontype(v)) - return layout_uses_free_typevars(((jl_uniontype_t*)v)->a, env) || - layout_uses_free_typevars(((jl_uniontype_t*)v)->b, env); - if (jl_is_vararg(v)) { - jl_vararg_t *vm = (jl_vararg_t*)v; - if (vm->T && layout_uses_free_typevars(vm->T, env)) - return 1; - if (vm->N && layout_uses_free_typevars(vm->N, env)) - return 1; - return 0; - } - if (jl_is_unionall(v)) { - jl_unionall_t *ua = (jl_unionall_t*)v; - jl_typeenv_t newenv = { ua->var, NULL, env }; - return layout_uses_free_typevars(ua->body, &newenv); + while (env != NULL) { + if (env->var == v) + return env->val != (jl_value_t*)v; // consider it actually not present if it is bound to itself unchanging + env = env->prev; } - if (jl_is_datatype(v)) { - jl_datatype_t *dt = (jl_datatype_t*)v; - if (dt->layout || dt->isconcretetype || !dt->name->mayinlinealloc) - return 0; - if (dt->name == jl_namedtuple_typename) - return layout_uses_free_typevars(jl_tparam0(dt), env) || layout_uses_free_typevars(jl_tparam1(dt), env); - if (dt->name == jl_tuple_typename) - // conservative, since we don't want to inline an abstract tuple, - // and we currently declare !has_fixed_layout for these, but that - // means we also won't be able to inline a tuple which is concrete - // except for the use of free type-vars - return 1; - jl_svec_t *types = jl_get_fieldtypes(dt); - size_t i, l = jl_svec_len(types); - for (i = 0; i < l; i++) { - jl_value_t *ft = jl_svecref(types, i); - if (layout_uses_free_typevars(ft, env)) { - // This might be inline-alloc, but we don't know the layout + return 0; +} + + +static int layout_uses_free_typevars(jl_value_t *v, jl_typeenv_t *env) +{ + while (1) { + if (jl_typeis(v, jl_tvar_type)) + return !typeenv_has(env, (jl_tvar_t*)v); + while (jl_is_unionall(v)) { + jl_unionall_t *ua = (jl_unionall_t*)v; + jl_typeenv_t *newenv = (jl_typeenv_t*)alloca(sizeof(jl_typeenv_t)); + newenv->var = ua->var; + newenv->val = NULL; + newenv->prev = env; + env = newenv; + v = ua->body; + } + if (jl_is_datatype(v)) { + jl_datatype_t *dt = (jl_datatype_t*)v; + if (dt->layout || dt->isconcretetype || !dt->name->mayinlinealloc) + return 0; + if (dt->name == jl_namedtuple_typename) + return layout_uses_free_typevars(jl_tparam0(dt), env) || layout_uses_free_typevars(jl_tparam1(dt), env); + if (dt->name == jl_tuple_typename) + // conservative, since we don't want to inline an abstract tuple, + // and we currently declare !has_fixed_layout for these, but that + // means we also won't be able to inline a tuple which is concrete + // except for the use of free type-vars return 1; + jl_svec_t *types = jl_get_fieldtypes(dt); + size_t i, l = jl_svec_len(types); + for (i = 0; i < l; i++) { + jl_value_t *ft = jl_svecref(types, i); + if (layout_uses_free_typevars(ft, env)) + // This might be inline-alloc, but we don't know the layout + return 1; } + return 0; + } + else if (jl_is_uniontype(v)) { + if (layout_uses_free_typevars(((jl_uniontype_t*)v)->a, env)) + return 1; + v = ((jl_uniontype_t*)v)->b; + } + else if (jl_is_vararg(v)) { + jl_vararg_t *vm = (jl_vararg_t*)v; + if (!vm->T) + return 0; + if (vm->N && layout_uses_free_typevars(vm->N, env)) + return 1; + v = vm->T; + } + else { + return 0; } } - return 0; } static int has_free_typevars(jl_value_t *v, jl_typeenv_t *env) JL_NOTSAFEPOINT { - if (jl_typeis(v, jl_tvar_type)) { - return !typeenv_has(env, (jl_tvar_t*)v); - } - if (jl_is_uniontype(v)) - return has_free_typevars(((jl_uniontype_t*)v)->a, env) || - has_free_typevars(((jl_uniontype_t*)v)->b, env); - if (jl_is_vararg(v)) { - jl_vararg_t *vm = (jl_vararg_t*)v; - if (vm->T) { - if (has_free_typevars(vm->T, env)) - return 1; - return vm->N && has_free_typevars(vm->N, env); + while (1) { + if (jl_typeis(v, jl_tvar_type)) { + return !typeenv_has(env, (jl_tvar_t*)v); } - } - if (jl_is_unionall(v)) { - jl_unionall_t *ua = (jl_unionall_t*)v; - jl_typeenv_t newenv = { ua->var, NULL, env }; - return has_free_typevars(ua->var->lb, env) || has_free_typevars(ua->var->ub, env) || - has_free_typevars(ua->body, &newenv); - } - if (jl_is_datatype(v)) { - int expect = ((jl_datatype_t*)v)->hasfreetypevars; - if (expect == 0 || env == NULL) - return expect; - size_t i; - for (i = 0; i < jl_nparams(v); i++) { - if (has_free_typevars(jl_tparam(v, i), env)) { + while (jl_is_unionall(v)) { + jl_unionall_t *ua = (jl_unionall_t*)v; + if (ua->var->lb != jl_bottom_type && has_free_typevars(ua->var->lb, env)) + return 1; + if (ua->var->ub != (jl_value_t*)jl_any_type && has_free_typevars(ua->var->ub, env)) return 1; + jl_typeenv_t *newenv = (jl_typeenv_t*)alloca(sizeof(jl_typeenv_t)); + newenv->var = ua->var; + newenv->val = NULL; + newenv->prev = env; + env = newenv; + v = ua->body; + } + if (jl_is_datatype(v)) { + int expect = ((jl_datatype_t*)v)->hasfreetypevars; + if (expect == 0 || env == NULL) + return expect; + size_t i; + for (i = 0; i < jl_nparams(v); i++) { + if (has_free_typevars(jl_tparam(v, i), env)) + return 1; } + return 0; + } + else if (jl_is_uniontype(v)) { + if (has_free_typevars(((jl_uniontype_t*)v)->a, env)) + return 1; + v = ((jl_uniontype_t*)v)->b; + } + else if (jl_is_vararg(v)) { + jl_vararg_t *vm = (jl_vararg_t*)v; + if (!vm->T) + return 0; + if (vm->N && has_free_typevars(vm->N, env)) + return 1; + v = vm->T; + } + else { + return 0; } } - return 0; } JL_DLLEXPORT int jl_has_free_typevars(jl_value_t *v) JL_NOTSAFEPOINT @@ -125,36 +159,48 @@ JL_DLLEXPORT int jl_has_free_typevars(jl_value_t *v) JL_NOTSAFEPOINT static void find_free_typevars(jl_value_t *v, jl_typeenv_t *env, jl_array_t *out) { - if (jl_typeis(v, jl_tvar_type)) { - if (!typeenv_has(env, (jl_tvar_t*)v)) - jl_array_ptr_1d_push(out, v); - } - else if (jl_is_uniontype(v)) { - find_free_typevars(((jl_uniontype_t*)v)->a, env, out); - find_free_typevars(((jl_uniontype_t*)v)->b, env, out); - } - else if (jl_is_vararg(v)) { - jl_vararg_t *vm = (jl_vararg_t *)v; - if (vm->T) { - find_free_typevars(vm->T, env, out); - if (vm->N) { + while (1) { + if (jl_typeis(v, jl_tvar_type)) { + if (!typeenv_has(env, (jl_tvar_t*)v)) + jl_array_ptr_1d_push(out, v); + return; + } + while (jl_is_unionall(v)) { + jl_unionall_t *ua = (jl_unionall_t*)v; + if (ua->var->lb != jl_bottom_type) + find_free_typevars(ua->var->lb, env, out); + if (ua->var->ub != (jl_value_t*)jl_any_type) + find_free_typevars(ua->var->ub, env, out); + jl_typeenv_t *newenv = (jl_typeenv_t*)alloca(sizeof(jl_typeenv_t)); + newenv->var = ua->var; + newenv->val = NULL; + newenv->prev = env; + env = newenv; + v = ua->body; + } + if (jl_is_datatype(v)) { + if (!((jl_datatype_t*)v)->hasfreetypevars) + return; + size_t i; + for (i = 0; i < jl_nparams(v); i++) + find_free_typevars(jl_tparam(v, i), env, out); + return; + } + else if (jl_is_uniontype(v)) { + find_free_typevars(((jl_uniontype_t*)v)->a, env, out); + v = ((jl_uniontype_t*)v)->b; + } + else if (jl_is_vararg(v)) { + jl_vararg_t *vm = (jl_vararg_t *)v; + if (!vm->T) + return; + if (vm->N) // this swap the visited order, but we don't mind it find_free_typevars(vm->N, env, out); - } + v = vm->T; } - } - else if (jl_is_unionall(v)) { - jl_unionall_t *ua = (jl_unionall_t*)v; - jl_typeenv_t newenv = { ua->var, NULL, env }; - find_free_typevars(ua->var->lb, env, out); - find_free_typevars(ua->var->ub, env, out); - find_free_typevars(ua->body, &newenv, out); - } - else if (jl_is_datatype(v)) { - if (!((jl_datatype_t*)v)->hasfreetypevars) + else { return; - size_t i; - for (i=0; i < jl_nparams(v); i++) - find_free_typevars(jl_tparam(v,i), env, out); + } } } @@ -170,41 +216,55 @@ JL_DLLEXPORT jl_array_t *jl_find_free_typevars(jl_value_t *v) // test whether a type has vars bound by the given environment static int jl_has_bound_typevars(jl_value_t *v, jl_typeenv_t *env) JL_NOTSAFEPOINT { - if (jl_typeis(v, jl_tvar_type)) - return typeenv_has(env, (jl_tvar_t*)v); - if (jl_is_uniontype(v)) - return jl_has_bound_typevars(((jl_uniontype_t*)v)->a, env) || - jl_has_bound_typevars(((jl_uniontype_t*)v)->b, env); - if (jl_is_vararg(v)) { - jl_vararg_t *vm = (jl_vararg_t *)v; - return vm->T && (jl_has_bound_typevars(vm->T, env) || - (vm->N && jl_has_bound_typevars(vm->N, env))); - } - if (jl_is_unionall(v)) { - jl_unionall_t *ua = (jl_unionall_t*)v; - if (jl_has_bound_typevars(ua->var->lb, env) || jl_has_bound_typevars(ua->var->ub, env)) - return 1; - jl_typeenv_t *te = env; - while (te != NULL) { - if (te->var == ua->var) - break; - te = te->prev; + while (1) { + if (jl_typeis(v, jl_tvar_type)) { + return typeenv_has_ne(env, (jl_tvar_t*)v); } - if (te) te->var = NULL; // temporarily remove this var from env - int ans = jl_has_bound_typevars(ua->body, env); - if (te) te->var = ua->var; - return ans; - } - if (jl_is_datatype(v)) { - if (!((jl_datatype_t*)v)->hasfreetypevars) + while (jl_is_unionall(v)) { + jl_unionall_t *ua = (jl_unionall_t*)v; + if (ua->var->lb != jl_bottom_type && jl_has_bound_typevars(ua->var->lb, env)) + return 1; + if (ua->var->ub != (jl_value_t*)jl_any_type && jl_has_bound_typevars(ua->var->ub, env)) + return 1; + // Temporarily remove this var from env if necessary + // Note that te might be bound more than once in the env, so + // we remove it by setting it to itself in a new env. + if (typeenv_has_ne(env, ua->var)) { + jl_typeenv_t *newenv = (jl_typeenv_t*)alloca(sizeof(jl_typeenv_t)); + newenv->var = ua->var; + newenv->val = (jl_value_t*)ua->var; + newenv->prev = env; + env = newenv; + } + v = ua->body; + } + if (jl_is_datatype(v)) { + if (!((jl_datatype_t*)v)->hasfreetypevars) + return 0; + size_t i; + for (i = 0; i < jl_nparams(v); i++) { + if (jl_has_bound_typevars(jl_tparam(v, i), env)) + return 1; + } return 0; - size_t i; - for (i=0; i < jl_nparams(v); i++) { - if (jl_has_bound_typevars(jl_tparam(v,i), env)) + } + else if (jl_is_uniontype(v)) { + if (jl_has_bound_typevars(((jl_uniontype_t*)v)->a, env)) + return 1; + v = ((jl_uniontype_t*)v)->b; + } + else if (jl_is_vararg(v)) { + jl_vararg_t *vm = (jl_vararg_t *)v; + if (!vm->T) + return 0; + if (vm->N && jl_has_bound_typevars(vm->N, env)) return 1; + v = vm->T; + } + else { + return 0; } } - return 0; } JL_DLLEXPORT int jl_has_typevar(jl_value_t *t, jl_tvar_t *v) JL_NOTSAFEPOINT @@ -283,26 +343,28 @@ JL_DLLEXPORT int jl_get_size(jl_value_t *val, size_t *pnt) static int count_union_components(jl_value_t **types, size_t n) { - size_t i, c=0; - for(i=0; i < n; i++) { + size_t i, c = 0; + for (i = 0; i < n; i++) { jl_value_t *e = types[i]; - if (jl_is_uniontype(e)) { + while (jl_is_uniontype(e)) { jl_uniontype_t *u = (jl_uniontype_t*)e; c += count_union_components(&u->a, 1); - c += count_union_components(&u->b, 1); - } - else { - c++; + e = u->b; } + c++; } return c; } int jl_count_union_components(jl_value_t *v) { - if (!jl_is_uniontype(v)) return 1; - jl_uniontype_t *u = (jl_uniontype_t*)v; - return jl_count_union_components(u->a) + jl_count_union_components(u->b); + size_t c = 0; + while (jl_is_uniontype(v)) { + jl_uniontype_t *u = (jl_uniontype_t*)v; + c += jl_count_union_components(u->a); + v = u->b; + } + return c + 1; } // Return the `*pi`th element of a nested type union, according to a @@ -310,16 +372,16 @@ int jl_count_union_components(jl_value_t *v) // considered an "element". `*pi` is destroyed in the process. static jl_value_t *nth_union_component(jl_value_t *v, int *pi) JL_NOTSAFEPOINT { - if (!jl_is_uniontype(v)) { - if (*pi == 0) - return v; - (*pi)--; - return NULL; + while (jl_is_uniontype(v)) { + jl_uniontype_t *u = (jl_uniontype_t*)v; + jl_value_t *a = nth_union_component(u->a, pi); + if (a) return a; + v = u->b; } - jl_uniontype_t *u = (jl_uniontype_t*)v; - jl_value_t *a = nth_union_component(u->a, pi); - if (a) return a; - return nth_union_component(u->b, pi); + if (*pi == 0) + return v; + (*pi)--; + return NULL; } jl_value_t *jl_nth_union_component(jl_value_t *v, int i) JL_NOTSAFEPOINT @@ -330,12 +392,11 @@ jl_value_t *jl_nth_union_component(jl_value_t *v, int i) JL_NOTSAFEPOINT // inverse of jl_nth_union_component int jl_find_union_component(jl_value_t *haystack, jl_value_t *needle, unsigned *nth) JL_NOTSAFEPOINT { - if (jl_is_uniontype(haystack)) { - if (jl_find_union_component(((jl_uniontype_t*)haystack)->a, needle, nth)) - return 1; - if (jl_find_union_component(((jl_uniontype_t*)haystack)->b, needle, nth)) + while (jl_is_uniontype(haystack)) { + jl_uniontype_t *u = (jl_uniontype_t*)haystack; + if (jl_find_union_component(u->a, needle, nth)) return 1; - return 0; + haystack = u->b; } if (needle == haystack) return 1; @@ -346,17 +407,15 @@ int jl_find_union_component(jl_value_t *haystack, jl_value_t *needle, unsigned * static void flatten_type_union(jl_value_t **types, size_t n, jl_value_t **out, size_t *idx) JL_NOTSAFEPOINT { size_t i; - for(i=0; i < n; i++) { + for (i = 0; i < n; i++) { jl_value_t *e = types[i]; - if (jl_is_uniontype(e)) { + while (jl_is_uniontype(e)) { jl_uniontype_t *u = (jl_uniontype_t*)e; flatten_type_union(&u->a, 1, out, idx); - flatten_type_union(&u->b, 1, out, idx); - } - else { - out[*idx] = e; - (*idx)++; + e = u->b; } + out[*idx] = e; + (*idx)++; } } @@ -1168,6 +1227,8 @@ jl_unionall_t *jl_rename_unionall(jl_unionall_t *u) jl_value_t *jl_substitute_var(jl_value_t *t, jl_tvar_t *var, jl_value_t *val) { + if (val == (jl_value_t*)var) + return t; jl_typeenv_t env = { var, val, NULL }; return inst_type_w_(t, &env, NULL, 1); } @@ -1421,45 +1482,54 @@ static jl_value_t *extract_wrapper(jl_value_t *t JL_PROPAGATES_ROOT) JL_GLOBALLY int _may_substitute_ub(jl_value_t *v, jl_tvar_t *var, int inside_inv, int *cov_count) JL_NOTSAFEPOINT { - if (v == (jl_value_t*)var) { - if (inside_inv) { - return 0; + while (1) { + if (v == (jl_value_t*)var) { + if (inside_inv) { + return 0; + } + else { + (*cov_count)++; + return *cov_count <= 1 || jl_is_concrete_type(var->ub); + } } - else { - (*cov_count)++; - return *cov_count <= 1 || jl_is_concrete_type(var->ub); + while (jl_is_unionall(v)) { + jl_unionall_t *ua = (jl_unionall_t*)v; + if (ua->var == var) + return 1; + if (ua->var->lb != jl_bottom_type && !_may_substitute_ub(ua->var->lb, var, inside_inv, cov_count)) + return 0; + if (ua->var->ub != (jl_value_t*)jl_any_type && !_may_substitute_ub(ua->var->ub, var, inside_inv, cov_count)) + return 0; + v = ua->body; } - } - else if (jl_is_uniontype(v)) { - return _may_substitute_ub(((jl_uniontype_t*)v)->a, var, inside_inv, cov_count) && - _may_substitute_ub(((jl_uniontype_t*)v)->b, var, inside_inv, cov_count); - } - else if (jl_is_unionall(v)) { - jl_unionall_t *ua = (jl_unionall_t*)v; - if (ua->var == var) + if (jl_is_datatype(v)) { + int invar = inside_inv || !jl_is_tuple_type(v); + for (size_t i = 0; i < jl_nparams(v); i++) { + if (!_may_substitute_ub(jl_tparam(v, i), var, invar, cov_count)) + return 0; + } return 1; - return _may_substitute_ub(ua->var->lb, var, inside_inv, cov_count) && - _may_substitute_ub(ua->var->ub, var, inside_inv, cov_count) && - _may_substitute_ub(ua->body, var, inside_inv, cov_count); - } - else if (jl_is_datatype(v)) { - int invar = inside_inv || !jl_is_tuple_type(v); - for (size_t i = 0; i < jl_nparams(v); i++) { - if (!_may_substitute_ub(jl_tparam(v,i), var, invar, cov_count)) + } + else if (jl_is_uniontype(v)) { + // TODO: is !inside_inv, these don't have to share the changes to cov_count + if (!_may_substitute_ub(((jl_uniontype_t*)v)->a, var, inside_inv, cov_count)) return 0; + v = ((jl_uniontype_t*)v)->b; + } + else if (jl_is_vararg(v)) { + jl_vararg_t *va = (jl_vararg_t*)v; + if (!va->T) + return 1; + if (va->N && !_may_substitute_ub(va->N, var, 1, cov_count)) + return 0; + if (!jl_is_concrete_type(var->ub)) + inside_inv = 1; // treat as invariant inside vararg, for the sake of this algorithm + v = va->T; + } + else { + return 1; } } - else if (jl_is_vararg(v)) { - jl_vararg_t *va = (jl_vararg_t*)v; - int old_count = *cov_count; - if (va->T && !_may_substitute_ub(va->T, var, inside_inv, cov_count)) - return 0; - if (*cov_count > old_count && !jl_is_concrete_type(var->ub)) - return 0; - if (va->N && !_may_substitute_ub(va->N, var, 1, cov_count)) - return 0; - } - return 1; } // Check whether `var` may be replaced with its upper bound `ub` in `v where var<:ub` @@ -1475,7 +1545,6 @@ int may_substitute_ub(jl_value_t *v, jl_tvar_t *var) JL_NOTSAFEPOINT jl_value_t *normalize_unionalls(jl_value_t *t) { - JL_GC_PUSH1(&t); if (jl_is_uniontype(t)) { jl_uniontype_t *u = (jl_uniontype_t*)t; jl_value_t *a = NULL; @@ -1491,14 +1560,14 @@ jl_value_t *normalize_unionalls(jl_value_t *t) else if (jl_is_unionall(t)) { jl_unionall_t *u = (jl_unionall_t*)t; jl_value_t *body = normalize_unionalls(u->body); + JL_GC_PUSH1(&body); if (body != u->body) { - JL_GC_PUSH1(&body); t = jl_new_struct(jl_unionall_type, u->var, body); - JL_GC_POP(); u = (jl_unionall_t*)t; } if (u->var->lb == u->var->ub || may_substitute_ub(body, u->var)) { + body = (jl_value_t*)u; JL_TRY { t = jl_instantiate_unionall(u, u->var->ub); } @@ -1507,8 +1576,8 @@ jl_value_t *normalize_unionalls(jl_value_t *t) // (may happen for bounds inconsistent with the wrapper's bounds) } } + JL_GC_POP(); } - JL_GC_POP(); return t; } @@ -1588,9 +1657,9 @@ static jl_value_t *inst_datatype_inner(jl_datatype_t *dt, jl_svec_t *p, jl_value } jl_datatype_t *ndt = NULL; - jl_value_t *last = iparams[ntp - 1]; - JL_GC_PUSH3(&p, &ndt, &last); + JL_GC_PUSH2(&p, &ndt); + jl_value_t *last = iparams[ntp - 1]; if (istuple && ntp > 0 && jl_is_vararg(last)) { // normalize Tuple{..., Vararg{Int, 3}} to Tuple{..., Int, Int, Int} jl_value_t *va = jl_unwrap_unionall(last);