Skip to content

Commit

Permalink
more accurate lattice implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Aug 19, 2024
1 parent 3229055 commit d1b45b2
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 25 deletions.
28 changes: 19 additions & 9 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,13 @@ end
@nospecializeinfer function (lattice::PartialsLattice, @nospecialize(a), @nospecialize(b))
if isa(a, PartialStruct)
if isa(b, PartialStruct)
if !(length(a.fields) == length(b.fields) && a.typ <: b.typ)
return false
a.typ <: b.typ || return false
if length(a.fields) length(b.fields)
if !(isvarargtype(a.fields[end]) || isvarargtype(b.fields[end]))
length(a.fields) length(b.fields) || return false
else
return false
end
end
for i in 1:length(b.fields)
af = a.fields[i]
Expand All @@ -493,19 +498,24 @@ end
return isa(b, Type) && a.typ <: b
elseif isa(b, PartialStruct)
if isa(a, Const)
nf = nfields(a.val)
nf == length(b.fields) || return false
widea = widenconst(a)::DataType
wideb = widenconst(b)
wideb′ = unwrap_unionall(wideb)::DataType
widea.name === wideb′.name || return false
# We can skip the subtype check if b is a Tuple, since in that
# case, the ⊑ of the elements is sufficient.
if wideb′.name !== Tuple.name && !(widea <: wideb)
return false
if wideb′.name === Tuple.name
# We can skip the subtype check if b is a Tuple, since in that
# case, the ⊑ of the elements is sufficient.
# But for tuple comparisons, we need their lengths to be the same for now.
# TODO improve accuracy for cases when `b` contains vararg element
nfields(a.val) == length(b.fields) || return false
else
widea <: wideb || return false
# for structs we need to check that `a` has more information than `b` that may be partially initialized
n_initialized(a) length(b.fields) || return false
end
for i in 1:nf
for i in 1:nfields(a.val)
isdefined(a.val, i) || continue # since ∀ T Union{} ⊑ T
i > length(b.fields) && break # `a` has more information than `b` that is partially initialized struct
bfᵢ = b.fields[i]
if i == nf
bfᵢ = unwrapva(bfᵢ)
Expand Down
36 changes: 20 additions & 16 deletions base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,15 +321,24 @@ end
# even after complicated recursion and other operations on it elsewhere
const issimpleenoughtupleelem = issimpleenoughtype

function n_initialized(t::Const)
nf = nfields(t.val)
return something(findfirst(i::Int->!isdefined(t.val,i), 1:nf), nf+1)-1
end

# A simplified type_more_complex query over the extended lattice
# (assumes typeb ⊑ typea)
@nospecializeinfer function issimplertype(𝕃::AbstractLattice, @nospecialize(typea), @nospecialize(typeb))
@assert !isa(typea, LimitedAccuracy) && !isa(typeb, LimitedAccuracy) "LimitedAccuracy not supported by simplertype lattice" # n.b. the caller was supposed to handle these
typea === typeb && return true
if typea isa PartialStruct
aty = widenconst(typea)
if length(typea.fields) > datatype_min_ninitialized(unwrap_unionall(aty))
return false # TODO more accuracy here?
if typeb isa Const
@assert length(typea.fields) n_initialized(typeb) "typeb ⊑ typea is assumed"
elseif typeb isa PartialStruct
@assert length(typea.fields) length(typeb.fields) "typeb ⊑ typea is assumed"
else
return false
end
for i = 1:length(typea.fields)
ai = unwrapva(typea.fields[i])
Expand Down Expand Up @@ -579,26 +588,21 @@ end
aty = widenconst(typea)
bty = widenconst(typeb)
if aty === bty
# must have egal here, since we do not create PartialStruct for non-concrete types
typea_nfields = nfields_tfunc(𝕃, typea)
typeb_nfields = nfields_tfunc(𝕃, typeb)
isa(typea_nfields, Const) || return nothing
isa(typeb_nfields, Const) || return nothing
type_nfields = typea_nfields.val::Int
type_nfields == typeb_nfields.val::Int || return nothing
type_nfields == 0 && return nothing
if typea isa PartialStruct
if typeb isa PartialStruct
length(typea.fields) == length(typeb.fields) || return nothing
nflds = min(length(typea.fields), length(typeb.fields))
else
length(typea.fields) == type_nfields || return nothing
nflds = min(length(typea.fields), n_initialized(typeb::Const))
end
elseif typeb isa PartialStruct
length(typeb.fields) == type_nfields || return nothing
nflds = min(n_initialized(typea::Const), length(typeb.fields))
else
nflds = min(n_initialized(typea::Const), n_initialized(typeb::Const))
end
fields = Vector{Any}(undef, type_nfields)
anyrefine = false
for i = 1:type_nfields
nflds == 0 && return nothing
fields = Vector{Any}(undef, nflds)
anyrefine = nflds > datatype_min_ninitialized(unwrap_unionall(aty))
for i = 1:nflds
ai = getfield_tfunc(𝕃, typea, Const(i))
bi = getfield_tfunc(𝕃, typeb, Const(i))
# N.B.: We're assuming here that !isType(aty), because that case
Expand Down
65 changes: 65 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4754,6 +4754,14 @@ end
@test a c && b c
@test c === typeof(init)
end
let init = Base.ImmutableDict{Any,Any}(1,2)
a = Const(init)
b = PartialStruct(typeof(init), Any[Const(getfield(init,1)), Any, Any])
c = a b
@test a c && b c
@test c isa PartialStruct
@test length(c.fields) == 3
end
let init = Base.ImmutableDict{Number,Number}()
a = Const(init)
b = PartialStruct(typeof(init), Any[Const(init), Number, ComplexF64])
Expand All @@ -4766,6 +4774,7 @@ end
b = PartialStruct(typeof(init), Any[Const(init), Number, ComplexF64])
c = a b
@test a c && b c
@test c isa PartialStruct
@test c.fields[2] === Number
@test c.fields[3] === ComplexF64
end
Expand All @@ -4774,9 +4783,40 @@ end
b = PartialStruct(typeof(init), Any[Const(init), ComplexF32, Union{ComplexF32,ComplexF64}])
c = a b
@test a c && b c
@test c isa PartialStruct
@test c.fields[2] === Complex
@test c.fields[3] === Complex
end
let T = Base.ImmutableDict{Number,Number}
a = PartialStruct(T, Any[T])
b = PartialStruct(T, Any[T, Number, Number])
@test b a
c = a b
@test a c && b c
@test c isa PartialStruct
@test length(c.fields) == 1
end
let T = Base.ImmutableDict{Number,Number}
a = PartialStruct(T, Any[T])
b = Const(T())
c = a b
@test a c && b c
@test c === T
end
let T = Base.ImmutableDict{Number,Number}
a = Const(T())
b = PartialStruct(T, Any[T])
c = a b
@test a c && b c
@test c === T
end
let T = Base.ImmutableDict{Number,Number}
a = Const(T())
b = Const(T(1,2))
c = a b
@test a c && b c
@test c === T
end

