Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve optimization of harmless call cycles #38231

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ const _REF_NAME = Ref.body.name
call_result_unused(frame::InferenceState, pc::LineNum=frame.currpc) =
isexpr(frame.src.code[frame.currpc], :call) && isempty(frame.ssavalue_uses[pc])


function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, @nospecialize(atype), sv::InferenceState,
max_methods::Int = InferenceParams(interp).MAX_METHODS)
if sv.params.unoptimize_throw_blocks && sv.currpc in sv.throw_blocks
Expand Down Expand Up @@ -1380,7 +1379,11 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
if isa(fname, Slot)
changes = StateUpdate(fname, VarState(Any, false), changes)
end
elseif hd === :inbounds || hd === :meta || hd === :loopinfo || hd === :code_coverage_effect
elseif hd === :meta
if stmt.args[1] == :noinline
frame.saw_noinline = true
end
elseif hd === :inbounds || hd === :loopinfo || hd === :code_coverage_effect
# these do not generate code
else
t = abstract_eval_statement(interp, stmt, changes, frame)
Expand Down
3 changes: 2 additions & 1 deletion base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ mutable struct InferenceState
limited::Bool
inferred::Bool
dont_work_on_me::Bool
saw_noinline::Bool

# The place to look up methods while working on this function.
# In particular, we cache method lookup results for the same function to
Expand Down Expand Up @@ -113,7 +114,7 @@ mutable struct InferenceState
Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges
Vector{InferenceState}(), # callers_in_cycle
#=parent=#nothing,
cached, false, false, false,
cached, false, false, false, false,
CachedMethodTable(method_table(interp)),
interp)
result.result = frame
Expand Down
47 changes: 28 additions & 19 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,24 @@ function push!(et::EdgeTracker, ci::CodeInstance)
push!(et, ci.def)
end

struct InferenceCaches{T, S}
inf_cache::T
# An mi_cache that overlays some base cache, but also caches
# temporary results while we're working on a cycle
struct CycleInferenceCache{S}
cycle_mis::IdDict{MethodInstance, InferenceResult}
mi_cache::S
end
CycleInferenceCache(mi_cache) = CycleInferenceCache(IdDict{MethodInstance, InferenceResult}(), mi_cache)

function setindex!(cic::CycleInferenceCache, v::InferenceResult, mi::MethodInstance)
cic.cycle_mis[mi] = v
return cic
end

function get(cic::CycleInferenceCache, mi::MethodInstance, @nospecialize(default))
result = get(cic.cycle_mis, mi, nothing)
result !== nothing && return result
return get(cic.mi_cache, mi, default)
end

struct InliningState{S <: Union{EdgeTracker, Nothing}, T <: Union{InferenceCaches, Nothing}, V <: Union{Nothing, MethodTableView}}
params::OptimizationParams
Expand All @@ -42,24 +56,22 @@ mutable struct OptimizationState
sptypes::Vector{Any} # static parameters
slottypes::Vector{Any}
const_api::Bool
inlining::InliningState
params::OptimizationParams
et::Union{Nothing, EdgeTracker}
mt::Union{Nothing, MethodTableView}
function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter)
s_edges = frame.stmt_edges[1]
if s_edges === nothing
s_edges = []
frame.stmt_edges[1] = s_edges
end
src = frame.src
inlining = InliningState(params,
EdgeTracker(s_edges::Vector{Any}, frame.valid_worlds),
InferenceCaches(
get_inference_cache(interp),
WorldView(code_cache(interp), frame.world)),
method_table(interp))
return new(frame.linfo,
src, frame.stmt_info, frame.mod, frame.nargs,
frame.sptypes, frame.slottypes, false,
inlining)
params,
EdgeTracker(s_edges::Vector{Any}, frame.valid_worlds),
method_table(interp))
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter)
# prepare src for running optimization passes
Expand All @@ -86,16 +98,13 @@ mutable struct OptimizationState
end
# Allow using the global MI cache, but don't track edges.
# This method is mostly used for unit testing the optimizer
inlining = InliningState(params,
nothing,
InferenceCaches(
get_inference_cache(interp),
WorldView(code_cache(interp), get_world_counter())),
method_table(interp))

