Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: mega lattice implementation overhaul #42596

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -421,10 +421,7 @@ eval(Core, :(CodeInstance(mi::MethodInstance, @nospecialize(rettype), @nospecial
min_world::UInt, max_world::UInt) =
ccall(:jl_new_codeinst, Ref{CodeInstance}, (Any, Any, Any, Any, Int32, UInt, UInt),
mi, rettype, inferred_const, inferred, const_flags, min_world, max_world)))
eval(Core, :(Const(@nospecialize(v)) = $(Expr(:new, :Const, :v))))
eval(Core, :(PartialStruct(@nospecialize(typ), fields::Array{Any, 1}) = $(Expr(:new, :PartialStruct, :typ, :fields))))
eval(Core, :(PartialOpaque(@nospecialize(typ), @nospecialize(env), isva::Bool, parent::MethodInstance, source::Method) = $(Expr(:new, :PartialOpaque, :typ, :env, :isva, :parent, :source))))
eval(Core, :(InterConditional(slot::Int, @nospecialize(vtype), @nospecialize(elsetype)) = $(Expr(:new, :InterConditional, :slot, :vtype, :elsetype))))
eval(Core, :(MethodMatch(@nospecialize(spec_types), sparams::SimpleVector, method::Method, fully_covers::Bool) =
$(Expr(:new, :MethodMatch, :spec_types, :sparams, :method, :fully_covers))))

Expand Down Expand Up @@ -495,13 +492,11 @@ Symbol(s::Symbol) = s
module IR
export CodeInfo, MethodInstance, CodeInstance, GotoNode, GotoIfNot, ReturnNode,
NewvarNode, SSAValue, Slot, SlotNumber, TypedSlot, Argument,
PiNode, PhiNode, PhiCNode, UpsilonNode, LineInfoNode,
Const, PartialStruct
PiNode, PhiNode, PhiCNode, UpsilonNode, LineInfoNode

import Core: CodeInfo, MethodInstance, CodeInstance, GotoNode, GotoIfNot, ReturnNode,
NewvarNode, SSAValue, Slot, SlotNumber, TypedSlot, Argument,
PiNode, PhiNode, PhiCNode, UpsilonNode, LineInfoNode,
Const, PartialStruct
PiNode, PhiNode, PhiCNode, UpsilonNode, LineInfoNode

end

Expand Down
809 changes: 417 additions & 392 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

31 changes: 17 additions & 14 deletions base/compiler/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,17 @@ include("ordering.jl")
using .Order
include("sort.jl")
using .Sort
# required by sort/sort! functions
function extrema(x::Array)
isempty(x) && throw(ArgumentError("collection must be non-empty"))
vmin = vmax = x[1]
for i in 2:length(x)
xi = x[i]
vmax = max(vmax, xi)
vmin = min(vmin, xi)
end
return vmin, vmax
end

# We don't include some.jl, but this definition is still useful.
something(x::Nothing, y...) = something(y...)
Expand All @@ -114,6 +125,12 @@ something(x::Any, y...) = x
# compiler #
############

include("compiler/typelattice.jl")

const Argtypes = Vector{LatticeElement}
const EMPTY_SLOTTYPES = Argtypes()
anymap(f::Function, a::Argtypes) = Any[ f(a[i]) for i in 1:length(a) ]

include("compiler/cicache.jl")
include("compiler/types.jl")
include("compiler/utilities.jl")
Expand All @@ -125,27 +142,13 @@ include("compiler/inferencestate.jl")

include("compiler/typeutils.jl")
include("compiler/typelimits.jl")
include("compiler/typelattice.jl")
include("compiler/tfuncs.jl")
include("compiler/stmtinfo.jl")

include("compiler/abstractinterpretation.jl")
include("compiler/typeinfer.jl")
include("compiler/optimize.jl") # TODO: break this up further + extract utilities

# required for bootstrap
# TODO: find why this is needed and remove it.
function extrema(x::Array)
isempty(x) && throw(ArgumentError("collection must be non-empty"))
vmin = vmax = x[1]
for i in 2:length(x)
xi = x[i]
vmax = max(vmax, xi)
vmin = min(vmin, xi)
end
return vmin, vmax
end

include("compiler/bootstrap.jl")
ccall(:jl_set_typeinf_func, Cvoid, (Any,), typeinf_ext_toplevel)

Expand Down
49 changes: 25 additions & 24 deletions base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

