Skip to content

Commit

Permalink
Widen diagonal var during Type unwrapping in instanceof_tfunc (Ju…
Browse files Browse the repository at this point in the history
  • Loading branch information
N5N3 authored and mkitti committed Dec 9, 2023
1 parent 8bdeceb commit b0a1fdb
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 4 deletions.
14 changes: 10 additions & 4 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,25 +95,31 @@ add_tfunc(throw, 1, 1, @nospecs((𝕃::AbstractLattice, x)->Bottom), 0)
# if isexact is false, the actual runtime type may (will) be a subtype of t
# if isconcrete is true, the actual runtime type is definitely concrete (unreachable if not valid as a typeof)
# if istype is true, the actual runtime value will definitely be a type (e.g. this is false for Union{Type{Int}, Int})
function instanceof_tfunc(@nospecialize(t), astag::Bool=false)
function instanceof_tfunc(@nospecialize(t), astag::Bool=false, @nospecialize(troot) = t)
if isa(t, Const)
if isa(t.val, Type) && valid_as_lattice(t.val, astag)
return t.val, true, isconcretetype(t.val), true
end
return Bottom, true, false, false # runtime throws on non-Type
end
t = widenconst(t)
troot = widenconst(troot)
if t === Bottom
return Bottom, true, true, false # runtime unreachable
elseif t === typeof(Bottom) || !hasintersect(t, Type)
return Bottom, true, false, false # literal Bottom or non-Type
elseif isType(t)
tp = t.parameters[1]
valid_as_lattice(tp, astag) || return Bottom, true, false, false # runtime unreachable / throws on non-Type
if troot isa UnionAll
# Free `TypeVar`s inside `Type` has violated the "diagonal" rule.
# Widen them before `UnionAll` rewraping to relax concrete constraint.
tp = widen_diagonal(tp, troot)
end
return tp, !has_free_typevars(tp), isconcretetype(tp), true
elseif isa(t, UnionAll)
t′ = unwrap_unionall(t)
t′′, isexact, isconcrete, istype = instanceof_tfunc(t′, astag)
t′′, isexact, isconcrete, istype = instanceof_tfunc(t′, astag, rewrap_unionall(t, troot))
tr = rewrap_unionall(t′′, t)
if t′′ isa DataType && t′′.name !== Tuple.name && !has_free_typevars(tr)
# a real instance must be within the declared bounds of the type,
Expand All @@ -128,8 +134,8 @@ function instanceof_tfunc(@nospecialize(t), astag::Bool=false)
end
return tr, isexact, isconcrete, istype
elseif isa(t, Union)
ta, isexact_a, isconcrete_a, istype_a = instanceof_tfunc(t.a, astag)
tb, isexact_b, isconcrete_b, istype_b = instanceof_tfunc(t.b, astag)
ta, isexact_a, isconcrete_a, istype_a = instanceof_tfunc(t.a, astag, troot)
tb, isexact_b, isconcrete_b, istype_b = instanceof_tfunc(t.b, astag, troot)
isconcrete = isconcrete_a && isconcrete_b
istype = istype_a && istype_b
# most users already handle the Union case, so here we assume that
Expand Down
5 changes: 5 additions & 0 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,11 @@ function rename_unionall(@nospecialize(u))
return UnionAll(nv, body{nv})
end

# remove concrete constraint on diagonal TypeVar if it comes from troot
function widen_diagonal(@nospecialize(t), troot::UnionAll)
body = ccall(:jl_widen_diagonal, Any, (Any, Any), t, troot)
end

function isvarargtype(@nospecialize(t))
return isa(t, Core.TypeofVararg)
end
Expand Down
206 changes: 206 additions & 0 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -4304,6 +4304,212 @@ int jl_subtype_matching(jl_value_t *a, jl_value_t *b, jl_svec_t **penv)
return sub;
}

