Skip to content

Commit

Permalink
more accurate lattice implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Aug 19, 2024
1 parent 3229055 commit d22ecae
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 20 deletions.
14 changes: 10 additions & 4 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,13 @@ end
@nospecializeinfer function (lattice::PartialsLattice, @nospecialize(a), @nospecialize(b))
if isa(a, PartialStruct)
if isa(b, PartialStruct)
if !(length(a.fields) == length(b.fields) && a.typ <: b.typ)
return false
a.typ <: b.typ || return false
if length(a.fields) length(b.fields)
if !(isvarargtype(a.fields[end]) || isvarargtype(b.fields[end]))
length(a.fields) length(b.fields) || return false
else
return false
end
end
for i in 1:length(b.fields)
af = a.fields[i]
Expand All @@ -493,8 +498,7 @@ end
return isa(b, Type) && a.typ <: b
elseif isa(b, PartialStruct)
if isa(a, Const)
nf = nfields(a.val)
nf == length(b.fields) || return false
n_initialized(a) length(b.fields) || return false
widea = widenconst(a)::DataType
wideb = widenconst(b)
wideb′ = unwrap_unionall(wideb)::DataType
Expand All @@ -504,8 +508,10 @@ end
if wideb′.name !== Tuple.name && !(widea <: wideb)
return false
end
nf = nfields(a.val)
for i in 1:nf
isdefined(a.val, i) || continue # since ∀ T Union{} ⊑ T
i > length(b.fields) && break
bfᵢ = b.fields[i]
if i == nf
bfᵢ = unwrapva(bfᵢ)
Expand Down
36 changes: 20 additions & 16 deletions base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,15 +321,24 @@ end
# even after complicated recursion and other operations on it elsewhere
const issimpleenoughtupleelem = issimpleenoughtype

function n_initialized(t::Const)
nf = nfields(t.val)
return something(findfirst(i::Int->!isdefined(t.val,i), 1:nf), nf+1)-1
end

# A simplified type_more_complex query over the extended lattice
# (assumes typeb ⊑ typea)
@nospecializeinfer function issimplertype(𝕃::AbstractLattice, @nospecialize(typea), @nospecialize(typeb))
@assert !isa(typea, LimitedAccuracy) && !isa(typeb, LimitedAccuracy) "LimitedAccuracy not supported by simplertype lattice" # n.b. the caller was supposed to handle these
typea === typeb && return true
if typea isa PartialStruct
aty = widenconst(typea)
if length(typea.fields) > datatype_min_ninitialized(unwrap_unionall(aty))
return false # TODO more accuracy here?
if typeb isa Const
@assert length(typea.fields) n_initialized(typeb) "typeb ⊑ typea is assumed"
elseif typeb isa PartialStruct
@assert length(typea.fields) length(typeb.fields) "typeb ⊑ typea is assumed"
else
return false
end
for i = 1:length(typea.fields)
ai = unwrapva(typea.fields[i])
Expand Down Expand Up @@ -579,26 +588,21 @@ end
aty = widenconst(typea)
bty = widenconst(typeb)
if aty === bty
# must have egal here, since we do not create PartialStruct for non-concrete types
typea_nfields = nfields_tfunc(𝕃, typea)
typeb_nfields = nfields_tfunc(𝕃, typeb)
isa(typea_nfields, Const) || return nothing
isa(typeb_nfields, Const) || return nothing
type_nfields = typea_nfields.val::Int
type_nfields == typeb_nfields.val::Int || return nothing
type_nfields == 0 && return nothing
if typea isa PartialStruct
if typeb isa PartialStruct
length(typea.fields) == length(typeb.fields) || return nothing
nflds = min(length(typea.fields), length(typeb.fields))
else
length(typea.fields) == type_nfields || return nothing
nflds = min(length(typea.fields), n_initialized(typeb::Const))
end
elseif typeb isa PartialStruct
length(typeb.fields) == type_nfields || return nothing
nflds = min(n_initialized(typea::Const), length(typeb.fields))
else
nflds = min(n_initialized(typea::Const), n_initialized(typeb::Const))
end
fields = Vector{Any}(undef, type_nfields)
anyrefine = false
for i = 1:type_nfields
nflds == 0 && return nothing
fields = Vector{Any}(undef, nflds)
anyrefine = nflds > datatype_min_ninitialized(unwrap_unionall(aty))
for i = 1:nflds
ai = getfield_tfunc(𝕃, typea, Const(i))
bi = getfield_tfunc(𝕃, typeb, Const(i))
# N.B.: We're assuming here that !isType(aty), because that case
Expand Down
40 changes: 40 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4754,6 +4754,14 @@ end
@test a c && b c
@test c === typeof(init)
end
let init = Base.ImmutableDict{Any,Any}(1,2)
a = Const(init)
b = PartialStruct(typeof(init), Any[Const(getfield(init,1)), Any, Any])
c = a b
@test a c && b c
@test c isa PartialStruct
@test length(c.fields) == 3
end
let init = Base.ImmutableDict{Number,Number}()
a = Const(init)
b = PartialStruct(typeof(init), Any[Const(init), Number, ComplexF64])
Expand All @@ -4766,6 +4774,7 @@ end
b = PartialStruct(typeof(init), Any[Const(init), Number, ComplexF64])
c = a b
@test a c && b c
@test c isa PartialStruct
@test c.fields[2] === Number
@test c.fields[3] === ComplexF64
end
Expand All @@ -4774,9 +4783,40 @@ end
b = PartialStruct(typeof(init), Any[Const(init), ComplexF32, Union{ComplexF32,ComplexF64}])
c = a b
@test a c && b c
@test c isa PartialStruct
@test c.fields[2] === Complex
@test c.fields[3] === Complex
end
let T = Base.ImmutableDict{Number,Number}
a = PartialStruct(T, Any[T])
b = PartialStruct(T, Any[T, Number, Number])
@test b a
c = a b
@test a c && b c
@test c isa PartialStruct
@test length(c.fields) == 1
end
let T = Base.ImmutableDict{Number,Number}
a = PartialStruct(T, Any[T])
b = Const(T())
c = a b
@test a c && b c
@test c === T
end
let T = Base.ImmutableDict{Number,Number}
a = Const(T())
b = PartialStruct(T, Any[T])
c = a b
@test a c && b c
@test c === T
end
let T = Base.ImmutableDict{Number,Number}
a = Const(T())
b = Const(T(1,2))
c = a b
@test a c && b c
@test c === T
end

global const ginit43784 = Base.ImmutableDict{Any,Any}()
@test Base.return_types() do
Expand Down

0 comments on commit d22ecae

Please sign in to comment.