Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AbstractInterpreter: enable selective pure/concrete eval for external AbstractInterpreter with overlayed method table #44515

Merged
merged 2 commits into from
Mar 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 78 additions & 42 deletions base/compiler/abstractinterpretation.jl

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ mutable struct InferenceState
#=parent=#nothing,
#=cached=#cache === :global,
#=inferred=#false, #=dont_work_on_me=#false,
#=ipo_effects=#Effects(consistent, ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, inbounds_taints_consistency),
#=ipo_effects=#Effects(consistent, ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, false, inbounds_taints_consistency),
interp)
result.result = frame
cache !== :no && push!(get_inference_cache(interp), result)
Expand Down
60 changes: 35 additions & 25 deletions base/compiler/methodtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,18 @@ end
getindex(result::MethodLookupResult, idx::Int) = getindex(result.matches, idx)::MethodMatch

"""
findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) -> MethodLookupResult or missing
findall(sig::Type, view::MethodTableView; limit::Int=typemax(Int)) ->
(matches::MethodLookupResult, overlayed::Bool) or missing

Find all methods in the given method table `view` that are applicable to the
given signature `sig`. If no applicable methods are found, an empty result is
returned. If the number of applicable methods exceeded the specified limit,
`missing` is returned.
Find all methods in the given method table `view` that are applicable to the given signature `sig`.
If no applicable methods are found, an empty result is returned.
If the number of applicable methods exceeded the specified limit, `missing` is returned.
`overlayed` indicates if any matching method is defined in an overlayed method table.
"""
function findall(@nospecialize(sig::Type), table::InternalMethodTable; limit::Int=Int(typemax(Int32)))
return _findall(sig, nothing, table.world, limit)
result = _findall(sig, nothing, table.world, limit)
result === missing && return missing
return result, false
end

function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int=Int(typemax(Int32)))
Expand All @@ -57,7 +60,7 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
nr = length(result)
if nr ≥ 1 && result[nr].fully_covers
# no need to fall back to the internal method table
return result
return result, true
end
# fall back to the internal method table
fallback_result = _findall(sig, nothing, table.world, limit)
Expand All @@ -68,7 +71,7 @@ function findall(@nospecialize(sig::Type), table::OverlayMethodTable; limit::Int
WorldRange(
max(result.valid_worlds.min_world, fallback_result.valid_worlds.min_world),
min(result.valid_worlds.max_world, fallback_result.valid_worlds.max_world)),
result.ambig | fallback_result.ambig)
result.ambig | fallback_result.ambig), !isempty(result)
end

function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt, limit::Int)
Expand All @@ -83,31 +86,38 @@ function _findall(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable},
end

"""
findsup(sig::Type, view::MethodTableView) -> Tuple{MethodMatch, WorldRange} or nothing

Find the (unique) method `m` such that `sig <: m.sig`, while being more
specific than any other method with the same property. In other words, find
the method which is the least upper bound (supremum) under the specificity/subtype
relation of the queried `signature`. If `sig` is concrete, this is equivalent to
asking for the method that will be called given arguments whose types match the
given signature. This query is also used to implement `invoke`.

Such a method `m` need not exist. It is possible that no method is an
upper bound of `sig`, or it is possible that among the upper bounds, there
is no least element. In both cases `nothing` is returned.
findsup(sig::Type, view::MethodTableView) ->
(match::MethodMatch, valid_worlds::WorldRange, overlayed::Bool) or nothing

Find the (unique) method such that `sig <: match.method.sig`, while being more
specific than any other method with the same property. In other words, find the method
which is the least upper bound (supremum) under the specificity/subtype relation of
the queried `sig`nature. If `sig` is concrete, this is equivalent to asking for the method
that will be called given arguments whose types match the given signature.
Note that this query is also used to implement `invoke`.

Such a matching method `match` doesn't necessarily exist.
It is possible that no method is an upper bound of `sig`, or
it is possible that among the upper bounds, there is no least element.
In both cases `nothing` is returned.

`overlayed` indicates if the matching method is defined in an overlayed method table.
"""
function findsup(@nospecialize(sig::Type), table::InternalMethodTable)
return _findsup(sig, nothing, table.world)
return (_findsup(sig, nothing, table.world)..., false)
end

