diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 5835a7d6f5602..4fd1a26b315b3 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -1773,12 +1773,12 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) frame.src.ssavaluetypes[pc] = t lhs = stmt.args[1] if isa(lhs, Slot) - changes = StateUpdate(lhs, VarState(t, false), changes) + changes = StateUpdate(lhs, VarState(t, false), changes, false) end elseif hd === :method fname = stmt.args[1] if isa(fname, Slot) - changes = StateUpdate(fname, VarState(Any, false), changes) + changes = StateUpdate(fname, VarState(Any, false), changes, false) end elseif hd === :inbounds || hd === :meta || hd === :loopinfo || hd === :code_coverage_effect # these do not generate code @@ -1843,7 +1843,7 @@ end function conditional_changes(changes::VarTable, @nospecialize(typ), var::Slot) if typ ⊑ (changes[slot_id(var)]::VarState).typ - return StateUpdate(var, VarState(typ, false), changes) + return StateUpdate(var, VarState(typ, false), changes, true) end return changes end diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index 17ff193db24c0..e2668e15c9e2c 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -87,6 +87,7 @@ struct StateUpdate var::Union{Slot,SSAValue} vtype::VarState state::VarTable + conditional::Bool end # Represent that the type estimate has been approximated, due to "causes" @@ -321,16 +322,19 @@ function stupdate!(state::Nothing, changes::StateUpdate) changeid = slot_id(changes.var::Slot) newst[changeid] = changes.vtype # remove any Conditional for this Slot from the vtable - for i = 1:length(newst) - newtype = newst[i] - if isa(newtype, VarState) - newtypetyp = ignorelimited(newtype.typ) - if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) == changeid - newtypetyp = widenconditional(newtypetyp) - if newtype.typ isa LimitedAccuracy - newtypetyp = LimitedAccuracy(newtypetyp, newtype.typ.causes) + # (unless this change is came from the conditional) + if !changes.conditional + for i = 1:length(newst) + newtype = newst[i] + if isa(newtype, VarState) + newtypetyp = ignorelimited(newtype.typ) + if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) == changeid + newtypetyp = widenconditional(newtypetyp) + if newtype.typ isa LimitedAccuracy + newtypetyp = LimitedAccuracy(newtypetyp, newtype.typ.causes) + end + newst[i] = VarState(newtypetyp, newtype.undef) end - newst[i] = VarState(newtypetyp, newtype.undef) end end end @@ -352,7 +356,8 @@ function stupdate!(state::VarTable, changes::StateUpdate) end oldtype = state[i] # remove any Conditional for this Slot from the vtable - if isa(newtype, VarState) + # (unless this change is came from the conditional) + if !changes.conditional && isa(newtype, VarState) newtypetyp = ignorelimited(newtype.typ) if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) == changeid newtypetyp = widenconditional(newtypetyp) @@ -393,16 +398,19 @@ function stupdate1!(state::VarTable, change::StateUpdate) end changeid = slot_id(change.var::Slot) # remove any Conditional for this Slot from the catch block vtable - for i = 1:length(state) - oldtype = state[i] - if isa(oldtype, VarState) - oldtypetyp = ignorelimited(oldtype.typ) - if isa(oldtypetyp, Conditional) && slot_id(oldtypetyp.var) == changeid - oldtypetyp = widenconditional(oldtypetyp) - if oldtype.typ isa LimitedAccuracy - oldtypetyp = LimitedAccuracy(oldtypetyp, oldtype.typ.causes) + # (unless this change is came from the conditional) + if !change.conditional + for i = 1:length(state) + oldtype = state[i] + if isa(oldtype, VarState) + oldtypetyp = ignorelimited(oldtype.typ) + if isa(oldtypetyp, Conditional) && slot_id(oldtypetyp.var) == changeid + oldtypetyp = widenconditional(oldtypetyp) + if oldtype.typ isa LimitedAccuracy + oldtypetyp = LimitedAccuracy(oldtypetyp, oldtype.typ.causes) + end + state[i] = VarState(oldtypetyp, oldtype.undef) end - state[i] = VarState(oldtypetyp, oldtype.undef) end end end diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 73dbedf8a39d5..3ac749afaf936 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -1795,6 +1795,31 @@ end end == Any[Union{Nothing,Expr}] end +@testset "branching on conditional object" begin + # simple + @test Base.return_types((Union{Nothing,Int},)) do a + b = a === nothing + return b ? 0 : a # ::Int + end == Any[Int] + + # can use multiple times (as far as the subject of condition hasn't changed) + @test Base.return_types((Union{Nothing,Int},)) do a + b = a === nothing + c = b ? 0 : a # c::Int + d = !b ? a : 0 # d::Int + return c, d # ::Tuple{Int,Int} + end == Any[Tuple{Int,Int}] + + # shouldn't use the old constraint when the subject of condition has changed + @test Base.return_types((Union{Nothing,Int},)) do a + b = a === nothing + c = b ? 0 : a # c::Int + a = 0 + d = b ? a : 1 # d::Int, not d::Union{Nothing,Int} + return c, d # ::Tuple{Int,Int} + end == Any[Tuple{Int,Int}] +end + function f25579(g) h = g[] t = (h === nothing)