diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index ab93375db4d0e..f23f1029bb5ab 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -241,7 +241,7 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int, push!(from_bbs, length(state.new_cfg_blocks)) # TODO: Right now we unconditionally generate a fallback block # in case of subtyping errors - This is probably unnecessary. - if i != length(cases) || (!fully_covered || (!params.trust_inference && isdispatchtuple(cases[i].sig))) + if i != length(cases) || (!fully_covered || (!params.trust_inference)) # This block will have the next condition or the final else case push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx))) push!(state.new_cfg_blocks[cond_bb].succs, length(state.new_cfg_blocks)) @@ -313,7 +313,6 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector spec = item.spec::ResolvedInliningSpec sparam_vals = item.mi.sparam_vals def = item.mi.def::Method - inline_cfg = spec.ir.cfg linetable_offset::Int32 = length(linetable) # Append the linetable of the inlined function to our line table inlined_at = Int(compact.result[idx][:line]) @@ -471,8 +470,9 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, join_bb = bbs[end] pn = PhiNode() local bb = compact.active_result_bb - @assert length(bbs) >= length(cases) - for i in 1:length(cases) + ncases = length(cases) + @assert length(bbs) >= ncases + for i = 1:ncases ithcase = cases[i] mtype = ithcase.sig::DataType # checked within `handle_cases!` case = ithcase.item @@ -480,8 +480,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, cond = true nparams = fieldcount(atype) @assert nparams == fieldcount(mtype) - if i != length(cases) || !fully_covered || - (!params.trust_inference && isdispatchtuple(cases[i].sig)) + if i != ncases || !fully_covered || !params.trust_inference for i = 1:nparams a, m = fieldtype(atype, i), fieldtype(mtype, i) # If this is always true, we don't need to check for it @@ -538,7 +537,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, bb += 1 # We're now in the fall through block, decide what to do if fully_covered - if !params.trust_inference && isdispatchtuple(cases[end].sig) + if !params.trust_inference e = Expr(:call, GlobalRef(Core, :throw), FATAL_TYPE_BOUND_ERROR) insert_node_here!(compact, NewInstruction(e, Union{}, line)) insert_node_here!(compact, NewInstruction(ReturnNode(), Union{}, line)) @@ -561,7 +560,7 @@ function batch_inline!(todo::Vector{Pair{Int, Any}}, ir::IRCode, linetable::Vect state = CFGInliningState(ir) for (idx, item) in todo if isa(item, UnionSplit) - cfg_inline_unionsplit!(ir, idx, item::UnionSplit, state, params) + cfg_inline_unionsplit!(ir, idx, item, state, params) else item = item::InliningTodo spec = item.spec::ResolvedInliningSpec @@ -1175,12 +1174,8 @@ function analyze_single_call!( sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}}) argtypes = sig.argtypes cases = InliningCase[] - local only_method = nothing # keep track of whether there is one matching method - local meth::MethodLookupResult + local any_fully_covered = false local handled_all_cases = true - local any_covers_full = false - local revisit_idx = nothing - for i in 1:length(infos) meth = infos[i].results if meth.ambig @@ -1191,66 +1186,20 @@ function analyze_single_call!( # No applicable methods; try next union split handled_all_cases = false continue - else - if length(meth) == 1 && only_method !== false - if only_method === nothing - only_method = meth[1].method - elseif only_method !== meth[1].method - only_method = false - end - else - only_method = false - end end - for (j, match) in enumerate(meth) - any_covers_full |= match.fully_covers - if !isdispatchtuple(match.spec_types) - if !match.fully_covers - handled_all_cases = false - continue - end - if revisit_idx === nothing - revisit_idx = (i, j) - else - handled_all_cases = false - revisit_idx = nothing - end - else - handled_all_cases &= handle_match!(match, argtypes, flag, state, cases) - end + for match in meth + handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true) + any_fully_covered |= match.fully_covers end end - atype = argtypes_to_type(argtypes) - if handled_all_cases && revisit_idx !== nothing - # If there's only one case that's not a dispatchtuple, we can - # still unionsplit by visiting all the other cases first. - # This is useful for code like: - # foo(x::Int) = 1 - # foo(@nospecialize(x::Any)) = 2 - # where we where only a small number of specific dispatchable - # cases are split off from an ::Any typed fallback. - (i, j) = revisit_idx - match = infos[i].results[j] - handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true) - elseif length(cases) == 0 && only_method isa Method - # if the signature is fully covered and there is only one applicable method, - # we can try to inline it even if the signature is not a dispatch tuple. - # -- But don't try it if we already tried to handle the match in the revisit_idx - # case, because that'll (necessarily) be the same method. - if length(infos) > 1 - (metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), - atype, only_method.sig)::SimpleVector - match = MethodMatch(metharg, methsp::SimpleVector, only_method, true) - else - @assert length(meth) == 1 - match = meth[1] - end - handle_match!(match, argtypes, flag, state, cases, true) || return nothing - any_covers_full = handled_all_cases = match.fully_covers + if !handled_all_cases + # if we've not seen all candidates, union split is valid only for dispatch tuples + filter!(case::InliningCase->isdispatchtuple(case.sig), cases) end - handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params) + handle_cases!(ir, idx, stmt, argtypes_to_type(argtypes), cases, + handled_all_cases & any_fully_covered, todo, state.params) end # similar to `analyze_single_call!`, but with constant results @@ -1261,8 +1210,8 @@ function handle_const_call!( (; call, results) = cinfo infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches cases = InliningCase[] + local any_fully_covered = false local handled_all_cases = true - local any_covers_full = false local j = 0 for i in 1:length(infos) meth = infos[i].results @@ -1278,32 +1227,26 @@ function handle_const_call!( for match in meth j += 1 result = results[j] - any_covers_full |= match.fully_covers + any_fully_covered |= match.fully_covers if isa(result, ConstResult) case = const_result_item(result, state) push!(cases, InliningCase(result.mi.specTypes, case)) elseif isa(result, InferenceResult) - handled_all_cases &= handle_inf_result!(result, argtypes, flag, state, cases) + handled_all_cases &= handle_inf_result!(result, argtypes, flag, state, cases, true) else @assert result === nothing - handled_all_cases &= handle_match!(match, argtypes, flag, state, cases) + handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true) end end end - # if the signature is fully covered and there is only one applicable method, - # we can try to inline it even if the signature is not a dispatch tuple - atype = argtypes_to_type(argtypes) - if length(cases) == 0 - length(results) == 1 || return nothing - result = results[1] - isa(result, InferenceResult) || return nothing - handle_inf_result!(result, argtypes, flag, state, cases, true) || return nothing - spec_types = cases[1].sig - any_covers_full = handled_all_cases = atype <: spec_types + if !handled_all_cases + # if we've not seen all candidates, union split is valid only for dispatch tuples + filter!(case::InliningCase->isdispatchtuple(case.sig), cases) end - handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params) + handle_cases!(ir, idx, stmt, argtypes_to_type(argtypes), cases, + handled_all_cases & any_fully_covered, todo, state.params) end function handle_match!( @@ -1311,9 +1254,12 @@ function handle_match!( cases::Vector{InliningCase}, allow_abstract::Bool = false) spec_types = match.spec_types allow_abstract || isdispatchtuple(spec_types) || return false + # we may see duplicated dispatch signatures here when a signature gets widened + # during abstract interpretation: for the purpose of inlining, we can just skip + # processing this dispatch candidate + _any(case->case.sig === spec_types, cases) && return true item = analyze_method!(match, argtypes, flag, state) item === nothing && return false - _any(case->case.sig === spec_types, cases) && return true push!(cases, InliningCase(spec_types, item)) return true end @@ -1349,7 +1295,24 @@ function handle_cases!(ir::IRCode, idx::Int, stmt::Expr, @nospecialize(atype), handle_single_case!(ir, idx, stmt, cases[1].item, todo, params) elseif length(cases) > 0 isa(atype, DataType) || return nothing - all(case::InliningCase->isa(case.sig, DataType), cases) || return nothing + # `ir_inline_unionsplit!` is going to generate `isa` checks corresponding to the + # signatures of union-split dispatch candidates in order to simulate the dispatch + # semantics, and inline their bodies into each `isa`-conditional block -- and since + # we may deal with abstract union-split callsites here, these dispatch candidates + # need to be sorted in order of their signature specificity. + # Fortunately, ml_matches already sorted them in that way, so we can just process + # them in order, as far as we haven't changed their order somewhere up to this point. + ncases = length(cases) + for i = 1:ncases + sigᵢ = cases[i].sig + isa(sigᵢ, DataType) || return nothing + for j = i+1:ncases + sigⱼ = cases[j].sig + # since we already bail out from ambiguous case, we can use `morespecific` as + # a strict total order of specificity (in a case when they don't have a type intersection) + @assert !hasintersect(sigᵢ, sigⱼ) || morespecific(sigᵢ, sigⱼ) "invalid order of dispatch candidate" + end + end push!(todo, idx=>UnionSplit(fully_covered, atype, cases)) end return nothing @@ -1445,7 +1408,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState) analyze_single_call!(ir, idx, stmt, infos, flag, sig, state, todo) end - todo + + return todo end function linear_inline_eligible(ir::IRCode) diff --git a/base/sort.jl b/base/sort.jl index d26e9a4b09332..981eea35d96ab 100644 --- a/base/sort.jl +++ b/base/sort.jl @@ -5,7 +5,7 @@ module Sort import ..@__MODULE__, ..parentmodule const Base = parentmodule(@__MODULE__) using .Base.Order -using .Base: copymutable, LinearIndices, length, (:), +using .Base: copymutable, LinearIndices, length, (:), iterate, eachindex, axes, first, last, similar, zip, OrdinalRange, AbstractVector, @inbounds, AbstractRange, @eval, @inline, Vector, @noinline, AbstractMatrix, AbstractUnitRange, isless, identity, eltype, >, <, <=, >=, |, +, -, *, !, diff --git a/test/compiler/inline.jl b/test/compiler/inline.jl index fa4425893767c..88e8ca0e62a41 100644 --- a/test/compiler/inline.jl +++ b/test/compiler/inline.jl @@ -810,6 +810,76 @@ let @test invoke(Any[10]) === false end +# test union-split, non-dispatchtuple callsite inlining + +@constprop :none @noinline abstract_unionsplit(@nospecialize x::Any) = Base.inferencebarrier(:Any) +@constprop :none @noinline abstract_unionsplit(@nospecialize x::Number) = Base.inferencebarrier(:Number) +let src = code_typed1((Any,)) do x + abstract_unionsplit(x) + end + @test count(isinvoke(:abstract_unionsplit), src.code) == 2 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end +let src = code_typed1((Union{Type,Number},)) do x + abstract_unionsplit(x) + end + @test count(isinvoke(:abstract_unionsplit), src.code) == 2 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end + +@constprop :none @noinline abstract_unionsplit_fallback(@nospecialize x::Type) = Base.inferencebarrier(:Any) +@constprop :none @noinline abstract_unionsplit_fallback(@nospecialize x::Number) = Base.inferencebarrier(:Number) +let src = code_typed1((Any,)) do x + abstract_unionsplit_fallback(x) + end + @test count(isinvoke(:abstract_unionsplit_fallback), src.code) == 2 + @test count(iscall((src, abstract_unionsplit_fallback)), src.code) == 1 # fallback dispatch +end +let src = code_typed1((Union{Type,Number},)) do x + abstract_unionsplit_fallback(x) + end + @test count(isinvoke(:abstract_unionsplit_fallback), src.code) == 2 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end + +@constprop :aggressive @inline abstract_unionsplit(c, @nospecialize x::Any) = (c && println("erase me"); typeof(x)) +@constprop :aggressive @inline abstract_unionsplit(c, @nospecialize x::Number) = (c && println("erase me"); typeof(x)) +let src = code_typed1((Any,)) do x + abstract_unionsplit(false, x) + end + @test count(iscall((src, typeof)), src.code) == 2 + @test count(isinvoke(:println), src.code) == 0 + @test count(iscall((src, println)), src.code) == 0 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end +let src = code_typed1((Union{Type,Number},)) do x + abstract_unionsplit(false, x) + end + @test count(iscall((src, typeof)), src.code) == 2 + @test count(isinvoke(:println), src.code) == 0 + @test count(iscall((src, println)), src.code) == 0 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end + +@constprop :aggressive @inline abstract_unionsplit_fallback(c, @nospecialize x::Type) = (c && println("erase me"); typeof(x)) +@constprop :aggressive @inline abstract_unionsplit_fallback(c, @nospecialize x::Number) = (c && println("erase me"); typeof(x)) +let src = code_typed1((Any,)) do x + abstract_unionsplit_fallback(false, x) + end + @test count(iscall((src, typeof)), src.code) == 2 + @test count(isinvoke(:println), src.code) == 0 + @test count(iscall((src, println)), src.code) == 0 + @test count(iscall((src, abstract_unionsplit_fallback)), src.code) == 1 # fallback dispatch +end +let src = code_typed1((Union{Type,Number},)) do x + abstract_unionsplit_fallback(false, x) + end + @test count(iscall((src, typeof)), src.code) == 2 + @test count(isinvoke(:println), src.code) == 0 + @test count(iscall((src, println)), src.code) == 0 + @test count(iscall((src, abstract_unionsplit)), src.code) == 0 # no fallback dispatch +end + # issue 43104 @inline isGoodType(@nospecialize x::Type) = @@ -1097,11 +1167,11 @@ end global x44200::Int = 0 function f44200() - global x = 0 - while x < 10 - x += 1 + global x44200 = 0 + while x44200 < 10 + x44200 += 1 end - x + x44200 end let src = code_typed1(f44200) @test count(x -> isa(x, Core.PiNode), src.code) == 0