Skip to content

Commit

Permalink
more type-stable type-inference (#41697)
Browse files Browse the repository at this point in the history
(this PR is the final output of my demo at [our workshop](https://github.com/aviatesk/juliacon2021-workshop-pkgdev))

This PR eliminated much of runtime dispatches within our type inference
routine, that are reported by the following JET analysis:
```julia
using JETTest

const CC = Core.Compiler

function function_filter(@nospecialize(ft))
    ft === typeof(CC.isprimitivetype) && return false
    ft === typeof(CC.ismutabletype) && return false
    ft === typeof(CC.isbitstype) && return false
    ft === typeof(CC.widenconst) && return false
    ft === typeof(CC.widenconditional) && return false
    ft === typeof(CC.widenwrappedconditional) && return false
    ft === typeof(CC.maybe_extract_const_bool) && return false
    ft === typeof(CC.ignorelimited) && return false
    return true
end

function frame_filter((; linfo) = sv)
    meth = linfo.def
    isa(meth, Method) || return true
    return occursin("compiler/", string(meth.file))
end

report_dispatch(CC.typeinf, (CC.NativeInterpreter, CC.InferenceState); function_filter, frame_filter)
```

> on master
```
═════ 137 possible errors found ═════
...
```
> on this PR
```
═════ 51 possible errors found ═════
...
```

And it seems like this PR makes JIT slightly faster:
> on master
```julia
~/julia/julia master
❯ ./usr/bin/julia -e '@time using Plots; @time plot(rand(10,3));'
  3.659865 seconds (7.19 M allocations: 497.982 MiB, 3.94% gc time, 0.39% compilation time)
  2.696410 seconds (3.62 M allocations: 202.905 MiB, 7.49% gc time, 56.39% compilation time)
```
> on this PR
```julia
~/julia/julia avi/jetdemo* 7s
❯ ./usr/bin/julia -e '@time using Plots; @time plot(rand(10,3));'
  3.396974 seconds (7.16 M allocations: 491.442 MiB, 4.80% gc time, 0.28% compilation time)
  2.591130 seconds (3.48 M allocations: 196.026 MiB, 7.29% gc time, 56.72% compilation time)
```
  • Loading branch information
aviatesk authored Jul 29, 2021
1 parent 66f9b55 commit 795935f
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 73 deletions.
109 changes: 63 additions & 46 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
push!(edges, edge)
end
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
const_rt, const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
if const_rt !== rt && const_rt rt
rt = const_rt
const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
if const_result !== nothing
const_rt, const_result = const_result
if const_rt !== rt && const_rt rt
rt = const_rt
end
end
push!(const_results, const_result)
if const_result !== nothing
Expand All @@ -107,9 +110,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# try constant propagation with argtypes for this match
# this is in preparation for inlining, or improving the return result
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
const_this_rt, const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
if const_this_rt !== this_rt && const_this_rt this_rt
this_rt = const_this_rt
const_result = abstract_call_method_with_const_args(interp, result, f, this_argtypes, match, sv, false)
if const_result !== nothing
const_this_rt, const_result = const_result
if const_this_rt !== this_rt && const_this_rt this_rt
this_rt = const_this_rt
end
end
push!(const_results, const_result)
if const_result !== nothing
Expand Down Expand Up @@ -523,33 +529,35 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, resul
@nospecialize(f), argtypes::Vector{Any}, match::MethodMatch,
sv::InferenceState, va_override::Bool)
mi = maybe_get_const_prop_profitable(interp, result, f, argtypes, match, sv)
mi === nothing && return Any, nothing
mi === nothing && return nothing
# try constant prop'
inf_cache = get_inference_cache(interp)
inf_result = cache_lookup(mi, argtypes, inf_cache)
if inf_result === nothing
# if there might be a cycle, check to make sure we don't end up
# calling ourselves here.
if result.edgecycle && _any(InfStackUnwind(sv)) do infstate
# if the type complexity limiting didn't decide to limit the call signature (`result.edgelimited = false`)
# we can relax the cycle detection by comparing `MethodInstance`s and allow inference to
# propagate different constant elements if the recursion is finite over the lattice
return (result.edgelimited ? match.method === infstate.linfo.def : mi === infstate.linfo) &&
any(infstate.result.overridden_by_const)
let result = result # prevent capturing
if result.edgecycle && _any(InfStackUnwind(sv)) do infstate
# if the type complexity limiting didn't decide to limit the call signature (`result.edgelimited = false`)
# we can relax the cycle detection by comparing `MethodInstance`s and allow inference to
# propagate different constant elements if the recursion is finite over the lattice
return (result.edgelimited ? match.method === infstate.linfo.def : mi === infstate.linfo) &&
any(infstate.result.overridden_by_const)
end
add_remark!(interp, sv, "[constprop] Edge cycle encountered")
return nothing
end
add_remark!(interp, sv, "[constprop] Edge cycle encountered")
return Any, nothing
end
inf_result = InferenceResult(mi, argtypes, va_override)
frame = InferenceState(inf_result, #=cache=#false, interp)
frame === nothing && return Any, nothing # this is probably a bad generated function (unsound), but just ignore it
frame === nothing && return nothing # this is probably a bad generated function (unsound), but just ignore it
frame.parent = sv
push!(inf_cache, inf_result)
typeinf(interp, frame) || return Any, nothing
typeinf(interp, frame) || return nothing
end
result = inf_result.result
# if constant inference hits a cycle, just bail out
isa(result, InferenceState) && return Any, nothing
isa(result, InferenceState) && return nothing
add_backedge!(mi, sv)
return result, inf_result
end
Expand Down Expand Up @@ -1174,7 +1182,8 @@ function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv:
nargtype === Bottom && return CallMeta(Bottom, false)
nargtype isa DataType || return CallMeta(Any, false) # other cases are not implemented below
isdispatchelem(ft) || return CallMeta(Any, false) # check that we might not have a subtype of `ft` at runtime, before doing supertype lookup below
types = rewrap_unionall(Tuple{ft, unwrap_unionall(types).parameters...}, types)
ft = ft::DataType
types = rewrap_unionall(Tuple{ft, unwrap_unionall(types).parameters...}, types)::Type
nargtype = Tuple{ft, nargtype.parameters...}
argtype = Tuple{ft, argtype.parameters...}
result = findsup(types, method_table(interp))
Expand All @@ -1196,12 +1205,14 @@ function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv:
# t, a = ti.parameters[i], argtypes′[i]
# argtypes′[i] = t ⊑ a ? t : a
# end
const_rt, const_result = abstract_call_method_with_const_args(interp, result, argtype_to_function(ft′), argtypes′, match, sv, false)
if const_rt !== rt && const_rt rt
return CallMeta(const_rt, InvokeCallInfo(match, const_result))
else
return CallMeta(rt, InvokeCallInfo(match, nothing))
const_result = abstract_call_method_with_const_args(interp, result, argtype_to_function(ft′), argtypes′, match, sv, false)
if const_result !== nothing
const_rt, const_result = const_result
if const_rt !== rt && const_rt rt
return CallMeta(const_rt, InvokeCallInfo(match, const_result))
end
end
return CallMeta(rt, InvokeCallInfo(match, nothing))
end

# call where the function is known exactly
Expand Down Expand Up @@ -1301,19 +1312,20 @@ end
function abstract_call_opaque_closure(interp::AbstractInterpreter, closure::PartialOpaque, argtypes::Vector{Any}, sv::InferenceState)
pushfirst!(argtypes, closure.env)
sig = argtypes_to_type(argtypes)
(; rt, edge) = result = abstract_call_method(interp, closure.source::Method, sig, Core.svec(), false, sv)
(; rt, edge) = result = abstract_call_method(interp, closure.source, sig, Core.svec(), false, sv)
edge !== nothing && add_backedge!(edge, sv)
tt = closure.typ
sigT = unwrap_unionall(tt).parameters[1]
match = MethodMatch(sig, Core.svec(), closure.source::Method, sig <: rewrap_unionall(sigT, tt))
sigT = (unwrap_unionall(tt)::DataType).parameters[1]
match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt))
info = OpaqueClosureCallInfo(match)
if !result.edgecycle
const_rettype, const_result = abstract_call_method_with_const_args(interp, result, closure, argtypes,
const_result = abstract_call_method_with_const_args(interp, result, closure, argtypes,
match, sv, closure.isva)
if const_rettype rt
rt = const_rettype
end
if const_result !== nothing
const_rettype, const_result = const_result
if const_rettype rt
rt = const_rettype
end
info = ConstCallInfo(info, Union{Nothing,InferenceResult}[const_result])
end
end
Expand All @@ -1323,7 +1335,7 @@ end
function most_general_argtypes(closure::PartialOpaque)
ret = Any[]
cc = widenconst(closure)
argt = unwrap_unionall(cc).parameters[1]
argt = (unwrap_unionall(cc)::DataType).parameters[1]
if !isa(argt, DataType) || argt.name !== typename(Tuple)
argt = Tuple
end
Expand All @@ -1338,8 +1350,8 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{
f = argtype_to_function(ft)
if isa(ft, PartialOpaque)
return abstract_call_opaque_closure(interp, ft, argtypes[2:end], sv)
elseif isa(unwrap_unionall(ft), DataType) && unwrap_unionall(ft).name === typename(Core.OpaqueClosure)
return CallMeta(rewrap_unionall(unwrap_unionall(ft).parameters[2], ft), false)
elseif (uft = unwrap_unionall(ft); isa(uft, DataType) && uft.name === typename(Core.OpaqueClosure))
return CallMeta(rewrap_unionall((uft::DataType).parameters[2], ft), false)
elseif f === nothing
# non-constant function, but the number of arguments is known
# and the ft is not a Builtin or IntrinsicFunction
Expand Down Expand Up @@ -1534,12 +1546,12 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
if length(e.args) == 2 && isconcretetype(t) && !ismutabletype(t)
at = abstract_eval_value(interp, e.args[2], vtypes, sv)
n = fieldcount(t)
if isa(at, Const) && isa(at.val, Tuple) && n == length(at.val) &&
let t = t; _all(i->getfield(at.val, i) isa fieldtype(t, i), 1:n); end
if isa(at, Const) && isa(at.val, Tuple) && n == length(at.val::Tuple) &&
let t = t; _all(i->getfield(at.val::Tuple, i) isa fieldtype(t, i), 1:n); end
t = Const(ccall(:jl_new_structt, Any, (Any, Any), t, at.val))
elseif isa(at, PartialStruct) && at Tuple && n == length(at.fields) &&
let t = t, at = at; _all(i->at.fields[i] fieldtype(t, i), 1:n); end
t = PartialStruct(t, at.fields)
elseif isa(at, PartialStruct) && at Tuple && n == length(at.fields::Vector{Any}) &&
let t = t, at = at; _all(i->(at.fields::Vector{Any})[i] fieldtype(t, i), 1:n); end
t = PartialStruct(t, at.fields::Vector{Any})
end
end
elseif e.head === :new_opaque_closure
Expand Down Expand Up @@ -1587,7 +1599,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
sym = e.args[1]
t = Bool
if isa(sym, SlotNumber)
vtyp = vtypes[slot_id(sym)]
vtyp = vtypes[slot_id(sym)]::VarState
if vtyp.typ === Bottom
t = Const(false) # never assigned previously
elseif !vtyp.undef
Expand All @@ -1602,7 +1614,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
t = Const(true)
end
elseif isa(sym, Expr) && sym.head === :static_parameter
n = sym.args[1]
n = sym.args[1]::Int
if 1 <= n <= length(sv.sptypes)
spty = sv.sptypes[n]
if isa(spty, Const)
Expand Down Expand Up @@ -1637,7 +1649,7 @@ function abstract_eval_global(M::Module, s::Symbol)
end

function abstract_eval_ssavalue(s::SSAValue, src::CodeInfo)
typ = src.ssavaluetypes[s.id]
typ = (src.ssavaluetypes::Vector{Any})[s.id]
if typ === NOT_FOUND
return Bottom
end
Expand Down Expand Up @@ -1725,6 +1737,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
isva = isa(def, Method) && def.isva
nslots = nargs - isva
slottypes = frame.slottypes
ssavaluetypes = frame.src.ssavaluetypes::Vector{Any}
while frame.pc´´ <= n
# make progress on the active ip set
local pc::Int = frame.pc´´ # current program-counter
Expand Down Expand Up @@ -1828,7 +1841,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
for (caller, caller_pc) in frame.cycle_backedges
# notify backedges of updated type information
typeassert(caller.stmt_types[caller_pc], VarTable) # we must have visited this statement before
if !(caller.src.ssavaluetypes[caller_pc] === Any)
if !((caller.src.ssavaluetypes::Vector{Any})[caller_pc] === Any)
# no reason to revisit if that call-site doesn't affect the final result
if caller_pc < caller.pc´´
caller.pc´´ = caller_pc
Expand All @@ -1838,6 +1851,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end
end
elseif hd === :enter
stmt = stmt::Expr
l = stmt.args[1]::Int
frame.cur_hand = Pair{Any,Any}(l, frame.cur_hand)
# propagate type info to exception handler
Expand All @@ -1853,21 +1867,24 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
typeassert(states[l], VarTable)
frame.handler_at[l] = frame.cur_hand
elseif hd === :leave
stmt = stmt::Expr
for i = 1:((stmt.args[1])::Int)
frame.cur_hand = (frame.cur_hand::Pair{Any,Any}).second
end
else
if hd === :(=)
stmt = stmt::Expr
t = abstract_eval_statement(interp, stmt.args[2], changes, frame)
if t === Bottom
break
end
frame.src.ssavaluetypes[pc] = t
ssavaluetypes[pc] = t
lhs = stmt.args[1]
if isa(lhs, SlotNumber)
changes = StateUpdate(lhs, VarState(t, false), changes, false)
end
elseif hd === :method
stmt = stmt::Expr
fname = stmt.args[1]
if isa(fname, SlotNumber)
changes = StateUpdate(fname, VarState(Any, false), changes, false)
Expand All @@ -1882,7 +1899,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
if !isempty(frame.ssavalue_uses[pc])
record_ssa_assign(pc, t, frame)
else
frame.src.ssavaluetypes[pc] = t
ssavaluetypes[pc] = t
end
end
if frame.cur_hand !== nothing && isa(changes, StateUpdate)
Expand All @@ -1903,7 +1920,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)

if t === nothing
# mark other reached expressions as `Any` to indicate they don't throw
frame.src.ssavaluetypes[pc] = Any
ssavaluetypes[pc] = Any
end

pc´ > n && break # can't proceed with the fast-path fall-through
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ function sptypes_from_meth_instance(linfo::MethodInstance)
while temp isa UnionAll
temp = temp.body
end
sigtypes = temp.parameters
sigtypes = (temp::DataType).parameters
for j = 1:length(sigtypes)
tj = sigtypes[j]
if isType(tj) && tj.parameters[1] === Pi
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/legacy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function replace_code_newstyle!(ci::CodeInfo, ir::IRCode, nargs::Int)
for metanode in ir.meta
push!(ci.code, metanode)
push!(ci.codelocs, 1)
push!(ci.ssavaluetypes, Any)
push!(ci.ssavaluetypes::Vector{Any}, Any)
push!(ci.ssaflags, 0x00)
end
# Translate BB Edges to statement edges
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,7 @@ function type_lift_pass!(ir::IRCode)
if haskey(processed, id)
val = processed[id]
else
push!(worklist, (id, up_id, new_phi, i))
push!(worklist, (id, up_id, new_phi::SSAValue, i))
continue
end
else
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, defuse,
changed = false
for new_idx in type_refine_phi
node = new_nodes.stmts[new_idx]
new_typ = recompute_type(node[:inst], ci, ir, ir.sptypes, slottypes)
new_typ = recompute_type(node[:inst]::Union{PhiNode,PhiCNode}, ci, ir, ir.sptypes, slottypes)
if !(node[:type] new_typ) || !(new_typ node[:type])
node[:type] = new_typ
changed = true
Expand Down
4 changes: 2 additions & 2 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1564,7 +1564,7 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
if length(argtypes) - 1 == tf[2]
argtypes = argtypes[1:end-1]
else
vatype = argtypes[end]
vatype = argtypes[end]::Core.TypeofVararg
argtypes = argtypes[1:end-1]
while length(argtypes) < tf[1]
push!(argtypes, unwrapva(vatype))
Expand Down Expand Up @@ -1670,7 +1670,7 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s
aft = argtypes[2]
if isa(aft, Const) || (isType(aft) && !has_free_typevars(aft)) ||
(isconcretetype(aft) && !(aft <: Builtin))
af_argtype = isa(tt, Const) ? tt.val : tt.parameters[1]
af_argtype = isa(tt, Const) ? tt.val : (tt::DataType).parameters[1]
if isa(af_argtype, DataType) && af_argtype <: Tuple
argtypes_vec = Any[aft, af_argtype.parameters...]
if contains_is(argtypes_vec, Union{})
Expand Down
Loading

2 comments on commit 795935f

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily package evaluation, I will reply here when finished:

@nanosoldier runtests(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your package evaluation job has completed - possible new issues were detected. A full report can be found here.

Please sign in to comment.