global const ginit43784 = Base.ImmutableDict{Any,Any}()
@test Base.return_types() do
Expand Down Expand Up @@ -4810,6 +4850,31 @@ end
@test a == Tuple
end

let = Core.Compiler.partialorder(Core.Compiler.fallback_lattice)
= Core.Compiler.join(Core.Compiler.fallback_lattice)
Const, PartialStruct = Core.Const, Core.PartialStruct

@test (Const((1,2)) PartialStruct(Tuple{Int,Int}, Any[Const(1),Int]))
@test !(Const((1,2)) PartialStruct(Tuple{Int,Int,Int}, Any[Const(1),Int,Int]))
@test !(Const((1,2,3)) PartialStruct(Tuple{Int,Int}, Any[Const(1),Int]))
@test (Const((1,2,3)) PartialStruct(Tuple{Int,Int,Int}, Any[Const(1),Int,Int]))
@test (Const((1,2)) PartialStruct(Tuple{Int,Vararg{Int}}, Any[Const(1),Vararg{Int}]))
@test (Const((1,2)) PartialStruct(Tuple{Int,Int,Vararg{Int}}, Any[Const(1),Int,Vararg{Int}])) broken=true
@test (Const((1,2,3)) PartialStruct(Tuple{Int,Int,Vararg{Int}}, Any[Const(1),Int,Vararg{Int}]))
@test !(PartialStruct(Tuple{Int,Int}, Any[Const(1),Int]) Const((1,2)))
@test !(PartialStruct(Tuple{Int,Int,Int}, Any[Const(1),Int,Int]) Const((1,2)))
@test !(PartialStruct(Tuple{Int,Int}, Any[Const(1),Int]) Const((1,2,3)))
@test !(PartialStruct(Tuple{Int,Int,Int}, Any[Const(1),Int,Int]) Const((1,2,3)))
@test !(PartialStruct(Tuple{Int,Vararg{Int}}, Any[Const(1),Vararg{Int}]) Const((1,2)))
@test !(PartialStruct(Tuple{Int,Int,Vararg{Int}}, Any[Const(1),Int,Vararg{Int}]) Const((1,2)))
@test !(PartialStruct(Tuple{Int,Int,Vararg{Int}}, Any[Const(1),Int,Vararg{Int}]) Const((1,2,3)))

t = Const((false, false)) Const((false, true))
@test t isa PartialStruct && length(t.fields) == 2 && t.fields[1] === Const(false)
t = t Const((false, false, 0))
@test t Union{Tuple{Bool,Bool},Tuple{Bool,Bool,Int}}
end

# Test that a function-wise `@max_methods` works as expected
Base.Experimental.@max_methods 1 function f_max_methods end
f_max_methods(x::Int) = 1
Expand Down

0 comments on commit d1b45b2

Please sign in to comment.