From 0b7ee72491d4b96e53084a8efbdb5953b672b66c Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Tue, 1 Aug 2023 23:32:13 -0400 Subject: [PATCH] inference: permit recursive type traits (#50694) We had a special case for Type that disallowed type trait recursion in favor of a pattern that almost never appears in code (only once in the compiler by accident where it doesn't matter). This was unnecessarily confusing and unexpected to predict what can infer, and made traits harder than necessary (such as Broadcast.ndims since 70fc3cdc11b). Fix #43296 Fix #43368 (cherry picked from commit 33e3d9f7de229a109cc2afeb72be2bb7931d3e79) --- base/compiler/typelimits.jl | 45 +++++++++++++++++++-------------- test/compiler/inference.jl | 50 +++++++++++++++++++++++++++++++++---- 2 files changed, 71 insertions(+), 24 deletions(-) diff --git a/base/compiler/typelimits.jl b/base/compiler/typelimits.jl index b648144ea3bd1e..e24fd257f02a6d 100644 --- a/base/compiler/typelimits.jl +++ b/base/compiler/typelimits.jl @@ -116,15 +116,32 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec return Union{a, b} end elseif isa(t, DataType) - if isType(t) # see equivalent case in type_more_complex - tt = unwrap_unionall(t.parameters[1]) - if isa(tt, Union) || isa(tt, TypeVar) || isType(tt) - is_derived_type_from_any(tt, sources, depth + 1) && return t + if isType(t) + # Type is fairly important, so do not widen it as fast as other types if avoidable + tt = t.parameters[1] + ttu = unwrap_unionall(tt) # TODO: use argument_datatype(tt) after #50692 fixed + # must forbid nesting through this if we detect that potentially occurring + # we already know !is_derived_type_from_any so refuse to recurse here + if !isa(ttu, DataType) + return Type + elseif isType(ttu) + return Type{<:Type} + end + # try to peek into c to get a comparison object, but if we can't perhaps t is already simple enough on its own + # (this is slightly more permissive than type_more_complex implements for the same case). + if isType(c) + ct = c.parameters[1] else - isType(c) && (c = unwrap_unionall(c.parameters[1])) - type_more_complex(tt, c, sources, depth, 0, 0) || return t + ct = Union{} end - return Type + Qt = __limit_type_size(tt, ct, sources, depth + 1, 0) + Qt === Any && return Type + Qt === tt && return t + # Can't form Type{<:Qt} just yet, without first make sure we limited the depth + # enough, since this moves Qt outside of Type for is_derived_type_from_any + Qt = __limit_type_size(tt, ct, sources, depth + 2, 0) + Qt === Any && return Type + return Type{<:Qt} elseif isa(c, DataType) tP = t.parameters cP = c.parameters @@ -157,6 +174,7 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec end end if allowed_tuplelen < 1 && t.name === Tuple.name + # forbid nesting Tuple{Tuple{Tuple...}} through this return Any end widert = t.name.wrapper @@ -247,18 +265,7 @@ function type_more_complex(@nospecialize(t), @nospecialize(c), sources::SimpleVe # base case for data types if isa(t, DataType) tP = t.parameters - if isType(t) - # Treat Type{T} and T as equivalent to allow taking typeof any - # source type (DataType) anywhere as Type{...}, as long as it isn't - # nesting as Type{Type{...}} - tt = unwrap_unionall(t.parameters[1]) - if isa(tt, Union) || isa(tt, TypeVar) || isType(tt) - return !is_derived_type_from_any(tt, sources, depth + 1) - else - isType(c) && (c = unwrap_unionall(c.parameters[1])) - return type_more_complex(tt, c, sources, depth, 0, 0) - end - elseif isa(c, DataType) && t.name === c.name + if isa(c, DataType) && t.name === c.name cP = c.parameters length(cP) < length(tP) && return true isempty(tP) && return false diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 2ef172b3e36437..a34ba18e0d04ee 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -60,7 +60,8 @@ end # issue #42835 @test !Core.Compiler.type_more_complex(Int, Any, Core.svec(), 1, 1, 1) @test !Core.Compiler.type_more_complex(Int, Type{Int}, Core.svec(), 1, 1, 1) -@test !Core.Compiler.type_more_complex(Type{Int}, Any, Core.svec(), 1, 1, 1) +@test Core.Compiler.type_more_complex(Type{Int}, Any, Core.svec(), 1, 1, 1) # maybe should be fixed? +@test Core.Compiler.limit_type_size(Type{Int}, Any, Union{}, 0, 0) == Type{Int} @test Core.Compiler.type_more_complex(Type{Type{Int}}, Type{Int}, Core.svec(Type{Int}), 1, 1, 1) @test Core.Compiler.type_more_complex(Type{Type{Int}}, Int, Core.svec(Type{Int}), 1, 1, 1) @test Core.Compiler.type_more_complex(Type{Type{Int}}, Any, Core.svec(), 1, 1, 1) @@ -71,22 +72,23 @@ end @test Core.Compiler.type_more_complex(ComplexF32, Type{ComplexF32}, Core.svec(), 1, 1, 1) @test !Core.Compiler.type_more_complex(Type{ComplexF32}, Any, Core.svec(Type{Type{ComplexF32}}), 1, 1, 1) @test Core.Compiler.type_more_complex(Type{ComplexF32}, Type{Type{ComplexF32}}, Core.svec(), 1, 1, 1) -@test !Core.Compiler.type_more_complex(Type{ComplexF32}, ComplexF32, Core.svec(), 1, 1, 1) +@test Core.Compiler.type_more_complex(Type{ComplexF32}, ComplexF32, Core.svec(), 1, 1, 1) +@test Core.Compiler.limit_type_size(Type{ComplexF32}, ComplexF32, Union{}, 1, 1) == Type{<:Complex} @test Core.Compiler.type_more_complex(Type{ComplexF32}, Any, Core.svec(), 1, 1, 1) @test Core.Compiler.type_more_complex(Type{Type{ComplexF32}}, Type{ComplexF32}, Core.svec(Type{ComplexF32}), 1, 1, 1) @test Core.Compiler.type_more_complex(Type{Type{ComplexF32}}, ComplexF32, Core.svec(ComplexF32), 1, 1, 1) @test Core.Compiler.type_more_complex(Type{Type{Type{ComplexF32}}}, Type{Type{ComplexF32}}, Core.svec(Type{ComplexF32}), 1, 1, 1) # n.b. Type{Type{Union{}} === Type{Core.TypeofBottom} -@test !Core.Compiler.type_more_complex(Type{Union{}}, Any, Core.svec(), 1, 1, 1) -@test !Core.Compiler.type_more_complex(Type{Type{Union{}}}, Any, Core.svec(), 1, 1, 1) +@test Core.Compiler.type_more_complex(Type{Union{}}, Any, Core.svec(), 1, 1, 1) +@test Core.Compiler.type_more_complex(Type{Type{Union{}}}, Any, Core.svec(), 1, 1, 1) @test Core.Compiler.type_more_complex(Type{Type{Type{Union{}}}}, Any, Core.svec(), 1, 1, 1) @test Core.Compiler.type_more_complex(Type{Type{Type{Union{}}}}, Type{Type{Union{}}}, Core.svec(Type{Type{Union{}}}), 1, 1, 1) @test Core.Compiler.type_more_complex(Type{Type{Type{Type{Union{}}}}}, Type{Type{Type{Union{}}}}, Core.svec(Type{Type{Type{Union{}}}}), 1, 1, 1) @test !Core.Compiler.type_more_complex(Type{1}, Type{2}, Core.svec(), 1, 1, 1) @test Core.Compiler.type_more_complex(Type{Union{Float32,Float64}}, Union{Float32,Float64}, Core.svec(Union{Float32,Float64}), 1, 1, 1) -@test !Core.Compiler.type_more_complex(Type{Union{Float32,Float64}}, Union{Float32,Float64}, Core.svec(Union{Float32,Float64}), 0, 1, 1) +@test Core.Compiler.type_more_complex(Type{Union{Float32,Float64}}, Union{Float32,Float64}, Core.svec(Union{Float32,Float64}), 0, 1, 1) @test Core.Compiler.type_more_complex(Type{<:Union{Float32,Float64}}, Type{Union{Float32,Float64}}, Core.svec(Union{Float32,Float64}), 1, 1, 1) @test Core.Compiler.type_more_complex(Type{<:Union{Float32,Float64}}, Any, Core.svec(Union{Float32,Float64}), 1, 1, 1) @@ -101,6 +103,44 @@ let # 40336 @test t !== r && t <: r end +@test Core.Compiler.limit_type_size(Type{Type{Type{Int}}}, Type, Union{}, 0, 0) == Type{<:Type} +@test Core.Compiler.limit_type_size(Type{Type{Int}}, Type, Union{}, 0, 0) == Type{<:Type} +@test Core.Compiler.limit_type_size(Type{Int}, Type, Union{}, 0, 0) == Type{Int} +@test Core.Compiler.limit_type_size(Type{<:Int}, Type, Union{}, 0, 0) == Type{<:Int} +@test Core.Compiler.limit_type_size(Type{ComplexF32}, ComplexF32, Union{}, 0, 0) == Type{<:Complex} # added nesting +@test Core.Compiler.limit_type_size(Type{ComplexF32}, Type{ComplexF64}, Union{}, 0, 0) == Type{ComplexF32} # base matches +@test Core.Compiler.limit_type_size(Type{ComplexF32}, Type, Union{}, 0, 0) == Type{<:Complex} +@test_broken Core.Compiler.limit_type_size(Type{<:ComplexF64}, Type, Union{}, 0, 0) == Type{<:Complex} +@test Core.Compiler.limit_type_size(Type{<:ComplexF64}, Type, Union{}, 0, 0) == Type #50692 +@test Core.Compiler.limit_type_size(Type{Union{ComplexF32,ComplexF64}}, Type, Union{}, 0, 0) == Type +@test_broken Core.Compiler.limit_type_size(Type{Union{ComplexF32,ComplexF64}}, Type, Union{}, 0, 0) == Type{<:Complex} #50692 +@test Core.Compiler.limit_type_size(Type{Union{Float32,Float64}}, Type, Union{}, 0, 0) == Type +@test Core.Compiler.limit_type_size(Type{Union{Int,Type{Int}}}, Type{Type{Int}}, Union{}, 0, 0) == Type +@test Core.Compiler.limit_type_size(Type{Union{Int,Type{Int}}}, Union{Type{Int},Type{Type{Int}}}, Union{}, 0, 0) == Type +@test Core.Compiler.limit_type_size(Type{Union{Int,Type{Int}}}, Type{Union{Type{Int},Type{Type{Int}}}}, Union{}, 0, 0) == Type{Union{Int, Type{Int}}} +@test Core.Compiler.limit_type_size(Type{Union{Int,Type{Int}}}, Type{Type{Int}}, Union{}, 0, 0) == Type + + +# issue #43296 #43296 +struct C43296{t,I} end +r43296(b) = r43296(typeof(b)) +r43296(::Type) = nothing +r43296(::Nothing) = nonexistent +r43296(::Type{C43296{c,d}}) where {c,d} = f43296(r43296(c), e) +f43296(::Nothing, :) = nothing +f43296(g, :) = h +k43296(b, j, :) = l +k43296(b, j, ::Nothing) = b +i43296(b, j) = k43296(b, j, r43296(j)) +@test only(Base.return_types(i43296, (Int, C43296{C43296{C43296{Val, Tuple}, Tuple}}))) == Int + +abstract type e43296{a, j} <: AbstractArray{a, j} end +abstract type b43296{a, j, c, d} <: e43296{a, j} end +struct h43296{a, j, f, d, i} <: b43296{a, j, f, d} end +Base.ndims(::Type{f}) where {f<:e43296} = ndims(supertype(f)) +Base.ndims(g::e43296) = ndims(typeof(g)) +@test only(Base.return_types(ndims, (h43296{Any, 0, Any, Int, Any},))) == Int + @test Core.Compiler.unionlen(Union{}) == 1 @test Core.Compiler.unionlen(Int8) == 1 @test Core.Compiler.unionlen(Union{Int8, Int16}) == 2