return new(linfo,
src, stmt_info, inmodule, nargs,
sptypes_from_meth_instance(linfo), slottypes, false,
inlining)
params,
nothing,
method_table(interp))
end
end

Expand Down Expand Up @@ -180,10 +189,10 @@ function stmt_affects_purity(@nospecialize(stmt), ir)
end

# run the optimization work
function optimize(opt::OptimizationState, params::OptimizationParams, @nospecialize(result))
function optimize(opt::OptimizationState, params::OptimizationParams, caches::InferenceCaches, @nospecialize(result))
def = opt.linfo.def
nargs = Int(opt.nargs) - 1
@timeit "optimizer" ir = run_passes(opt.src, nargs, opt)
@timeit "optimizer" ir = run_passes(opt.src, nargs, caches, opt)
force_noinline = _any(@nospecialize(x) -> isexpr(x, :meta) && x.args[1] === :noinline, ir.meta)

# compute inlining and other related optimizations
Expand Down
13 changes: 9 additions & 4 deletions base/compiler/ssair/domtree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,14 @@ end

length(D::DFSTree) = length(D.from_pre)

function DFS!(D::DFSTree, blocks::Vector{BasicBlock})
succs(bb::BasicBlock) = bb.succs

function DFS!(D::DFSTree, blocks::Vector; roots = 1)
copy!(D, DFSTree(length(blocks)))
to_visit = Tuple{BBNumber, PreNumber, Bool}[(1, 0, false)]
to_visit = Tuple{BBNumber, PreNumber, Bool}[]
for root in roots
push!(to_visit, (root, 0, false))
end
pre_num = 1
post_num = 1
while !isempty(to_visit)
Expand Down Expand Up @@ -144,7 +149,7 @@ function DFS!(D::DFSTree, blocks::Vector{BasicBlock})
to_visit[end] = (current_node_bb, parent_pre, true)

# Push children to the stack
for succ_bb in blocks[current_node_bb].succs
for succ_bb in succs(blocks[current_node_bb])
push!(to_visit, (succ_bb, pre_num, false))
end

Expand All @@ -161,7 +166,7 @@ function DFS!(D::DFSTree, blocks::Vector{BasicBlock})
return D
end

DFS(blocks::Vector{BasicBlock}) = DFS!(DFSTree(0), blocks)
DFS(blocks::Vector; roots = 1) = DFS!(DFSTree(0), blocks; roots)

