diff --git a/base/boot.jl b/base/boot.jl index b84f9b76838e0..218e4e2e533b1 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -957,7 +957,6 @@ function _hasmethod(@nospecialize(tt)) # this function has a special tfunc return Intrinsics.not_int(ccall(:jl_gf_invoke_lookup, Any, (Any, Any, UInt), tt, nothing, world) === nothing) end - # for backward compat arrayref(inbounds::Bool, A::Array, i::Int...) = Main.Base.getindex(A, i...) const_arrayref(inbounds::Bool, A::Array, i::Int...) = Main.Base.getindex(A, i...) @@ -969,4 +968,6 @@ export arrayref, arrayset, arraysize, const_arrayref # For convenience EnterNode(old::EnterNode, new_dest::Int) = EnterNode(new_dest) +include(Core, "optimized_generics.jl") + ccall(:jl_set_istopmod, Cvoid, (Any, Bool), Core, true) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 8b3ae55fe7429..f280b11532c19 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -6,6 +6,13 @@ function is_known_call(@nospecialize(x), @nospecialize(func), ir::Union{IRCode,I return singleton_type(ft) === func end +function is_known_invoke_or_call(@nospecialize(x), @nospecialize(func), ir::Union{IRCode,IncrementalCompact}) + isinvoke = isexpr(x, :invoke) + (isinvoke || isexpr(x, :call)) || return false + ft = argextype(x.args[isinvoke ? 2 : 1], ir) + return singleton_type(ft) === func +end + struct SSAUse kind::Symbol idx::Int @@ -819,6 +826,76 @@ function lift_svec_ref!(compact::IncrementalCompact, idx::Int, stmt::Expr) return end +function lift_leaves_keyvalue(compact::IncrementalCompact, @nospecialize(key), + leaves::Vector{Any}, 𝕃ₒ::AbstractLattice) + # For every leaf, the lifted value + lifted_leaves = LiftedLeaves() + for i = 1:length(leaves) + leaf = leaves[i] + cache_key = leaf + if isa(leaf, AnySSAValue) + (def, leaf) = walk_to_def(compact, leaf) + if is_known_invoke_or_call(def, Core.OptimizedGenerics.KeyValue.set, compact) + @assert isexpr(def, :invoke) + if length(def.args) in (5, 6) + collection = def.args[end-2] + set_key = def.args[end-1] + set_val_idx = length(def.args) + elseif length(def.args) == 4 + collection = def.args[end-1] + # Key is deleted + # TODO: Model this + return nothing + elseif length(def.args) == 3 + collection = def.args[end] + # The whole collection is deleted + # TODO: Model this + return nothing + else + return nothing + end + if set_key === key || (egal_tfunc(𝕃ₒ, argextype(key, compact), argextype(set_key, compact)) == Const(true)) + lift_arg!(compact, leaf, cache_key, def, set_val_idx, lifted_leaves) + continue + end + # TODO: Continue walking the chain + return nothing + end + end + return nothing + end + return lifted_leaves +end + +function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr, 𝕃ₒ::AbstractLattice) + collection = stmt.args[end-1] + key = stmt.args[end] + + leaves, visited_philikes = collect_leaves(compact, collection, Any, 𝕃ₒ, phi_or_ifelse_predecessors) + isempty(leaves) && return + + lifted_leaves = lift_leaves_keyvalue(compact, key, leaves, 𝕃ₒ) + lifted_leaves === nothing && return + + result_t = Union{} + for v in values(lifted_leaves) + v === nothing && return + result_t = tmerge(𝕃ₒ, result_t, argextype(v.val, compact)) + end + + lifted_val = perform_lifting!(compact, + visited_philikes, key, result_t, lifted_leaves, collection, nothing) + + compact[idx] = lifted_val === nothing ? nothing : Expr(:call, Core.tuple, lifted_val.val) + if lifted_val !== nothing + if !⊑(𝕃ₒ, compact[SSAValue(idx)][:type], result_t) + compact[SSAValue(idx)][:flag] |= IR_FLAG_REFINED + end + end + + return +end + # TODO: We could do the whole lifing machinery here, but really all # we want to do is clean this up when it got inserted by inlining, # which always targets simple `svec` call or `_compute_sparams`, @@ -1004,7 +1081,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing) for ((_, idx), stmt) in compact # check whether this statement is `getfield` / `setfield!` (or other "interesting" statement) isa(stmt, Expr) || continue - is_setfield = is_isdefined = is_finalizer = false + is_setfield = is_isdefined = is_finalizer = is_keyvalue_get = false field_ordering = :unspecified if is_known_call(stmt, setfield!, compact) 4 <= length(stmt.args) <= 5 || continue @@ -1094,6 +1171,9 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing) lift_comparison!(isa, compact, idx, stmt, 𝕃ₒ) elseif is_known_call(stmt, Core.ifelse, compact) fold_ifelse!(compact, idx, stmt) + elseif is_known_invoke_or_call(stmt, Core.OptimizedGenerics.KeyValue.get, compact) + 2 == (length(stmt.args) - (isexpr(stmt, :invoke) ? 2 : 1)) || continue + lift_keyvalue_get!(compact, idx, stmt, 𝕃ₒ) elseif isexpr(stmt, :new) refine_new_effects!(𝕃ₒ, compact, idx, stmt) end diff --git a/base/dict.jl b/base/dict.jl index 83180f5c0ee1b..768f8215946b8 100644 --- a/base/dict.jl +++ b/base/dict.jl @@ -887,10 +887,35 @@ _similar_for(c::AbstractDict, ::Type{T}, itr, isz, len) where {T} = include("hamt.jl") using .HashArrayMappedTries +using Core.OptimizedGenerics: KeyValue const HAMT = HashArrayMappedTries struct PersistentDict{K,V} <: AbstractDict{K,V} trie::HAMT.HAMT{K,V} + # Serves as a marker for an empty initialization + @noinline function KeyValue.set(::Type{PersistentDict{K, V}}) where {K, V} + new{K, V}(HAMT.HAMT{K,V}()) + end + @noinline function KeyValue.set(::Type{PersistentDict{K, V}}, ::Nothing, key, val) where {K, V} + new{K, V}(HAMT.HAMT{K, V}(key => val)) + end + @noinline function KeyValue.set(dict::PersistentDict{K, V}, key, val) where {K, V} + trie = dict.trie + h = HAMT.HashState(key) + found, present, trie, i, bi, top, hs = HAMT.path(trie, key, h, #=persistent=# true) + HAMT.insert!(found, present, trie, i, bi, hs, val) + return new{K, V}(top) + end + @noinline function KeyValue.set(dict::PersistentDict{K, V}, key) where {K, V} + trie = dict.trie + h = HAMT.HashState(key) + found, present, trie, i, bi, top, _ = HAMT.path(trie, key, h, #=persistent=# true) + if found && present + deleteat!(trie.data, i) + HAMT.unset!(trie, bi) + end + return new{K, V}(top) + end end """ @@ -925,19 +950,27 @@ Base.PersistentDict{Symbol, Int64} with 1 entry: """ PersistentDict -PersistentDict{K,V}() where {K,V} = PersistentDict(HAMT.HAMT{K,V}()) -PersistentDict{K,V}(KV::Pair) where {K,V} = PersistentDict(HAMT.HAMT{K,V}(KV)) -PersistentDict(KV::Pair{K,V}) where {K,V} = PersistentDict(HAMT.HAMT{K,V}(KV)) +PersistentDict{K,V}() where {K, V} = KeyValue.set(PersistentDict{K,V}) +function PersistentDict{K,V}(KV::Pair) where {K,V} + KeyValue.set( + PersistentDict{K, V}, + nothing, + KV...) +end +function PersistentDict(KV::Pair{K,V}) where {K,V} + KeyValue.set( + PersistentDict{K, V}, + nothing, + KV...) +end PersistentDict(dict::PersistentDict, pair::Pair) = PersistentDict(dict, pair...) PersistentDict{K,V}(dict::PersistentDict{K,V}, pair::Pair) where {K,V} = PersistentDict(dict, pair...) + + function PersistentDict(dict::PersistentDict{K,V}, key, val) where {K,V} key = convert(K, key) val = convert(V, val) - trie = dict.trie - h = HAMT.HashState(key) - found, present, trie, i, bi, top, hs = HAMT.path(trie, key, h, #=persistent=# true) - HAMT.insert!(found, present, trie, i, bi, hs, val) - return PersistentDict(top) + return KeyValue.set(dict, key, val) end function PersistentDict{K,V}(KV::Pair, rest::Pair...) where {K,V} @@ -959,84 +992,60 @@ end eltype(::PersistentDict{K,V}) where {K,V} = Pair{K,V} function in(key_val::Pair{K,V}, dict::PersistentDict{K,V}, valcmp=(==)) where {K,V} - trie = dict.trie - if HAMT.islevel_empty(trie) - return false - end - key, val = key_val - - h = HAMT.HashState(key) - found, present, trie, i, _, _, _ = HAMT.path(trie, key, h) - if found && present - leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V} - return valcmp(val, leaf.val) && return true - end - return false + found = KeyValue.get(dict, key) + found === nothing && return false + return valcmp(val, only(found)) end function haskey(dict::PersistentDict{K}, key::K) where K - trie = dict.trie - h = HAMT.HashState(key) - found, present, _, _, _, _, _ = HAMT.path(trie, key, h) - return found && present + return KeyValue.get(dict, key) !== nothing end function getindex(dict::PersistentDict{K,V}, key::K) where {K,V} - trie = dict.trie - if HAMT.islevel_empty(trie) - throw(KeyError(key)) - end - h = HAMT.HashState(key) - found, present, trie, i, _, _, _ = HAMT.path(trie, key, h) - if found && present - leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V} - return leaf.val - end - throw(KeyError(key)) + found = KeyValue.get(dict, key) + found === nothing && throw(KeyError(key)) + return only(found) end function get(dict::PersistentDict{K,V}, key::K, default) where {K,V} - trie = dict.trie - if HAMT.islevel_empty(trie) - return default - end - h = HAMT.HashState(key) - found, present, trie, i, _, _, _ = HAMT.path(trie, key, h) - if found && present - leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V} - return leaf.val - end - return default + found = KeyValue.get(dict, key) + found === nothing && return default + return only(found) end -function get(default::Callable, dict::PersistentDict{K,V}, key::K) where {K,V} +@noinline function KeyValue.get(dict::PersistentDict{K, V}, key) where {K, V} trie = dict.trie if HAMT.islevel_empty(trie) - return default + return nothing end h = HAMT.HashState(key) found, present, trie, i, _, _, _ = HAMT.path(trie, key, h) if found && present leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V} - return leaf.val + return (leaf.val,) end - return default() + return nothing end -iterate(dict::PersistentDict, state=nothing) = HAMT.iterate(dict.trie, state) +@noinline function KeyValue.get(default, dict::PersistentDict, key) + found = KeyValue.get(dict, key) + found === nothing && return default() + return only(found) +end + +function get(default::Callable, dict::PersistentDict{K,V}, key::K) where {K,V} + found = KeyValue.get(dict, key) + found === nothing && return default() + return only(found) +end function delete(dict::PersistentDict{K}, key::K) where K - trie = dict.trie - h = HAMT.HashState(key) - found, present, trie, i, bi, top, _ = HAMT.path(trie, key, h, #=persistent=# true) - if found && present - deleteat!(trie.data, i) - HAMT.unset!(trie, bi) - end - return PersistentDict(top) + return KeyValue.set(dict, key) end +iterate(dict::PersistentDict, state=nothing) = HAMT.iterate(dict.trie, state) + length(dict::PersistentDict) = HAMT.length(dict.trie) isempty(dict::PersistentDict) = HAMT.isempty(dict.trie) empty(::PersistentDict, ::Type{K}, ::Type{V}) where {K, V} = PersistentDict{K, V}() diff --git a/base/hamt.jl b/base/hamt.jl index e940f4e00b1d5..fc298b8b7a338 100644 --- a/base/hamt.jl +++ b/base/hamt.jl @@ -65,12 +65,18 @@ mutable struct HAMT{K, V} HAMT{K,V}(data, bitmap) where {K,V} = new{K,V}(data, bitmap) HAMT{K, V}() where {K, V} = new{K,V}(Vector{Union{Leaf{K, V}, HAMT{K, V}}}(undef, 0), zero(BITMAP)) end -function HAMT{K,V}((k,v)::Pair) where {K, V} - k = convert(K, k) - v = convert(V, v) + +@Base.assume_effects :nothrow function init_hamt(K, V, k, v) # For a single element we can't have a hash-collision trie = HAMT{K,V}(Vector{Union{Leaf{K, V}, HAMT{K, V}}}(undef, 1), zero(BITMAP)) trie.data[1] = Leaf{K,V}(k,v) + return trie +end + +function HAMT{K,V}((k,v)::Pair) where {K, V} + k = convert(K, k) + v = convert(V, v) + trie = init_hamt(K, V, k, v) bi = BitmapIndex(HashState(k)) set!(trie, bi) return trie diff --git a/base/optimized_generics.jl b/base/optimized_generics.jl new file mode 100644 index 0000000000000..86b54a294564d --- /dev/null +++ b/base/optimized_generics.jl @@ -0,0 +1,57 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +module OptimizedGenerics + +# This file defines interfaces that are recognized and optimized by the compiler +# They are intended to be used by data structure implementations that wish to +# opt into some level of compiler optimizations. These interfaces are +# EXPERIMENTAL and currently intended for use by Base only. They are subject +# to change or removal without notice. It is undefined behavior to add methods +# to these generics that do not conform to the specified interface. +# +# The intended way to use these generics is that data structures will provide +# appropriate implementations for a generic. In the absence of compiler +# optimizations, these behave like regular methods. However, the compiler is +# semantically allowed to perform certain structural optimizations on +# appropriate combinations of these intrinsics without proving correctness. + +# Compiler-recognized generics for immutable key-value stores (dicts, etc.) +""" + module KeyValue + +Implements a key-value like interface where the compiler has liberty to perform +the following transformations. The core optimization semantically allowed for +the compiler is: + + get(set(x, key, val), key) -> (val,) + +where the compiler will recursively look through `x`. Keys are compared by +egality. + +Implementations must observe the following constraints: + +1. It is undefined behavior for `get` not to return the exact (by egality) val + stored for a given `key`. +""" +module KeyValue + """ + set(collection, [key [, val]]) + set(T, collection, key, val) + + Set the `key` in `collection` to `val`. If `val` is omitted, deletes the + value from the collection. If `key` is omitted as well, deletes all elements + of the collection. + """ + function set end + + """ + get(collection, key) + + Retrieve the value corresponding to `key` in `collection` as a single + element tuple or `nothing` if no value corresponding to the key was found. + `key`s are compared by egal. + """ + function get end +end + +end diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 204d0400ea701..fe5263059b0ba 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -1616,3 +1616,13 @@ let m = Meta.@lower 1 + 1 end # JET.test_opt(Core.Compiler.cfg_simplify!, (Core.Compiler.IRCode,)) + +# Test support for Core.OptimizedGenerics.KeyValue protocol +function persistent_dict_elim() + a = Base.PersistentDict(:a => 1) + return a[:a] +end +# Ideally we would be able to fully eliminate this, +# but currently this would require an extra round of constprop +@test_broken fully_eliminated(persistent_dict_elim) +@test code_typed(persistent_dict_elim)[1][1].code[end] == Core.ReturnNode(1)