From d1b45b216b27d45b2f3af98bda48b0d3224fbe8a Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Fri, 16 Aug 2024 15:10:08 +0900 Subject: [PATCH] more accurate lattice implementation --- base/compiler/typelattice.jl | 28 +++++++++++----- base/compiler/typelimits.jl | 36 +++++++++++--------- test/compiler/inference.jl | 65 ++++++++++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 25 deletions(-) diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index 7565740338a1dc..16c27aea928a5f 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -469,8 +469,13 @@ end @nospecializeinfer function ⊑(lattice::PartialsLattice, @nospecialize(a), @nospecialize(b)) if isa(a, PartialStruct) if isa(b, PartialStruct) - if !(length(a.fields) == length(b.fields) && a.typ <: b.typ) - return false + a.typ <: b.typ || return false + if length(a.fields) ≠ length(b.fields) + if !(isvarargtype(a.fields[end]) || isvarargtype(b.fields[end])) + length(a.fields) ≥ length(b.fields) || return false + else + return false + end end for i in 1:length(b.fields) af = a.fields[i] @@ -493,19 +498,24 @@ end return isa(b, Type) && a.typ <: b elseif isa(b, PartialStruct) if isa(a, Const) - nf = nfields(a.val) - nf == length(b.fields) || return false widea = widenconst(a)::DataType wideb = widenconst(b) wideb′ = unwrap_unionall(wideb)::DataType widea.name === wideb′.name || return false - # We can skip the subtype check if b is a Tuple, since in that - # case, the ⊑ of the elements is sufficient. - if wideb′.name !== Tuple.name && !(widea <: wideb) - return false + if wideb′.name === Tuple.name + # We can skip the subtype check if b is a Tuple, since in that + # case, the ⊑ of the elements is sufficient. + # But for tuple comparisons, we need their lengths to be the same for now. + # TODO improve accuracy for cases when `b` contains vararg element + nfields(a.val) == length(b.fields) || return false + else + widea <: wideb || return false + # for structs we need to check that `a` has more information than `b` that may be partially initialized + n_initialized(a) ≥ length(b.fields) || return false end - for i in 1:nf + for i in 1:nfields(a.val) isdefined(a.val, i) || continue # since ∀ T Union{} ⊑ T + i > length(b.fields) && break # `a` has more information than `b` that is partially initialized struct bfᵢ = b.fields[i] if i == nf bfᵢ = unwrapva(bfᵢ) diff --git a/base/compiler/typelimits.jl b/base/compiler/typelimits.jl index 1a1aa7a35e840e..91a44d3b117abd 100644 --- a/base/compiler/typelimits.jl +++ b/base/compiler/typelimits.jl @@ -321,6 +321,11 @@ end # even after complicated recursion and other operations on it elsewhere const issimpleenoughtupleelem = issimpleenoughtype +function n_initialized(t::Const) + nf = nfields(t.val) + return something(findfirst(i::Int->!isdefined(t.val,i), 1:nf), nf+1)-1 +end + # A simplified type_more_complex query over the extended lattice # (assumes typeb ⊑ typea) @nospecializeinfer function issimplertype(𝕃::AbstractLattice, @nospecialize(typea), @nospecialize(typeb)) @@ -328,8 +333,12 @@ const issimpleenoughtupleelem = issimpleenoughtype typea === typeb && return true if typea isa PartialStruct aty = widenconst(typea) - if length(typea.fields) > datatype_min_ninitialized(unwrap_unionall(aty)) - return false # TODO more accuracy here? + if typeb isa Const + @assert length(typea.fields) ≤ n_initialized(typeb) "typeb ⊑ typea is assumed" + elseif typeb isa PartialStruct + @assert length(typea.fields) ≤ length(typeb.fields) "typeb ⊑ typea is assumed" + else + return false end for i = 1:length(typea.fields) ai = unwrapva(typea.fields[i]) @@ -579,26 +588,21 @@ end aty = widenconst(typea) bty = widenconst(typeb) if aty === bty - # must have egal here, since we do not create PartialStruct for non-concrete types - typea_nfields = nfields_tfunc(𝕃, typea) - typeb_nfields = nfields_tfunc(𝕃, typeb) - isa(typea_nfields, Const) || return nothing - isa(typeb_nfields, Const) || return nothing - type_nfields = typea_nfields.val::Int - type_nfields == typeb_nfields.val::Int || return nothing - type_nfields == 0 && return nothing if typea isa PartialStruct if typeb isa PartialStruct - length(typea.fields) == length(typeb.fields) || return nothing + nflds = min(length(typea.fields), length(typeb.fields)) else - length(typea.fields) == type_nfields || return nothing + nflds = min(length(typea.fields), n_initialized(typeb::Const)) end elseif typeb isa PartialStruct - length(typeb.fields) == type_nfields || return nothing + nflds = min(n_initialized(typea::Const), length(typeb.fields)) + else + nflds = min(n_initialized(typea::Const), n_initialized(typeb::Const)) end - fields = Vector{Any}(undef, type_nfields) - anyrefine = false - for i = 1:type_nfields + nflds == 0 && return nothing + fields = Vector{Any}(undef, nflds) + anyrefine = nflds > datatype_min_ninitialized(unwrap_unionall(aty)) + for i = 1:nflds ai = getfield_tfunc(𝕃, typea, Const(i)) bi = getfield_tfunc(𝕃, typeb, Const(i)) # N.B.: We're assuming here that !isType(aty), because that case diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index b5c1321c94eb37..75f33a280e245c 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -4754,6 +4754,14 @@ end @test a ⊑ c && b ⊑ c @test c === typeof(init) end + let init = Base.ImmutableDict{Any,Any}(1,2) + a = Const(init) + b = PartialStruct(typeof(init), Any[Const(getfield(init,1)), Any, Any]) + c = a ⊔ b + @test a ⊑ c && b ⊑ c + @test c isa PartialStruct + @test length(c.fields) == 3 + end let init = Base.ImmutableDict{Number,Number}() a = Const(init) b = PartialStruct(typeof(init), Any[Const(init), Number, ComplexF64]) @@ -4766,6 +4774,7 @@ end b = PartialStruct(typeof(init), Any[Const(init), Number, ComplexF64]) c = a ⊔ b @test a ⊑ c && b ⊑ c + @test c isa PartialStruct @test c.fields[2] === Number @test c.fields[3] === ComplexF64 end @@ -4774,9 +4783,40 @@ end b = PartialStruct(typeof(init), Any[Const(init), ComplexF32, Union{ComplexF32,ComplexF64}]) c = a ⊔ b @test a ⊑ c && b ⊑ c + @test c isa PartialStruct @test c.fields[2] === Complex @test c.fields[3] === Complex end + let T = Base.ImmutableDict{Number,Number} + a = PartialStruct(T, Any[T]) + b = PartialStruct(T, Any[T, Number, Number]) + @test b ⊑ a + c = a ⊔ b + @test a ⊑ c && b ⊑ c + @test c isa PartialStruct + @test length(c.fields) == 1 + end + let T = Base.ImmutableDict{Number,Number} + a = PartialStruct(T, Any[T]) + b = Const(T()) + c = a ⊔ b + @test a ⊑ c && b ⊑ c + @test c === T + end + let T = Base.ImmutableDict{Number,Number} + a = Const(T()) + b = PartialStruct(T, Any[T]) + c = a ⊔ b + @test a ⊑ c && b ⊑ c + @test c === T + end + let T = Base.ImmutableDict{Number,Number} + a = Const(T()) + b = Const(T(1,2)) + c = a ⊔ b + @test a ⊑ c && b ⊑ c + @test c === T + end global const ginit43784 = Base.ImmutableDict{Any,Any}() @test Base.return_types() do @@ -4810,6 +4850,31 @@ end @test a == Tuple end +let ⊑ = Core.Compiler.partialorder(Core.Compiler.fallback_lattice) + ⊔ = Core.Compiler.join(Core.Compiler.fallback_lattice) + Const, PartialStruct = Core.Const, Core.PartialStruct + + @test (Const((1,2)) ⊑ PartialStruct(Tuple{Int,Int}, Any[Const(1),Int])) + @test !(Const((1,2)) ⊑ PartialStruct(Tuple{Int,Int,Int}, Any[Const(1),Int,Int])) + @test !(Const((1,2,3)) ⊑ PartialStruct(Tuple{Int,Int}, Any[Const(1),Int])) + @test (Const((1,2,3)) ⊑ PartialStruct(Tuple{Int,Int,Int}, Any[Const(1),Int,Int])) + @test (Const((1,2)) ⊑ PartialStruct(Tuple{Int,Vararg{Int}}, Any[Const(1),Vararg{Int}])) + @test (Const((1,2)) ⊑ PartialStruct(Tuple{Int,Int,Vararg{Int}}, Any[Const(1),Int,Vararg{Int}])) broken=true + @test (Const((1,2,3)) ⊑ PartialStruct(Tuple{Int,Int,Vararg{Int}}, Any[Const(1),Int,Vararg{Int}])) + @test !(PartialStruct(Tuple{Int,Int}, Any[Const(1),Int]) ⊑ Const((1,2))) + @test !(PartialStruct(Tuple{Int,Int,Int}, Any[Const(1),Int,Int]) ⊑ Const((1,2))) + @test !(PartialStruct(Tuple{Int,Int}, Any[Const(1),Int]) ⊑ Const((1,2,3))) + @test !(PartialStruct(Tuple{Int,Int,Int}, Any[Const(1),Int,Int]) ⊑ Const((1,2,3))) + @test !(PartialStruct(Tuple{Int,Vararg{Int}}, Any[Const(1),Vararg{Int}]) ⊑ Const((1,2))) + @test !(PartialStruct(Tuple{Int,Int,Vararg{Int}}, Any[Const(1),Int,Vararg{Int}]) ⊑ Const((1,2))) + @test !(PartialStruct(Tuple{Int,Int,Vararg{Int}}, Any[Const(1),Int,Vararg{Int}]) ⊑ Const((1,2,3))) + + t = Const((false, false)) ⊔ Const((false, true)) + @test t isa PartialStruct && length(t.fields) == 2 && t.fields[1] === Const(false) + t = t ⊔ Const((false, false, 0)) + @test t ⊑ Union{Tuple{Bool,Bool},Tuple{Bool,Bool,Int}} +end + # Test that a function-wise `@max_methods` works as expected Base.Experimental.@max_methods 1 function f_max_methods end f_max_methods(x::Int) = 1