// type utils
static void check_diagonal(jl_value_t *t, jl_varbinding_t *troot, int param)
{
if (jl_is_uniontype(t)) {
int i, len = 0;
jl_varbinding_t *v;
for (v = troot; v != NULL; v = v->prev)
len++;
int8_t *occurs = (int8_t *)alloca(len);
for (v = troot, i = 0; v != NULL; v = v->prev, i++)
occurs[i] = v->occurs_inv | (v->occurs_cov << 2);
check_diagonal(((jl_uniontype_t *)t)->a, troot, param);
for (v = troot, i = 0; v != NULL; v = v->prev, i++) {
int8_t occurs_inv = occurs[i] & 3;
int8_t occurs_cov = occurs[i] >> 2;
occurs[i] = v->occurs_inv | (v->occurs_cov << 2);
v->occurs_inv = occurs_inv;
v->occurs_cov = occurs_cov;
}
check_diagonal(((jl_uniontype_t *)t)->b, troot, param);
for (v = troot, i = 0; v != NULL; v = v->prev, i++) {
if (v->occurs_inv < (occurs[i] & 3))
v->occurs_inv = occurs[i] & 3;
if (v->occurs_cov < (occurs[i] >> 2))
v->occurs_cov = occurs[i] >> 2;
}
}
else if (jl_is_unionall(t)) {
assert(troot != NULL);
jl_varbinding_t *v1 = troot, *v2 = troot->prev;
while (v2 != NULL) {
if (v2->var == ((jl_unionall_t *)t)->var) {
v1->prev = v2->prev;
break;
}
v1 = v2;
v2 = v2->prev;
}
check_diagonal(((jl_unionall_t *)t)->body, troot, param);
v1->prev = v2;
}
else if (jl_is_datatype(t)) {
int nparam = jl_is_tuple_type(t) ? 1 : 2;
if (nparam < param) nparam = param;
for (size_t i = 0; i < jl_nparams(t); i++) {
check_diagonal(jl_tparam(t, i), troot, nparam);
}
}
else if (jl_is_vararg(t)) {
jl_value_t *T = jl_unwrap_vararg(t);
jl_value_t *N = jl_unwrap_vararg_num(t);
int n = (N && jl_is_long(N)) ? jl_unbox_long(N) : 2;
if (T && n > 0) check_diagonal(T, troot, param);
if (T && n > 1) check_diagonal(T, troot, param);
if (N) check_diagonal(N, troot, 2);
}
else if (jl_is_typevar(t)) {
jl_varbinding_t *v = troot;
for (; v != NULL; v = v->prev) {
if (v->var == (jl_tvar_t *)t) {
if (param == 1 && v->occurs_cov < 2) v->occurs_cov++;
if (param == 2 && v->occurs_inv < 2) v->occurs_inv++;
break;
}
}
if (v == NULL)
check_diagonal(((jl_tvar_t *)t)->ub, troot, 0);
}
}

static jl_value_t *insert_nondiagonal(jl_value_t *type, jl_varbinding_t *troot, int widen2ub)
{
if (jl_is_typevar(type)) {
int concretekind = widen2ub > 1 ? 0 : 1;
jl_varbinding_t *v = troot;
for (; v != NULL; v = v->prev) {
if (v->occurs_inv == 0 &&
v->occurs_cov > concretekind &&
v->var == (jl_tvar_t *)type)
break;
}
if (v != NULL) {
if (widen2ub) {
type = insert_nondiagonal(((jl_tvar_t *)type)->ub, troot, 2);
}
else {
// we must replace each covariant occurrence of newvar with a different newvar2<:newvar (diagonal rule)
if (v->innervars == NULL)
v->innervars = jl_alloc_array_1d(jl_array_any_type, 0);
jl_value_t *newvar = NULL, *lb = v->var->lb, *ub = (jl_value_t *)v->var;
jl_array_t *innervars = v->innervars;
JL_GC_PUSH4(&newvar, &lb, &ub, &innervars);
newvar = (jl_value_t *)jl_new_typevar(v->var->name, lb, ub);
jl_array_ptr_1d_push(innervars, newvar);
JL_GC_POP();
type = newvar;
}
}
}
else if (jl_is_unionall(type)) {
jl_value_t *body = ((jl_unionall_t*)type)->body;
jl_tvar_t *var = ((jl_unionall_t*)type)->var;
jl_varbinding_t *v = troot;
for (; v != NULL; v = v->prev) {
if (v->var == var)
break;
}
if (v) v->var = NULL; // Temporarily remove `type->var` from binding list.
jl_value_t *newbody = insert_nondiagonal(body, troot, widen2ub);
if (v) v->var = var; // And restore it after inner insertation.
jl_value_t *newvar = NULL;
JL_GC_PUSH2(&newbody, &newvar);
if (body == newbody || jl_has_typevar(newbody, var)) {
if (body != newbody)
newbody = jl_new_struct(jl_unionall_type, var, newbody);
// n.b. we do not widen lb, since that would be the wrong direction
newvar = insert_nondiagonal(var->ub, troot, widen2ub);
if (newvar != var->ub) {
newvar = (jl_value_t*)jl_new_typevar(var->name, var->lb, newvar);
newbody = jl_apply_type1(newbody, newvar);
newbody = jl_type_unionall((jl_tvar_t*)newvar, newbody);
}
}
type = newbody;
JL_GC_POP();
}
else if (jl_is_uniontype(type)) {
jl_value_t *a = ((jl_uniontype_t*)type)->a;
jl_value_t *b = ((jl_uniontype_t*)type)->b;
jl_value_t *newa = NULL;
jl_value_t *newb = NULL;
JL_GC_PUSH2(&newa, &newb);
newa = insert_nondiagonal(a, troot, widen2ub);
newb = insert_nondiagonal(b, troot, widen2ub);
if (newa != a || newb != b)
type = simple_union(newa, newb);
JL_GC_POP();
}
else if (jl_is_vararg(type)) {
// As for Vararg we'd better widen it's var to ub as otherwise they are still diagonal
jl_value_t *t = jl_unwrap_vararg(type);
jl_value_t *n = jl_unwrap_vararg_num(type);
if (widen2ub == 0)
widen2ub = !(n && jl_is_long(n)) || jl_unbox_long(n) > 1;
jl_value_t *newt;
JL_GC_PUSH2(&newt, &n);
newt = insert_nondiagonal(t, troot, widen2ub);
if (t != newt)
type = (jl_value_t *)jl_wrap_vararg(newt, n, 0);
JL_GC_POP();
}
else if (jl_is_datatype(type)) {
if (jl_is_tuple_type(type)) {
jl_svec_t *newparams = NULL;
jl_value_t *newelt = NULL;
JL_GC_PUSH2(&newparams, &newelt);
for (size_t i = 0; i < jl_nparams(type); i++) {
jl_value_t *elt = jl_tparam(type, i);
newelt = insert_nondiagonal(elt, troot, widen2ub);
if (elt != newelt) {
if (!newparams)
newparams = jl_svec_copy(((jl_datatype_t*)type)->parameters);
jl_svecset(newparams, i, newelt);
}
}
if (newparams)
type = (jl_value_t*)jl_apply_tuple_type(newparams, 1);
JL_GC_POP();
}
}
return type;
}

