Skip to content

Commit

Permalink
inference: improve TypeVar/Vararg handling
Browse files Browse the repository at this point in the history
During working on the incoming lattice overhaul, I found it's quite
confusing that `TypeVar` and `Vararg` can appear in the same context
as valid `Type` objects as well as extended lattice elements.
Since it usually needs special cases to operate on `TypeVar` and `Vararg`
(e.g. they can not be used in subtyping as an obvious example), I believe
it would be great avoid bugs and catch logic errors in the future development
if we separate contexts where they can appear from ones where `Type`
objects and extended lattice elements are expected.

So this commit:
- tries to separate their context, e.g. now `TypeVar` and `Vararg` should
  not be used in `_limit_type_size`, which is supposed to return `Type`,
  but they should be handled its helper function `__limit_type_size`
- makes sure `tfunc`s don't return `TypeVar`s and `TypeVar` never spills
  into the abstract state
- makes sure `widenconst` are not called on `TypeVar` and `Vararg`,
  and now `widenconst` is ensured to return `Type` always
- and does other general refactors
  • Loading branch information
aviatesk committed Oct 16, 2021
1 parent b8ed1ae commit d60f92c
Show file tree
Hide file tree
Showing 12 changed files with 99 additions and 89 deletions.
20 changes: 8 additions & 12 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -778,10 +778,7 @@ function precise_container_type(interp::AbstractInterpreter, @nospecialize(itft)
if isa(tti, DataType) && tti.name === NamedTuple_typename
# A NamedTuple iteration is the same as the iteration of its Tuple parameter:
# compute a new `tti == unwrap_unionall(tti0)` based on that Tuple type
tti = tti.parameters[2]
while isa(tti, TypeVar)
tti = tti.ub
end
tti = unwraptv(tti.parameters[2])
tti0 = rewrap_unionall(tti, tti0)
end
if isa(tti, Union)
Expand Down Expand Up @@ -1153,7 +1150,8 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
end
end
end
return isa(rt, TypeVar) ? rt.ub : rt
@assert !isa(rt, TypeVar) "unhandled TypeVar"
return rt
end

function abstract_call_unionall(argtypes::Vector{Any})
Expand Down Expand Up @@ -1405,7 +1403,7 @@ function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool)
spsig = linfo.def.sig
if isa(spsig, UnionAll)
if !isempty(linfo.sparam_vals)
sparam_vals = Any[isa(v, Core.TypeofVararg) ? TypeVar(:N, Union{}, Any) :
sparam_vals = Any[isvarargtype(v) ? TypeVar(:N, Union{}, Any) :
v for v in linfo.sparam_vals]
T = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), T, spsig, sparam_vals)
isref && isreturn && T === Any && return Bottom # catch invalid return Ref{T} where T = Any
Expand All @@ -1419,10 +1417,7 @@ function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool)
end
end
end
while isa(T, TypeVar)
T = T.ub
end
return T
return unwraptv(T)
end

