Skip to content

Commit

Permalink
inference: model partially initialized structs with PartialStruct
Browse files Browse the repository at this point in the history
There is still room for improvement in the accuracy of `getfield` and
`isdefined` for structs with uninitialized fields. This commit aims to
enhance the accuracy of struct field defined-ness by propagating such
struct as `PartialStruct` in cases where fields that might be
uninitialized are confirmed to be defined. Specifically, the
improvements are made in the following situations:
1. when a `:new` expression receives arguments greater than the minimum
   number of initialized fields.
2. when new information about the initialized fields of `x` can be
   obtained in the `then` branch of `if isdefined(x, :f)`.

Combined with the existing optimizations, these improvements enable DCE
in scenarios such as:
```julia
julia> @noinline broadcast_noescape1(a) = (broadcast(identity, a); nothing);

julia> @allocated broadcast_noescape1(Ref("x"))
16 # master
0  # this PR
```

One important point to note is that, as revealed in
#48999, fields and globals can revert to `undef` during
precompilation. This commit does not affect globals. Furthermore, even
for fields, the refinements made by 1. and 2. are propagated along with
data-flow, and field defined-ness information is only used when fields
are confirmed to be initialized. Therefore, the same issues as
#48999 will not occur by this commit.
  • Loading branch information
aviatesk committed Jul 30, 2024
1 parent 686804d commit 8ce98cf
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 47 deletions.
73 changes: 45 additions & 28 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1968,33 +1968,52 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
return Conditional(aty.slot, thentype, elsetype)
end
elseif f === isdefined
uty = argtypes[2]
a = ssa_def_slot(fargs[2], sv)
if isa(uty, Union) && isa(a, SlotNumber)
fld = argtypes[3]
thentype = Bottom
elsetype = Bottom
for ty in uniontypes(uty)
cnd = isdefined_tfunc(𝕃ᵢ, ty, fld)
if isa(cnd, Const)
if cnd.val::Bool
thentype = thentype ty
if isa(a, SlotNumber)
argtype2 = argtypes[2]
if isa(argtype2, Union)
fld = argtypes[3]
thentype = Bottom
elsetype = Bottom
for ty in uniontypes(argtype2)
cnd = isdefined_tfunc(𝕃ᵢ, ty, fld)
if isa(cnd, Const)
if cnd.val::Bool
thentype = thentype ty
else
elsetype = elsetype ty
end
else
thentype = thentype ty
elsetype = elsetype ty
end
else
thentype = thentype ty
elsetype = elsetype ty
end
return Conditional(a, thentype, elsetype)
else
thentype = form_partially_defined_struct(argtype2, argtypes[3])
if thentype !== nothing
return Conditional(a, thentype, argtype2)
end
end
return Conditional(a, thentype, elsetype)
end
end
end
@assert !isa(rt, TypeVar) "unhandled TypeVar"
return rt
end

function form_partially_defined_struct(@nospecialize(obj), @nospecialize(name))
obj isa Const && return nothing # nothing to refine
name isa Const || return nothing
objt0 = widenconst(obj)
objt = unwrap_unionall(objt0)
isabstracttype(objt) && return nothing
fldidx = try_compute_fieldidx(objt, name.val)
fldidx === nothing && return nothing
fldidx datatype_min_ninitialized(objt) && return nothing
return PartialStruct(objt0, Any[fieldtype(objt0, i) for i = 1:fldidx])
end

function abstract_call_unionall(interp::AbstractInterpreter, argtypes::Vector{Any}, call::CallMeta)
na = length(argtypes)
if isvarargtype(argtypes[end])
Expand Down Expand Up @@ -2542,20 +2561,18 @@ function abstract_eval_new(interp::AbstractInterpreter, e::Expr, vtypes::Union{V
end
ats[i] = at
end
# For now, don't allow:
# - Const/PartialStruct of mutables (but still allow PartialStruct of mutables
# with `const` fields if anything refined)
# - partially initialized Const/PartialStruct
if fcount == nargs
if consistent === ALWAYS_TRUE && allconst
argvals = Vector{Any}(undef, nargs)
for j in 1:nargs
argvals[j] = (ats[j]::Const).val
end
rt = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), rt, argvals, nargs))
elseif anyrefine
rt = PartialStruct(rt, ats)
if fcount == nargs && consistent === ALWAYS_TRUE && allconst
argvals = Vector{Any}(undef, nargs)
for j in 1:nargs
argvals[j] = (ats[j]::Const).val
end
rt = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), rt, argvals, nargs))
elseif anyrefine || nargs > datatype_min_ninitialized(rt)
# propagate partially initialized struct as `PartialStruct` when:
# - any refinement information is available (`anyrefine`), or when
# - `nargs` is greater than `n_initialized` derived from the struct type
# information alone
rt = PartialStruct(rt, ats)
end
else
rt = refine_partial_type(rt)
Expand Down Expand Up @@ -3062,7 +3079,7 @@ end
@nospecializeinfer function widenreturn_partials(𝕃ᵢ::PartialsLattice, @nospecialize(rt), info::BestguessInfo)
if isa(rt, PartialStruct)
fields = copy(rt.fields)
local anyrefine = false
anyrefine = !isvarargtype(rt.fields[end]) && length(rt.fields) > datatype_min_ninitialized(rt.typ)
𝕃 = typeinf_lattice(info.interp)
= strictpartialorder(𝕃)
for i in 1:length(fields)
Expand Down
7 changes: 6 additions & 1 deletion base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1166,7 +1166,12 @@ struct IntermediaryCollector <: WalkerCallback
intermediaries::SPCSet
end
function (walker_callback::IntermediaryCollector)(@nospecialize(def), @nospecialize(defssa::AnySSAValue))
isa(def, Expr) || push!(walker_callback.intermediaries, defssa.id)
if !(def isa Expr)
push!(walker_callback.intermediaries, defssa.id)
if def isa PiNode
return LiftedValue(def.val)
end
end
return nothing
end

