diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index 17141cf8eaec7..09c550b09029d 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -57,7 +57,7 @@ struct InvokeCase end struct InliningCase - sig # ::Type + sig # Type item # Union{InliningTodo, MethodInstance, ConstantCase} function InliningCase(@nospecialize(sig), @nospecialize(item)) @assert isa(item, Union{InliningTodo, InvokeCase, ConstantCase}) "invalid inlining item" @@ -67,10 +67,10 @@ end struct UnionSplit fully_covered::Bool - atype # ::Type + atype::DataType cases::Vector{InliningCase} bbs::Vector{Int} - UnionSplit(fully_covered::Bool, atype, cases::Vector{InliningCase}) = + UnionSplit(fully_covered::Bool, atype::DataType, cases::Vector{InliningCase}) = new(fully_covered, atype, cases, Int[]) end @@ -474,12 +474,11 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, @assert length(bbs) >= length(cases) for i in 1:length(cases) ithcase = cases[i] - metharg = ithcase.sig + metharg = ithcase.sig::DataType # checked within `handle_cases!` case = ithcase.item next_cond_bb = bbs[i] - @assert isa(metharg, DataType) cond = true - aparams, mparams = atype.parameters::SimpleVector, metharg.parameters::SimpleVector + aparams, mparams = atype.parameters, metharg.parameters @assert length(aparams) == length(mparams) if i != length(cases) || !fully_covered || (!params.trust_inference && isdispatchtuple(cases[i].sig)) @@ -1222,7 +1221,6 @@ function analyze_single_call!( end end - atype = argtypes_to_type(argtypes) if handled_all_cases && revisit_idx !== nothing # If there's only one case that's not a dispatchtuple, we can @@ -1234,7 +1232,7 @@ function analyze_single_call!( # cases are split off from an ::Any typed fallback. (i, j) = revisit_idx match = infos[i].results[j] - handled_all_cases &= handle_match!(match, argtypes, flag, state, cases) + handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true) elseif length(cases) == 0 && only_method isa Method # if the signature is fully covered and there is only one applicable method, # we can try to inline it even if the signature is not a dispatch tuple. @@ -1248,9 +1246,7 @@ function analyze_single_call!( @assert length(meth) == 1 match = meth[1] end - item = analyze_method!(match, argtypes, flag, state) - item === nothing && return nothing - push!(cases, InliningCase(match.spec_types, item)) + handle_match!(match, argtypes, flag, state, cases, true) || return nothing any_covers_full = handled_all_cases = match.fully_covers end @@ -1290,7 +1286,7 @@ function handle_const_call!( handled_all_cases &= handle_inf_result!(result, argtypes, flag, state, cases) else @assert result === nothing - handled_all_cases &= isdispatchtuple(match.spec_types) && handle_match!(match, argtypes, flag, state, cases) + handled_all_cases &= handle_match!(match, argtypes, flag, state, cases) end end end @@ -1298,13 +1294,13 @@ function handle_const_call!( # if the signature is fully covered and there is only one applicable method, # we can try to inline it even if the signature is not a dispatch tuple atype = argtypes_to_type(argtypes) - if length(cases) == 0 && length(results) == 1 && isa(results[1], InferenceResult) - (; mi) = item = InliningTodo(results[1]::InferenceResult, argtypes) - state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) - validate_sparams(mi.sparam_vals) || return nothing - item === nothing && return nothing - push!(cases, InliningCase(mi.specTypes, item)) - any_covers_full = handled_all_cases = atype <: mi.specTypes + if length(cases) == 0 + length(results) == 1 || return nothing + result = results[1] + isa(result, InferenceResult) || return nothing + handle_inf_result!(result, argtypes, flag, state, cases, true) || return nothing + spec_types = cases[1].sig + any_covers_full = handled_all_cases = atype <: spec_types end handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params) @@ -1312,8 +1308,9 @@ end function handle_match!( match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState, - cases::Vector{InliningCase}) + cases::Vector{InliningCase}, allow_abstract::Bool = false) spec_types = match.spec_types + allow_abstract || isdispatchtuple(spec_types) || return false item = analyze_method!(match, argtypes, flag, state) item === nothing && return false _any(case->case.sig === spec_types, cases) && return true @@ -1323,10 +1320,10 @@ end function handle_inf_result!( result::InferenceResult, argtypes::Vector{Any}, flag::UInt8, state::InliningState, - cases::Vector{InliningCase}) + cases::Vector{InliningCase}, allow_abstract::Bool = false) (; mi) = item = InliningTodo(result, argtypes) spec_types = mi.specTypes - isdispatchtuple(spec_types) || return false + allow_abstract || isdispatchtuple(spec_types) || return false validate_sparams(mi.sparam_vals) || return false state.mi_cache !== nothing && (item = resolve_todo(item, state, flag)) item === nothing && return false @@ -1351,6 +1348,8 @@ function handle_cases!(ir::IRCode, idx::Int, stmt::Expr, @nospecialize(atype), if fully_covered && length(cases) == 1 handle_single_case!(ir, idx, stmt, cases[1].item, todo, params) elseif length(cases) > 0 + isa(atype, DataType) || return nothing + all(case::InliningCase->isa(case.sig, DataType), cases) || return nothing push!(todo, idx=>UnionSplit(fully_covered, atype, cases)) end return nothing