function findsup(@nospecialize(sig::Type), table::OverlayMethodTable)
match, valid_worlds = _findsup(sig, table.mt, table.world)
match !== nothing && return match, valid_worlds
match !== nothing && return match, valid_worlds, true
# fall back to the internal method table
fallback_match, fallback_valid_worlds = _findsup(sig, nothing, table.world)
return fallback_match, WorldRange(
max(valid_worlds.min_world, fallback_valid_worlds.min_world),
min(valid_worlds.max_world, fallback_valid_worlds.max_world))
return (
fallback_match,
WorldRange(
max(valid_worlds.min_world, fallback_valid_worlds.min_world),
min(valid_worlds.max_world, fallback_valid_worlds.max_world)),
false)
end

function _findsup(@nospecialize(sig::Type), mt::Union{Nothing,Core.MethodTable}, world::UInt)
Expand Down
1 change: 1 addition & 0 deletions base/compiler/ssair/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ function Base.show(io::IO, e::Core.Compiler.Effects)
print(io, ',')
printstyled(io, string(tristate_letter(e.terminates), 't'); color=tristate_color(e.terminates))
print(io, ')')
e.overlayed && printstyled(io, ''; color=:red)
end

@specialize
12 changes: 8 additions & 4 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1789,11 +1789,11 @@ function builtin_effects(f::Builtin, argtypes::Vector{Any}, rt)
if (f === Core.getfield || f === Core.isdefined) && length(argtypes) >= 3
# consistent if the argtype is immutable
if isvarargtype(argtypes[2])
return Effects(Effects(), effect_free=ALWAYS_TRUE, terminates=ALWAYS_TRUE)
return Effects(; effect_free=ALWAYS_TRUE, terminates=ALWAYS_TRUE, overlayed=false)
end
s = widenconst(argtypes[2])
if isType(s) || !isa(s, DataType) || isabstracttype(s)
return Effects(Effects(), effect_free=ALWAYS_TRUE, terminates=ALWAYS_TRUE)
return Effects(; effect_free=ALWAYS_TRUE, terminates=ALWAYS_TRUE, overlayed=false)
end
s = s::DataType
ipo_consistent = !ismutabletype(s)
Expand Down Expand Up @@ -1826,7 +1826,9 @@ function builtin_effects(f::Builtin, argtypes::Vector{Any}, rt)
ipo_consistent ? ALWAYS_TRUE : ALWAYS_FALSE,
effect_free ? ALWAYS_TRUE : ALWAYS_FALSE,
nothrow ? ALWAYS_TRUE : TRISTATE_UNKNOWN,
ALWAYS_TRUE)
#=terminates=#ALWAYS_TRUE,
#=overlayed=#false,
)
end

function builtin_nothrow(@nospecialize(f), argtypes::Array{Any, 1}, @nospecialize(rt))
Expand Down Expand Up @@ -2007,7 +2009,9 @@ function intrinsic_effects(f::IntrinsicFunction, argtypes::Vector{Any})
ipo_consistent ? ALWAYS_TRUE : ALWAYS_FALSE,
effect_free ? ALWAYS_TRUE : ALWAYS_FALSE,
nothrow ? ALWAYS_TRUE : TRISTATE_UNKNOWN,
ALWAYS_TRUE)
#=terminates=#ALWAYS_TRUE,
#=overlayed=#false,
)
end

