Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

release: backport complicated compiler changes #42180

Merged
merged 3 commits into from
Sep 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
304 changes: 176 additions & 128 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

36 changes: 19 additions & 17 deletions base/compiler/inferenceresult.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,35 @@ end
# for the provided `linfo` and `given_argtypes`. The purpose of this function is
# to return a valid value for `cache_lookup(linfo, argtypes, cache).argtypes`,
# so that we can construct cache-correct `InferenceResult`s in the first place.
function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector, va_override)
function matching_cache_argtypes(linfo::MethodInstance, given_argtypes::Vector, va_override::Bool)
@assert isa(linfo.def, Method) # ensure the next line works
nargs::Int = linfo.def.nargs
@assert length(given_argtypes) >= (nargs - 1)
given_argtypes = anymap(widenconditional, given_argtypes)
if va_override || linfo.def.isva
isva = va_override || linfo.def.isva
if isva || isvarargtype(given_argtypes[end])
isva_given_argtypes = Vector{Any}(undef, nargs)
for i = 1:(nargs - 1)
for i = 1:(nargs - isva)
isva_given_argtypes[i] = argtype_by_index(given_argtypes, i)
end
if length(given_argtypes) >= nargs || !isvarargtype(given_argtypes[end])
isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[nargs:end])
else
isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[end:end])
if isva
if length(given_argtypes) < nargs && isvarargtype(given_argtypes[end])
last = length(given_argtypes)
else
last = nargs
end
isva_given_argtypes[nargs] = tuple_tfunc(given_argtypes[last:end])
end
given_argtypes = isva_given_argtypes
end
@assert length(given_argtypes) == nargs
cache_argtypes, overridden_by_const = matching_cache_argtypes(linfo, nothing, va_override)
if nargs === length(given_argtypes)
for i in 1:nargs
given_argtype = given_argtypes[i]
cache_argtype = cache_argtypes[i]
if !is_argtype_match(given_argtype, cache_argtype, overridden_by_const[i])
# prefer the argtype we were given over the one computed from `linfo`
cache_argtypes[i] = given_argtype
overridden_by_const[i] = true
end
for i in 1:nargs
given_argtype = given_argtypes[i]
cache_argtype = cache_argtypes[i]
if !is_argtype_match(given_argtype, cache_argtype, overridden_by_const[i])
# prefer the argtype we were given over the one computed from `linfo`
cache_argtypes[i] = given_argtype
overridden_by_const[i] = true
end
end
return cache_argtypes, overridden_by_const
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ function sptypes_from_meth_instance(linfo::MethodInstance)
while temp isa UnionAll
temp = temp.body
end
sigtypes = temp.parameters
sigtypes = (temp::DataType).parameters
for j = 1:length(sigtypes)
tj = sigtypes[j]
if isType(tj) && tj.parameters[1] === Pi
Expand Down
25 changes: 13 additions & 12 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,11 @@ function stmt_affects_purity(@nospecialize(stmt), ir)
return true
end

# Convert IRCode back to CodeInfo and compute inlining cost and sideeffects
# compute inlining cost and sideeffects
function finish(interp::AbstractInterpreter, opt::OptimizationState, params::OptimizationParams, ir::IRCode, @nospecialize(result))
def = opt.linfo.def
nargs = Int(opt.nargs) - 1
(; src, nargs, linfo) = opt
(; def, specTypes) = linfo
nargs = Int(nargs) - 1

force_noinline = _any(@nospecialize(x) -> isexpr(x, :meta) && x.args[1] === :noinline, ir.meta)

Expand All @@ -221,7 +222,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
end
end
if proven_pure
for fl in opt.src.slotflags
for fl in src.slotflags
if (fl & SLOT_USEDUNDEF) != 0
proven_pure = false
break
Expand All @@ -230,7 +231,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
end
end
if proven_pure
opt.src.pure = true
src.pure = true
end

if proven_pure
Expand All @@ -243,7 +244,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
if !(isa(result, Const) && !is_inlineable_constant(result.val))
opt.const_api = true
end
force_noinline || (opt.src.inlineable = true)
force_noinline || (src.inlineable = true)
end
end