static jl_value_t *_widen_diagonal(jl_value_t *t, jl_varbinding_t *troot) {
check_diagonal(t, troot, 0);
int any_concrete = 0;
for (jl_varbinding_t *v = troot; v != NULL; v = v->prev)
any_concrete |= v->occurs_cov > 1 && v->occurs_inv == 0;
if (!any_concrete)
return t; // no diagonal
return insert_nondiagonal(t, troot, 0);
}

static jl_value_t *widen_diagonal(jl_value_t *t, jl_unionall_t *u, jl_varbinding_t *troot)
{
jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot };
jl_value_t *nt;
JL_GC_PUSH2(&vb.innervars, &nt);
if (jl_is_unionall(u->body))
nt = widen_diagonal(t, (jl_unionall_t *)u->body, &vb);
else
nt = _widen_diagonal(t, &vb);
if (vb.innervars != NULL) {
for (size_t i = 0; i < jl_array_nrows(vb.innervars); i++) {
jl_tvar_t *var = (jl_tvar_t*)jl_array_ptr_ref(vb.innervars, i);
nt = jl_type_unionall(var, nt);
}
}
JL_GC_POP();
return nt;
}

JL_DLLEXPORT jl_value_t *jl_widen_diagonal(jl_value_t *t, jl_unionall_t *ua)
{
return widen_diagonal(t, ua, NULL);
}

// specificity comparison

Expand Down
13 changes: 13 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5593,3 +5593,16 @@ end |> only === Float64
@test Base.infer_exception_type(c::Bool -> c ? 1 : 2) == Union{}
@test Base.infer_exception_type(c::Missing -> c ? 1 : 2) == TypeError
@test Base.infer_exception_type(c::Any -> c ? 1 : 2) == TypeError

# Issue #52168
f52168(x, t::Type) = x::NTuple{2, Base.inferencebarrier(t)::Type}
@test f52168((1, 2.), Any) === (1, 2.)

# Issue #27031
let x = 1, _Any = Any
@noinline bar27031(tt::Tuple{T,T}, ::Type{Val{T}}) where {T} = notsame27031(tt)
@noinline notsame27031(tt::Tuple{T, T}) where {T} = error()
@noinline notsame27031(tt::Tuple{T, S}) where {T, S} = "OK"
foo27031() = bar27031((x, 1.0), Val{_Any})
@test foo27031() == "OK"
end
11 changes: 11 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8059,3 +8059,14 @@ check_globalref_lowering() = @insert_global
let src = code_lowered(check_globalref_lowering)[1]
@test length(src.code) == 2
end

# Test correctness of widen_diagonal
let widen_diagonal(x::UnionAll) = Base.rewrap_unionall(Base.widen_diagonal(Base.unwrap_unionall(x), x), x),
check_widen_diagonal(x, y) = !<:(x, y) && x <: widen_diagonal(y)
@test Tuple{Int,Float64} <: widen_diagonal(NTuple)
@test Tuple{Int,Float64} <: widen_diagonal(Tuple{T,T} where {T})
@test Tuple{Real,Int,Float64} <: widen_diagonal(Tuple{S,Vararg{T}} where {S, T<:S})
@test Tuple{Int,Int,Float64,Float64} <: widen_diagonal(Tuple{S,S,Vararg{T}} where {S, T<:S})
@test Union{Tuple{T}, Tuple{T,Int}} where {T} === widen_diagonal(Union{Tuple{T}, Tuple{T,Int}} where {T})
@test Tuple === widen_diagonal(Union{Tuple{Vararg{S}}, Tuple{Vararg{T}}} where {S, T})
end

0 comments on commit b0a1fdb

Please sign in to comment.