From ab1a6d65ba1caf393d5cc1628f2449483b7ca625 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Tue, 11 Jul 2023 14:18:10 +0900 Subject: [PATCH] SROA: generalize `unswitchtupleunion` optimization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit improves SROA pass by extending the `unswitchtupleunion` optimization to handle the general parametric types, e.g.: ```julia julia> struct A{T} x::T end; julia> function foo(a1, a2, c) t = c ? A(a1) : A(a2) return getfield(t, :x) end; julia> only(Base.code_ircode(foo, (Int,Float64,Bool); optimize_until="SROA")) ``` > Before ``` 2 1 ─ goto #3 if not _4 │ 2 ─ %2 = %new(A{Int64}, _2)::A{Int64} │╻ A └── goto #4 │ 3 ─ %4 = %new(A{Float64}, _3)::A{Float64} │╻ A 4 ┄ %5 = φ (#2 => %2, #3 => %4)::Union{A{Float64}, A{Int64}} │ 3 │ %6 = Main.getfield(%5, :x)::Union{Float64, Int64} │ └── return %6 │ => Union{Float64, Int64} ``` > After ``` julia> only(Base.code_ircode(foo, (Int,Float64,Bool); optimize_until="SROA")) 2 1 ─ goto #3 if not _4 │ 2 ─ nothing::A{Int64} │╻ A └── goto #4 │ 3 ─ nothing::A{Float64} │╻ A 4 ┄ %8 = φ (#2 => _2, #3 => _3)::Union{Float64, Int64} │ │ nothing::Union{A{Float64}, A{Int64}} 3 │ %6 = %8::Union{Float64, Int64} │ └── return %6 │ => Union{Float64, Int64} ``` --- base/compiler/ssair/passes.jl | 4 ++-- base/compiler/typeutils.jl | 45 ++++++++++++++++++++--------------- test/compiler/irpasses.jl | 31 ++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 21 deletions(-) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 6e125d72d5e41..8f20ac28e3606 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -1107,8 +1107,8 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing) end struct_typ = widenconst(argextype(val, compact)) struct_typ_unwrapped = unwrap_unionall(struct_typ) - if isa(struct_typ, Union) && struct_typ <: Tuple - struct_typ_unwrapped = unswitchtupleunion(struct_typ_unwrapped) + if isa(struct_typ, Union) + struct_typ_unwrapped = unswitchtypeunion(struct_typ_unwrapped) end if isa(struct_typ_unwrapped, Union) && is_isdefined lift_comparison!(isdefined, compact, idx, stmt, lifting_cache, 𝕃ₒ) diff --git a/base/compiler/typeutils.jl b/base/compiler/typeutils.jl index f1794bb83c375..7383ec2a440bf 100644 --- a/base/compiler/typeutils.jl +++ b/base/compiler/typeutils.jl @@ -317,33 +317,40 @@ function unionall_depth(@nospecialize ua) # aka subtype_env_size return depth end -# convert a Union of Tuple types to a Tuple of Unions -unswitchtupleunion(u::Union) = unswitchtypeunion(u, Tuple.name) - +# convert a Union of same `UnionAll` types to the `UnionAll` type whose parameter is the Unions function unswitchtypeunion(u::Union, typename::Union{Nothing,Core.TypeName}=nothing) ts = uniontypes(u) n = -1 for t in ts - if t isa DataType - if typename === nothing - typename = t.name - elseif typename !== t.name - return u - end - if length(t.parameters) != 0 && !isvarargtype(t.parameters[end]) - if n == -1 - n = length(t.parameters) - elseif n != length(t.parameters) - return u - end - end - else + t isa DataType || return u + if typename === nothing + typename = t.name + elseif typename !== t.name + return u + end + params = t.parameters + np = length(params) + if np == 0 || isvarargtype(params[end]) + return u + end + if n == -1 + n = np + elseif n ≠ np return u end end Head = (typename::Core.TypeName).wrapper - unionparams = Any[ Union{Any[(t::DataType).parameters[i] for t in ts]...} for i in 1:n ] - return Head{unionparams...} + hparams = Any[] + for i = 1:n + uparams = Any[] + for t in ts + tpᵢ = (t::DataType).parameters[i] + tpᵢ isa Type || return u + push!(uparams, tpᵢ) + end + push!(hparams, Union{uparams...}) + end + return Head{hparams...} end function unwraptv_ub(@nospecialize t) diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 1cba2d2ee0006..c163d141fc08f 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -1390,3 +1390,34 @@ function wrap1_wrap1_wrapper(b, x, y) end @test wrap1_wrap1_wrapper(true, 1, 1.0) === 1.0 @test wrap1_wrap1_wrapper(false, 1, 1.0) === 1 + +# Test unswitching-union optimization within SRO Apass +function sroaunswitchuniontuple(c, x1, x2) + t = c ? (x1,) : (x2,) + return getfield(t, 1) +end +struct SROAUnswitchUnion1{T} + x::T +end +struct SROAUnswitchUnion2{S,T} + x::T + @inline SROAUnswitchUnion2{S}(x::T) where {S,T} = new{S,T}(x) +end +function sroaunswitchunionstruct1(c, x1, x2) + x = c ? SROAUnswitchUnion1(x1) : SROAUnswitchUnion1(x2) + return getfield(x, :x) +end +function sroaunswitchunionstruct2(c, x1, x2) + x = c ? SROAUnswitchUnion2{:a}(x1) : SROAUnswitchUnion2{:a}(x2) + return getfield(x, :x) +end +let src = code_typed1(sroaunswitchuniontuple, Tuple{Bool, Int, Float64}) + @test count(isnew, src.code) == 0 + @test count(iscall((src, getfield)), src.code) == 0 +end +let src = code_typed1(sroaunswitchunionstruct1, Tuple{Bool, Int, Float64}) + @test count(isnew, src.code) == 0 + @test count(iscall((src, getfield)), src.code) == 0 +end +@test sroaunswitchunionstruct2(true, 1, 1.0) === 1 +@test sroaunswitchunionstruct2(false, 1, 1.0) === 1.0