Skip to content

Commit

Permalink
inference: enable constant propagation for invoked calls, fixes #41024
Browse files Browse the repository at this point in the history
 (#41383)

* inference: enable constant propagation for `invoke`d calls, fixes #41024

Especially useful for defining mixins with typed interface fields, e.g.
```julia
abstract type AbstractInterface end # mixin, which expects common field `x::Int`
function Base.getproperty(x::AbstractInterface, sym::Symbol)
    if sym === :x
        return getfield(x, sym)::Int # inferred field
    else
        return getfield(x, sym)      # fallback
    end
end

abstract type AbstractInterfaceExtended <: AbstractInterface end # extended mixin, which expects additional common field `y::Rational{Int}`
function Base.getproperty(x::AbstractInterfaceExtended, sym::Symbol)
    if sym === :y
        return getfield(x, sym)::Rational{Int}
    end
    return Base.@invoke getproperty(x::AbstractInterface, sym::Symbol)
end
```

As a bonus, inliner is able to use `InferenceResult` as a fast inlining
pass for constant-prop'ed `invoke`s

* improve compile-time latency

* Update base/compiler/abstractinterpretation.jl

* Update base/compiler/abstractinterpretation.jl
  • Loading branch information
aviatesk authored Jun 30, 2021
1 parent 7c566b1 commit bc6da93
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 17 deletions.
51 changes: 37 additions & 14 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,8 @@ function abstract_call_unionall(argtypes::Vector{Any})
end

function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState)
ft = widenconst(argtype_by_index(argtypes, 2))
ft′ = argtype_by_index(argtypes, 2)
ft = widenconst(ft′)
ft === Bottom && return CallMeta(Bottom, false)
(types, isexact, isconcrete, istype) = instanceof_tfunc(argtype_by_index(argtypes, 3))
types === Bottom && return CallMeta(Bottom, false)
Expand All @@ -1149,15 +1150,30 @@ function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv:
nargtype = Tuple{ft, nargtype.parameters...}
argtype = Tuple{ft, argtype.parameters...}
result = findsup(types, method_table(interp))
if result === nothing
return CallMeta(Any, false)
end
result === nothing && return CallMeta(Any, false)
method, valid_worlds = result
update_valid_age!(sv, valid_worlds)
(ti, env::SimpleVector) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), nargtype, method.sig)::SimpleVector
rt, edge = typeinf_edge(interp, method, ti, env, sv)
(; rt, edge) = result = abstract_call_method(interp, method, ti, env, false, sv)
edge !== nothing && add_backedge!(edge::MethodInstance, sv)
return CallMeta(rt, InvokeCallInfo(MethodMatch(ti, env, method, argtype <: method.sig)))
match = MethodMatch(ti, env, method, argtype <: method.sig)
# try constant propagation with manual inlinings of some of the heuristics
# since some checks within `abstract_call_method_with_const_args` seem a bit costly
const_prop_entry_heuristic(interp, result, sv) || return CallMeta(rt, InvokeCallInfo(match, nothing))
argtypes′ = argtypes[4:end]
const_prop_argument_heuristic(interp, argtypes′) || const_prop_rettype_heuristic(interp, rt) || return CallMeta(rt, InvokeCallInfo(match, nothing))
pushfirst!(argtypes′, ft)
# # typeintersect might have narrowed signature, but the accuracy gain doesn't seem worth the cost involved with the lattice comparisons
# for i in 1:length(argtypes′)
# t, a = ti.parameters[i], argtypes′[i]
# argtypes′[i] = t ⊑ a ? t : a
# end
const_rt, const_result = abstract_call_method_with_const_args(interp, result, argtype_to_function(ft′), argtypes′, match, sv, false)
if const_rt !== rt && const_rt rt
return CallMeta(const_rt, InvokeCallInfo(match, const_result))
else
return CallMeta(rt, InvokeCallInfo(match, nothing))
end
end

