Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand more Vararg elements during re-intersection if valid. #46604

Merged
merged 1 commit into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 96 additions & 51 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ typedef struct jl_varbinding_t {
int8_t occurs_inv; // occurs in invariant position
int8_t occurs_cov; // # of occurrences in covariant position
int8_t concrete; // 1 if another variable has a constraint forcing this one to be concrete
int8_t max_offset; // record the maximum positive offset of the variable (up to 32)
// max_offset < 0 if this variable occurs outside VarargNum.
// constraintkind: in covariant position, we try three different ways to compute var ∩ type:
// let ub = var.ub ∩ type
// 0 - var.ub <: type ? var : ub
Expand All @@ -77,6 +79,7 @@ typedef struct jl_varbinding_t {
int8_t constraintkind;
int8_t intvalued; // intvalued: must be integer-valued; i.e. occurs as N in Vararg{_,N}
int8_t limited;
int8_t intersected; // whether this variable has been intersected
int16_t depth0; // # of invariant constructors nested around the UnionAll type for this var
// array of typevars that our bounds depend on, whose UnionAlls need to be
// moved outside ours.
Expand Down Expand Up @@ -168,9 +171,9 @@ static int current_env_length(jl_stenv_t *e)
typedef struct {
int8_t *buf;
int rdepth;
int8_t _space[24]; // == 8 * 3
int8_t _space[32]; // == 8 * 4
jl_gcframe_t gcframe;
jl_value_t *roots[24];
jl_value_t *roots[24]; // == 8 * 3
} jl_savedenv_t;

static void re_save_env(jl_stenv_t *e, jl_savedenv_t *se, int root)
Expand Down Expand Up @@ -200,6 +203,7 @@ static void re_save_env(jl_stenv_t *e, jl_savedenv_t *se, int root)
se->buf[j++] = v->occurs;
se->buf[j++] = v->occurs_inv;
se->buf[j++] = v->occurs_cov;
se->buf[j++] = v->max_offset;
v = v->prev;
}
assert(i == nroots); (void)nroots;
Expand Down Expand Up @@ -231,7 +235,7 @@ static void alloc_env(jl_stenv_t *e, jl_savedenv_t *se, int root)
ct->gcstack = &se->gcframe;
}
}
se->buf = (len > 8 ? (int8_t*)malloc_s(len * 3) : se->_space);
se->buf = (len > 8 ? (int8_t*)malloc_s(len * 4) : se->_space);
#ifdef __clang_gcanalyzer__
memset(se->buf, 0, len * 3);
#endif
Expand Down Expand Up @@ -281,6 +285,7 @@ static void restore_env(jl_stenv_t *e, jl_savedenv_t *se, int root) JL_NOTSAFEPO
v->occurs = se->buf[j++];
v->occurs_inv = se->buf[j++];
v->occurs_cov = se->buf[j++];
v->max_offset = se->buf[j++];
v = v->prev;
}
assert(i == nroots); (void)nroots;
Expand Down Expand Up @@ -677,6 +682,10 @@ static void record_var_occurrence(jl_varbinding_t *vb, jl_stenv_t *e, int param)
else if (vb->occurs_cov < 2) {
vb->occurs_cov++;
}
// Always set `max_offset` to `-1` during the 1st round intersection.
// Would be recovered in `intersect_varargs`/`subtype_tuple_varargs` if needed.
if (!vb->intersected)
vb->max_offset = -1;
}
}