Expand All @@ -252,7 +253,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
# determine and cache inlineability
union_penalties = false
if !force_noinline
sig = unwrap_unionall(opt.linfo.specTypes)
sig = unwrap_unionall(specTypes)
if isa(sig, DataType) && sig.name === Tuple.name
for P in sig.parameters
P = unwrap_unionall(P)
Expand All @@ -264,25 +265,25 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
else
force_noinline = true
end
if !opt.src.inlineable && result === Union{}
if !src.inlineable && result === Union{}
force_noinline = true
end
end
if force_noinline
opt.src.inlineable = false
src.inlineable = false
elseif isa(def, Method)
if opt.src.inlineable && isdispatchtuple(opt.linfo.specTypes)
if src.inlineable && isdispatchtuple(specTypes)
# obey @inline declaration if a dispatch barrier would not help
else
bonus = 0
if result ⊑ Tuple && !isconcretetype(widenconst(result))
bonus = params.inline_tupleret_bonus
end
if opt.src.inlineable
if src.inlineable
# For functions declared @inline, increase the cost threshold 20x
bonus += params.inline_cost_threshold*19
end
opt.src.inlineable = isinlineable(def, opt, params, union_penalties, bonus)
src.inlineable = isinlineable(def, opt, params, union_penalties, bonus)
end
end

Expand Down
32 changes: 17 additions & 15 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,10 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
push!(linetable, LineInfoNode(entry.module, entry.method, entry.file, entry.line,
(entry.inlined_at > 0 ? entry.inlined_at + linetable_offset : inlined_at)))
end
nargs_def = item.mi.def.nargs
isva = nargs_def > 0 && item.mi.def.isva
(; def, sparam_vals) = item.mi
nargs_def = def.nargs::Int32
isva = nargs_def > 0 && def.isva
sig = def.sig
if isva
vararg = mk_tuplecall!(compact, argexprs[nargs_def:end], compact.result[idx][:line])
argexprs = Any[argexprs[1:(nargs_def - 1)]..., vararg]
Expand Down Expand Up @@ -347,7 +349,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
# face of rename_arguments! mutating in place - should figure out
# something better eventually.
inline_compact[idx′] = nothing
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, item.mi.def.sig, item.mi.sparam_vals, linetable_offset, boundscheck_idx, compact)
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck_idx, compact)
if isa(stmt′, ReturnNode)
isa(stmt′.val, SSAValue) && (compact.used_ssas[stmt′.val.id] += 1)
return_value = SSAValue(idx′)
Expand All @@ -374,7 +376,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
inline_compact = IncrementalCompact(compact, spec.ir, compact.result_idx)
for ((_, idx′), stmt′) in inline_compact
inline_compact[idx′] = nothing
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, item.mi.def.sig, item.mi.sparam_vals, linetable_offset, boundscheck_idx, compact)
stmt′ = ssa_substitute!(idx′, stmt′, argexprs, sig, sparam_vals, linetable_offset, boundscheck_idx, compact)
if isa(stmt′, ReturnNode)
if isdefined(stmt′, :val)
val = stmt′.val
Expand Down Expand Up @@ -709,9 +711,8 @@ function compileable_specialization(et::Union{EdgeTracker, Nothing}, match::Meth
return mi
end

function compileable_specialization(et::Union{EdgeTracker, Nothing}, result::InferenceResult)
mi = specialize_method(result.linfo.def::Method, result.linfo.specTypes,
result.linfo.sparam_vals, false, true)
function compileable_specialization(et::Union{EdgeTracker, Nothing}, (; linfo)::InferenceResult)
mi = specialize_method(linfo.def::Method, linfo.specTypes, linfo.sparam_vals, false, true)
mi !== nothing && et !== nothing && push!(et, mi::MethodInstance)
return mi
end
Expand Down Expand Up @@ -1065,9 +1066,9 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result):
pushfirst!(atypes, atype0)