Expand Down
39 changes: 30 additions & 9 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,15 @@ end
if !ismutabletype(a1) || isconst(a1, idx)
return Const(isdefined(arg1.val, idx))
end
elseif isa(arg1, PartialStruct)
nflds = length(arg1.fields)
if !isvarargtype(arg1.fields[end])
if 1 idx nflds
return Const(true)
elseif !ismutabletype(a1) || isconst(a1, idx)
return Const(false)
end
end
elseif !isvatuple(a1)
fieldT = fieldtype(a1, idx)
if isa(fieldT, DataType) && isbitstype(fieldT)
Expand Down Expand Up @@ -980,27 +989,39 @@ end
= partialorder(𝕃)

# If we have s00 being a const, we can potentially refine our type-based analysis above
if isa(s00, Const) || isconstType(s00)
if !isa(s00, Const)
sv = (s00::DataType).parameters[1]
else
if isa(s00, Const) || isconstType(s00) || isa(s00, PartialStruct)
if isa(s00, Const)
sv = s00.val
sty = typeof(sv)
nflds = nfields(sv)
ismod = sv isa Module
elseif isa(s00, PartialStruct)
sty = s00.typ
nflds = fieldcount_noerror(sty)
ismod = false
else
sv = (s00::DataType).parameters[1]
sty = typeof(sv)
nflds = nfields(sv)
ismod = sv isa Module
end
if isa(name, Const)
nval = name.val
if !isa(nval, Symbol)
isa(sv, Module) && return false
ismod && return false
isa(nval, Int) || return false
end
return isdefined_tfunc(𝕃, s00, name) === Const(true)
end
boundscheck && return false

# If bounds checking is disabled and all fields are assigned,
# we may assume that we don't throw
isa(sv, Module) && return false
@assert !boundscheck
ismod && return false
name Int || name Symbol || return false
typeof(sv).name.n_uninitialized == 0 && return true
for i = (datatype_min_ninitialized(typeof(sv)) + 1):nfields(sv)
sty.name.n_uninitialized == 0 && return true
nflds === nothing && return false
for i = (datatype_min_ninitialized(sty)+1):nflds
isdefined_tfunc(𝕃, s00, Const(i)) === Const(true) || return false
end
return true
Expand Down
10 changes: 1 addition & 9 deletions test/compiler/EscapeAnalysis/EscapeAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2139,21 +2139,13 @@ end
# ========================

# propagate escapes imposed on call arguments
@noinline broadcast_noescape1(a) = (broadcast(identity, a); nothing)
let result = code_escapes() do
broadcast_noescape1(Ref("Hi"))
end
i = only(findall(isnew, result.ir.stmts.stmt))
@test !has_return_escape(result.state[SSAValue(i)])
@test_broken !has_thrown_escape(result.state[SSAValue(i)]) # TODO `getfield(RefValue{String}, :x)` isn't safe
end
@noinline broadcast_noescape2(b) = broadcast(identity, b)
let result = code_escapes() do
broadcast_noescape2(Ref("Hi"))
end
i = only(findall(isnew, result.ir.stmts.stmt))
@test_broken !has_return_escape(result.state[SSAValue(i)]) # TODO interprocedural alias analysis
@test_broken !has_thrown_escape(result.state[SSAValue(i)]) # TODO `getfield(RefValue{String}, :x)` isn't safe
@test !has_thrown_escape(result.state[SSAValue(i)])
end
@noinline allescape_argument(a) = (global GV = a) # obvious escape
let result = code_escapes() do
Expand Down
89 changes: 89 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5866,3 +5866,92 @@ end
bar54341(args...) = foo54341(4, args...)

