Skip to content

Commit

Permalink
inference: form PartialStruct for extra type information propagation (
Browse files Browse the repository at this point in the history
#42831)

* inference: form `PartialStruct` for extra type information propagation

This commit forms `PartialStruct` whenever there is any type-level
refinement available about a field, even if it's not "constant" information.

In Julia "definitions" are allowed to be abstract whereas "usages"
(i.e. callsites) are often concrete. The basic idea is to allow inference
to make more use of such precise callsite type information by encoding it
as `PartialStruct`.

This may increase optimization possibilities of "unidiomatic" Julia code,
which may contain poorly-typed definitions, like this very contrived example:
```julia
struct Problem
    n; s; c; t
end

function main(args...)
    prob = Problem(args...)
    s = 0
    for i in 1:prob.n
        m = mod(i, 3)
        s += m == 0 ? sin(prob.s) : m == 1 ? cos(prob.c) : tan(prob.t)
    end
    return prob, s
end

main(10000, 1, 2, 3)
```

One of the obvious limitation is that this extra type information can be
propagated inter-procedurally only as a const-propagation.
I'm not sure this kind of "just a type-level" refinement can often make
constant-prop' successful (i.e. shape-up a method body and allow it to
be inlined, encoding the extra type information into the generated code),
thus I didn't not modify any part of const-prop' heuristics.

So the improvements from this change might not be very useful for general
inter-procedural analysis currently, but they should definitely improve the
accuracy of local analysis and very simple inter-procedural analysis.
  • Loading branch information
