Skip to content

Commit

Permalink
reflection: add Base.infer_return_type utility (JuliaLang#52247)
Browse files Browse the repository at this point in the history
This commit introduces `Base.infer_return_type`, a new reflection
utility which shares a similar interface with `Base.return_types` but
differs in its output; `Base.infer_return_type` provides a singular
return type taking into account all potential outcomes specified with
the given call signature. This function parallels `Base.infer_effects`
and the newly added `Base.infer_exception_type`, offering some utility,
especially in testing scenarios.
  • Loading branch information
aviatesk authored and mkitti committed Dec 9, 2023
1 parent 9c1b55f commit 8bdeceb
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 12 deletions.
9 changes: 6 additions & 3 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,11 @@ function typeinf_type(interp::AbstractInterpreter, method::Method, @nospecialize
if contains_is(unwrap_unionall(atype).parameters, Union{})
return Union{} # don't ask: it does weird and unnecessary things, if it occurs during bootstrap
end
mi = specialize_method(method, atype, sparams)::MethodInstance
return typeinf_type(interp, specialize_method(method, atype, sparams))
end
typeinf_type(interp::AbstractInterpreter, match::MethodMatch) =
typeinf_type(interp, specialize_method(match))
function typeinf_type(interp::AbstractInterpreter, mi::MethodInstance)
start_time = ccall(:jl_typeinf_timing_begin, UInt64, ())
code = get(code_cache(interp), mi, nothing)
if code isa CodeInstance
Expand Down Expand Up @@ -1120,8 +1124,7 @@ function _return_type(interp::AbstractInterpreter, t::DataType)
rt = widenconst(rt)
else
for match in _methods_by_ftype(t, -1, get_world_counter(interp))::Vector
match = match::MethodMatch
ty = typeinf_type(interp, match.method, match.spec_types, match.sparams)
ty = typeinf_type(interp, match::MethodMatch)
ty === nothing && return Any
rt = tmerge(rt, ty)
rt === Any && break
Expand Down
85 changes: 79 additions & 6 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1710,6 +1710,8 @@ check_generated_context(world::UInt) =
(ccall(:jl_is_in_pure_context, Bool, ()) || world == typemax(UInt)) &&
error("code reflection cannot be used from generated functions")

# TODO rename `Base.return_types` to `Base.infer_return_types`

"""
Base.return_types(
f, types=default_tt(f);
Expand Down Expand Up @@ -1741,9 +1743,9 @@ julia> Base.return_types(sum, Tuple{Vector{Int}})
julia> methods(sum, (Union{Vector{Int},UnitRange{Int}},))
# 2 methods for generic function "sum" from Base:
[1] sum(r::AbstractRange{<:Real})
@ range.jl:1396
@ range.jl:1399
[2] sum(a::AbstractArray; dims, kw...)
@ reducedim.jl:996
@ reducedim.jl:1010
julia> Base.return_types(sum, (Union{Vector{Int},UnitRange{Int}},))
2-element Vector{Any}:
Expand Down Expand Up @@ -1771,13 +1773,84 @@ function return_types(@nospecialize(f), @nospecialize(types=default_tt(f));
tt = signature_type(f, types)
matches = _methods_by_ftype(tt, #=lim=#-1, world)::Vector
for match in matches
match = match::Core.MethodMatch
ty = Core.Compiler.typeinf_type(interp, match.method, match.spec_types, match.sparams)
ty = Core.Compiler.typeinf_type(interp, match::Core.MethodMatch)
push!(rts, something(ty, Any))
end
return rts
end

"""
Base.infer_return_type(
f, types=default_tt(f);
world::UInt=get_world_counter(),
interp::Core.Compiler.AbstractInterpreter=Core.Compiler.NativeInterpreter(world)) -> rt::Type
Returns an inferred return type of the function call specified by `f` and `types`.
# Arguments
- `f`: The function to analyze.
- `types` (optional): The argument types of the function. Defaults to the default tuple type of `f`.
- `world` (optional): The world counter to use for the analysis. Defaults to the current world counter.
- `interp` (optional): The abstract interpreter to use for the analysis. Defaults to a new `Core.Compiler.NativeInterpreter` with the specified `world`.
# Returns
- `rt::Type`: An inferred return type of the function call specified by the given call signature.
!!! note
Note that, different from [`Base.return_types`](@ref), this doesn't give you the list
return types of every possible method matching with the given `f` and `types`.
It returns a single return type, taking into account all potential outcomes of
any function call entailed by the given signature type.
# Example
```julia
julia> checksym(::Symbol) = :symbol;
julia> checksym(x::Any) = x;
julia> Base.infer_return_type(checksym, (Union{Symbol,String},))
Union{String, Symbol}
julia> Base.return_types(checksym, (Union{Symbol,String},))
2-element Vector{Any}:
Symbol
Union{String, Symbol}
```
It's important to note the difference here: `Base.return_types` gives back inferred results
for each method that matches the given signature `checksum(::Union{Symbol,String})`.
On the other hand `Base.infer_return_type` returns one collective result that sums up all those possibilities.
!!! warning
The `Base.infer_return_type` function should not be used from generated functions;
doing so will result in an error.
"""
function infer_return_type(@nospecialize(f), @nospecialize(types=default_tt(f));
world::UInt=get_world_counter(),
interp::Core.Compiler.AbstractInterpreter=Core.Compiler.NativeInterpreter(world))
check_generated_context(world)
if isa(f, Core.OpaqueClosure)
return last(only(code_typed_opaque_closure(f)))
end
if isa(f, Core.Builtin)
return _builtin_return_type(interp, f, types)
end
tt = signature_type(f, types)
matches = Core.Compiler.findall(tt, Core.Compiler.method_table(interp))
if matches === nothing
# unanalyzable call, i.e. the interpreter world might be newer than the world where
# the `f` is defined, return the unknown return type
return Any
end
rt = Union{}
for match in matches.matches
ty = Core.Compiler.typeinf_type(interp, match::Core.MethodMatch)
rt = Core.Compiler.tmerge(rt, something(ty, Any))
end
return rt
end

"""
Base.infer_exception_types(
f, types=default_tt(f);
Expand Down Expand Up @@ -1880,7 +1953,7 @@ Returns the type of exception potentially thrown by the function call specified
!!! note
Note that, different from [`Base.infer_exception_types`](@ref), this doesn't give you the list
exception types for every possible matching method with the given `f` and `types`.
It provides a single exception type, taking into account all potential outcomes of
It returns a single exception type, taking into account all potential outcomes of
any function call entailed by the given signature type.
# Example
Expand Down Expand Up @@ -1964,7 +2037,7 @@ Returns the possible computation effects of the function call specified by `f` a
!!! note
Note that, different from [`Base.return_types`](@ref), this doesn't give you the list
effect analysis results for every possible matching method with the given `f` and `types`.
It provides a single effect, taking into account all potential outcomes of any function
It returns a single effect, taking into account all potential outcomes of any function
call entailed by the given signature type.
# Example
Expand Down
23 changes: 20 additions & 3 deletions test/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,22 @@ ambig_effects_test(a::Int, b) = 1
ambig_effects_test(a, b::Int) = 1
ambig_effects_test(a, b) = 1

@testset "infer_effects" begin
@testset "Base.infer_return_type[s]" begin
# generic function case
@test only(Base.return_types(issue41694, (Int,))) == Base.infer_return_type(issue41694, (Int,)) == Int
# case when it's not fully covered
@test only(Base.return_types(issue41694, (Integer,))) == Base.infer_return_type(issue41694, (Integer,)) == Int
# MethodError case
@test isempty(Base.return_types(issue41694, (Float64,)))
@test Base.infer_return_type(issue41694, (Float64,)) == Union{}
# builtin case
@test only(Base.return_types(typeof, (Any,))) == Base.infer_return_type(typeof, (Any,)) == DataType
@test only(Base.return_types(===, (Any,Any))) == Base.infer_return_type(===, (Any,Any)) == Bool
@test only(Base.return_types(setfield!, ())) == Base.infer_return_type(setfield!, ()) == Union{}
@test only(Base.return_types(Core.Intrinsics.mul_int, ())) == Base.infer_return_type(Core.Intrinsics.mul_int, ()) == Union{}
end

@testset "Base.infer_effects" begin
# generic functions
@test Base.infer_effects(issue41694, (Int,)) |> Core.Compiler.is_terminates
@test Base.infer_effects((Int,)) do x
Expand All @@ -1047,7 +1062,7 @@ ambig_effects_test(a, b) = 1
@test (Base.infer_effects(Core.Intrinsics.mul_int, ()); true) # `intrinsic_effects` shouldn't throw on empty `argtypes`
end

@testset "infer_exception_type[s]" begin
@testset "Base.infer_exception_type[s]" begin
# generic functions
@test Base.infer_exception_type(issue41694, (Int,)) == only(Base.infer_exception_types(issue41694, (Int,))) == ErrorException
@test Base.infer_exception_type((Int,)) do x
Expand Down Expand Up @@ -1119,7 +1134,9 @@ end
return :(x)
end
end
@test only(Base.return_types(generated_only_simple, (Real,))) == Core.Compiler.return_type(generated_only_simple, Tuple{Real}) == Any
@test only(Base.return_types(generated_only_simple, (Real,))) ==
Base.infer_return_type(generated_only_simple, (Real,)) ==
Core.Compiler.return_type(generated_only_simple, Tuple{Real}) == Any
let (src, rt) = only(code_typed(generated_only_simple, (Real,)))
@test src isa Method
@test rt == Any
Expand Down

0 comments on commit 8bdeceb

Please sign in to comment.