From d7d2c5395c2e4f540110ef0145139f116a7dff48 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Fri, 12 Mar 2021 01:42:22 +0900 Subject: [PATCH] inference: allows conditional object to propagate constraint multiple times (#39936) Currently we always `widenconditional` conditional var state, which makes us unable to propagate constraints from conditional object multiple times: ```julia @test Base.return_types((Union{Nothing,Int},)) do a b = a === nothing c = b ? 0 : a # c::Int d = !b ? a : 0 # d::Int ideally, but Union{Int,Nothing} c, d end == Any[Tuple{Int,Int}] # fail ``` This PR keeps conditional var state when the update is came from a conditional branching, and allows a conditional object to propagate constraint multiple times as far as the subject of condition doesn't change. AFAIU this is safe because the update from conditional branching doesn't change the condition itself. --- base/compiler/abstractinterpretation.jl | 6 ++-- base/compiler/typelattice.jl | 46 +++++++++++++++---------- test/compiler/inference.jl | 25 ++++++++++++++ 3 files changed, 55 insertions(+), 22 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 5835a7d6f5602..4fd1a26b315b3 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -1773,12 +1773,12 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) frame.src.ssavaluetypes[pc] = t lhs = stmt.args[1] if isa(lhs, Slot) - changes = StateUpdate(lhs, VarState(t, false), changes) + changes = StateUpdate(lhs, VarState(t, false), changes, false) end elseif hd === :method fname = stmt.args[1] if isa(fname, Slot) - changes = StateUpdate(fname, VarState(Any, false), changes) + changes = StateUpdate(fname, VarState(Any, false), changes, false) end elseif hd === :inbounds || hd === :meta || hd === :loopinfo || hd === :code_coverage_effect # these do not generate code @@ -1843,7 +1843,7 @@ end function conditional_changes(changes::VarTable, @nospecialize(typ), var::Slot) if typ ⊑ (changes[slot_id(var)]::VarState).typ - return StateUpdate(var, VarState(typ, false), changes) + return StateUpdate(var, VarState(typ, false), changes, true) end return changes end diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index 17ff193db24c0..e2668e15c9e2c 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -87,6 +87,7 @@ struct StateUpdate var::Union{Slot,SSAValue} vtype::VarState state::VarTable + conditional::Bool end # Represent that the type estimate has been approximated, due to "causes" @@ -321,16 +322,19 @@ function stupdate!(state::Nothing, changes::StateUpdate) changeid = slot_id(changes.var::Slot) newst[changeid] = changes.vtype # remove any Conditional for this Slot from the vtable - for i = 1:length(newst) - newtype = newst[i] - if isa(newtype, VarState) - newtypetyp = ignorelimited(newtype.typ) - if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) == changeid - newtypetyp = widenconditional(newtypetyp) - if newtype.typ isa LimitedAccuracy - newtypetyp = LimitedAccuracy(newtypetyp, newtype.typ.causes) + # (unless this change is came from the conditional) + if !changes.conditional + for i = 1:length(newst) + newtype = newst[i] + if isa(newtype, VarState) + newtypetyp = ignorelimited(newtype.typ) + if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) == changeid + newtypetyp = widenconditional(newtypetyp) + if newtype.typ isa LimitedAccuracy + newtypetyp = LimitedAccuracy(newtypetyp, newtype.typ.causes) + end + newst[i] = VarState(newtypetyp, newtype.undef) end - newst[i] = VarState(newtypetyp, newtype.undef) end end end @@ -352,7 +356,8 @@ function stupdate!(state::VarTable, changes::StateUpdate) end oldtype = state[i] # remove any Conditional for this Slot from the vtable - if isa(newtype, VarState) + # (unless this change is came from the conditional) + if !changes.conditional && isa(newtype, VarState) newtypetyp = ignorelimited(newtype.typ) if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) == changeid newtypetyp = widenconditional(newtypetyp) @@ -393,16 +398,19 @@ function stupdate1!(state::VarTable, change::StateUpdate) end changeid = slot_id(change.var::Slot) # remove any Conditional for this Slot from the catch block vtable - for i = 1:length(state) - oldtype = state[i] - if isa(oldtype, VarState) - oldtypetyp = ignorelimited(oldtype.typ) - if isa(oldtypetyp, Conditional) && slot_id(oldtypetyp.var) == changeid - oldtypetyp = widenconditional(oldtypetyp) - if oldtype.typ isa LimitedAccuracy - oldtypetyp = LimitedAccuracy(oldtypetyp, oldtype.typ.causes) + # (unless this change is came from the conditional) + if !change.conditional + for i = 1:length(state) + oldtype = state[i] + if isa(oldtype, VarState) + oldtypetyp = ignorelimited(oldtype.typ) + if isa(oldtypetyp, Conditional) && slot_id(oldtypetyp.var) == changeid + oldtypetyp = widenconditional(oldtypetyp) + if oldtype.typ isa LimitedAccuracy + oldtypetyp = LimitedAccuracy(oldtypetyp, oldtype.typ.causes) + end + state[i] = VarState(oldtypetyp, oldtype.undef) end - state[i] = VarState(oldtypetyp, oldtype.undef) end end end diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 73dbedf8a39d5..3ac749afaf936 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -1795,6 +1795,31 @@ end end == Any[Union{Nothing,Expr}] end +@testset "branching on conditional object" begin + # simple + @test Base.return_types((Union{Nothing,Int},)) do a + b = a === nothing + return b ? 0 : a # ::Int + end == Any[Int] + + # can use multiple times (as far as the subject of condition hasn't changed) + @test Base.return_types((Union{Nothing,Int},)) do a + b = a === nothing + c = b ? 0 : a # c::Int + d = !b ? a : 0 # d::Int + return c, d # ::Tuple{Int,Int} + end == Any[Tuple{Int,Int}] + + # shouldn't use the old constraint when the subject of condition has changed + @test Base.return_types((Union{Nothing,Int},)) do a + b = a === nothing + c = b ? 0 : a # c::Int + a = 0 + d = b ? a : 1 # d::Int, not d::Union{Nothing,Int} + return c, d # ::Tuple{Int,Int} + end == Any[Tuple{Int,Int}] +end + function f25579(g) h = g[] t = (h === nothing)