From a80fc029ff1fada2db597fbb93387698f216a218 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Fri, 2 Sep 2022 16:16:00 +0800 Subject: [PATCH] Expend more `Vararg` elements during re-intersection if valid. This should be valid if the type var is used only for `Vararg` length. This commit add a new field `max_offset` to our type env. which is used to tracks the minimum length of a `Vararg` during the 1st round intersection. (It's value would be set to `-1` if this var has other usage.) If we got a positive value, then we can expand more elements during the 2nd round intersection safely. With these extra elements, the offset between 2 `Vararg`s will reduce to 0 (hopefully), and thus we get a more accurate result. --- src/subtype.c | 147 +++++++++++++++++++++++++++++++----------------- test/subtype.jl | 23 +++++--- 2 files changed, 110 insertions(+), 60 deletions(-) diff --git a/src/subtype.c b/src/subtype.c index ce2572be2952d..17a6bf7041e3b 100644 --- a/src/subtype.c +++ b/src/subtype.c @@ -69,6 +69,8 @@ typedef struct jl_varbinding_t { int8_t occurs_inv; // occurs in invariant position int8_t occurs_cov; // # of occurrences in covariant position int8_t concrete; // 1 if another variable has a constraint forcing this one to be concrete + int8_t max_offset; // record the maximum positive offset of the variable (up to 32) + // max_offset < 0 if this variable occurs outside VarargNum. // constraintkind: in covariant position, we try three different ways to compute var ∩ type: // let ub = var.ub ∩ type // 0 - var.ub <: type ? var : ub @@ -77,6 +79,7 @@ typedef struct jl_varbinding_t { int8_t constraintkind; int8_t intvalued; // intvalued: must be integer-valued; i.e. occurs as N in Vararg{_,N} int8_t limited; + int8_t intersected; // whether this variable has been intersected int16_t depth0; // # of invariant constructors nested around the UnionAll type for this var // array of typevars that our bounds depend on, whose UnionAlls need to be // moved outside ours. @@ -168,9 +171,9 @@ static int current_env_length(jl_stenv_t *e) typedef struct { int8_t *buf; int rdepth; - int8_t _space[24]; // == 8 * 3 + int8_t _space[32]; // == 8 * 4 jl_gcframe_t gcframe; - jl_value_t *roots[24]; + jl_value_t *roots[24]; // == 8 * 3 } jl_savedenv_t; static void re_save_env(jl_stenv_t *e, jl_savedenv_t *se, int root) @@ -200,6 +203,7 @@ static void re_save_env(jl_stenv_t *e, jl_savedenv_t *se, int root) se->buf[j++] = v->occurs; se->buf[j++] = v->occurs_inv; se->buf[j++] = v->occurs_cov; + se->buf[j++] = v->max_offset; v = v->prev; } assert(i == nroots); (void)nroots; @@ -231,7 +235,7 @@ static void alloc_env(jl_stenv_t *e, jl_savedenv_t *se, int root) ct->gcstack = &se->gcframe; } } - se->buf = (len > 8 ? (int8_t*)malloc_s(len * 3) : se->_space); + se->buf = (len > 8 ? (int8_t*)malloc_s(len * 4) : se->_space); #ifdef __clang_gcanalyzer__ memset(se->buf, 0, len * 3); #endif @@ -281,6 +285,7 @@ static void restore_env(jl_stenv_t *e, jl_savedenv_t *se, int root) JL_NOTSAFEPO v->occurs = se->buf[j++]; v->occurs_inv = se->buf[j++]; v->occurs_cov = se->buf[j++]; + v->max_offset = se->buf[j++]; v = v->prev; } assert(i == nroots); (void)nroots; @@ -677,6 +682,10 @@ static void record_var_occurrence(jl_varbinding_t *vb, jl_stenv_t *e, int param) else if (vb->occurs_cov < 2) { vb->occurs_cov++; } + // Always set `max_offset` to `-1` during the 1st round intersection. + // Would be recovered in `intersect_varargs`/`subtype_tuple_varargs` if needed. + if (!vb->intersected) + vb->max_offset = -1; } } @@ -888,7 +897,7 @@ static jl_unionall_t *unalias_unionall(jl_unionall_t *u, jl_stenv_t *e) static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8_t R, int param) { u = unalias_unionall(u, e); - jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, + jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, 0, 0, e->invdepth, NULL, e->vars }; JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars); e->vars = &vb; @@ -1008,39 +1017,30 @@ static int subtype_tuple_varargs( jl_value_t *xp0 = jl_unwrap_vararg(vtx); jl_value_t *xp1 = jl_unwrap_vararg_num(vtx); jl_value_t *yp0 = jl_unwrap_vararg(vty); jl_value_t *yp1 = jl_unwrap_vararg_num(vty); + jl_varbinding_t *xlv = NULL, *ylv = NULL; + if (xp1 && jl_is_typevar(xp1)) + xlv = lookup(e, (jl_tvar_t*)xp1); + if (yp1 && jl_is_typevar(yp1)) + ylv = lookup(e, (jl_tvar_t*)yp1); + + int8_t max_offsetx = xlv ? xlv->max_offset : 0; + int8_t max_offsety = ylv ? ylv->max_offset : 0; + + jl_value_t *xl = xlv ? xlv->lb : xp1; + jl_value_t *yl = ylv ? ylv->lb : yp1; + if (!xp1) { - jl_value_t *yl = yp1; - if (yl) { - // Unconstrained on the left, constrained on the right - if (jl_is_typevar(yl)) { - jl_varbinding_t *ylv = lookup(e, (jl_tvar_t*)yl); - if (ylv) - yl = ylv->lb; - } - if (jl_is_long(yl)) { - return 0; - } - } + // Unconstrained on the left, constrained on the right + if (yl && jl_is_long(yl)) + return 0; } else { - jl_value_t *xl = jl_unwrap_vararg_num(vtx); - if (jl_is_typevar(xl)) { - jl_varbinding_t *xlv = lookup(e, (jl_tvar_t*)xl); - if (xlv) - xl = xlv->lb; - } if (jl_is_long(xl)) { if (jl_unbox_long(xl) + 1 == vx) { // LHS is exhausted. We're a subtype if the RHS is either // exhausted as well or unbounded (in which case we need to // set it to 0). - jl_value_t *yl = jl_unwrap_vararg_num(vty); if (yl) { - if (jl_is_typevar(yl)) { - jl_varbinding_t *ylv = lookup(e, (jl_tvar_t*)yl); - if (ylv) - yl = ylv->lb; - } if (jl_is_long(yl)) { return jl_unbox_long(yl) + 1 == vy; } @@ -1090,6 +1090,8 @@ static int subtype_tuple_varargs( // appropriately. e->invdepth++; int ans = subtype((jl_value_t*)jl_any_type, yp1, e, 2); + if (ylv && !ylv->intersected) + ylv->max_offset = max_offsety; e->invdepth--; return ans; } @@ -1130,6 +1132,10 @@ static int subtype_tuple_varargs( e->Loffset = 0; } JL_GC_POP(); + if (ylv && !ylv->intersected) + ylv->max_offset = max_offsety; + if (xlv && !xlv->intersected) + xlv->max_offset = max_offsetx; e->invdepth--; return ans; } @@ -3134,7 +3140,7 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_ { jl_value_t *res = NULL; jl_savedenv_t se; - jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, + jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, 0, 0, e->invdepth, NULL, e->vars }; JL_GC_PUSH4(&res, &vb.lb, &vb.ub, &vb.innervars); save_env(e, &se, 1); @@ -3142,6 +3148,7 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_ if (is_leaf_typevar(u->var) && noinv && always_occurs_cov(u->body, u->var, param)) vb.constraintkind = 1; res = intersect_unionall_(t, u, e, R, param, &vb); + vb.intersected = 1; if (vb.limited) { // if the environment got too big, avoid tree recursion and propagate the flag if (e->vars) @@ -3218,10 +3225,12 @@ static jl_value_t *intersect_varargs(jl_vararg_t *vmx, jl_vararg_t *vmy, ssize_t assert(e->Loffset == 0); e->Loffset = offset; jl_varbinding_t *xb = NULL, *yb = NULL; + int8_t max_offsetx = 0, max_offsety = 0; if (xp2) { assert(jl_is_typevar(xp2)); xb = lookup(e, (jl_tvar_t*)xp2); if (xb) xb->intvalued = 1; + if (xb) max_offsetx = xb->max_offset; if (!yp2) i2 = bound_var_below((jl_tvar_t*)xp2, xb, e, 0); } @@ -3229,6 +3238,7 @@ static jl_value_t *intersect_varargs(jl_vararg_t *vmx, jl_vararg_t *vmy, ssize_t assert(jl_is_typevar(yp2)); yb = lookup(e, (jl_tvar_t*)yp2); if (yb) yb->intvalued = 1; + if (yb) max_offsety = yb->max_offset; if (!xp2) i2 = bound_var_below((jl_tvar_t*)yp2, yb, e, 1); } @@ -3243,14 +3253,27 @@ static jl_value_t *intersect_varargs(jl_vararg_t *vmx, jl_vararg_t *vmy, ssize_t } assert(e->Loffset == offset); e->Loffset = 0; - if (i2 == jl_bottom_type) + if (i2 == jl_bottom_type) { ii = (jl_value_t*)jl_bottom_type; - else if (xp2 && obviously_egal(xp1, ii) && obviously_egal(xp2, i2)) - ii = (jl_value_t*)vmx; - else if (yp2 && obviously_egal(yp1, ii) && obviously_egal(yp2, i2)) - ii = (jl_value_t*)vmy; - else - ii = (jl_value_t*)jl_wrap_vararg(ii, i2, 1); + } + else { + if (xb && !xb->intersected) { + xb->max_offset = max_offsetx; + if (offset > xb->max_offset && xb->max_offset >= 0) + xb->max_offset = offset > 32 ? 32 : offset; + } + if (yb && !yb->intersected) { + yb->max_offset = max_offsety; + if (-offset > yb->max_offset && yb->max_offset >= 0) + yb->max_offset = -offset > 32 ? 32 : -offset; + } + if (xp2 && obviously_egal(xp1, ii) && obviously_egal(xp2, i2)) + ii = (jl_value_t*)vmx; + else if (yp2 && obviously_egal(yp1, ii) && obviously_egal(yp2, i2)) + ii = (jl_value_t*)vmy; + else + ii = (jl_value_t*)jl_wrap_vararg(ii, i2, 1); + } JL_GC_POP(); return ii; } @@ -3269,6 +3292,24 @@ static jl_value_t *intersect_tuple(jl_datatype_t *xd, jl_datatype_t *yd, jl_sten llx += jl_unbox_long(jl_unwrap_vararg_num((jl_vararg_t *)jl_tparam(xd, lx-1))) - 1; if (vvy == JL_VARARG_INT) lly += jl_unbox_long(jl_unwrap_vararg_num((jl_vararg_t *)jl_tparam(yd, ly-1))) - 1; + if (vvx == JL_VARARG_BOUND && (vvy == JL_VARARG_BOUND || vvy == JL_VARARG_UNBOUND)) { + jl_value_t *xlen = jl_unwrap_vararg_num((jl_vararg_t*)jl_tparam(xd, lx-1)); + assert(xlen && jl_is_typevar(xlen)); + jl_varbinding_t *xb = lookup(e, (jl_tvar_t*)xlen); + if (xb && xb->intersected && xb->max_offset > 0) { + assert(xb->max_offset <= 32); + llx += xb->max_offset; + } + } + if (vvy == JL_VARARG_BOUND && (vvx == JL_VARARG_BOUND || vvx == JL_VARARG_UNBOUND)) { + jl_value_t *ylen = jl_unwrap_vararg_num((jl_vararg_t*)jl_tparam(yd, ly-1)); + assert(ylen && jl_is_typevar(ylen)); + jl_varbinding_t *yb = lookup(e, (jl_tvar_t*)ylen); + if (yb && yb->intersected && yb->max_offset > 0) { + assert(yb->max_offset <= 32); + lly += yb->max_offset; + } + } if ((vvx == JL_VARARG_NONE || vvx == JL_VARARG_INT) && (vvy == JL_VARARG_NONE || vvy == JL_VARARG_INT)) { @@ -3301,8 +3342,8 @@ static jl_value_t *intersect_tuple(jl_datatype_t *xd, jl_datatype_t *yd, jl_sten assert(i == j && i == np); break; } - if (xi && jl_is_vararg(xi)) vx = vvx != JL_VARARG_INT; - if (yi && jl_is_vararg(yi)) vy = vvy != JL_VARARG_INT; + if (xi && jl_is_vararg(xi)) vx = vvx == JL_VARARG_UNBOUND || (vvx == JL_VARARG_BOUND && i == llx - 1); + if (yi && jl_is_vararg(yi)) vy = vvy == JL_VARARG_UNBOUND || (vvy == JL_VARARG_BOUND && j == lly - 1); if (xi == NULL || yi == NULL) { if (vx && intersect_vararg_length(xi, lly+1-llx, e, 0)) { np = j; @@ -3845,15 +3886,15 @@ static int merge_env(jl_stenv_t *e, jl_savedenv_t *se, int count) roots = se->roots; nroots = se->gcframe.nroots >> 2; } - int n = 0; + int m = 0, n = 0; jl_varbinding_t *v = e->vars; - v = e->vars; while (v != NULL) { if (count == 0) { // need to initialize this - se->buf[n] = 0; - se->buf[n+1] = 0; - se->buf[n+2] = 0; + se->buf[m] = 0; + se->buf[m+1] = 0; + se->buf[m+2] = 0; + se->buf[m+3] = v->max_offset; } if (v->occurs) { // only merge lb/ub/innervars if this var occurs. @@ -3879,13 +3920,17 @@ static int merge_env(jl_stenv_t *e, jl_savedenv_t *se, int count) roots[n+2] = b2; } // record the meeted vars. - se->buf[n] = 1; + se->buf[m] = 1; } // always merge occurs_inv/cov by max (never decrease) - if (v->occurs_inv > se->buf[n+1]) - se->buf[n+1] = v->occurs_inv; - if (v->occurs_cov > se->buf[n+2]) - se->buf[n+2] = v->occurs_cov; + if (v->occurs_inv > se->buf[m+1]) + se->buf[m+1] = v->occurs_inv; + if (v->occurs_cov > se->buf[m+2]) + se->buf[m+2] = v->occurs_cov; + // always merge max_offset by min + if (!v->intersected && v->max_offset < se->buf[m+3]) + se->buf[m+3] = v->max_offset; + m = m + 4; n = n + 3; v = v->prev; } @@ -3917,7 +3962,7 @@ static void final_merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se) } assert(nroots == current_env_length(e) * 3); assert(nroots % 3 == 0); - for (int n = 0; n < nroots; n = n + 3) { + for (int n = 0, m = 0; n < nroots; n += 3, m += 4) { if (merged[n] == NULL) merged[n] = saved[n]; if (merged[n+1] == NULL) @@ -3933,7 +3978,7 @@ static void final_merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se) else merged[n+2] = b2; } - me->buf[n] |= se->buf[n]; + me->buf[m] |= se->buf[m]; } } @@ -4489,7 +4534,7 @@ static jl_value_t *_widen_diagonal(jl_value_t *t, jl_varbinding_t *troot) { static jl_value_t *widen_diagonal(jl_value_t *t, jl_unionall_t *u, jl_varbinding_t *troot) { - jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot }; + jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot }; jl_value_t *nt; JL_GC_PUSH2(&vb.innervars, &nt); if (jl_is_unionall(u->body)) diff --git a/test/subtype.jl b/test/subtype.jl index 9a1dd62f4f9c4..edc38c8556f3c 100644 --- a/test/subtype.jl +++ b/test/subtype.jl @@ -2210,13 +2210,19 @@ let A = Tuple{NTuple{N, Int}, NTuple{N, Int}} where N, Bs = (Tuple{Tuple{Int, Vararg{Any}}, Tuple{Int, Int, Vararg{Any}}}, Tuple{Tuple{Int, Vararg{Any,N1}}, Tuple{Int, Int, Vararg{Any,N2}}} where {N1,N2}, Tuple{Tuple{Int, Vararg{Any,N}} where {N}, Tuple{Int, Int, Vararg{Any,N}} where {N}}) - Cerr = Tuple{Tuple{Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N} + C = Tuple{Tuple{Int, Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N} for B in Bs - C = typeintersect(A, B) - @test C == typeintersect(B, A) != Union{} - @test C != Cerr - # TODO: The ideal result is Tuple{Tuple{Int, Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N} - @test_broken C != Tuple{Tuple{Int, Vararg{Int}}, Tuple{Int, Int, Vararg{Int}}} + @testintersect(A, B, C) + end + A = Tuple{NTuple{N, Int}, Tuple{Int, Vararg{Int, N}}} where N + C = Tuple{Tuple{Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N} + for B in Bs + @testintersect(A, B, C) + end + A = Tuple{Tuple{Int, Vararg{Int, N}}, NTuple{N, Int}} where N + C = Tuple{Tuple{Int, Int, Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N} + for B in Bs + @testintersect(A, B, C) end end @@ -2229,9 +2235,8 @@ let A = Pair{NTuple{N, Int}, NTuple{N, Int}} where N, Bs = (Pair{<:Tuple{Int, Vararg{Int}}, <:Tuple{Int, Int, Vararg{Int}}}, Pair{Tuple{Int, Vararg{Int,N1}}, Tuple{Int, Int, Vararg{Int,N2}}} where {N1,N2}, Pair{<:Tuple{Int, Vararg{Int,N}} where {N}, <:Tuple{Int, Int, Vararg{Int,N}} where {N}}) - Cs = (Bs[2], Bs[2], Bs[3]) - for (B, C) in zip(Bs, Cs) - # TODO: The ideal result is Pair{Tuple{Int, Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N} + C = Pair{Tuple{Int, Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N} + for B in Bs @testintersect(A, B, C) end end