Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional argument to @inferred that relaxes the case of Nothing and Missing #27516

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 44 additions & 15 deletions stdlib/Test/src/Test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import Distributed: myid
using Random
using Random: AbstractRNG, GLOBAL_RNG
using InteractiveUtils: gen_call_with_extracted_types
using Core.Compiler: typesubtract

#-----------------------------------------------------------------------

Expand Down Expand Up @@ -1262,39 +1263,65 @@ end
_args_and_call(args...; kwargs...) = (args[1:end-1], kwargs, args[end](args[1:end-1]...; kwargs...))
_materialize_broadcasted(f, args...) = Broadcast.materialize(Broadcast.broadcasted(f, args...))
"""
@inferred f(x)
@inferred [AllowedType] f(x)

Tests that the call expression `f(x)` returns a value of the same type
inferred by the compiler. It is useful to check for type stability.
Tests that the call expression `f(x)` returns a value of the same type inferred by the
compiler. It is useful to check for type stability.

`f(x)` can be any call expression.
Returns the result of `f(x)` if the types match,
and an `Error` `Result` if it finds different types.
`f(x)` can be any call expression. Returns the result of `f(x)` if the types match, and an
`Error` `Result` if it finds different types.

Optionally, `AllowedType` relaxes the test, by making it pass when either the type of `f(x)`
matches the inferred type modulo `AllowedType`, or when the return type is a subtype of
`AllowedType`. This is useful when testing type stability of functions returning a small
union such as `Union{Nothing, T}` or `Union{Missing, T}`.

```jldoctest; setup = :(using InteractiveUtils), filter = r"begin\\n(.|\\n)*end"
julia> f(a, b, c) = b > 1 ? 1 : 1.0
julia> f(a) = a > 1 ? 1 : 1.0
f (generic function with 1 method)

julia> typeof(f(1, 2, 3))
julia> typeof(f(2))
Int64

julia> @code_warntype f(1, 2, 3)
julia> @code_warntype f(2)
Body::UNION{FLOAT64, INT64}
1 ─ %1 = (Base.slt_int)(1, b)::Bool
1 ─ %1 = (Base.slt_int)(1, a)::Bool
└── goto #3 if not %1
2 ─ return 1
3 ─ return 1.0

julia> @inferred f(1, 2, 3)
julia> @inferred f(2)
ERROR: return type Int64 does not match inferred return type Union{Float64, Int64}
Stacktrace:
[...]

julia> @inferred max(1, 2)
2

julia> g(a) = a < 10 ? missing : 1.0
g (generic function with 1 method)

julia> @inferred g(20)
ERROR: return type Float64 does not match inferred return type Union{Missing, Float64}
[...]

julia> @inferred Missing g(20)
1.0

julia> h(a) = a < 10 ? missing : f(a)
h (generic function with 1 method)

julia> @inferred Missing h(20)
ERROR: return type Int64 does not match inferred return type Union{Missing, Float64, Int64}
[...]
```
"""
macro inferred(ex)
_inferred(ex, __module__)
end
macro inferred(allow, ex)
_inferred(ex, __module__, allow)
end
function _inferred(ex, mod, allow = :(Union{}))
if Meta.isexpr(ex, :ref)
ex = Expr(:call, :getindex, ex.args...)
end
Expand All @@ -1307,13 +1334,15 @@ macro inferred(ex)
end
Base.remove_linenums!(quote
let
allow = $(esc(allow))
allow isa Type || throw(ArgumentError("@inferred requires a type as second argument"))
$(if any(a->(Meta.isexpr(a, :kw) || Meta.isexpr(a, :parameters)), ex.args)
# Has keywords
args = gensym()
kwargs = gensym()
quote
$(esc(args)), $(esc(kwargs)), result = $(esc(Expr(:call, _args_and_call, ex.args[2:end]..., ex.args[1])))
inftypes = $(gen_call_with_extracted_types(__module__, Base.return_types, :($(ex.args[1])($(args)...; $(kwargs)...))))
inftypes = $(gen_call_with_extracted_types(mod, Base.return_types, :($(ex.args[1])($(args)...; $(kwargs)...))))
end
else
# No keywords
Expand All @@ -1324,8 +1353,8 @@ macro inferred(ex)
end
end)
@assert length(inftypes) == 1
rettype = isa(result, Type) ? Type{result} : typeof(result)
rettype == inftypes[1] || error("return type $rettype does not match inferred return type $(inftypes[1])")
rettype = result isa Type ? Type{result} : typeof(result)
rettype <: allow || rettype == typesubtract(inftypes[1], allow) || error("return type $rettype does not match inferred return type $(inftypes[1])")
result
end
end)
Expand Down
32 changes: 15 additions & 17 deletions stdlib/Test/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -491,13 +491,15 @@ for i in 1:6
end

# test @inferred
function uninferrable_function(i)
q = [1, "1"]
return q[i]
end

uninferrable_function(i) = (1, "1")[i]
uninferrable_small_union(i) = (1, nothing)[i]
@test_throws ErrorException @inferred(uninferrable_function(1))
@test @inferred(identity(1)) == 1
@test @inferred(Nothing, uninferrable_small_union(1)) === 1
@test @inferred(Nothing, uninferrable_small_union(2)) === nothing
@test_throws ErrorException @inferred(Missing, uninferrable_small_union(1))
@test_throws ErrorException @inferred(Missing, uninferrable_small_union(2))
@test_throws ArgumentError @inferred(nothing, uninferrable_small_union(1))

# Ensure @inferred only evaluates the arguments once
inferred_test_global = 0
Expand All @@ -512,8 +514,8 @@ end
struct SillyArray <: AbstractArray{Float64,1} end
Base.getindex(a::SillyArray, i) = rand() > 0.5 ? 0 : false
@testset "@inferred works with A[i] expressions" begin
@test @inferred((1:3)[2]) == 2
test_result = @test_throws ErrorException @inferred(SillyArray()[2])
@test (@inferred (1:3)[2]) == 2
test_result = @test_throws ErrorException (@inferred SillyArray()[2])
@test occursin("Bool", test_result.value.msg)
end
# Issue #14928
Expand All @@ -522,16 +524,12 @@ end

# Issue #17105
# @inferred with kwargs
function inferrable_kwtest(x; y=1)
2x
end
function uninferrable_kwtest(x; y=1)
2x+y
end
@test @inferred(inferrable_kwtest(1)) == 2
@test @inferred(inferrable_kwtest(1; y=1)) == 2
@test @inferred(uninferrable_kwtest(1)) == 3
@test @inferred(uninferrable_kwtest(1; y=2)) == 4
inferrable_kwtest(x; y=1) = 2x
uninferrable_kwtest(x; y=1) = 2x+y
@test (@inferred inferrable_kwtest(1)) == 2
@test (@inferred inferrable_kwtest(1; y=1)) == 2
@test (@inferred uninferrable_kwtest(1)) == 3
@test (@inferred uninferrable_kwtest(1; y=2)) == 4

@test_throws ErrorException @testset "$(error())" for i in 1:10
end
Expand Down