Skip to content

Commit

Permalink
improve many type stabilities in Core.Compiler.typeinf (#39549)
Browse files Browse the repository at this point in the history
All of them are detected by JET.jl's self-profiling.
The following code will print type-instabilities/type-errors for all
code paths reachable from `typeinf(::NativeInterpreter, ::InferenceState)`.
```julia
julia> using JET
julia> report_call(Core.Compiler.typeinf,
(Core.Compiler.NativeInterpreter, Core.Compiler.InferenceState);
annotate_types = true)
```

The remaining error reports (e.g. `variable Core.Compiler.string is not
defined`) are because of missing functionality on error paths.
  • Loading branch information
aviatesk authored Feb 16, 2021
1 parent b1fbe7f commit 1bc7f43
Show file tree
Hide file tree
Showing 12 changed files with 65 additions and 49 deletions.
39 changes: 22 additions & 17 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
# Under direct self-recursion, permit much greater use of reducers.
# here we assume that complexity(specTypes) :>= complexity(sig)
comparison = sv.linfo.specTypes
l_comparison = length(unwrap_unionall(comparison).parameters)
l_comparison = length(unwrap_unionall(comparison).parameters)::Int
spec_len = max(spec_len, l_comparison)
else
comparison = method.sig
Expand Down Expand Up @@ -700,16 +700,20 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
res = Union{}
nargs = length(aargtypes)
splitunions = 1 < unionsplitcost(aargtypes) <= InferenceParams(interp).MAX_APPLY_UNION_ENUM
ctypes = Any[Any[aft]]
ctypes = [Any[aft]]
infos = [Union{Nothing, AbstractIterationInfo}[]]
for i = 1:nargs
ctypes´ = []
infos′ = []
ctypes´ = Vector{Any}[]
infos′ = Vector{Union{Nothing, AbstractIterationInfo}}[]
for ti in (splitunions ? uniontypes(aargtypes[i]) : Any[aargtypes[i]])
if !isvarargtype(ti)
cti, info = precise_container_type(interp, itft, ti, sv)
cti_info = precise_container_type(interp, itft, ti, sv)
cti = cti_info[1]::Vector{Any}
info = cti_info[2]::Union{Nothing,AbstractIterationInfo}
else
cti, info = precise_container_type(interp, itft, unwrapva(ti), sv)
cti_info = precise_container_type(interp, itft, unwrapva(ti), sv)
cti = cti_info[1]::Vector{Any}
info = cti_info[2]::Union{Nothing,AbstractIterationInfo}
# We can't represent a repeating sequence of the same types,
# so tmerge everything together to get one type that represents
# everything.
Expand All @@ -726,7 +730,7 @@ function abstract_apply(interp::AbstractInterpreter, @nospecialize(itft), @nospe
continue
end
for j = 1:length(ctypes)
ct = ctypes[j]
ct = ctypes[j]::Vector{Any}
if isvarargtype(ct[end])
# This is vararg, we're not gonna be able to do any inling,
# drop the info
Expand Down Expand Up @@ -850,7 +854,8 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
(a3 = argtypes[3]; isa(a3, Const)) && (idx = a3.val; isa(idx, Int)) &&
(a2 = argtypes[2]; a2 Tuple)
# TODO: why doesn't this use the getfield_tfunc?
cti, _ = precise_container_type(interp, iterate, a2, sv)
cti_info = precise_container_type(interp, iterate, a2, sv)
cti = cti_info[1]::Vector{Any}
if 1 <= idx <= length(cti)
rt = unwrapva(cti[idx])
end
Expand Down Expand Up @@ -1392,7 +1397,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
delete!(W, pc)
frame.currpc = pc
frame.cur_hand = frame.handler_at[pc]
frame.stmt_edges[pc] === nothing || empty!(frame.stmt_edges[pc])
edges = frame.stmt_edges[pc]
edges === nothing || empty!(edges)
stmt = frame.src.code[pc]
changes = s[pc]::VarTable
t = nothing
Expand All @@ -1405,7 +1411,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
elseif isa(stmt, GotoNode)
pc´ = (stmt::GotoNode).label
elseif isa(stmt, GotoIfNot)
condt = abstract_eval_value(interp, stmt.cond, s[pc], frame)
condt = abstract_eval_value(interp, stmt.cond, changes, frame)
if condt === Bottom
empty!(frame.pclimitations)
end
Expand Down Expand Up @@ -1438,7 +1444,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end
end
newstate_else = stupdate!(s[l], changes_else)
if newstate_else !== false
if newstate_else !== nothing
# add else branch to active IP list
if l < frame.pc´´
frame.pc´´ = l
Expand All @@ -1449,7 +1455,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end
elseif isa(stmt, ReturnNode)
pc´ = n + 1
rt = widenconditional(abstract_eval_value(interp, stmt.val, s[pc], frame))
rt = widenconditional(abstract_eval_value(interp, stmt.val, changes, frame))
if !isa(rt, Const) && !isa(rt, Type) && !isa(rt, PartialStruct) && !isa(rt, PartialOpaque)
# only propagate information we know we can store
# and is valid inter-procedurally
Expand Down Expand Up @@ -1483,9 +1489,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
frame.cur_hand = Pair{Any,Any}(l, frame.cur_hand)
# propagate type info to exception handler
old = s[l]
new = s[pc]::VarTable
newstate_catch = stupdate!(old, new)
if newstate_catch !== false
newstate_catch = stupdate!(old, changes)
if newstate_catch !== nothing
if l < frame.pc´´
frame.pc´´ = l
end
Expand Down Expand Up @@ -1556,12 +1561,12 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
# (such as a terminator for a loop, if-else, or try block),
# consider whether we should jump to an older backedge first,
# to try to traverse the statements in approximate dominator order
if newstate !== false
if newstate !== nothing
s[pc´] = newstate
end
push!(W, pc´)
pc = frame.pc´´
elseif newstate !== false
elseif newstate !== nothing
s[pc´] = newstate
pc = pc´
elseif pc´ in W
Expand Down
4 changes: 2 additions & 2 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ function statement_cost(ex::Expr, line::Int, src::CodeInfo, sptypes::Vector{Any}
end
a = ex.args[2]
if a isa Expr
cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, slottypes, params, error_path))
cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, slottypes, union_penalties, params, error_path))
end
return cost
elseif head === :copyast
Expand All @@ -392,7 +392,7 @@ function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::CodeInfo,
thiscost = 0
if stmt isa Expr
thiscost = statement_cost(stmt, line, src, sptypes, slottypes, union_penalties, params,
params.unoptimize_throw_blocks && line in throw_blocks)::Int
throw_blocks !== nothing && line in throw_blocks)::Int
elseif stmt isa GotoNode
# loops are generally always expensive
# but assume that forward jumps are already counted for from
Expand Down
7 changes: 4 additions & 3 deletions base/compiler/ssair/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
labelmap = coverage ? fill(0, length(code)) : changemap
prevloc = zero(eltype(ci.codelocs))
stmtinfo = sv.stmt_info
ssavaluetypes = ci.ssavaluetypes::Vector{Any}
while idx <= length(code)
codeloc = ci.codelocs[idx]
if coverage && codeloc != prevloc && codeloc != 0
# insert a side-effect instruction before the current instruction in the same basic block
insert!(code, idx, Expr(:code_coverage_effect))
insert!(ci.codelocs, idx, codeloc)
insert!(ci.ssavaluetypes, idx, Nothing)
insert!(ssavaluetypes, idx, Nothing)
insert!(stmtinfo, idx, nothing)
changemap[oldidx] += 1
if oldidx < length(labelmap)
Expand All @@ -58,12 +59,12 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, narg
idx += 1
prevloc = codeloc
end
if code[idx] isa Expr && ci.ssavaluetypes[idx] === Union{}
if code[idx] isa Expr && ssavaluetypes[idx] === Union{}
if !(idx < length(code) && isa(code[idx + 1], ReturnNode) && !isdefined((code[idx + 1]::ReturnNode), :val))
# insert unreachable in the same basic block after the current instruction (splitting it)
insert!(code, idx + 1, ReturnNode())
insert!(ci.codelocs, idx + 1, ci.codelocs[idx])
insert!(ci.ssavaluetypes, idx + 1, Union{})
insert!(ssavaluetypes, idx + 1, Union{})
insert!(stmtinfo, idx + 1, nothing)
if oldidx < length(changemap)
changemap[oldidx + 1] += 1
Expand Down
28 changes: 16 additions & 12 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
call = thisarginfo.each[i]
new_stmt = Expr(:call, argexprs[2], def, state...)
state1 = insert_node!(ir, idx, call.rt, new_stmt)
new_sig = with_atype(call_sig(ir, new_stmt))
new_sig = with_atype(call_sig(ir, new_stmt)::Signature)
if isa(call.info, MethodMatchInfo) || isa(call.info, UnionSplitInfo)
info = isa(call.info, MethodMatchInfo) ?
MethodMatchInfo[call.info] : call.info.matches
Expand Down Expand Up @@ -680,7 +680,7 @@ function resolve_todo(todo::InliningTodo, et::Union{EdgeTracker, Nothing}, cache
spec = todo.spec::DelayedInliningSpec
isconst, src = find_inferred(todo.mi, spec.atypes, caches, spec.stmttype)

if isconst
if isconst && et !== nothing
push!(et, todo.mi)
return ConstantCase(src)
end
Expand Down Expand Up @@ -988,9 +988,12 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, invoke_data::Invok
sig.atype, method.sig)::SimpleVector
methsp = methsp::SimpleVector
match = MethodMatch(metharg, methsp, method, true)
result = analyze_method!(match, sig.atypes, state.et, state.caches, state.params, calltype)
et = state.et
result = analyze_method!(match, sig.atypes, et, state.caches, state.params, calltype)
handle_single_case!(ir, stmt, idx, result, true, todo)
intersect!(state.et, WorldRange(invoke_data.min_valid, invoke_data.max_valid))
if et !== nothing
intersect!(et, WorldRange(invoke_data.min_valid, invoke_data.max_valid))
end
return nothing
end