if isa(result, InferenceResult)
item = InliningTodo(result, atypes, calltype)
validate_sparams(item.mi.sparam_vals) || return nothing
if argtypes_to_type(atypes) <: item.mi.def.sig
(; mi) = item = InliningTodo(result, atypes, calltype)
validate_sparams(mi.sparam_vals) || return nothing
if argtypes_to_type(atypes) <: mi.def.sig
state.mi_cache !== nothing && (item = resolve_todo(item, state))
handle_single_case!(ir, stmt, idx, item, true, todo)
return nothing
Expand Down Expand Up @@ -1195,7 +1196,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
for i in 1:length(infos)
info = infos[i]
meth = info.results
if meth === missing || meth.ambig
if meth.ambig
# Too many applicable methods
# Or there is a (partial?) ambiguity
too_many = true
Expand All @@ -1213,19 +1214,20 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
only_method = false
end
for match in meth
signature_union = Union{signature_union, match.spec_types}
if !isdispatchtuple(match.spec_types)
spec_types = match.spec_types
signature_union = Union{signature_union, spec_types}
if !isdispatchtuple(spec_types)
fully_covered = false
continue
end
case = analyze_method!(match, sig.atypes, state, calltype)
if case === nothing
fully_covered = false
continue
elseif _any(p->p[1] === match.spec_types, cases)
elseif _any(p->p[1] === spec_types, cases)
continue
end
push!(cases, Pair{Any,Any}(match.spec_types, case))
push!(cases, Pair{Any,Any}(spec_types, case))
end
end

Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/legacy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function replace_code_newstyle!(ci::CodeInfo, ir::IRCode, nargs::Int)
for metanode in ir.meta
push!(ci.code, metanode)
push!(ci.codelocs, 1)
push!(ci.ssavaluetypes, Any)
push!(ci.ssavaluetypes::Vector{Any}, Any)
push!(ci.ssaflags, 0x00)
end
# Translate BB Edges to statement edges
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ function type_lift_pass!(ir::IRCode)
if haskey(processed, id)
val = processed[id]
else
push!(worklist, (id, up_id, new_phi, i))
push!(worklist, (id, up_id, new_phi::SSAValue, i))
continue
end
else
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, defuse, narg
changed = false
for new_idx in type_refine_phi
node = new_nodes.stmts[new_idx]
new_typ = recompute_type(node[:inst], ci, ir, ir.sptypes, slottypes)
new_typ = recompute_type(node[:inst]::Union{PhiNode,PhiCNode}, ci, ir, ir.sptypes, slottypes)
if !(node[:type] ⊑ new_typ) || !(new_typ ⊑ node[:type])
node[:type] = new_typ
changed = true
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ to re-consult the method table. This info is illegal on any statement that is
not a call to a generic function.
"""
struct MethodMatchInfo
results::Union{Missing, MethodLookupResult}
results::MethodLookupResult
end

"""
Expand Down
4 changes: 2 additions & 2 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1627,7 +1627,7 @@ function builtin_tfunction(interp::AbstractInterpreter, @nospecialize(f), argtyp
if length(argtypes) - 1 == tf[2]
argtypes = argtypes[1:end-1]
else
vatype = argtypes[end]
vatype = argtypes[end]::Core.TypeofVararg
argtypes = argtypes[1:end-1]
while length(argtypes) < tf[1]
push!(argtypes, unwrapva(vatype))
Expand Down Expand Up @@ -1733,7 +1733,7 @@ function return_type_tfunc(interp::AbstractInterpreter, argtypes::Vector{Any}, s
aft = argtypes[2]
if isa(aft, Const) || (isType(aft) && !has_free_typevars(aft)) ||
(isconcretetype(aft) && !(aft <: Builtin))
af_argtype = isa(tt, Const) ? tt.val : tt.parameters[1]
af_argtype = isa(tt, Const) ? tt.val : (tt::DataType).parameters[1]
if isa(af_argtype, DataType) && af_argtype <: Tuple
argtypes_vec = Any[aft, af_argtype.parameters...]
if contains_is(argtypes_vec, Union{})
Expand Down
Loading