# call where the function is known exactly
Expand Down Expand Up @@ -1291,17 +1307,12 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
#print("call ", e.args[1], argtypes, "\n\n")
ft = argtypes[1]
if isa(ft, Const)
f = ft.val
elseif isconstType(ft)
f = ft.parameters[1]
elseif isa(ft, DataType) && isdefined(ft, :instance)
f = ft.instance
elseif isa(ft, PartialOpaque)
f = argtype_to_function(ft)
if isa(ft, PartialOpaque)
return abstract_call_opaque_closure(interp, ft, argtypes[2:end], sv)
elseif isa(unwrap_unionall(ft), DataType) && unwrap_unionall(ft).name === typename(Core.OpaqueClosure)
return CallMeta(rewrap_unionall(unwrap_unionall(ft).parameters[2], ft), false)
else
elseif f === nothing
# non-constant function, but the number of arguments is known
# and the ft is not a Builtin or IntrinsicFunction
if typeintersect(widenconst(ft), Union{Builtin, Core.OpaqueClosure}) != Union{}
Expand All @@ -1313,6 +1324,18 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{
return abstract_call_known(interp, f, fargs, argtypes, sv, max_methods)
end

function argtype_to_function(@nospecialize(ft))
if isa(ft, Const)
return ft.val
elseif isconstType(ft)
return ft.parameters[1]
elseif isa(ft, DataType) && isdefined(ft, :instance)
return ft.instance
else
return nothing
end
end

function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool)
isref = false
if T === Bottom
Expand Down
16 changes: 13 additions & 3 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1049,12 +1049,12 @@ is_builtin(s::Signature) =
isa(s.f, Builtin) ||
s.ft Builtin

function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, info::InvokeCallInfo,
function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result)::InvokeCallInfo,
state::InliningState, todo::Vector{Pair{Int, Any}})
stmt = ir.stmts[idx][:inst]
calltype = ir.stmts[idx][:type]

if !info.match.fully_covers
if !match.fully_covers
# TODO: We could union split out the signature check and continue on
return nothing
end
Expand All @@ -1064,7 +1064,17 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, info::InvokeCallIn
atypes = atypes[4:end]
pushfirst!(atypes, atype0)

result = analyze_method!(info.match, atypes, state, calltype)
if isa(result, InferenceResult)
item = InliningTodo(result, atypes, calltype)
validate_sparams(item.mi.sparam_vals) || return nothing
if argtypes_to_type(atypes) <: item.mi.def.sig
state.mi_cache !== nothing && (item = resolve_todo(item, state))
handle_single_case!(ir, stmt, idx, item, true, todo)
return nothing
end
end

result = analyze_method!(match, atypes, state, calltype)
handle_single_case!(ir, stmt, idx, result, true, todo)
return nothing
end
Expand Down
1 change: 1 addition & 0 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ method being processed.
"""
struct InvokeCallInfo
match::MethodMatch
result::Union{Nothing,InferenceResult}
end

struct OpaqueClosureCallInfo
Expand Down
40 changes: 40 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3355,3 +3355,43 @@ let
Expr(:opaque_closure_method, nothing, 2, LineNumberNode(0, nothing), ci)))(true, 1.0)
@test Base.return_types(oc, Tuple{}) == Any[Float64]
end

@testset "constant prop' on `invoke` calls" begin
m = Module()

# simple cases
@eval m begin
f(a::Any, sym::Bool) = sym ? Any : :any
f(a::Number, sym::Bool) = sym ? Number : :number
end
@test (@eval m Base.return_types((Any,)) do a
Base.@invoke f(a::Any, true::Bool)
end) == Any[Type{Any}]
@test (@eval m Base.return_types((Any,)) do a
Base.@invoke f(a::Number, true::Bool)
end) == Any[Type{Number}]
@test (@eval m Base.return_types((Any,)) do a
Base.@invoke f(a::Any, false::Bool)
end) == Any[Symbol]
@test (@eval m Base.return_types((Any,)) do a
Base.@invoke f(a::Number, false::Bool)
end) == Any[Symbol]

# https://github.com/JuliaLang/julia/issues/41024
@eval m begin
# mixin, which expects common field `x::Int`
abstract type AbstractInterface end
Base.getproperty(x::AbstractInterface, sym::Symbol) =
sym === :x ? getfield(x, sym)::Int :
return getfield(x, sym) # fallback

# extended mixin, which expects additional field `y::Rational{Int}`
abstract type AbstractInterfaceExtended <: AbstractInterface end
Base.getproperty(x::AbstractInterfaceExtended, sym::Symbol) =
sym === :y ? getfield(x, sym)::Rational{Int} :
return Base.@invoke getproperty(x::AbstractInterface, sym::Symbol)
end
@test (@eval m Base.return_types((AbstractInterfaceExtended,)) do x
x.x
end) == Any[Int]
end

0 comments on commit bc6da93

Please sign in to comment.