function abstract_eval_cfunction(interp::AbstractInterpreter, e::Expr, vtypes::VarTable, sv::InferenceState)
Expand Down Expand Up @@ -1632,7 +1627,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
else
t = abstract_eval_value_expr(interp, e, vtypes, sv)
end
@assert !isa(t, TypeVar)
@assert !isa(t, TypeVar) "unhandled TypeVar"
if isa(t, DataType) && isdefined(t, :instance)
# replace singleton types with their equivalent Const object
t = Const(t.instance)
Expand Down Expand Up @@ -1717,7 +1712,8 @@ function widenreturn(@nospecialize(rt), @nospecialize(bestguess), nslots::Int, s
fields = copy(rt.fields)
haveconst = false
for i in 1:length(fields)
a = widenreturn(fields[i], bestguess, nslots, slottypes, changes)
a = fields[i]
a = isvarargtype(a) ? a : widenreturn(a, bestguess, nslots, slottypes, changes)
if !haveconst && has_const_info(a)
# TODO: consider adding && const_prop_profitable(a) here?
haveconst = true
Expand Down
4 changes: 1 addition & 3 deletions base/compiler/bootstrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ let
# remove any TypeVars from the intersection
typ = Any[m.spec_types.parameters...]
for i = 1:length(typ)
if isa(typ[i], TypeVar)
typ[i] = typ[i].ub
end
typ[i] = unwraptv(typ[i])
end
typeinf_type(interp, m.method, Tuple{typ...}, m.sparams)
end
Expand Down
3 changes: 2 additions & 1 deletion base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ using Core.Intrinsics, Core.IR

import Core: print, println, show, write, unsafe_write, stdout, stderr,
_apply_iterate, svec, apply_type, Builtin, IntrinsicFunction,
MethodInstance, CodeInstance, MethodMatch, PartialOpaque
MethodInstance, CodeInstance, MethodMatch, PartialOpaque,
TypeofVararg

const getproperty = Core.getfield
const setproperty! = Core.setfield!
Expand Down
4 changes: 1 addition & 3 deletions base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,7 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
atyp = unwrapva(atyp)
tail_index -= 1
end
while isa(atyp, TypeVar)
atyp = atyp.ub
end
atyp = unwraptv(atyp)
if isa(atyp, DataType) && isdefined(atyp, :instance)
# replace singleton types with their equivalent Const object
atyp = Const(atyp.instance)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ function sptypes_from_meth_instance(linfo::MethodInstance)
ty = UnionAll(tv, Type{tv})
end
end
elseif isa(v, Core.TypeofVararg)
elseif isvarargtype(v)
ty = Int
else
ty = Const(v)
Expand Down
6 changes: 2 additions & 4 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ end

function validate_sparams(sparams::SimpleVector)
for i = 1:length(sparams)
(isa(sparams[i], TypeVar) || isa(sparams[i], Core.TypeofVararg)) && return false
(isa(sparams[i], TypeVar) || isvarargtype(sparams[i])) && return false
end
return true
end
Expand Down Expand Up @@ -873,9 +873,7 @@ function is_valid_type_for_apply_rewrite(@nospecialize(typ), params::Optimizatio
typ = widenconst(typ)
if isa(typ, DataType) && typ.name === NamedTuple_typename
typ = typ.parameters[2]
while isa(typ, TypeVar)
typ = typ.ub
end
typ = unwraptv(typ)
end
isa(typ, DataType) || return false
if typ.name === Tuple.name
Expand Down
43 changes: 25 additions & 18 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -549,11 +549,9 @@ function typeof_tfunc(@nospecialize(t))
return Type{<:t}
end
elseif isa(t, Union)
a = widenconst(typeof_tfunc(t.a))
b = widenconst(typeof_tfunc(t.b))
a = widenconst(_typeof_tfunc(t.a))
b = widenconst(_typeof_tfunc(t.b))
return Union{a, b}
elseif isa(t, TypeVar) && !(Any === t.ub)
return typeof_tfunc(t.ub)
elseif isa(t, UnionAll)
u = unwrap_unionall(t)
if isa(u, DataType) && !isabstracttype(u)
Expand All @@ -570,6 +568,13 @@ function typeof_tfunc(@nospecialize(t))
end
return DataType # typeof(anything)::DataType
end
# helper function of `typeof_tfunc`, which accepts `TypeVar`
function _typeof_tfunc(@nospecialize(t))
if isa(t, TypeVar)
return t.ub !== Any ? _typeof_tfunc(t.ub) : DataType
end
return typeof_tfunc(t)
end
add_tfunc(typeof, 1, 1, typeof_tfunc, 1)

function typeassert_tfunc(@nospecialize(v), @nospecialize(t))
Expand Down Expand Up @@ -865,10 +870,7 @@ function getfield_tfunc(@nospecialize(s00), @nospecialize(name))
elseif Symbol name
name = Int
end
_ts = s.parameters[2]
while isa(_ts, TypeVar)
_ts = _ts.ub
end
_ts = unwraptv(s.parameters[2])
_ts = rewrap_unionall(_ts, s00)
if !(_ts <: Tuple)
return Any
Expand Down Expand Up @@ -1268,7 +1270,7 @@ function apply_type_tfunc(@nospecialize(headtypetype), @nospecialize args...)
return Any
end
if !isempty(args) && isvarargtype(args[end])
return isvarargtype(headtype) ? Core.TypeofVararg : Type
return isvarargtype(headtype) ? TypeofVararg : Type
end
largs = length(args)
if headtype === Union
Expand Down Expand Up @@ -1329,7 +1331,7 @@ function apply_type_tfunc(@nospecialize(headtypetype), @nospecialize args...)
canconst &= !has_free_typevars(aip1)
push!(tparams, aip1)
elseif isa(ai, Const) && (isa(ai.val, Type) || isa(ai.val, TypeVar) ||
valid_tparam(ai.val) || (istuple && isa(ai.val, Core.TypeofVararg)))
valid_tparam(ai.val) || (istuple && isvarargtype(ai.val)))
push!(tparams, ai.val)
elseif isa(ai, PartialTypeVar)
canconst = false
Expand Down Expand Up @@ -1395,11 +1397,11 @@ function apply_type_tfunc(@nospecialize(headtypetype), @nospecialize args...)
catch ex
# type instantiation might fail if one of the type parameters
# doesn't match, which could happen if a type estimate is too coarse
return isvarargtype(headtype) ? Core.TypeofVararg : Type{<:headtype}
return isvarargtype(headtype) ? TypeofVararg : Type{<:headtype}
end
!uncertain && canconst && return Const(appl)
if isvarargtype(appl)
return Core.TypeofVararg
return TypeofVararg
end
if istuple
return Type{<:appl}
Expand Down Expand Up @@ -1439,12 +1441,15 @@ function tuple_tfunc(atypes::Vector{Any})
if has_struct_const_info(x)
anyinfo = true
else
atypes[i] = x = widenconst(x)
if !isvarargtype(x)
x = widenconst(x)
end
atypes[i] = x
end
if isa(x, Const)
params[i] = typeof(x.val)
else
x = widenconst(x)
x = isvarargtype(x) ? x : widenconst(x)
if isType(x)
anyinfo = true
xparam = x.parameters[1]
Expand All @@ -1467,10 +1472,12 @@ end
function arrayref_tfunc(@nospecialize(boundscheck), @nospecialize(a), @nospecialize i...)
a = widenconst(a)
if a <: Array
if isa(a, DataType) && (isa(a.parameters[1], Type) || isa(a.parameters[1], TypeVar))
if isa(a, DataType) && begin
ap1 = a.parameters[1]
isa(ap1, Type) || isa(ap1, TypeVar)
end
# TODO: the TypeVar case should not be needed here
a = a.parameters[1]
return isa(a, TypeVar) ? a.ub : a
return unwraptv(ap1)
elseif isa(a, UnionAll) && !has_free_typevars(a)
unw = unwrap_unionall(a)
if isa(unw, DataType)
Expand Down Expand Up @@ -1632,7 +1639,7 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
if length(argtypes) - 1 == tf[2]
argtypes = argtypes[1:end-1]
else
vatype = argtypes[end]::Core.TypeofVararg
vatype = argtypes[end]::TypeofVararg
argtypes = argtypes[1:end-1]
while length(argtypes) < tf[1]
push!(argtypes, unwrapva(vatype))
Expand Down
6 changes: 1 addition & 5 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -971,11 +971,7 @@ function _return_type(interp::AbstractInterpreter, @nospecialize(f), @nospeciali
rt = Union{}
if isa(f, Builtin)
rt = builtin_tfunction(interp, f, Any[t.parameters...], nothing)
if isa(rt, TypeVar)
rt = rt.ub
else
rt = widenconst(rt)
end
rt = widenconst(rt)
else
for match in _methods(f, t, -1, get_world_counter(interp))::Vector
match = match::Core.MethodMatch
Expand Down
4 changes: 2 additions & 2 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ widenconst(c::PartialTypeVar) = TypeVar
widenconst(t::PartialStruct) = t.typ
widenconst(t::PartialOpaque) = t.typ
widenconst(t::Type) = t
widenconst(t::TypeVar) = t
widenconst(t::Core.TypeofVararg) = t
widenconst(t::TypeVar) = error("unhandled TypeVar")
widenconst(t::TypeofVararg) = error("unhandled Vararg")
widenconst(t::LimitedAccuracy) = error("unhandled LimitedAccuracy")

issubstate(a::VarState, b::VarState) = (a.typ b.typ && a.undef <= b.undef)
Expand Down
75 changes: 42 additions & 33 deletions base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function is_derived_type(@nospecialize(t), @nospecialize(c), mindepth::Int)
# see if it is derived from the body
# also handle the var here, since this construct bounds the mindepth to the smallest possible value
return is_derived_type(t, c.var.ub, mindepth) || is_derived_type(t, c.body, mindepth)
elseif isa(c, Core.TypeofVararg)
elseif isvarargtype(c)
return is_derived_type(t, unwrapva(c), mindepth)
elseif isa(c, DataType)
if mindepth > 0
Expand Down Expand Up @@ -79,6 +79,7 @@ end
# The goal of this function is to return a type of greater "size" and less "complexity" than
# both `t` or `c` over the lattice defined by `sources`, `depth`, and `allowed_tuplelen`.
function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVector, depth::Int, allowed_tuplelen::Int)
@assert isa(t, Type) && isa(c, Type) "unhandled TypeVar / Vararg"
if t === c
return t # quick egal test
elseif t === Union{}
Expand All @@ -98,40 +99,22 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
# first attempt to turn `c` into a type that contributes meaningful information
# by peeling off meaningless non-matching wrappers of comparison one at a time
# then unwrap `t`
if isa(c, TypeVar)
if isa(t, TypeVar) && t.ub === c.ub && (t.lb === Union{} || t.lb === c.lb)
return t # it's ok to change the name, or widen `lb` to Union{}, so we can handle this immediately here
end
return _limit_type_size(t, c.ub, sources, depth, allowed_tuplelen)
end
# NOTE that `TypeVar` / `Vararg` are handled separately to catch the logic errors
if isa(c, UnionAll)
return _limit_type_size(t, c.body, sources, depth, allowed_tuplelen)
return __limit_type_size(t, c.body, sources, depth, allowed_tuplelen)::Type
end
if isa(t, UnionAll)
tbody = _limit_type_size(t.body, c, sources, depth, allowed_tuplelen)
tbody = __limit_type_size(t.body, c, sources, depth, allowed_tuplelen)
tbody === t.body && return t
return UnionAll(t.var, tbody)
elseif isa(t, TypeVar)
# don't have a matching TypeVar in comparison, so we keep just the upper bound
return _limit_type_size(t.ub, c, sources, depth, allowed_tuplelen)
return UnionAll(t.var, tbody)::Type
elseif isa(t, Union)
if isa(c, Union)
a = _limit_type_size(t.a, c.a, sources, depth, allowed_tuplelen)
b = _limit_type_size(t.b, c.b, sources, depth, allowed_tuplelen)
a = __limit_type_size(t.a, c.a, sources, depth, allowed_tuplelen)
b = __limit_type_size(t.b, c.b, sources, depth, allowed_tuplelen)
return Union{a, b}
end
elseif isa(t, Core.TypeofVararg)
isa(c, Core.TypeofVararg) || return Vararg
VaT = _limit_type_size(unwrapva(t), unwrapva(c), sources, depth + 1, 0)
if isdefined(t, :N) && (isa(t.N, TypeVar) || (isdefined(c, :N) && t.N === c.N))
return Vararg{VaT, t.N}
end
return Vararg{VaT}
elseif isa(t, DataType)
if isa(c, Core.TypeofVararg)
# Tuple{Vararg{T}} --> Tuple{T} is OK
return _limit_type_size(t, unwrapva(c), sources, depth, 0)
elseif isType(t) # allow taking typeof as Type{...}, but ensure it doesn't start nesting
if isType(t) # allow taking typeof as Type{...}, but ensure it doesn't start nesting
tt = unwrap_unionall(t.parameters[1])
(!isa(tt, DataType) || isType(tt)) && (depth += 1)
is_derived_type_from_any(tt, sources, depth) && return t
Expand Down Expand Up @@ -161,7 +144,7 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
else
cPi = Any
end
Q[i] = _limit_type_size(Q[i], cPi, sources, depth + 1, 0)
Q[i] = __limit_type_size(Q[i], cPi, sources, depth + 1, 0)
end
return Tuple{Q...}
end
Expand All @@ -182,6 +165,31 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec
return Any
end

# helper function of `_limit_type_size`, which has the right to take and return `TypeVar` / `Vararg`
function __limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVector, depth::Int, allowed_tuplelen::Int)
if isa(c, TypeVar)
if isa(t, TypeVar) && t.ub === c.ub && (t.lb === Union{} || t.lb === c.lb)
return t # it's ok to change the name, or widen `lb` to Union{}, so we can handle this immediately here
end
return __limit_type_size(t, c.ub, sources, depth, allowed_tuplelen)
elseif isa(t, TypeVar)
# don't have a matching TypeVar in comparison, so we keep just the upper bound
return __limit_type_size(t.ub, c, sources, depth, allowed_tuplelen)
elseif isvarargtype(t)
isvarargtype(c) || return Vararg
VaT = __limit_type_size(unwrapva(t), unwrapva(c), sources, depth + 1, 0)
if isdefined(t, :N) && (isa(t.N, TypeVar) || (isdefined(c, :N) && t.N === c.N))
return Vararg{VaT, t.N}
end
return Vararg{VaT}
elseif isvarargtype(c)
# Tuple{Vararg{T}} --> Tuple{T} is OK
return __limit_type_size(t, unwrapva(c), sources, depth, 0)
else
return _limit_type_size(t, c, sources, depth, allowed_tuplelen)
end
end