Expand Down Expand Up @@ -1118,6 +1121,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
sig.atype, only_method.sig)::SimpleVector
match = MethodMatch(metharg, methsp, only_method, true)
else
meth = meth::MethodLookupResult
@assert length(meth) == 1
match = meth[1]
end
Expand Down Expand Up @@ -1145,6 +1149,8 @@ end
function assemble_inline_todo!(ir::IRCode, state::InliningState)
# todo = (inline_idx, (isva, isinvoke, na), method, spvals, inline_linetable, inline_ir, lie)
todo = Pair{Int, Any}[]
et = state.et
method_table = state.method_table
for idx in 1:length(ir.stmts)
r = process_simple!(ir, todo, idx, state)
r === nothing && continue
Expand Down Expand Up @@ -1176,20 +1182,18 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
nu = unionsplitcost(sig.atypes)
if nu == 1 || nu > state.params.MAX_UNION_SPLITTING
if !isa(info, MethodMatchInfo)
if state.method_table === nothing
continue
end
info = recompute_method_matches(sig.atype, state.params, state.et, state.method_table)
method_table === nothing && continue
et === nothing && continue
info = recompute_method_matches(sig.atype, state.params, et, method_table)
end
infos = MethodMatchInfo[info]
else
if !isa(info, UnionSplitInfo)
if state.method_table === nothing
continue
end
method_table === nothing && continue
et === nothing && continue
infos = MethodMatchInfo[]
for union_sig in UnionSplitSignature(sig.atypes)
push!(infos, recompute_method_matches(argtypes_to_type(union_sig), state.params, state.et, state.method_table))
push!(infos, recompute_method_matches(argtypes_to_type(union_sig), state.params, et, method_table))
end
else
infos = info.matches
Expand Down
4 changes: 3 additions & 1 deletion base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,15 @@ function compute_basic_blocks(stmts::Vector{Any})
return CFG(blocks, basic_block_index)
end