# TODO: this function is a very buggy and poor model of the return_type function
Expand Down
6 changes: 3 additions & 3 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ function rt_adjust_effects(@nospecialize(rt), ipo_effects::Effects)
# but we don't currently model idempontency using dataflow, so we don't notice.
# Fix that up here to improve precision.
if !ipo_effects.inbounds_taints_consistency && rt === Union{}
return Effects(ipo_effects, consistent=ALWAYS_TRUE)
return Effects(ipo_effects; consistent=ALWAYS_TRUE)
end
return ipo_effects
end
Expand Down Expand Up @@ -755,11 +755,11 @@ function merge_call_chain!(parent::InferenceState, ancestor::InferenceState, chi
# and ensure that walking the parent list will get the same result (DAG) from everywhere
# Also taint the termination effect, because we can no longer guarantee the absence
# of recursion.
tristate_merge!(parent, Effects(EFFECTS_TOTAL, terminates=TRISTATE_UNKNOWN))
tristate_merge!(parent, Effects(EFFECTS_TOTAL; terminates=TRISTATE_UNKNOWN))
while true
add_cycle_backedge!(child, parent, parent.currpc)
union_caller_cycle!(ancestor, child)
tristate_merge!(child, Effects(EFFECTS_TOTAL, terminates=TRISTATE_UNKNOWN))
tristate_merge!(child, Effects(EFFECTS_TOTAL; terminates=TRISTATE_UNKNOWN))
child = parent
child === ancestor && break
parent = child.parent::InferenceState
Expand Down
36 changes: 22 additions & 14 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct Effects
effect_free::TriState
nothrow::TriState
terminates::TriState
overlayed::Bool
# This effect is currently only tracked in inference and modified
# :consistent before caching. We may want to track it in the future.
inbounds_taints_consistency::Bool
Expand All @@ -46,27 +47,33 @@ function Effects(
consistent::TriState,
effect_free::TriState,
nothrow::TriState,
terminates::TriState)
terminates::TriState,
overlayed::Bool)
return Effects(
consistent,
effect_free,
nothrow,
terminates,
overlayed,
false)
end
Effects() = Effects(TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, TRISTATE_UNKNOWN)

function Effects(e::Effects;
const EFFECTS_TOTAL = Effects(ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, false)
const EFFECTS_UNKNOWN = Effects(TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, TRISTATE_UNKNOWN, true)

function Effects(e::Effects = EFFECTS_UNKNOWN;
consistent::TriState = e.consistent,
effect_free::TriState = e.effect_free,
nothrow::TriState = e.nothrow,
terminates::TriState = e.terminates,
overlayed::Bool = e.overlayed,
inbounds_taints_consistency::Bool = e.inbounds_taints_consistency)
return Effects(
consistent,
effect_free,
nothrow,
terminates,
overlayed,
inbounds_taints_consistency)
end

Expand All @@ -82,20 +89,20 @@ is_removable_if_unused(effects::Effects) =
effects.terminates === ALWAYS_TRUE &&
effects.nothrow === ALWAYS_TRUE

const EFFECTS_TOTAL = Effects(ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE, ALWAYS_TRUE)

function encode_effects(e::Effects)
return e.consistent.state |
(e.effect_free.state << 2) |
(e.nothrow.state << 4) |
(e.terminates.state << 6)
return (e.consistent.state << 1) |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is bigger than a UInt8 now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In #43899 I changed this to a UInt32, so you may want to pick up those pieces.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since e.overlayed is a Boolean property, I think this can be encoded within UInt8?

julia> let maxbits = 0x02
           (maxbits << 1) |
           (maxbits << 3) |
           (maxbits << 5) |
           (maxbits << 7) |
           true
       end |> Int
85

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, 4*2+1 is 9 bits. You're cutting of the top bit (in particular the top bit of the two bit unit you're shifting by 7).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Au, you're right. Maybe we can pick UInt16 instead though?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think UInt32 is fine. I have at least two more effects I think would be interesting, so with you're extra bit, and the one I already have in that PR, we'd be out again. It's not a huge amount of memory.

(e.effect_free.state << 3) |
(e.nothrow.state << 5) |
(e.terminates.state << 7) |
(e.overlayed)
end
function decode_effects(e::UInt8)
return Effects(
TriState(e & 0x3),
TriState((e >> 2) & 0x3),
TriState((e >> 4) & 0x3),
TriState((e >> 6) & 0x3),
TriState((e >> 1) & 0x03),
TriState((e >> 3) & 0x03),
TriState((e >> 5) & 0x03),
TriState((e >> 7) & 0x03),
e & 0x01 ≠ 0x00,
false)
end

Expand All @@ -109,6 +116,7 @@ function tristate_merge(old::Effects, new::Effects)
old.nothrow, new.nothrow),
tristate_merge(
old.terminates, new.terminates),
old.overlayed | new.overlayed,
old.inbounds_taints_consistency | new.inbounds_taints_consistency)
end