@test Core.Compiler.return_type(bar54341, Tuple{Vararg{Int}}) === Int

# `PartialStruct` for partially initialized structs:
struct PartiallyInitialized1
a; b; c
PartiallyInitialized1(a) = (@nospecialize; new(a))
PartiallyInitialized1(a, b) = (@nospecialize; new(a, b))
PartiallyInitialized1(a, b, c) = (@nospecialize; new(a, b, c))
end
mutable struct PartiallyInitialized2
a; b; c
PartiallyInitialized2(a) = (@nospecialize; new(a))
PartiallyInitialized2(a, b) = (@nospecialize; new(a, b))
PartiallyInitialized2(a, b, c) = (@nospecialize; new(a, b, c))
end

# 1. isdefined modeling for partial struct
@test Base.infer_return_type((Any,Any)) do a, b
Val(isdefined(PartiallyInitialized1(a, b), :b))
end == Val{true}
@test Base.infer_return_type((Any,Any,)) do a, b
Val(isdefined(PartiallyInitialized1(a, b), :c))
end == Val{false}
@test Base.infer_return_type((Any,Any,Any)) do a, b, c
Val(isdefined(PartiallyInitialized1(a, b, c), :c))
end == Val{true}
@test Base.infer_return_type((Any,Any)) do a, b
Val(isdefined(PartiallyInitialized2(a, b), :b))
end == Val{true}
@test Base.infer_return_type((Any,Any,)) do a, b
Val(isdefined(PartiallyInitialized2(a, b), :c))
end >: Val{false}
@test Base.infer_return_type((Any,Any,Any)) do a, b, c
s = PartiallyInitialized2(a, b)
s.c = c
Val(isdefined(s, :c))
end >: Val{true}
@test Base.infer_return_type((Any,Any,Any)) do a, b, c
Val(isdefined(PartiallyInitialized2(a, b, c), :c))
end == Val{true}
@test Base.infer_return_type((Vector{Int},)) do xs
Val(isdefined(tuple(1, xs...), 1))
end == Val{true}
@test Base.infer_return_type((Vector{Int},)) do xs
Val(isdefined(tuple(1, xs...), 2))
end == Val

# 2. getfield modeling for partial struct
@test Base.infer_effects((Any,Any); optimize=false) do a, b
getfield(PartiallyInitialized1(a, b), :b)
end |> Core.Compiler.is_nothrow
@test Base.infer_effects((Any,Any,Symbol,); optimize=false) do a, b, f
getfield(PartiallyInitialized1(a, b), f, #=boundscheck=#false)
end |> !Core.Compiler.is_nothrow
@test Base.infer_effects((Any,Any,Any); optimize=false) do a, b, c
getfield(PartiallyInitialized1(a, b, c), :c)
end |> Core.Compiler.is_nothrow
@test Base.infer_effects((Any,Any,Any,Symbol); optimize=false) do a, b, c, f
getfield(PartiallyInitialized1(a, b, c), f, #=boundscheck=#false)
end |> Core.Compiler.is_nothrow
@test Base.infer_effects((Any,Any); optimize=false) do a, b
getfield(PartiallyInitialized2(a, b), :b)
end |> Core.Compiler.is_nothrow
@test Base.infer_effects((Any,Any,Symbol,); optimize=false) do a, b, f
getfield(PartiallyInitialized2(a, b), f, #=boundscheck=#false)
end |> !Core.Compiler.is_nothrow
@test Base.infer_effects((Any,Any,Any); optimize=false) do a, b, c
getfield(PartiallyInitialized2(a, b, c), :c)
end |> Core.Compiler.is_nothrow
@test Base.infer_effects((Any,Any,Any,Symbol); optimize=false) do a, b, c, f
getfield(PartiallyInitialized2(a, b, c), f, #=boundscheck=#false)
end |> Core.Compiler.is_nothrow

# isdefined-Conditionals
@test Base.infer_effects((Base.RefValue{Any},)) do x
if isdefined(x, :x)
return getfield(x, :x)
end
end |> Core.Compiler.is_nothrow
@test Base.infer_effects((Base.RefValue{Any},)) do x
if isassigned(x)
return x[]
end
end |> Core.Compiler.is_nothrow

# End to end test case for the partially initialized struct with `PartialStruct`
@noinline broadcast_noescape1(a) = (broadcast(identity, a); nothing)
@test fully_eliminated() do
broadcast_noescape1(Ref("x"))
end

0 comments on commit 8ce98cf

Please sign in to comment.