"""
Keeps the per-BB state of the Semi NCA algorithm. In the original formulation,
Expand Down
5 changes: 3 additions & 2 deletions base/compiler/ssair/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,15 @@ function slot2reg(ir::IRCode, ci::CodeInfo, nargs::Int, sv::OptimizationState)
return ir
end

function run_passes(ci::CodeInfo, nargs::Int, sv::OptimizationState)
function run_passes(ci::CodeInfo, nargs::Int, caches::InferenceCaches, sv::OptimizationState)
preserve_coverage = coverage_enabled(sv.mod)
ir = convert_to_ircode(ci, copy_exprargs(ci.code), preserve_coverage, nargs, sv)
ir = slot2reg(ir, ci, nargs, sv)
#@Base.show ("after_construct", ir)
# TODO: Domsorting can produce an updated domtree - no need to recompute here
@timeit "compact 1" ir = compact!(ir)
@timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds)
inlining = InliningState(sv.params, sv.et, caches, sv.mt)
@timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, inlining, ci.propagate_inbounds)
#@timeit "verify 2" verify_ir(ir)
ir = compact!(ir)
#@Base.show ("before_sroa", ir)
Expand Down
16 changes: 13 additions & 3 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1419,8 +1419,18 @@ function find_inferred(mi::MethodInstance, atypes::Vector{Any}, caches::Inferenc
return svec(true, quoted(linfo.rettype_const))
end
return svec(false, linfo.inferred)
else
# `linfo` may be `nothing` or an IRCode here
return svec(false, linfo)
elseif isa(linfo, InferenceResult)
let inferred_src = linfo.src
if isa(inferred_src, CodeInfo)
return svec(false, inferred_src)
end
if isa(inferred_src, Const) && is_inlineable_constant(inferred_src.val)
return svec(true, quoted(inferred_src.val),)
end
end
linfo = nothing
end

# `linfo` may be `nothing` or an IRCode here
return svec(false, linfo)
end
43 changes: 40 additions & 3 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,36 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState)
end
end

struct ForwardEdges
edges::Vector{Int}
end
ForwardEdges() = ForwardEdges(Int[])
succs(f::ForwardEdges) = f.edges
push!(f::ForwardEdges, i::Int) = push!(f.edges, i)

function postorder_sort_frames(frames)
length(frames) == 1 && return frames
roots = Int[i for i in 1:length(frames) if frames[i].saw_noinline]
# If there are no noinline annoations, just leave the default order
isempty(roots) && return frames

# Number frames
numbering = IdDict{Any, Int}(frames[i] => i for i in 1:length(frames))

# Compute forward edges
forward_edges = ForwardEdges[ForwardEdges() for i in 1:length(frames)]
for i in 1:length(frames)
frame = frames[i]
for (edge, _) in frame.cycle_backedges
push!(forward_edges[numbering[edge]], i)
end
end

# Compute postorder
dfs_tree = DFS(forward_edges; roots=roots)
return InferenceState[frames[i] for i in dfs_tree.from_post]
end

function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
typeinf_nocycle(interp, frame) || return false # frame is now part of a higher cycle
# with no active ip's, frame is done
Expand All @@ -220,19 +250,26 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
for caller in frames
finish(caller, interp)
end
# We postorder sort frames rooted on any frames marked noinline (if any).
# This makes sure that the inliner has the maximum opportunity to inline.
frames = postorder_sort_frames(frames)
# collect results for the new expanded frame
results = Tuple{InferenceResult, Bool}[ ( frames[i].result,
frames[i].cached || frames[i].parent !== nothing ) for i in 1:length(frames) ]
# empty!(frames)
valid_worlds = frame.valid_worlds
cached = frame.cached
caches = InferenceCaches(interp)
cycle_cache = CycleInferenceCache(caches.mi_cache)
caches = InferenceCaches(caches.inf_cache, cycle_cache)
if cached || frame.parent !== nothing
for (caller, doopt) in results
opt = caller.src
if opt isa OptimizationState
run_optimizer = doopt && may_optimize(interp)
if run_optimizer
optimize(opt, OptimizationParams(interp), caller.result)
cycle_cache[opt.linfo] = caller
optimize(opt, OptimizationParams(interp), caches, caller.result)
finish(opt.src, interp)
# finish updating the result struct
validate_code_in_debug_mode(opt.linfo, opt.src, "optimized")
Expand All @@ -251,7 +288,7 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
end
# As a hack the et reuses frame_edges[1] to push any optimization
# edges into, so we don't need to handle them specially here
valid_worlds = intersect(valid_worlds, opt.inlining.et.valid_worlds[])
valid_worlds = intersect(valid_worlds, opt.et.valid_worlds[])
end
end
end
Expand Down Expand Up @@ -768,7 +805,7 @@ function typeinf_code(interp::AbstractInterpreter, method::Method, @nospecialize
if typeinf(interp, frame) && run_optimizer
opt_params = OptimizationParams(interp)
opt = OptimizationState(frame, opt_params, interp)
optimize(opt, opt_params, result.result)
optimize(opt, opt_params, InferenceCaches(interp), result.result)
opt.src.inferred = true
end
ccall(:jl_typeinf_end, Cvoid, ())
Expand Down
10 changes: 10 additions & 0 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ struct InferenceParams
end
end

struct InferenceCaches{T, S}
inf_cache::T
mi_cache::S
end

"""
NativeInterpreter

Expand Down Expand Up @@ -187,6 +192,11 @@ get_inference_cache(ni::NativeInterpreter) = ni.cache

code_cache(ni::NativeInterpreter) = WorldView(GLOBAL_CI_CACHE, ni.world)

InferenceCaches(ni::NativeInterpreter) =
InferenceCaches(
get_inference_cache(ni),
WorldView(code_cache(ni), ni.world))

"""
lock_mi_inference(ni::NativeInterpreter, mi::MethodInstance)

Expand Down