Skip to content

Commit

Permalink
SROA: don't use unswitchtupleunion and explicitly use type name only (
Browse files Browse the repository at this point in the history
#50522)

Since construction of `UnionAll` of `Union`s can be expensive. The SROA pass just needs to
look at type name information and do not need to propagate full type objects.

- xref: <#50511 (comment)>
- closes #50511
  • Loading branch information
aviatesk authored Jul 13, 2023
1 parent 824cdf1 commit 9b73611
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 58 deletions.
22 changes: 11 additions & 11 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ function try_compute_field(ir::Union{IncrementalCompact,IRCode}, @nospecialize(f
end

# assume `stmt` is a call of `getfield`/`setfield!`/`isdefined`
function try_compute_fieldidx_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr, typ::DataType)
function try_compute_fieldidx_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr, @nospecialize(typ))
field = try_compute_field(ir, stmt.args[3])
return try_compute_fieldidx(typ, field)
end
Expand Down Expand Up @@ -1106,24 +1106,24 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
val = stmt.args[2]
end
struct_typ = widenconst(argextype(val, compact))
struct_typ_unwrapped = unwrap_unionall(struct_typ)
if isa(struct_typ, Union)
struct_typ_unwrapped = unswitchtypeunion(struct_typ_unwrapped)
end
if isa(struct_typ_unwrapped, Union) && is_isdefined
lift_comparison!(isdefined, compact, idx, stmt, lifting_cache, 𝕃ₒ)
struct_typ_name = argument_datatype(struct_typ)
if struct_typ_name === nothing
if isa(struct_typ, Union)

This comment has been minimized.

Copy link
@Keno

Keno Jul 14, 2023

Member

Lost && is_isdefined condition?

This comment has been minimized.

Copy link
@Keno

Keno Jul 14, 2023

Member
lift_comparison!(isdefined, compact, idx, stmt, lifting_cache, 𝕃ₒ)
end
continue
else
struct_typ_name = struct_typ_name.name
end
isa(struct_typ_unwrapped, DataType) || continue

struct_typ_unwrapped.name.atomicfields == C_NULL || continue # TODO: handle more
struct_typ_name.atomicfields == C_NULL || continue # TODO: handle more
if !((field_ordering === :unspecified) ||
(field_ordering isa Const && field_ordering.val === :not_atomic))
continue
end

# analyze this mutable struct here for the later pass
if ismutabletype(struct_typ_unwrapped)
if ismutabletypename(struct_typ_name)
isa(val, SSAValue) || continue
let intermediaries = SPCSet()
callback = IntermediaryCollector(intermediaries)
Expand Down Expand Up @@ -1153,7 +1153,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
end

# perform SROA on immutable structs here on
field = try_compute_fieldidx_stmt(compact, stmt, struct_typ_unwrapped)
field = try_compute_fieldidx_stmt(compact, stmt, struct_typ)
field === nothing && continue

leaves, visited_philikes = collect_leaves(compact, val, struct_typ, 𝕃ₒ, phi_or_ifelse_predecessors)
Expand Down
16 changes: 6 additions & 10 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -877,13 +877,10 @@ function fieldcount_noerror(@nospecialize t)
if t === nothing
return nothing
end
t = t::DataType
elseif t === Union{}
return 0
end
if !(t isa DataType)
return nothing
end
t isa DataType || return nothing
if t.name === _NAMEDTUPLE_NAME
names, types = t.parameters
if names isa Tuple
Expand All @@ -892,17 +889,16 @@ function fieldcount_noerror(@nospecialize t)
if types isa DataType && types <: Tuple
return fieldcount_noerror(types)
end
abstr = true
else
abstr = isabstracttype(t) || (t.name === Tuple.name && isvatuple(t))
end
if abstr
return nothing
elseif isabstracttype(t) || (t.name === Tuple.name && isvatuple(t))
return nothing
end
return isdefined(t, :types) ? length(t.types) : length(t.name.names)
end

function try_compute_fieldidx(typ::DataType, @nospecialize(field))
function try_compute_fieldidx(@nospecialize(typ), @nospecialize(field))
typ = argument_datatype(typ)
typ === nothing && return nothing
if isa(field, Symbol)
field = fieldindex(typ, field, false)
field == 0 && return nothing
Expand Down
36 changes: 0 additions & 36 deletions base/compiler/typeutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,42 +317,6 @@ function unionall_depth(@nospecialize ua) # aka subtype_env_size
return depth
end

# convert a Union of same `UnionAll` types to the `UnionAll` type whose parameter is the Unions
function unswitchtypeunion(u::Union, typename::Union{Nothing,Core.TypeName}=nothing)
ts = uniontypes(u)
n = -1
for t in ts
t isa DataType || return u
if typename === nothing
typename = t.name
elseif typename !== t.name
return u
end
params = t.parameters
np = length(params)
if np == 0 || isvarargtype(params[end])
return u
end
if n == -1
n = np
elseif n np
return u
end
end
Head = (typename::Core.TypeName).wrapper
hparams = Any[]
for i = 1:n
uparams = Any[]
for t in ts
tpᵢ = (t::DataType).parameters[i]
tpᵢ isa Type || return u
push!(uparams, tpᵢ)
end
push!(hparams, Union{uparams...})
end
return Head{hparams...}
end

function unwraptv_ub(@nospecialize t)
while isa(t, TypeVar)
t = t.ub
Expand Down
4 changes: 3 additions & 1 deletion base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -529,9 +529,11 @@ function ismutabletype(@nospecialize t)
@_total_meta
t = unwrap_unionall(t)
# TODO: what to do for `Union`?
return isa(t, DataType) && t.name.flags & 0x2 == 0x2
return isa(t, DataType) && ismutabletypename(t.name)
end

ismutabletypename(tn::Core.TypeName) = tn.flags & 0x2 == 0x2

"""
isstructtype(T) -> Bool
Expand Down

0 comments on commit 9b73611

Please sign in to comment.