Expand Down Expand Up @@ -888,7 +897,7 @@ static jl_unionall_t *unalias_unionall(jl_unionall_t *u, jl_stenv_t *e)
static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8_t R, int param)
{
u = unalias_unionall(u, e);
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0,
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, 0, 0,
e->invdepth, NULL, e->vars };
JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars);
e->vars = &vb;
Expand Down Expand Up @@ -1008,39 +1017,30 @@ static int subtype_tuple_varargs(
jl_value_t *xp0 = jl_unwrap_vararg(vtx); jl_value_t *xp1 = jl_unwrap_vararg_num(vtx);
jl_value_t *yp0 = jl_unwrap_vararg(vty); jl_value_t *yp1 = jl_unwrap_vararg_num(vty);

jl_varbinding_t *xlv = NULL, *ylv = NULL;
if (xp1 && jl_is_typevar(xp1))
xlv = lookup(e, (jl_tvar_t*)xp1);
if (yp1 && jl_is_typevar(yp1))
ylv = lookup(e, (jl_tvar_t*)yp1);

int8_t max_offsetx = xlv ? xlv->max_offset : 0;
int8_t max_offsety = ylv ? ylv->max_offset : 0;

jl_value_t *xl = xlv ? xlv->lb : xp1;
jl_value_t *yl = ylv ? ylv->lb : yp1;

if (!xp1) {
jl_value_t *yl = yp1;
if (yl) {
// Unconstrained on the left, constrained on the right
if (jl_is_typevar(yl)) {
jl_varbinding_t *ylv = lookup(e, (jl_tvar_t*)yl);
if (ylv)
yl = ylv->lb;
}
if (jl_is_long(yl)) {
return 0;
}
}
// Unconstrained on the left, constrained on the right
if (yl && jl_is_long(yl))
return 0;
}
else {
jl_value_t *xl = jl_unwrap_vararg_num(vtx);
if (jl_is_typevar(xl)) {
jl_varbinding_t *xlv = lookup(e, (jl_tvar_t*)xl);
if (xlv)
xl = xlv->lb;
}
if (jl_is_long(xl)) {
if (jl_unbox_long(xl) + 1 == vx) {
// LHS is exhausted. We're a subtype if the RHS is either
// exhausted as well or unbounded (in which case we need to
// set it to 0).
jl_value_t *yl = jl_unwrap_vararg_num(vty);
if (yl) {
if (jl_is_typevar(yl)) {
jl_varbinding_t *ylv = lookup(e, (jl_tvar_t*)yl);
if (ylv)
yl = ylv->lb;
}
if (jl_is_long(yl)) {
return jl_unbox_long(yl) + 1 == vy;
}
Expand Down Expand Up @@ -1090,6 +1090,8 @@ static int subtype_tuple_varargs(
// appropriately.
e->invdepth++;
int ans = subtype((jl_value_t*)jl_any_type, yp1, e, 2);
if (ylv && !ylv->intersected)
ylv->max_offset = max_offsety;
e->invdepth--;
return ans;
}
Expand Down Expand Up @@ -1130,6 +1132,10 @@ static int subtype_tuple_varargs(
e->Loffset = 0;
}
JL_GC_POP();
if (ylv && !ylv->intersected)
ylv->max_offset = max_offsety;
if (xlv && !xlv->intersected)
xlv->max_offset = max_offsetx;
e->invdepth--;
return ans;
}
Expand Down Expand Up @@ -3134,14 +3140,15 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
{
jl_value_t *res = NULL;
jl_savedenv_t se;
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0,
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, 0, 0,
e->invdepth, NULL, e->vars };
JL_GC_PUSH4(&res, &vb.lb, &vb.ub, &vb.innervars);
save_env(e, &se, 1);
int noinv = !var_occurs_invariant(u->body, u->var);
if (is_leaf_typevar(u->var) && noinv && always_occurs_cov(u->body, u->var, param))
vb.constraintkind = 1;
res = intersect_unionall_(t, u, e, R, param, &vb);
vb.intersected = 1;
if (vb.limited) {
// if the environment got too big, avoid tree recursion and propagate the flag
if (e->vars)
Expand Down Expand Up @@ -3218,17 +3225,20 @@ static jl_value_t *intersect_varargs(jl_vararg_t *vmx, jl_vararg_t *vmy, ssize_t
assert(e->Loffset == 0);
e->Loffset = offset;
jl_varbinding_t *xb = NULL, *yb = NULL;
int8_t max_offsetx = 0, max_offsety = 0;
if (xp2) {
assert(jl_is_typevar(xp2));
xb = lookup(e, (jl_tvar_t*)xp2);
if (xb) xb->intvalued = 1;
if (xb) max_offsetx = xb->max_offset;
if (!yp2)
i2 = bound_var_below((jl_tvar_t*)xp2, xb, e, 0);
}
if (yp2) {
assert(jl_is_typevar(yp2));
yb = lookup(e, (jl_tvar_t*)yp2);
if (yb) yb->intvalued = 1;
if (yb) max_offsety = yb->max_offset;
if (!xp2)
i2 = bound_var_below((jl_tvar_t*)yp2, yb, e, 1);
}
Expand All @@ -3243,14 +3253,27 @@ static jl_value_t *intersect_varargs(jl_vararg_t *vmx, jl_vararg_t *vmy, ssize_t
}
assert(e->Loffset == offset);
e->Loffset = 0;
if (i2 == jl_bottom_type)
if (i2 == jl_bottom_type) {
ii = (jl_value_t*)jl_bottom_type;
else if (xp2 && obviously_egal(xp1, ii) && obviously_egal(xp2, i2))
ii = (jl_value_t*)vmx;
else if (yp2 && obviously_egal(yp1, ii) && obviously_egal(yp2, i2))
ii = (jl_value_t*)vmy;
else
ii = (jl_value_t*)jl_wrap_vararg(ii, i2, 1);
}
else {
if (xb && !xb->intersected) {
xb->max_offset = max_offsetx;
if (offset > xb->max_offset && xb->max_offset >= 0)
xb->max_offset = offset > 32 ? 32 : offset;
}
if (yb && !yb->intersected) {
yb->max_offset = max_offsety;
if (-offset > yb->max_offset && yb->max_offset >= 0)
yb->max_offset = -offset > 32 ? 32 : -offset;
}
if (xp2 && obviously_egal(xp1, ii) && obviously_egal(xp2, i2))
ii = (jl_value_t*)vmx;
else if (yp2 && obviously_egal(yp1, ii) && obviously_egal(yp2, i2))
ii = (jl_value_t*)vmy;
else
ii = (jl_value_t*)jl_wrap_vararg(ii, i2, 1);
}
JL_GC_POP();
return ii;
}
Expand All @@ -3269,6 +3292,24 @@ static jl_value_t *intersect_tuple(jl_datatype_t *xd, jl_datatype_t *yd, jl_sten
llx += jl_unbox_long(jl_unwrap_vararg_num((jl_vararg_t *)jl_tparam(xd, lx-1))) - 1;
if (vvy == JL_VARARG_INT)
lly += jl_unbox_long(jl_unwrap_vararg_num((jl_vararg_t *)jl_tparam(yd, ly-1))) - 1;
if (vvx == JL_VARARG_BOUND && (vvy == JL_VARARG_BOUND || vvy == JL_VARARG_UNBOUND)) {
jl_value_t *xlen = jl_unwrap_vararg_num((jl_vararg_t*)jl_tparam(xd, lx-1));
assert(xlen && jl_is_typevar(xlen));
jl_varbinding_t *xb = lookup(e, (jl_tvar_t*)xlen);
if (xb && xb->intersected && xb->max_offset > 0) {
assert(xb->max_offset <= 32);
llx += xb->max_offset;
}
}
if (vvy == JL_VARARG_BOUND && (vvx == JL_VARARG_BOUND || vvx == JL_VARARG_UNBOUND)) {
jl_value_t *ylen = jl_unwrap_vararg_num((jl_vararg_t*)jl_tparam(yd, ly-1));
assert(ylen && jl_is_typevar(ylen));
jl_varbinding_t *yb = lookup(e, (jl_tvar_t*)ylen);
if (yb && yb->intersected && yb->max_offset > 0) {
assert(yb->max_offset <= 32);
lly += yb->max_offset;
}
}

if ((vvx == JL_VARARG_NONE || vvx == JL_VARARG_INT) &&
(vvy == JL_VARARG_NONE || vvy == JL_VARARG_INT)) {
Expand Down Expand Up @@ -3301,8 +3342,8 @@ static jl_value_t *intersect_tuple(jl_datatype_t *xd, jl_datatype_t *yd, jl_sten
assert(i == j && i == np);
break;
}
if (xi && jl_is_vararg(xi)) vx = vvx != JL_VARARG_INT;
if (yi && jl_is_vararg(yi)) vy = vvy != JL_VARARG_INT;
if (xi && jl_is_vararg(xi)) vx = vvx == JL_VARARG_UNBOUND || (vvx == JL_VARARG_BOUND && i == llx - 1);
if (yi && jl_is_vararg(yi)) vy = vvy == JL_VARARG_UNBOUND || (vvy == JL_VARARG_BOUND && j == lly - 1);
if (xi == NULL || yi == NULL) {
if (vx && intersect_vararg_length(xi, lly+1-llx, e, 0)) {
np = j;
Expand Down Expand Up @@ -3845,15 +3886,15 @@ static int merge_env(jl_stenv_t *e, jl_savedenv_t *se, int count)
roots = se->roots;
nroots = se->gcframe.nroots >> 2;
}
int n = 0;
int m = 0, n = 0;
jl_varbinding_t *v = e->vars;
v = e->vars;
while (v != NULL) {
if (count == 0) {
// need to initialize this
se->buf[n] = 0;
se->buf[n+1] = 0;
se->buf[n+2] = 0;
se->buf[m] = 0;
se->buf[m+1] = 0;
se->buf[m+2] = 0;
se->buf[m+3] = v->max_offset;
}
if (v->occurs) {
// only merge lb/ub/innervars if this var occurs.
Expand All @@ -3879,13 +3920,17 @@ static int merge_env(jl_stenv_t *e, jl_savedenv_t *se, int count)
roots[n+2] = b2;
}
// record the meeted vars.
se->buf[n] = 1;
se->buf[m] = 1;
}
// always merge occurs_inv/cov by max (never decrease)
if (v->occurs_inv > se->buf[n+1])
se->buf[n+1] = v->occurs_inv;
if (v->occurs_cov > se->buf[n+2])
se->buf[n+2] = v->occurs_cov;
if (v->occurs_inv > se->buf[m+1])
se->buf[m+1] = v->occurs_inv;
if (v->occurs_cov > se->buf[m+2])
se->buf[m+2] = v->occurs_cov;
// always merge max_offset by min
if (!v->intersected && v->max_offset < se->buf[m+3])
se->buf[m+3] = v->max_offset;
m = m + 4;
n = n + 3;
v = v->prev;
}
Expand Down Expand Up @@ -3917,7 +3962,7 @@ static void final_merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se)
}
assert(nroots == current_env_length(e) * 3);
assert(nroots % 3 == 0);
for (int n = 0; n < nroots; n = n + 3) {
for (int n = 0, m = 0; n < nroots; n += 3, m += 4) {
if (merged[n] == NULL)
merged[n] = saved[n];
if (merged[n+1] == NULL)
Expand All @@ -3933,7 +3978,7 @@ static void final_merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se)
else
merged[n+2] = b2;
}
me->buf[n] |= se->buf[n];
me->buf[m] |= se->buf[m];
}
}

Expand Down Expand Up @@ -4489,7 +4534,7 @@ static jl_value_t *_widen_diagonal(jl_value_t *t, jl_varbinding_t *troot) {

static jl_value_t *widen_diagonal(jl_value_t *t, jl_unionall_t *u, jl_varbinding_t *troot)
{
jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot };
jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot };
jl_value_t *nt;
JL_GC_PUSH2(&vb.innervars, &nt);
if (jl_is_unionall(u->body))
Expand Down
23 changes: 14 additions & 9 deletions test/subtype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2210,13 +2210,19 @@ let A = Tuple{NTuple{N, Int}, NTuple{N, Int}} where N,
Bs = (Tuple{Tuple{Int, Vararg{Any}}, Tuple{Int, Int, Vararg{Any}}},
Tuple{Tuple{Int, Vararg{Any,N1}}, Tuple{Int, Int, Vararg{Any,N2}}} where {N1,N2},
Tuple{Tuple{Int, Vararg{Any,N}} where {N}, Tuple{Int, Int, Vararg{Any,N}} where {N}})
Cerr = Tuple{Tuple{Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N}
C = Tuple{Tuple{Int, Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N}
for B in Bs
C = typeintersect(A, B)
@test C == typeintersect(B, A) != Union{}
@test C != Cerr
# TODO: The ideal result is Tuple{Tuple{Int, Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N}
@test_broken C != Tuple{Tuple{Int, Vararg{Int}}, Tuple{Int, Int, Vararg{Int}}}
@testintersect(A, B, C)
end
A = Tuple{NTuple{N, Int}, Tuple{Int, Vararg{Int, N}}} where N
C = Tuple{Tuple{Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N}
for B in Bs
@testintersect(A, B, C)
end
A = Tuple{Tuple{Int, Vararg{Int, N}}, NTuple{N, Int}} where N
C = Tuple{Tuple{Int, Int, Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N}
for B in Bs
@testintersect(A, B, C)
end
end

Expand All @@ -2229,9 +2235,8 @@ let A = Pair{NTuple{N, Int}, NTuple{N, Int}} where N,
Bs = (Pair{<:Tuple{Int, Vararg{Int}}, <:Tuple{Int, Int, Vararg{Int}}},
Pair{Tuple{Int, Vararg{Int,N1}}, Tuple{Int, Int, Vararg{Int,N2}}} where {N1,N2},
Pair{<:Tuple{Int, Vararg{Int,N}} where {N}, <:Tuple{Int, Int, Vararg{Int,N}} where {N}})
Cs = (Bs[2], Bs[2], Bs[3])
for (B, C) in zip(Bs, Cs)
# TODO: The ideal result is Pair{Tuple{Int, Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N}
C = Pair{Tuple{Int, Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N}
for B in Bs
@testintersect(A, B, C)
end
end
Expand Down
Loading