function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVector, depth::Int, tupledepth::Int, allowed_tuplelen::Int)
# detect cases where the comparison is trivial
if t === c
Expand Down Expand Up @@ -225,13 +233,13 @@ function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVe
return t !== 1 && !(0 <= t < c) # alternatively, could use !(abs(t) <= abs(c) || abs(t) < n) for some n
end
# base case for data types
if isa(t, Core.TypeofVararg)
if isa(c, Core.TypeofVararg)
if isvarargtype(t)
if isvarargtype(c)
return type_more_complex(unwrapva(t), unwrapva(c), sources, depth + 1, tupledepth, 0)
end
elseif isa(t, DataType)
tP = t.parameters
if isa(c, Core.TypeofVararg)
if isvarargtype(c)
return type_more_complex(t, unwrapva(c), sources, depth, tupledepth, 0)
elseif isType(t) # allow taking typeof any source type anywhere as Type{...}, as long as it isn't nesting Type{Type{...}}
tt = unwrap_unionall(t.parameters[1])
Expand Down Expand Up @@ -603,10 +611,11 @@ function tmeet(@nospecialize(v), @nospecialize(t))
@assert widev <: Tuple
new_fields = Vector{Any}(undef, length(v.fields))
for i = 1:length(new_fields)
if isa(v.fields[i], Core.TypeofVararg)
new_fields[i] = v.fields[i]
vfi = v.fields[i]
if isvarargtype(vfi)
new_fields[i] = vfi
else
new_fields[i] = tmeet(v.fields[i], widenconst(getfield_tfunc(t, Const(i))))
new_fields[i] = tmeet(vfi, widenconst(getfield_tfunc(t, Const(i))))
if new_fields[i] === Bottom
return Bottom
end
Expand Down
Loading

0 comments on commit d60f92c

Please sign in to comment.