aviatesk authored Nov 1, 2021
1 parent 6c274ed commit a121721
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 45 deletions.
57 changes: 29 additions & 28 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,9 @@ function from_interconditional(@nospecialize(typ), (; fargs, argtypes)::ArgInfo,
else
elsetype = tmeet(elsetype, widenconst(new_elsetype))
end
if (slot > 0 || condval !== false) && !(old vtype) # essentially vtype ⋤ old
if (slot > 0 || condval !== false) && vtype old
slot = id
elseif (slot > 0 || condval !== true) && !(old elsetype) # essentially elsetype ⋤ old
elseif (slot > 0 || condval !== true) && elsetype old
slot = id
else # reset: no new useful information for this slot
vtype = elsetype = Any
Expand Down Expand Up @@ -1598,36 +1598,35 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
elseif ehead === :new
t = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))[1]
if isconcretetype(t) && !ismutabletype(t)
args = Vector{Any}(undef, length(e.args)-1)
ats = Vector{Any}(undef, length(e.args)-1)
anyconst = false
allconst = true
nargs = length(e.args) - 1
ats = Vector{Any}(undef, nargs)
local anyrefine = false
local allconst = true
for i = 2:length(e.args)
at = widenconditional(abstract_eval_value(interp, e.args[i], vtypes, sv))
if !anyconst
anyconst = has_nontrivial_const_info(at)
end
ats[i-1] = at
ft = fieldtype(t, i-1)
at = tmeet(at, ft)
if at === Bottom
t = Bottom
allconst = anyconst = false
break
elseif at isa Const
if !(at.val isa fieldtype(t, i - 1))
t = Bottom
allconst = anyconst = false
break
end
args[i-1] = at.val
else
@goto t_computed
elseif !isa(at, Const)
allconst = false
end
if !anyrefine
anyrefine = has_nontrivial_const_info(at) || # constant information
at ft # just a type-level information, but more precise than the declared type
end
ats[i-1] = at
end
# For now, don't allow partially initialized Const/PartialStruct
if t !== Bottom && fieldcount(t) == length(ats)
if fieldcount(t) == nargs
if allconst
t = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), t, args, length(args)))
elseif anyconst
argvals = Vector{Any}(undef, nargs)
for j in 1:nargs
argvals[j] = (ats[j]::Const).val
end
t = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), t, argvals, nargs))
elseif anyrefine
t = PartialStruct(t, ats)
end
end
Expand All @@ -1638,7 +1637,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
at = abstract_eval_value(interp, e.args[2], vtypes, sv)
n = fieldcount(t)
if isa(at, Const) && isa(at.val, Tuple) && n == length(at.val::Tuple) &&
let t = t; _all(i->getfield(at.val::Tuple, i) isa fieldtype(t, i), 1:n); end
let t = t, at = at; _all(i->getfield(at.val::Tuple, i) isa fieldtype(t, i), 1:n); end
t = Const(ccall(:jl_new_structt, Any, (Any, Any), t, at.val))
elseif isa(at, PartialStruct) && at Tuple && n == length(at.fields::Vector{Any}) &&
let t = t, at = at; _all(i->(at.fields::Vector{Any})[i] fieldtype(t, i), 1:n); end
Expand Down Expand Up @@ -1718,6 +1717,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
else
t = abstract_eval_value_expr(interp, e, vtypes, sv)
end
@label t_computed
@assert !isa(t, TypeVar) "unhandled TypeVar"
if isa(t, DataType) && isdefined(t, :instance)
# replace singleton types with their equivalent Const object
Expand Down Expand Up @@ -1801,17 +1801,18 @@ function widenreturn(@nospecialize(rt), @nospecialize(bestguess), nslots::Int, s
isa(rt, Type) && return rt
if isa(rt, PartialStruct)
fields = copy(rt.fields)
haveconst = false
local anyrefine = false
for i in 1:length(fields)
a = fields[i]
a = isvarargtype(a) ? a : widenreturn(a, bestguess, nslots, slottypes, changes)
if !haveconst && has_const_info(a)
if !anyrefine
# TODO: consider adding && const_prop_profitable(a) here?
haveconst = true
anyrefine = has_const_info(a) ||
a fieldtype(rt.typ, i)
end
fields[i] = a
end
haveconst && return PartialStruct(rt.typ, fields)
anyrefine && return PartialStruct(rt.typ, fields)
end
if isa(rt, PartialOpaque)
return rt # XXX: this case was missed in #39512
Expand Down
23 changes: 22 additions & 1 deletion base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,12 @@ function maybe_extract_const_bool(c::AnyConditional)
end
maybe_extract_const_bool(@nospecialize c) = nothing

function (@nospecialize(a), @nospecialize(b))
"""
a ⊑ b -> Bool
The non-strict partial order over the type inference lattice.
"""
@nospecialize(a) @nospecialize(b) = begin
if isa(b, LimitedAccuracy)
if !isa(a, LimitedAccuracy)
return false
Expand Down Expand Up @@ -232,6 +237,22 @@ function ⊑(@nospecialize(a), @nospecialize(b))
end
end

"""
a ⊏ b -> Bool
The strict partial order over the type inference lattice.
This is defined as the irreflexive kernel of `⊑`.
"""
@nospecialize(a) @nospecialize(b) = a b && !(b, a)

"""
a ⋤ b -> Bool
This order could be used as a slightly more efficient version of the strict order `⊏`,
where we can safely assume `a ⊑ b` holds.
"""
@nospecialize(a) @nospecialize(b) = !(b, a)

# Check if two lattice elements are partial order equivalent. This is basically
# `a ⊑ b && b ⊑ a` but with extra performance optimizations.
function is_lattice_equal(@nospecialize(a), @nospecialize(b))
Expand Down
23 changes: 23 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3669,3 +3669,26 @@ end

# issue #42646
@test only(Base.return_types(getindex, (Array{undef}, Int))) >: Union{} # check that it does not throw

# form PartialStruct for extra type information propagation
struct FieldTypeRefinement{S,T}
s::S
t::T
end
@test Base.return_types((Int,)) do s
o = FieldTypeRefinement{Any,Int}(s, s)
o.s
end |> only == Int
@test Base.return_types((Int,)) do s
o = FieldTypeRefinement{Int,Any}(s, s)
o.t
end |> only == Int
@test Base.return_types((Int,)) do s
o = FieldTypeRefinement{Any,Any}(s, s)
o.s, o.t
end |> only == Tuple{Int,Int}
@test Base.return_types((Int,)) do a
s1 = Some{Any}(a)
s2 = Some{Any}(s1)
s2.value.value
end |> only == Int
23 changes: 7 additions & 16 deletions test/compiler/irpasses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,31 +426,22 @@ let # `getfield_elim_pass!` should work with constant globals
end
end

let # `typeassert_elim_pass!`
let
# `typeassert` elimination after SROA
# NOTE we can remove this optimization once inference is able to reason about memory-effects
src = @eval Module() begin
struct Foo; x; end
mutable struct Foo; x; end

code_typed((Int,)) do a
x1 = Foo(a)
x2 = Foo(x1)
x3 = Foo(x2)

r1 = (x2.x::Foo).x
r2 = (x2.x::Foo).x::Int
r3 = (x2.x::Foo).x::Integer
r4 = ((x3.x::Foo).x::Foo).x

return r1, r2, r3, r4
return typeassert(x2.x, Foo).x
end |> only |> first
end
# eliminate `typeassert(f2.a, Foo)`
@test all(src.code) do @nospecialize(stmt)
# eliminate `typeassert(x2.x, Foo)`
@test all(src.code) do @nospecialize stmt
Meta.isexpr(stmt, :call) || return true
ft = Core.Compiler.argextype(stmt.args[1], src, Any[], src.slottypes)
return Core.Compiler.widenconst(ft) !== typeof(typeassert)
end
# succeeding simple DCE will eliminate `Foo(a)`
@test all(src.code) do @nospecialize(stmt)
return !Meta.isexpr(stmt, :new)
end
end

0 comments on commit a121721

Please sign in to comment.