function is_argtype_match(@nospecialize(given_argtype),
@nospecialize(cache_argtype),
function is_argtype_match(given_argtype::LatticeElement,
cache_argtype::LatticeElement,
overridden_by_const::Bool)
if is_forwardable_argtype(given_argtype)
return is_lattice_equal(given_argtype, cache_argtype)
Expand All @@ -10,10 +10,10 @@ function is_argtype_match(@nospecialize(given_argtype),
end

function is_forwardable_argtype(@nospecialize x)
return isa(x, Const) ||
isa(x, Conditional) ||
isa(x, PartialStruct) ||
isa(x, PartialOpaque)
return isConst(x) ||
isConditional(x) ||
isPartialStruct(x) ||
isPartialOpaque(x)
end

# In theory, there could be a `cache` containing a matching `InferenceResult`
Expand All @@ -26,43 +26,43 @@ function matching_cache_argtypes(
@assert isa(linfo.def, Method) # ensure the next line works
nargs::Int = linfo.def.nargs
cache_argtypes, overridden_by_const = matching_cache_argtypes(linfo, nothing, va_override)
given_argtypes = Vector{Any}(undef, length(argtypes))
given_argtypes = Vector{LatticeElement}(undef, length(argtypes))
local condargs = nothing
for i in 1:length(argtypes)
argtype = argtypes[i]
# forward `Conditional` if it conveys a constraint on any other argument
if isa(argtype, Conditional) && fargs !== nothing
cnd = argtype
if isConditional(argtype) && fargs !== nothing
cnd = conditional(argtype)
slotid = find_constrained_arg(cnd, fargs, sv)
if slotid !== nothing
# using union-split signature, we may be able to narrow down `Conditional`
sigt = widenconst(slotid > nargs ? argtypes[slotid] : cache_argtypes[slotid])
vtype = tmeet(cnd.vtype, sigt)
elsetype = tmeet(cnd.elsetype, sigt)
vtype = cnd.vtypesigt
elsetype = cnd.elsetypesigt
if vtype === Bottom && elsetype === Bottom
# we accidentally proved this method match is impossible
# TODO bail out here immediately rather than just propagating Bottom ?
given_argtypes[i] = Bottom
given_argtypes[i] =
else
if condargs === nothing
condargs = Tuple{Int,Int}[]
end
push!(condargs, (slotid, i))
given_argtypes[i] = Conditional(SlotNumber(slotid), vtype, elsetype)
given_argtypes[i] = Conditional(slotid, vtype, elsetype)
end
continue
end
end
given_argtypes[i] = widenconditional(argtype)
end
isva = va_override || linfo.def.isva
if isva || isvarargtype(given_argtypes[end])
isva_given_argtypes = Vector{Any}(undef, nargs)
if isva || isVararg(given_argtypes[end])
isva_given_argtypes = Vector{LatticeElement}(undef, nargs)
for i = 1:(nargs - isva)
isva_given_argtypes[i] = argtype_by_index(given_argtypes, i)
end
if isva
if length(given_argtypes) < nargs && isvarargtype(given_argtypes[end])
if length(given_argtypes) < nargs && isVararg(given_argtypes[end])
last = length(given_argtypes)
else
last = nargs
Expand Down Expand Up @@ -101,7 +101,7 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
# For opaque closure, the closure environment is processed elsewhere
nargs -= 1
end
cache_argtypes = Vector{Any}(undef, nargs)
cache_argtypes = Vector{LatticeElement}(undef, nargs)
# First, if we're dealing with a varargs method, then we set the last element of `args`
# to the appropriate `Tuple` type or `PartialStruct` instance.
if !toplevel && isva
Expand All @@ -110,23 +110,24 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
linfo_argtypes = Any[Any for i = 1:nargs]
linfo_argtypes[end] = Vararg{Any}
end
vargtype = Tuple
vargtype = NativeType(Tuple)
else
linfo_argtypes_length = length(linfo_argtypes)
if nargs > linfo_argtypes_length
va = linfo_argtypes[linfo_argtypes_length]
if isvarargtype(va)
new_va = rewrap_unionall(unconstrain_vararg_length(va), specTypes)
vargtype = Tuple{new_va}
vargtype = NativeType(Tuple{new_va})
else
vargtype = Tuple{}
vargtype = NativeType(Tuple{})
end
else
vargtype_elements = Any[]
vargtype_elements = LatticeElement[]
for i in nargs:linfo_argtypes_length
p = linfo_argtypes[i]
p = unwraptv(isvarargtype(p) ? unconstrain_vararg_length(p) : p)
push!(vargtype_elements, elim_free_typevars(rewrap_unionall(p, specTypes)))
p = elim_free_typevars(rewrap_unionall(p, specTypes))
push!(vargtype_elements, NativeType(p))
end
for i in 1:length(vargtype_elements)
atyp = vargtype_elements[i]
Expand Down Expand Up @@ -165,7 +166,7 @@ function most_general_argtypes(method::Union{Method, Nothing}, @nospecialize(spe
elseif isconstType(atyp)
atyp = Const(atyp.parameters[1])
else
atyp = elim_free_typevars(rewrap_unionall(atyp, specTypes))
atyp = NativeType(elim_free_typevars(rewrap_unionall(atyp, specTypes)))
end
i == n && (lastatype = atyp)
cache_argtypes[i] = atyp
Expand Down Expand Up @@ -199,7 +200,7 @@ function matching_cache_argtypes(linfo::MethodInstance, ::Nothing, va_override::
return cache_argtypes, falses(length(cache_argtypes))
end

function cache_lookup(linfo::MethodInstance, given_argtypes::Vector{Any}, cache::Vector{InferenceResult})
function cache_lookup(linfo::MethodInstance, given_argtypes::Argtypes, cache::Vector{InferenceResult})
method = linfo.def::Method
nargs::Int = method.nargs
method.isva && (nargs -= 1)
Expand Down
46 changes: 15 additions & 31 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,12 @@

const LineNum = Int

# The type of a variable load is either a value or an UndefVarError
# (only used in abstractinterpret, doesn't appear in optimize)
struct VarState
typ
undef::Bool
VarState(@nospecialize(typ), undef::Bool) = new(typ, undef)
end

"""
const VarTable = Vector{VarState}

The extended lattice that maps local variables to inferred type represented as `AbstractLattice`.
Each index corresponds to the `id` of `SlotNumber` which identifies each local variable.
Note that `InferenceState` will maintain multiple `VarTable`s at each SSA statement
to enable flow-sensitive analysis.
"""
const VarTable = Vector{VarState}

mutable struct InferenceState
params::InferenceParams
result::InferenceResult # remember where to put the result
linfo::MethodInstance
sptypes::Vector{Any} # types of static parameter
slottypes::Vector{Any}
sptypes::Argtypes # types of static parameter
slottypes::Argtypes
mod::Module
currpc::LineNum
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
Expand All @@ -40,7 +22,7 @@ mutable struct InferenceState
stmt_edges::Vector{Union{Nothing, Vector{Any}}}
stmt_info::Vector{Any}
# return type
bestguess #::Type
bestguess::LatticeElement
# current active instruction pointers
ip::BitSet
pc´´::LineNum
Expand Down Expand Up @@ -79,6 +61,8 @@ mutable struct InferenceState
sp = sptypes_from_meth_instance(linfo::MethodInstance)

nssavalues = src.ssavaluetypes::Int
# NOTE we can't initialize `src.ssavaluetypes` as `Argtypes` to avoid
# an allocation within `ir_to_codeinf!(src)` where we widen all ssavaluetypes to native Julia types
src.ssavaluetypes = Any[ NOT_FOUND for i = 1:nssavalues ]
stmt_info = Any[ nothing for i = 1:length(code) ]

Expand All @@ -91,9 +75,9 @@ mutable struct InferenceState
argtypes = result.argtypes
nargs = length(argtypes)
s_argtypes = VarTable(undef, nslots)
slottypes = Vector{Any}(undef, nslots)
slottypes = Vector{LatticeElement}(undef, nslots)
for i in 1:nslots
at = (i > nargs) ? Bottom : argtypes[i]
at = (i > nargs) ? : LatticeElement(argtypes[i])
s_argtypes[i] = VarState(at, i > nargs)
slottypes[i] = at
end
Expand All @@ -120,7 +104,7 @@ mutable struct InferenceState
IdSet{InferenceState}(), IdSet{InferenceState}(),
src, get_world_counter(interp), valid_worlds,
nargs, s_types, s_edges, stmt_info,
Union{}, ip, 1, n, handler_at,
, ip, 1, n, handler_at,
ssavalue_uses,
Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges
Vector{InferenceState}(), # callers_in_cycle
Expand Down Expand Up @@ -316,9 +300,9 @@ function sptypes_from_meth_instance(linfo::MethodInstance)
else
ty = Const(v)
end
sp[i] = ty
sp[i] = LatticeElement(ty)
end
return sp
return collect(LatticeElement, sp)
end

_topmod(sv::InferenceState) = _topmod(sv.mod)
Expand All @@ -332,14 +316,14 @@ end

update_valid_age!(edge::InferenceState, sv::InferenceState) = update_valid_age!(sv, edge.valid_worlds)

function record_ssa_assign(ssa_id::Int, @nospecialize(new), frame::InferenceState)
ssavaluetypes = frame.src.ssavaluetypes::Vector{Any}
old = ssavaluetypes[ssa_id]
function record_ssa_assign(ssa_id::Int, new::LatticeElement, frame::InferenceState)
ssavaluetypes = frame.src.ssavaluetypes::SSAValueTypes
old = ssavaluetypes[ssa_id]::SSAValueType
if old === NOT_FOUND || !(new ⊑ old)
# typically, we expect that old ⊑ new (that output information only
# gets less precise with worse input information), but to actually
# guarantee convergence we need to use tmerge here to ensure that is true
ssavaluetypes[ssa_id] = old === NOT_FOUND ? new : tmerge(old, new)
# guarantee convergence we need to use here to ensure that is true
ssavaluetypes[ssa_id] = old === NOT_FOUND ? new : oldnew
W = frame.ip
s = frame.stmt_types
for r in frame.ssavalue_uses[ssa_id]
Expand Down
Loading