From 53401f1cb4e829b3a0d74295473c9a48fff2bbab Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Thu, 29 Oct 2020 17:36:53 -0400 Subject: [PATCH] Improve optimization of harmless call cycles When inference detects a call cycle, one of two things could happen: 1. It determines that in order for inference to converge it needs to limit the signatures of some methods to something more general, or 2. The cycle is determined to be harmless at the inference level (because though there is a cycle in the CFG there is no dependency cycle of type information). In the first case, we simply disable optimizations, which is sensible, because we're likely to have to recompute some information anyway, when we actually get there dynamically. In the second case however, we do do optimizations, but it's a bit of an unusual case. In particular, inference usually delivers methods to inlining in postorder (meaning callees get optimized before their callers) such that a caller can always inline a callee. However, if there is a cycle, there is of course no unique postorder of functions, since by definition we're looking a locally strongly connected component. In this case, we would just essentially pick an arbitrary order (and moreover, depending on the point at which we enter the cycle and subsequently cached, leading to potential performance instabilities, depending on function order). However, the arbitrary order is quite possibly suboptimal. For example in #36414, we have a cycle randn -> _randn -> randn_unlikely -> rand. In this cycle the author of this code expected both `randn` and `_randn` to inline and annotated the functions as such. However, in 1.5+ the order we happed to pick would have inlined randn_unlikely into _randn (had it not been marked noinline), with a hard call break between randn and _randn, whch is problematic from a performance standpoint. This PR aims to address this by introducing a heuristic: If some functions in a cycle are marked as `@noinline`, we want to make sure to infer these last (since they won't ever be inlined anyway). To ensure this happens, while restoring postorder if this happens to break the cycle, we perform a DFS traversal rooted at any `@noinline` functions and then optimize the functions in the cycle in DFS-postorder. Of course still may still not be a true postorder in the inlining graph (if the `@noinline` functions don't break the cycle), but even in that case, it should be no worse than the default order. Fixes #36414 Closes #37234 --- base/compiler/abstractinterpretation.jl | 7 ++-- base/compiler/inferencestate.jl | 3 +- base/compiler/optimize.jl | 47 +++++++++++++++---------- base/compiler/ssair/domtree.jl | 13 ++++--- base/compiler/ssair/driver.jl | 5 +-- base/compiler/ssair/inlining.jl | 16 +++++++-- base/compiler/typeinfer.jl | 43 ++++++++++++++++++++-- base/compiler/types.jl | 10 ++++++ 8 files changed, 110 insertions(+), 34 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index ae1e269af597d..b1655686acc89 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -16,7 +16,6 @@ const _REF_NAME = Ref.body.name call_result_unused(frame::InferenceState, pc::LineNum=frame.currpc) = isexpr(frame.src.code[frame.currpc], :call) && isempty(frame.ssavalue_uses[pc]) - function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), argtypes::Vector{Any}, @nospecialize(atype), sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS) if sv.params.unoptimize_throw_blocks && sv.currpc in sv.throw_blocks @@ -1380,7 +1379,11 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) if isa(fname, Slot) changes = StateUpdate(fname, VarState(Any, false), changes) end - elseif hd === :inbounds || hd === :meta || hd === :loopinfo || hd === :code_coverage_effect + elseif hd === :meta + if stmt.args[1] == :noinline + frame.saw_noinline = true + end + elseif hd === :inbounds || hd === :loopinfo || hd === :code_coverage_effect # these do not generate code else t = abstract_eval_statement(interp, stmt, changes, frame) diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index 2d5fce04c0454..12671e89f855a 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -42,6 +42,7 @@ mutable struct InferenceState limited::Bool inferred::Bool dont_work_on_me::Bool + saw_noinline::Bool # The place to look up methods while working on this function. # In particular, we cache method lookup results for the same function to @@ -113,7 +114,7 @@ mutable struct InferenceState Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges Vector{InferenceState}(), # callers_in_cycle #=parent=#nothing, - cached, false, false, false, + cached, false, false, false, false, CachedMethodTable(method_table(interp)), interp) result.result = frame diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index 90324b9665175..071d6cce7ab01 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -21,10 +21,24 @@ function push!(et::EdgeTracker, ci::CodeInstance) push!(et, ci.def) end -struct InferenceCaches{T, S} - inf_cache::T +# An mi_cache that overlays some base cache, but also caches +# temporary results while we're working on a cycle +struct CycleInferenceCache{S} + cycle_mis::IdDict{MethodInstance, InferenceResult} mi_cache::S end +CycleInferenceCache(mi_cache) = CycleInferenceCache(IdDict{MethodInstance, InferenceResult}(), mi_cache) + +function setindex!(cic::CycleInferenceCache, v::InferenceResult, mi::MethodInstance) + cic.cycle_mis[mi] = v + return cic +end + +function get(cic::CycleInferenceCache, mi::MethodInstance, @nospecialize(default)) + result = get(cic.cycle_mis, mi, nothing) + result !== nothing && return result + return get(cic.mi_cache, mi, default) +end struct InliningState{S <: Union{EdgeTracker, Nothing}, T <: Union{InferenceCaches, Nothing}, V <: Union{Nothing, MethodTableView}} params::OptimizationParams @@ -42,7 +56,9 @@ mutable struct OptimizationState sptypes::Vector{Any} # static parameters slottypes::Vector{Any} const_api::Bool - inlining::InliningState + params::OptimizationParams + et::Union{Nothing, EdgeTracker} + mt::Union{Nothing, MethodTableView} function OptimizationState(frame::InferenceState, params::OptimizationParams, interp::AbstractInterpreter) s_edges = frame.stmt_edges[1] if s_edges === nothing @@ -50,16 +66,12 @@ mutable struct OptimizationState frame.stmt_edges[1] = s_edges end src = frame.src - inlining = InliningState(params, - EdgeTracker(s_edges::Vector{Any}, frame.valid_worlds), - InferenceCaches( - get_inference_cache(interp), - WorldView(code_cache(interp), frame.world)), - method_table(interp)) return new(frame.linfo, src, frame.stmt_info, frame.mod, frame.nargs, frame.sptypes, frame.slottypes, false, - inlining) + params, + EdgeTracker(s_edges::Vector{Any}, frame.valid_worlds), + method_table(interp)) end function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter) # prepare src for running optimization passes @@ -86,16 +98,13 @@ mutable struct OptimizationState end # Allow using the global MI cache, but don't track edges. # This method is mostly used for unit testing the optimizer - inlining = InliningState(params, - nothing, - InferenceCaches( - get_inference_cache(interp), - WorldView(code_cache(interp), get_world_counter())), - method_table(interp)) + return new(linfo, src, stmt_info, inmodule, nargs, sptypes_from_meth_instance(linfo), slottypes, false, - inlining) + params, + nothing, + method_table(interp)) end end @@ -180,10 +189,10 @@ function stmt_affects_purity(@nospecialize(stmt), ir) end # run the optimization work -function optimize(opt::OptimizationState, params::OptimizationParams, @nospecialize(result)) +function optimize(opt::OptimizationState, params::OptimizationParams, caches::InferenceCaches, @nospecialize(result)) def = opt.linfo.def nargs = Int(opt.nargs) - 1 - @timeit "optimizer" ir = run_passes(opt.src, nargs, opt) + @timeit "optimizer" ir = run_passes(opt.src, nargs, caches, opt) force_noinline = _any(@nospecialize(x) -> isexpr(x, :meta) && x.args[1] === :noinline, ir.meta) # compute inlining and other related optimizations diff --git a/base/compiler/ssair/domtree.jl b/base/compiler/ssair/domtree.jl index 1ab2876b769da..2afed390db6c8 100644 --- a/base/compiler/ssair/domtree.jl +++ b/base/compiler/ssair/domtree.jl @@ -109,9 +109,14 @@ end length(D::DFSTree) = length(D.from_pre) -function DFS!(D::DFSTree, blocks::Vector{BasicBlock}) +succs(bb::BasicBlock) = bb.succs + +function DFS!(D::DFSTree, blocks::Vector; roots = 1) copy!(D, DFSTree(length(blocks))) - to_visit = Tuple{BBNumber, PreNumber, Bool}[(1, 0, false)] + to_visit = Tuple{BBNumber, PreNumber, Bool}[] + for root in roots + push!(to_visit, (root, 0, false)) + end pre_num = 1 post_num = 1 while !isempty(to_visit) @@ -144,7 +149,7 @@ function DFS!(D::DFSTree, blocks::Vector{BasicBlock}) to_visit[end] = (current_node_bb, parent_pre, true) # Push children to the stack - for succ_bb in blocks[current_node_bb].succs + for succ_bb in succs(blocks[current_node_bb]) push!(to_visit, (succ_bb, pre_num, false)) end @@ -161,7 +166,7 @@ function DFS!(D::DFSTree, blocks::Vector{BasicBlock}) return D end -DFS(blocks::Vector{BasicBlock}) = DFS!(DFSTree(0), blocks) +DFS(blocks::Vector; roots = 1) = DFS!(DFSTree(0), blocks; roots) """ Keeps the per-BB state of the Semi NCA algorithm. In the original formulation, diff --git a/base/compiler/ssair/driver.jl b/base/compiler/ssair/driver.jl index 83205033342d6..bbed8edc218b9 100644 --- a/base/compiler/ssair/driver.jl +++ b/base/compiler/ssair/driver.jl @@ -118,14 +118,15 @@ function slot2reg(ir::IRCode, ci::CodeInfo, nargs::Int, sv::OptimizationState) return ir end -function run_passes(ci::CodeInfo, nargs::Int, sv::OptimizationState) +function run_passes(ci::CodeInfo, nargs::Int, caches::InferenceCaches, sv::OptimizationState) preserve_coverage = coverage_enabled(sv.mod) ir = convert_to_ircode(ci, copy_exprargs(ci.code), preserve_coverage, nargs, sv) ir = slot2reg(ir, ci, nargs, sv) #@Base.show ("after_construct", ir) # TODO: Domsorting can produce an updated domtree - no need to recompute here @timeit "compact 1" ir = compact!(ir) - @timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds) + inlining = InliningState(sv.params, sv.et, caches, sv.mt) + @timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, inlining, ci.propagate_inbounds) #@timeit "verify 2" verify_ir(ir) ir = compact!(ir) #@Base.show ("before_sroa", ir) diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 0e95f812e5eb6..02a4be1074e95 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -1419,8 +1419,18 @@ function find_inferred(mi::MethodInstance, atypes::Vector{Any}, caches::Inferenc return svec(true, quoted(linfo.rettype_const)) end return svec(false, linfo.inferred) - else - # `linfo` may be `nothing` or an IRCode here - return svec(false, linfo) + elseif isa(linfo, InferenceResult) + let inferred_src = linfo.src + if isa(inferred_src, CodeInfo) + return svec(false, inferred_src) + end + if isa(inferred_src, Const) && is_inlineable_constant(inferred_src.val) + return svec(true, quoted(inferred_src.val),) + end + end + linfo = nothing end + + # `linfo` may be `nothing` or an IRCode here + return svec(false, linfo) end diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 04c0edb9a0fde..ddf2f5c813781 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -208,6 +208,36 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState) end end +struct ForwardEdges + edges::Vector{Int} +end +ForwardEdges() = ForwardEdges(Int[]) +succs(f::ForwardEdges) = f.edges +push!(f::ForwardEdges, i::Int) = push!(f.edges, i) + +function postorder_sort_frames(frames) + length(frames) == 1 && return frames + roots = Int[i for i in 1:length(frames) if frames[i].saw_noinline] + # If there are no noinline annoations, just leave the default order + isempty(roots) && return frames + + # Number frames + numbering = IdDict{Any, Int}(frames[i] => i for i in 1:length(frames)) + + # Compute forward edges + forward_edges = ForwardEdges[ForwardEdges() for i in 1:length(frames)] + for i in 1:length(frames) + frame = frames[i] + for (edge, _) in frame.cycle_backedges + push!(forward_edges[numbering[edge]], i) + end + end + + # Compute postorder + dfs_tree = DFS(forward_edges; roots=roots) + return InferenceState[frames[i] for i in dfs_tree.from_post] +end + function _typeinf(interp::AbstractInterpreter, frame::InferenceState) typeinf_nocycle(interp, frame) || return false # frame is now part of a higher cycle # with no active ip's, frame is done @@ -220,19 +250,26 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState) for caller in frames finish(caller, interp) end + # We postorder sort frames rooted on any frames marked noinline (if any). + # This makes sure that the inliner has the maximum opportunity to inline. + frames = postorder_sort_frames(frames) # collect results for the new expanded frame results = Tuple{InferenceResult, Bool}[ ( frames[i].result, frames[i].cached || frames[i].parent !== nothing ) for i in 1:length(frames) ] # empty!(frames) valid_worlds = frame.valid_worlds cached = frame.cached + caches = InferenceCaches(interp) + cycle_cache = CycleInferenceCache(caches.mi_cache) + caches = InferenceCaches(caches.inf_cache, cycle_cache) if cached || frame.parent !== nothing for (caller, doopt) in results opt = caller.src if opt isa OptimizationState run_optimizer = doopt && may_optimize(interp) if run_optimizer - optimize(opt, OptimizationParams(interp), caller.result) + cycle_cache[opt.linfo] = caller + optimize(opt, OptimizationParams(interp), caches, caller.result) finish(opt.src, interp) # finish updating the result struct validate_code_in_debug_mode(opt.linfo, opt.src, "optimized") @@ -251,7 +288,7 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState) end # As a hack the et reuses frame_edges[1] to push any optimization # edges into, so we don't need to handle them specially here - valid_worlds = intersect(valid_worlds, opt.inlining.et.valid_worlds[]) + valid_worlds = intersect(valid_worlds, opt.et.valid_worlds[]) end end end @@ -768,7 +805,7 @@ function typeinf_code(interp::AbstractInterpreter, method::Method, @nospecialize if typeinf(interp, frame) && run_optimizer opt_params = OptimizationParams(interp) opt = OptimizationState(frame, opt_params, interp) - optimize(opt, opt_params, result.result) + optimize(opt, opt_params, InferenceCaches(interp), result.result) opt.src.inferred = true end ccall(:jl_typeinf_end, Cvoid, ()) diff --git a/base/compiler/types.jl b/base/compiler/types.jl index 3ca6cff20ccd6..446df5719be4a 100644 --- a/base/compiler/types.jl +++ b/base/compiler/types.jl @@ -134,6 +134,11 @@ struct InferenceParams end end +struct InferenceCaches{T, S} + inf_cache::T + mi_cache::S +end + """ NativeInterpreter @@ -187,6 +192,11 @@ get_inference_cache(ni::NativeInterpreter) = ni.cache code_cache(ni::NativeInterpreter) = WorldView(GLOBAL_CI_CACHE, ni.world) +InferenceCaches(ni::NativeInterpreter) = + InferenceCaches( + get_inference_cache(ni), + WorldView(code_cache(ni), ni.world)) + """ lock_mi_inference(ni::NativeInterpreter, mi::MethodInstance)