Skip to content

Commit

Permalink
inlining: avoid source deserialization by using volatile inference re…
Browse files Browse the repository at this point in the history
…sult (#51934)

Currently the inlining algorithm is allowed to use inferred source of
const-prop'ed call that is always locally available (since const-prop'
result isn't cached globally). For non const-prop'ed and globally cached
calls, however, it undergoes a more expensive process, making a
round-trip through serialized inferred source.

We can improve efficiency by bypassing the serialization round-trip for
newly-inferred and globally-cached frames. As these frames are never
cached locally, they can be viewed as volatile. This means we can use
their source destructively while inline-expanding them.

The benchmark results show that this optimization achieves 2-4% 
allocation reduction and about 5% speed up in the real-world-ish 
compilation targets (`allinference`).

Note that it would be more efficient to propagate `IRCode` object
directly and skip inflation from `CodeInfo` to `IRCode` as experimented
in #47137, but currently the round-trip through
`CodeInfo`-representation is necessary because it often leads to better
CFG simplification while `cfg_simplify!` being expensive (xref: #51960).
  • Loading branch information
aviatesk authored Nov 6, 2023
1 parent b723f41 commit fae6b78
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 91 deletions.
26 changes: 14 additions & 12 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
splitsigs = switchtupleunion(sig)
for sig_n in splitsigs
result = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, si, sv)
(; rt, edge, effects) = result
(; rt, edge, effects, volatile_inf_result) = result
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
this_arginfo = ArgInfo(fargs, this_argtypes)
const_call_result = abstract_call_method_with_const_args(interp,
result, f, this_arginfo, si, match, sv)
const_result = nothing
const_result = volatile_inf_result
if const_call_result !== nothing
if const_call_result.rt ₚ rt
rt = const_call_result.rt
Expand All @@ -90,7 +90,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
this_rt = widenwrappedconditional(this_rt)
else
result = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, si, sv)
(; rt, edge, effects) = result
(; rt, edge, effects, volatile_inf_result) = result
this_conditional = ignorelimited(rt)
this_rt = widenwrappedconditional(rt)
# try constant propagation with argtypes for this match
Expand All @@ -99,7 +99,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
this_arginfo = ArgInfo(fargs, this_argtypes)
const_call_result = abstract_call_method_with_const_args(interp,
result, f, this_arginfo, si, match, sv)
const_result = nothing
const_result = volatile_inf_result
if const_call_result !== nothing
this_const_conditional = ignorelimited(const_call_result.rt)
this_const_rt = widenwrappedconditional(const_call_result.rt)
Expand Down Expand Up @@ -621,7 +621,7 @@ function abstract_call_method(interp::AbstractInterpreter,
sparams = recomputed[2]::SimpleVector
end

(; rt, edge, effects) = typeinf_edge(interp, method, sig, sparams, sv)
(; rt, edge, effects, volatile_inf_result) = typeinf_edge(interp, method, sig, sparams, sv)

if edge === nothing
edgecycle = edgelimited = true
Expand All @@ -645,7 +645,7 @@ function abstract_call_method(interp::AbstractInterpreter,
end
end

return MethodCallResult(rt, edgecycle, edgelimited, edge, effects)
return MethodCallResult(rt, edgecycle, edgelimited, edge, effects, volatile_inf_result)
end

function edge_matches_sv(interp::AbstractInterpreter, frame::AbsIntState,
Expand Down Expand Up @@ -748,12 +748,14 @@ struct MethodCallResult
edgelimited::Bool
edge::Union{Nothing,MethodInstance}
effects::Effects
volatile_inf_result::Union{Nothing,VolatileInferenceResult}
function MethodCallResult(@nospecialize(rt),
edgecycle::Bool,
edgelimited::Bool,
edge::Union{Nothing,MethodInstance},
effects::Effects)
return new(rt, edgecycle, edgelimited, edge, effects)
effects::Effects,
volatile_inf_result::Union{Nothing,VolatileInferenceResult}=nothing)
return new(rt, edgecycle, edgelimited, edge, effects, volatile_inf_result)
end
end

Expand Down Expand Up @@ -1945,7 +1947,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
tienv = ccall(:jl_type_intersection_with_env, Any, (Any, Any), nargtype, method.sig)::SimpleVector
ti = tienv[1]; env = tienv[2]::SimpleVector
result = abstract_call_method(interp, method, ti, env, false, si, sv)
(; rt, edge, effects) = result
(; rt, edge, effects, volatile_inf_result) = result
match = MethodMatch(ti, env, method, argtype <: method.sig)
res = nothing
sig = match.spec_types
Expand All @@ -1962,7 +1964,7 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
invokecall = InvokeCall(types, lookupsig)
const_call_result = abstract_call_method_with_const_args(interp,
result, f, arginfo, si, match, sv, invokecall)
const_result = nothing
const_result = volatile_inf_result
if const_call_result !== nothing
if (𝕃ₚ, const_call_result.rt, rt)
(; rt, effects, const_result, edge) = const_call_result
Expand Down Expand Up @@ -2091,13 +2093,13 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
closure::PartialOpaque, arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, check::Bool=true)
sig = argtypes_to_type(arginfo.argtypes)
result = abstract_call_method(interp, closure.source::Method, sig, Core.svec(), false, si, sv)
(; rt, edge, effects) = result
(; rt, edge, effects, volatile_inf_result) = result
tt = closure.typ
sigT = (unwrap_unionall(tt)::DataType).parameters[1]
match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt))
𝕃ₚ = ipo_lattice(interp)
= (𝕃ₚ)
const_result = nothing
const_result = volatile_inf_result
if !result.edgecycle
const_call_result = abstract_call_method_with_const_args(interp, result,
nothing, arginfo, si, match, sv)
Expand Down
9 changes: 6 additions & 3 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,10 @@ to enable flow-sensitive analysis.
"""
const VarTable = Vector{VarState}

const CACHE_MODE_NULL = 0x00
const CACHE_MODE_GLOBAL = 0x01 << 0
const CACHE_MODE_LOCAL = 0x01 << 1
const CACHE_MODE_NULL = 0x00 # not cached, without optimization
const CACHE_MODE_GLOBAL = 0x01 << 0 # cached globally, optimization allowed
const CACHE_MODE_LOCAL = 0x01 << 1 # cached locally, optimization allowed
const CACHE_MODE_VOLATILE = 0x01 << 2 # not cached, optimization allowed

mutable struct InferenceState
#= information about this method instance =#
Expand Down Expand Up @@ -467,6 +468,8 @@ function convert_cache_mode(cache_mode::Symbol)
return CACHE_MODE_GLOBAL
elseif cache_mode === :local
return CACHE_MODE_LOCAL
elseif cache_mode === :volatile
return CACHE_MODE_VOLATILE
elseif cache_mode === :no
return CACHE_MODE_NULL
end
Expand Down
16 changes: 1 addition & 15 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,11 @@ is_source_inferred(@nospecialize src::MaybeCompressed) =

function inlining_policy(interp::AbstractInterpreter,
@nospecialize(src), @nospecialize(info::CallInfo), stmt_flag::UInt32, mi::MethodInstance,
argtypes::Vector{Any})
_::Vector{Any})
if isa(src, MaybeCompressed)
is_source_inferred(src) || return nothing
src_inlineable = is_stmt_inline(stmt_flag) || is_inlineable(src)
return src_inlineable ? src : nothing
elseif src === nothing && is_stmt_inline(stmt_flag)
# if this statement is forced to be inlined, make an additional effort to find the
# inferred source in the local cache
# we still won't find a source for recursive call because the "single-level" inlining
# seems to be more trouble and complex than it's worth
inf_result = cache_lookup(optimizer_lattice(interp), mi, argtypes, get_inference_cache(interp))
inf_result === nothing && return nothing
src = inf_result.src
if isa(src, CodeInfo)
src_inferred = is_source_inferred(src)
return src_inferred ? src : nothing
else
return nothing
end
elseif isa(src, IRCode)
return src
elseif isa(src, SemiConcreteResult)
Expand Down
102 changes: 62 additions & 40 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -842,51 +842,57 @@ function compileable_specialization(match::MethodMatch, effects::Effects,
return compileable_specialization(mi, effects, et, info; compilesig_invokes)
end

struct CachedResult
struct InferredResult
src::Any
effects::Effects
CachedResult(@nospecialize(src), effects::Effects) = new(src, effects)
InferredResult(@nospecialize(src), effects::Effects) = new(src, effects)
end
@inline function get_cached_result(state::InliningState, mi::MethodInstance)
code = get(code_cache(state), mi, nothing)
if code isa CodeInstance
if use_const_api(code)
# in this case function can be inlined to a constant
return ConstantCase(quoted(code.rettype_const))
else
src = @atomic :monotonic code.inferred
end
src = @atomic :monotonic code.inferred
effects = decode_effects(code.ipo_purity_bits)
return CachedResult(src, effects)
return InferredResult(src, effects)
end
return CachedResult(nothing, Effects())
return InferredResult(nothing, Effects())
end
@inline function get_local_result(inf_result::InferenceResult)
effects = inf_result.ipo_effects
if is_foldable_nothrow(effects)
res = inf_result.result
if isa(res, Const) && is_inlineable_constant(res.val)
# use constant calling convention
return ConstantCase(quoted(res.val))
end
end
return InferredResult(inf_result.src, effects)
end

# the general resolver for usual and const-prop'ed calls
function resolve_todo(mi::MethodInstance, result::Union{MethodMatch,InferenceResult},
argtypes::Vector{Any}, @nospecialize(info::CallInfo), flag::UInt32,
state::InliningState; invokesig::Union{Nothing,Vector{Any}}=nothing)
function resolve_todo(mi::MethodInstance, result::Union{Nothing,InferenceResult,VolatileInferenceResult},
argtypes::Vector{Any}, @nospecialize(info::CallInfo), flag::UInt32, state::InliningState;
invokesig::Union{Nothing,Vector{Any}}=nothing)
et = InliningEdgeTracker(state, invokesig)

preserve_local_sources = true
if isa(result, InferenceResult)
src = result.src
effects = result.ipo_effects
if is_foldable_nothrow(effects)
res = result.result
if isa(res, Const) && is_inlineable_constant(res.val)
# use constant calling convention
add_inlining_backedge!(et, mi)
return ConstantCase(quoted(res.val))
end
end
inferred_result = get_local_result(result)
elseif isa(result, VolatileInferenceResult)
inferred_result = get_local_result(result.inf_result)
# volatile inference result can be inlined destructively
preserve_local_sources = OptimizationParams(state.interp).preserve_local_sources
else
cached_result = get_cached_result(state, mi)
if cached_result isa ConstantCase
add_inlining_backedge!(et, mi)
return cached_result
end
(; src, effects) = cached_result
inferred_result = get_cached_result(state, mi)
end
if inferred_result isa ConstantCase
add_inlining_backedge!(et, mi)
return inferred_result
end
(; src, effects) = inferred_result

# the duplicated check might have been done already within `analyze_method!`, but still
# we need it here too since we may come here directly using a constant-prop' result
Expand All @@ -900,7 +906,8 @@ function resolve_todo(mi::MethodInstance, result::Union{MethodMatch,InferenceRes
compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes)

add_inlining_backedge!(et, mi)
return InliningTodo(mi, retrieve_ir_for_inlining(mi, src), effects)
ir = retrieve_ir_for_inlining(mi, src, preserve_local_sources)
return InliningTodo(mi, ir, effects)
end

# the special resolver for :invoke-d call
Expand Down Expand Up @@ -944,7 +951,8 @@ end

function analyze_method!(match::MethodMatch, argtypes::Vector{Any},
@nospecialize(info::CallInfo), flag::UInt32, state::InliningState;
allow_typevars::Bool, invokesig::Union{Nothing,Vector{Any}}=nothing)
allow_typevars::Bool, invokesig::Union{Nothing,Vector{Any}}=nothing,
volatile_inf_result::Union{Nothing,VolatileInferenceResult}=nothing)
method = match.method
spec_types = match.spec_types

Expand Down Expand Up @@ -973,15 +981,25 @@ function analyze_method!(match::MethodMatch, argtypes::Vector{Any},
# Get the specialization for this method signature
# (later we will decide what to do with it)
mi = specialize_method(match)
return resolve_todo(mi, match, argtypes, info, flag, state; invokesig)
return resolve_todo(mi, volatile_inf_result, argtypes, info, flag, state; invokesig)
end

function retrieve_ir_for_inlining(mi::MethodInstance, src::String)
function retrieve_ir_for_inlining(mi::MethodInstance, src::String, ::Bool=true)
src = _uncompressed_ir(mi.def, src)
return inflate_ir!(src, mi)
end
retrieve_ir_for_inlining(mi::MethodInstance, src::CodeInfo) = inflate_ir(src, mi)
retrieve_ir_for_inlining(mi::MethodInstance, ir::IRCode) = copy(ir)
function retrieve_ir_for_inlining(mi::MethodInstance, src::CodeInfo, preserve_local_sources::Bool=true)
if preserve_local_sources
src = copy(src)
end
return inflate_ir!(src, mi)
end
function retrieve_ir_for_inlining(::MethodInstance, ir::IRCode, preserve_local_sources::Bool=true)
if preserve_local_sources
ir = copy(ir)
end
return ir
end

function flags_for_effects(effects::Effects)
flags::UInt32 = 0
Expand Down Expand Up @@ -1203,7 +1221,8 @@ function handle_invoke_call!(todo::Vector{Pair{Int,Any}},
return nothing
end
end
item = analyze_method!(match, argtypes, info, flag, state; allow_typevars=false, invokesig)
volatile_inf_result = result isa VolatileInferenceResult ? result : nothing
item = analyze_method!(match, argtypes, info, flag, state; allow_typevars=false, invokesig, volatile_inf_result)
end
end
handle_single_case!(todo, ir, idx, stmt, item, true)
Expand Down Expand Up @@ -1343,8 +1362,8 @@ function handle_any_const_result!(cases::Vector{InliningCase},
if isa(result, ConstPropResult)
return handle_const_prop_result!(cases, result, argtypes, info, flag, state; allow_abstract, allow_typevars)
else
@assert result === nothing
return handle_match!(cases, match, argtypes, info, flag, state; allow_abstract, allow_typevars)
@assert result === nothing || result isa VolatileInferenceResult
return handle_match!(cases, match, argtypes, info, flag, state; allow_abstract, allow_typevars, volatile_inf_result = result)
end
end

Expand Down Expand Up @@ -1475,14 +1494,14 @@ end
function handle_match!(cases::Vector{InliningCase},
match::MethodMatch, argtypes::Vector{Any}, @nospecialize(info::CallInfo), flag::UInt32,
state::InliningState;
allow_abstract::Bool, allow_typevars::Bool)
allow_abstract::Bool, allow_typevars::Bool, volatile_inf_result::Union{Nothing,VolatileInferenceResult})
spec_types = match.spec_types
allow_abstract || isdispatchtuple(spec_types) || return false
# We may see duplicated dispatch signatures here when a signature gets widened
# during abstract interpretation: for the purpose of inlining, we can just skip
# processing this dispatch candidate (unless unmatched type parameters are present)
!allow_typevars && any(case::InliningCase->case.sig === spec_types, cases) && return true
item = analyze_method!(match, argtypes, info, flag, state; allow_typevars)
item = analyze_method!(match, argtypes, info, flag, state; allow_typevars, volatile_inf_result)
item === nothing && return false
push!(cases, InliningCase(spec_types, item))
return true
Expand Down Expand Up @@ -1512,7 +1531,9 @@ function semiconcrete_result_item(result::SemiConcreteResult,
return compileable_specialization(mi, result.effects, et, info;
compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes)
else
return InliningTodo(mi, retrieve_ir_for_inlining(mi, result.ir), result.effects)
preserve_local_sources = OptimizationParams(state.interp).preserve_local_sources
ir = retrieve_ir_for_inlining(mi, result.ir, preserve_local_sources)
return InliningTodo(mi, ir, result.effects)
end
end

Expand Down Expand Up @@ -1587,7 +1608,9 @@ function handle_opaque_closure_call!(todo::Vector{Pair{Int,Any}},
if isa(result, SemiConcreteResult)
item = semiconcrete_result_item(result, info, flag, state)
else
item = analyze_method!(info.match, sig.argtypes, info, flag, state; allow_typevars=false)
@assert result === nothing || result isa VolatileInferenceResult
volatile_inf_result = result
item = analyze_method!(info.match, sig.argtypes, info, flag, state; allow_typevars=false, volatile_inf_result)
end
end
handle_single_case!(todo, ir, idx, stmt, item)
Expand All @@ -1612,8 +1635,7 @@ function handle_modifyfield!_call!(ir::IRCode, idx::Int, stmt::Expr, info::Modif
end

function handle_finalizer_call!(ir::IRCode, idx::Int, stmt::Expr, info::FinalizerInfo,
state::InliningState)

state::InliningState)
# Finalizers don't return values, so if their execution is not observable,
# we can just not register them
if is_removable_if_unused(info.effects)
Expand Down
8 changes: 8 additions & 0 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ struct SemiConcreteResult <: ConstResult
effects::Effects
end

# XXX Technically this does not represent a result of constant inference, but rather that of
# regular edge inference. It might be more appropriate to rename `ConstResult` and
# `ConstCallInfo` to better reflect the fact that they represent either of local or
# volatile inference result.
struct VolatileInferenceResult <: ConstResult
inf_result::InferenceResult
end

"""
info::ConstCallInfo <: CallInfo
Expand Down
Loading

2 comments on commit fae6b78

@aviatesk
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nanosoldier runbenchmarks("inference", vs="@af0bd56f83f305ca941b3fe28acb8b2babcd6d54")

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your benchmark job has completed - no performance regressions were detected. A full report can be found here.

Please sign in to comment.