Skip to content

Commit

Permalink
SROA: generalize unswitchtupleunion optimization
Browse files Browse the repository at this point in the history
This commit improves SROA pass by extending the `unswitchtupleunion`
optimization to handle the general parametric types, e.g.:
```julia
julia> struct A{T}
           x::T
       end;

julia> function foo(a1, a2, c)
           t = c ? A(a1) : A(a2)
           return getfield(t, :x)
       end;

julia> only(Base.code_ircode(foo, (Int,Float64,Bool); optimize_until="SROA"))
```

> Before
```
2 1 ─      goto #3 if not _4                                          │
  2 ─ %2 = %new(A{Int64}, _2)::A{Int64}                               │╻ A
  └──      goto #4                                                    │
  3 ─ %4 = %new(A{Float64}, _3)::A{Float64}                           │╻ A
  4 ┄ %5 = φ (#2 => %2, #3 => %4)::Union{A{Float64}, A{Int64}}        │
3 │   %6 = Main.getfield(%5, :x)::Union{Float64, Int64}               │
  └──      return %6                                                  │
   => Union{Float64, Int64}
```

> After
```
julia> only(Base.code_ircode(foo, (Int,Float64,Bool); optimize_until="SROA"))
2 1 ─      goto #3 if not _4                                           │
  2 ─      nothing::A{Int64}                                           │╻ A
  └──      goto #4                                                     │
  3 ─      nothing::A{Float64}                                         │╻ A
  4 ┄ %8 = φ (#2 => _2, #3 => _3)::Union{Float64, Int64}               │
  │        nothing::Union{A{Float64}, A{Int64}}
3 │   %6 = %8::Union{Float64, Int64}                                   │
  └──      return %6                                                   │
   => Union{Float64, Int64}
```
  • Loading branch information
aviatesk committed Jul 11, 2023
1 parent 680e3b3 commit d7bf34a
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 21 deletions.
4 changes: 2 additions & 2 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1107,8 +1107,8 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
end
struct_typ = widenconst(argextype(val, compact))
struct_typ_unwrapped = unwrap_unionall(struct_typ)
if isa(struct_typ, Union) && struct_typ <: Tuple
struct_typ_unwrapped = unswitchtupleunion(struct_typ_unwrapped)
if isa(struct_typ, Union)
struct_typ_unwrapped = unswitchtypeunion(struct_typ_unwrapped)
end
if isa(struct_typ_unwrapped, Union) && is_isdefined
lift_comparison!(isdefined, compact, idx, stmt, lifting_cache, 𝕃ₒ)
Expand Down
45 changes: 26 additions & 19 deletions base/compiler/typeutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,33 +317,40 @@ function unionall_depth(@nospecialize ua) # aka subtype_env_size
return depth
end

# convert a Union of Tuple types to a Tuple of Unions
unswitchtupleunion(u::Union) = unswitchtypeunion(u, Tuple.name)

# convert a Union of same `UnionAll` types to the `UnionAll` type whose parameter is the Unions
function unswitchtypeunion(u::Union, typename::Union{Nothing,Core.TypeName}=nothing)
ts = uniontypes(u)
n = -1
for t in ts
if t isa DataType
if typename === nothing
typename = t.name
elseif typename !== t.name
return u
end
if length(t.parameters) != 0 && !isvarargtype(t.parameters[end])
if n == -1
n = length(t.parameters)
elseif n != length(t.parameters)
return u
end
end
else
t isa DataType || return u
if typename === nothing
typename = t.name
elseif typename !== t.name
return u
end
params = t.parameters
np = length(params)
if np == 0 || isvarargtype(params[end])
return u
end
if n == -1
n = np
elseif n np
return u
end
end
Head = (typename::Core.TypeName).wrapper
unionparams = Any[ Union{Any[(t::DataType).parameters[i] for t in ts]...} for i in 1:n ]
return Head{unionparams...}
hparams = Any[]
for i = 1:n
uparams = Any[]
for t in ts
tpᵢ = (t::DataType).parameters[i]
tpᵢ isa Type || return u
push!(uparams, tpᵢ)
end
push!(hparams, Union{uparams...})
end
return Head{hparams...}
end

function unwraptv_ub(@nospecialize t)
Expand Down
31 changes: 31 additions & 0 deletions test/compiler/irpasses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1390,3 +1390,34 @@ function wrap1_wrap1_wrapper(b, x, y)
end
@test wrap1_wrap1_wrapper(true, 1, 1.0) === 1.0
@test wrap1_wrap1_wrapper(false, 1, 1.0) === 1

# Test unswitching-union optimization within SRO Apass
function sroaunswitchuniontuple(c, x1, x2)
t = c ? (x1,) : (x2,)
return getfield(t, 1)
end
struct SROAUnswitchUnion1{T}
x::T
end
struct SROAUnswitchUnion2{S,T}
x::T
@inline SROAUnswitchUnion2{S}(x::T) where {S,T} = new{S,T}(x)
end
function sroaunswitchunionstruct1(c, x1, x2)
x = c ? SROAUnswitchUnion1(x1) : SROAUnswitchUnion1(x2)
return getfield(x, :x)
end
function sroaunswitchunionstruct2(c, x1, x2)
x = c ? SROAUnswitchUnion2{:a}(x1) : SROAUnswitchUnion2{:a}(x2)
return getfield(x, :x)
end
let src = code_typed1(sroaunswitchuniontuple, Tuple{Bool, Int, Float64})
@test count(isnew, src.code) == 0
@test count(iscall((src, getfield)), src.code) == 0
end
let src = code_typed1(sroaunswitchunionstruct1, Tuple{Bool, Int, Float64})
@test count(isnew, src.code) == 0
@test count(iscall((src, getfield)), src.code) == 0
end
@test sroaunswitchunionstruct2(true, 1, 1.0) === 1
@test sroaunswitchunionstruct2(false, 1, 1.0) === 1.0

0 comments on commit d7bf34a

Please sign in to comment.