Expand Down Expand Up @@ -158,7 +166,7 @@ mutable struct InferenceResult
arginfo#=::Union{Nothing,Tuple{ArgInfo,InferenceState}}=# = nothing)
argtypes, overridden_by_const = matching_cache_argtypes(linfo, arginfo)
return new(linfo, argtypes, overridden_by_const, Any, nothing,
WorldRange(), Effects(), Effects(), nothing)
WorldRange(), Effects(; overlayed=false), Effects(; overlayed=false), nothing)
end
end

Expand Down
5 changes: 3 additions & 2 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1291,11 +1291,12 @@ function code_typed_opaque_closure(@nospecialize(closure::Core.OpaqueClosure);
end
end

function return_types(@nospecialize(f), @nospecialize(types=default_tt(f)), interp=Core.Compiler.NativeInterpreter())
function return_types(@nospecialize(f), @nospecialize(types=default_tt(f));
world = get_world_counter(),
interp = Core.Compiler.NativeInterpreter(world))
ccall(:jl_is_in_pure_context, Bool, ()) && error("code reflection cannot be used from generated functions")
types = to_tuple_type(types)
rt = []
world = get_world_counter()
for match in _methods(f, types, -1, world)::Vector
match = match::Core.MethodMatch
meth = func_for_method_checked(match.method, types, match.sparams)
Expand Down
54 changes: 41 additions & 13 deletions test/compiler/AbstractInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,53 @@ import Base.Experimental: @MethodTable, @overlay
@MethodTable(OverlayedMT)
CC.method_table(interp::MTOverlayInterp) = CC.OverlayMethodTable(CC.get_world_counter(interp), OverlayedMT)

@overlay OverlayedMT sin(x::Float64) = 1
@test Base.return_types((Int,), MTOverlayInterp()) do x
sin(x)
end == Any[Int]
@test Base.return_types((Any,), MTOverlayInterp()) do x
Base.@invoke sin(x::Float64)
end == Any[Int]
strangesin(x) = sin(x)
@overlay OverlayedMT strangesin(x::Float64) = iszero(x) ? nothing : cos(x)
@test Base.return_types((Float64,); interp=MTOverlayInterp()) do x
strangesin(x)
end |> only === Union{Float64,Nothing}
@test Base.return_types((Any,); interp=MTOverlayInterp()) do x
Base.@invoke strangesin(x::Float64)
end |> only === Union{Float64,Nothing}

# fallback to the internal method table
@test Base.return_types((Int,), MTOverlayInterp()) do x
@test Base.return_types((Int,); interp=MTOverlayInterp()) do x
cos(x)
end == Any[Float64]
@test Base.return_types((Any,), MTOverlayInterp()) do x
end |> only === Float64
@test Base.return_types((Any,); interp=MTOverlayInterp()) do x
Base.@invoke cos(x::Float64)
end == Any[Float64]
end |> only === Float64

# not fully covered overlay method match
overlay_match(::Any) = nothing
@overlay OverlayedMT overlay_match(::Int) = missing
@test Base.return_types((Any,), MTOverlayInterp()) do x
@test Base.return_types((Any,); interp=MTOverlayInterp()) do x
overlay_match(x)
end == Any[Union{Nothing,Missing}]
end |> only === Union{Nothing,Missing}

# partial pure/concrete evaluation
@test Base.return_types(; interp=MTOverlayInterp()) do
isbitstype(Int) ? nothing : missing
end |> only === Nothing
Base.@assume_effects :terminates_globally function issue41694(x)
res = 1
1 < x < 20 || throw("bad")
while x > 1
res *= x
x -= 1
end
return res
end
@test Base.return_types(; interp=MTOverlayInterp()) do
issue41694(3) == 6 ? nothing : missing
end |> only === Nothing

# disable partial pure/concrete evaluation when tainted by any overlayed call
Base.@assume_effects :total totalcall(f, args...) = f(args...)
@test Base.return_types(; interp=MTOverlayInterp()) do
if totalcall(strangesin, 1.0) == cos(1.0)
return nothing
else
return missing
end
end |> only === Nothing