# this function assumes insert position exists
function first_insert_for_bb(code, cfg::CFG, block::Int)
for idx in cfg.blocks[block].stmts
stmt = code[idx]
if !isa(stmt, PhiNode)
return idx
end
end
error("any insert position isn't found")
end

# SSA-indexed nodes
Expand Down Expand Up @@ -893,7 +895,7 @@ function kill_edge!(compact::IncrementalCompact, active_bb::Int, from::Int, to::
# Check if the block is now dead
if length(preds) == 0
for succ in copy(compact.result_bbs[compact.bb_rename_succ[to]].succs)
kill_edge!(compact, active_bb, to, findfirst(x->x === succ, compact.bb_rename_pred))
kill_edge!(compact, active_bb, to, findfirst(x->x === succ, compact.bb_rename_pred)::Int)
end
if to < active_bb
# Kill all statements in the block
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, defuse, narg
# Having undef_token appear on the RHS is possible if we're on a dead branch.
# Do something reasonable here, by marking the LHS as undef as well.
if val !== undef_token
incoming_vals[id] = SSAValue(make_ssa!(ci, code, idx, id, typ))
incoming_vals[id] = SSAValue(make_ssa!(ci, code, idx, id, typ)::Int)
else
code[idx] = nothing
incoming_vals[id] = undef_token
Expand Down
4 changes: 2 additions & 2 deletions base/compiler/ssair/verify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ end
function check_op(ir::IRCode, domtree::DomTree, @nospecialize(op), use_bb::Int, use_idx::Int, print::Bool)
if isa(op, SSAValue)
if op.id > length(ir.stmts)
def_bb = block_for_inst(ir.cfg, ir.new_nodes[op.id - length(ir.stmts)].pos)
def_bb = block_for_inst(ir.cfg, ir.new_nodes.info[op.id - length(ir.stmts)].pos)
else
def_bb = block_for_inst(ir.cfg, op.id)
end
if (def_bb == use_bb)
if op.id > length(ir.stmts)
@assert ir.new_nodes[op.id - length(ir.stmts)].pos <= use_idx
@assert ir.new_nodes.info[op.id - length(ir.stmts)].pos <= use_idx
else
if op.id >= use_idx
@verify_error "Def ($(op.id)) does not dominate use ($(use_idx)) in same BB"
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -715,8 +715,8 @@ function merge_call_chain!(parent::InferenceState, ancestor::InferenceState, chi
add_cycle_backedge!(child, parent, parent.currpc)
union_caller_cycle!(ancestor, child)
child = parent
parent = child.parent
child === ancestor && break
parent = child.parent::InferenceState
end
end

Expand Down
6 changes: 3 additions & 3 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ function stupdate!(state::VarTable, changes::StateUpdate)
if !isa(changes.var, Slot)
return stupdate!(state, changes.state)
end
newstate = false
newstate = nothing
changeid = slot_id(changes.var::Slot)
for i = 1:length(state)
if i == changeid
Expand Down Expand Up @@ -346,7 +346,7 @@ function stupdate!(state::VarTable, changes::StateUpdate)
end

function stupdate!(state::VarTable, changes::VarTable)
newstate = false
newstate = nothing
for i = 1:length(state)
newtype = changes[i]
oldtype = state[i]
Expand All @@ -360,7 +360,7 @@ end

stupdate!(state::Nothing, changes::VarTable) = copy(changes)

stupdate!(state::Nothing, changes::Nothing) = false
stupdate!(state::Nothing, changes::Nothing) = nothing

function stupdate1!(state::VarTable, change::StateUpdate)
if !isa(change.var, Slot)
Expand Down
8 changes: 4 additions & 4 deletions base/compiler/typeutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,18 +204,18 @@ end
# unioncomplexity estimates the number of calls to `tmerge` to obtain the given type by
# counting the Union instances, taking also into account those hidden in a Tuple or UnionAll
function unioncomplexity(u::Union)
return unioncomplexity(u.a) + unioncomplexity(u.b) + 1
return unioncomplexity(u.a)::Int + unioncomplexity(u.b)::Int + 1
end
function unioncomplexity(t::DataType)
t.name === Tuple.name || isvarargtype(t) || return 0
c = 0
for ti in t.parameters
c = max(c, unioncomplexity(ti))
c = max(c, unioncomplexity(ti)::Int)
end
return c
end
unioncomplexity(u::UnionAll) = max(unioncomplexity(u.body), unioncomplexity(u.var.ub))
unioncomplexity(t::Core.TypeofVararg) = isdefined(t, :T) ? unioncomplexity(t.T) : 0
unioncomplexity(u::UnionAll) = max(unioncomplexity(u.body)::Int, unioncomplexity(u.var.ub)::Int)
unioncomplexity(t::Core.TypeofVararg) = isdefined(t, :T) ? unioncomplexity(t.T)::Int : 0
unioncomplexity(@nospecialize(x)) = 0

function improvable_via_constant_propagation(@nospecialize(t))
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/validation.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

# Expr head => argument count bounds
const VALID_EXPR_HEADS = IdDict{Any,Any}(
const VALID_EXPR_HEADS = IdDict{Symbol,UnitRange}(
:call => 1:typemax(Int),
:invoke => 2:typemax(Int),
:static_parameter => 1:1,
Expand Down
8 changes: 6 additions & 2 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,14 @@ end

function rewrap_unionall(t::Core.TypeofVararg, @nospecialize(u))
isdefined(t, :T) || return t
if !isa(u, UnionAll)
return t
end
T = rewrap_unionall(t.T, u)
if !isdefined(t, :N) || t.N === u.var
return Vararg{rewrap_unionall(t.T, u)}
return Vararg{T}
end
Vararg{rewrap_unionall(t.T, u), t.N}
return Vararg{T, t.N}
end

# replace TypeVars in all enclosing UnionAlls with fresh TypeVars
Expand Down

0 comments on commit 1bc7f43

Please sign in to comment.