diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 8f20ac28e3606..bf3635f81ccf9 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -64,7 +64,7 @@ function try_compute_field(ir::Union{IncrementalCompact,IRCode}, @nospecialize(f end # assume `stmt` is a call of `getfield`/`setfield!`/`isdefined` -function try_compute_fieldidx_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr, typ::DataType) +function try_compute_fieldidx_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr, @nospecialize(typ)) field = try_compute_field(ir, stmt.args[3]) return try_compute_fieldidx(typ, field) end @@ -1106,24 +1106,24 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing) val = stmt.args[2] end struct_typ = widenconst(argextype(val, compact)) - struct_typ_unwrapped = unwrap_unionall(struct_typ) - if isa(struct_typ, Union) - struct_typ_unwrapped = unswitchtypeunion(struct_typ_unwrapped) - end - if isa(struct_typ_unwrapped, Union) && is_isdefined - lift_comparison!(isdefined, compact, idx, stmt, lifting_cache, 𝕃ₒ) + struct_typ_name = argument_datatype(struct_typ) + if struct_typ_name === nothing + if isa(struct_typ, Union) + lift_comparison!(isdefined, compact, idx, stmt, lifting_cache, 𝕃ₒ) + end continue + else + struct_typ_name = struct_typ_name.name end - isa(struct_typ_unwrapped, DataType) || continue - struct_typ_unwrapped.name.atomicfields == C_NULL || continue # TODO: handle more + struct_typ_name.atomicfields == C_NULL || continue # TODO: handle more if !((field_ordering === :unspecified) || (field_ordering isa Const && field_ordering.val === :not_atomic)) continue end # analyze this mutable struct here for the later pass - if ismutabletype(struct_typ_unwrapped) + if ismutabletypename(struct_typ_name) isa(val, SSAValue) || continue let intermediaries = SPCSet() callback = IntermediaryCollector(intermediaries) @@ -1153,7 +1153,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing) end # perform SROA on immutable structs here on - field = try_compute_fieldidx_stmt(compact, stmt, struct_typ_unwrapped) + field = try_compute_fieldidx_stmt(compact, stmt, struct_typ) field === nothing && continue leaves, visited_philikes = collect_leaves(compact, val, struct_typ, 𝕃ₒ, phi_or_ifelse_predecessors) diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index 117f5288418e1..e431de009affc 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -877,13 +877,10 @@ function fieldcount_noerror(@nospecialize t) if t === nothing return nothing end - t = t::DataType elseif t === Union{} return 0 end - if !(t isa DataType) - return nothing - end + t isa DataType || return nothing if t.name === _NAMEDTUPLE_NAME names, types = t.parameters if names isa Tuple @@ -892,17 +889,16 @@ function fieldcount_noerror(@nospecialize t) if types isa DataType && types <: Tuple return fieldcount_noerror(types) end - abstr = true - else - abstr = isabstracttype(t) || (t.name === Tuple.name && isvatuple(t)) - end - if abstr + return nothing + elseif isabstracttype(t) || (t.name === Tuple.name && isvatuple(t)) return nothing end return isdefined(t, :types) ? length(t.types) : length(t.name.names) end -function try_compute_fieldidx(typ::DataType, @nospecialize(field)) +function try_compute_fieldidx(@nospecialize(typ), @nospecialize(field)) + typ = argument_datatype(typ) + typ === nothing && return nothing if isa(field, Symbol) field = fieldindex(typ, field, false) field == 0 && return nothing diff --git a/base/compiler/typeutils.jl b/base/compiler/typeutils.jl index 7383ec2a440bf..038e5f8abdf88 100644 --- a/base/compiler/typeutils.jl +++ b/base/compiler/typeutils.jl @@ -317,42 +317,6 @@ function unionall_depth(@nospecialize ua) # aka subtype_env_size return depth end -# convert a Union of same `UnionAll` types to the `UnionAll` type whose parameter is the Unions -function unswitchtypeunion(u::Union, typename::Union{Nothing,Core.TypeName}=nothing) - ts = uniontypes(u) - n = -1 - for t in ts - t isa DataType || return u - if typename === nothing - typename = t.name - elseif typename !== t.name - return u - end - params = t.parameters - np = length(params) - if np == 0 || isvarargtype(params[end]) - return u - end - if n == -1 - n = np - elseif n ≠ np - return u - end - end - Head = (typename::Core.TypeName).wrapper - hparams = Any[] - for i = 1:n - uparams = Any[] - for t in ts - tpᵢ = (t::DataType).parameters[i] - tpᵢ isa Type || return u - push!(uparams, tpᵢ) - end - push!(hparams, Union{uparams...}) - end - return Head{hparams...} -end - function unwraptv_ub(@nospecialize t) while isa(t, TypeVar) t = t.ub diff --git a/base/reflection.jl b/base/reflection.jl index 05ffb3a6e9211..0cd48e0c822ce 100644 --- a/base/reflection.jl +++ b/base/reflection.jl @@ -529,9 +529,11 @@ function ismutabletype(@nospecialize t) @_total_meta t = unwrap_unionall(t) # TODO: what to do for `Union`? - return isa(t, DataType) && t.name.flags & 0x2 == 0x2 + return isa(t, DataType) && ismutabletypename(t.name) end +ismutabletypename(tn::Core.TypeName) = tn.flags & 0x2 == 0x2 + """ isstructtype(T) -> Bool