Skip to content

Commit

Permalink
RFC: Compiler support for optimizing PersistentDict
Browse files Browse the repository at this point in the history
This is part of the work to address #51352 by attempting to allow
the compiler to perform SRAO on persistent data structures like
`PersistentDict` as if they were regular immutable data structures.
These sorts of data structures have very complicated internals
(with lots of mutation, memory sharing, etc.), but a relatively
simple interface. As such, it is unlikely that our compiler will
have sufficient power to optimize this interface by analyzing
the implementation.

We thus need to come up with some other mechanism that gives the
compiler license to perform the requisite optimization. One way
would be to just hardcode `PersistentDict` into the compiler,
optimizing it like any of the other builtin datatypes. However,
this is of course very unsatisfying. At the other end of the
spectrum would be something like a generic rewrite rule system
(e-graphs anyone?) that would let the PersistentDict
implementation declare its interface to the compiler and the
compiler would use this for optimization (in a perfect world,
the actual rewrite would then be checked using some sort of
formal methods). I think that would be interesting, but we're
very far from even being able to design something like that
(at least in Base - experiments with external AbstractInterpreters
in this direction are encouraged).

This PR tries to come up with a reasonable middle ground, where
the compiler gets some knowledge of the protocol hardcoded without
having to know about the implementation details of the data structure.

The basic ideas is that `Core` provides some magic generic functions
that implementations can extend. Semantically, they are not special.
They dispatch as usual, and implementations are expected to work
properly even in the absence of any compiler optimizations.

However, the compiler is semantically permitted to perform structural
optimization using these magic generic functions. In the concrete
case, this PR introduces the `KeyValue` interface which consists
of two generic functions, `get` and `set`. The core optimization
is that the compiler is allowed to rewrite any occurrence of
`get(set(x, k, v), k)` into `v` without additional legality checks.
In particular, the compiler performs no type checks, conversions, etc.
The higher level implementation code is expected to do all that.

This approach closely matches the general direction we've been taking
in external AbstractInterpreters for embedding additional semantics
and optimization opportunities into Julia code (although we generally
use methods there, rather than full generic functions), so I think
we have some evidence that this sort of approach works reasonably well.

Nevertheless, this is certainly an experiment and the interface is
explicitly declared unstable.

## Current Status

This is fully working and implemented, but the optimization currently
bails on anything but the simplest cases. Filling all those cases in
is not particularly hard, but should be done along with a more invasive
refactoring of SROA, so we should figure out the general direction
here first and then we can finish all that up in a follow-up cleanup.

## Obligatory benchmark
Before:
```
julia> using BenchmarkTools

julia> function foo()
           a = Base.PersistentDict(:a => 1)
           return a[:a]
       end
foo (generic function with 1 method)

julia> @benchmark foo()
BenchmarkTools.Trial: 10000 samples with 993 evaluations.
 Range (min … max):  32.940 ns …  28.754 μs  ┊ GC (min … max):  0.00% … 99.76%
 Time  (median):     49.647 ns               ┊ GC (median):     0.00%
 Time  (mean ± σ):   57.519 ns ± 333.275 ns  ┊ GC (mean ± σ):  10.81% ±  2.22%

        ▃█▅               ▁▃▅▅▃▁                ▁▃▂   ▂
  ▁▂▄▃▅▇███▇▃▁▂▁▁▁▁▁▁▁▁▂▂▅██████▅▂▁▁▁▁▁▁▁▁▁▁▂▃▃▇███▇▆███▆▄▃▃▂▂ ▃
  32.9 ns         Histogram: frequency by time         68.6 ns <

 Memory estimate: 128 bytes, allocs estimate: 4.

julia> @code_typed foo()
CodeInfo(
1 ─ %1  = invoke Vector{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}}(Base.HashArrayMappedTries.undef::UndefInitializer, 1::Int64)::Vector{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}}
│   %2  = %new(Base.HashArrayMappedTries.HAMT{Symbol, Int64}, %1, 0x00000000)::Base.HashArrayMappedTries.HAMT{Symbol, Int64}
│   %3  = %new(Base.HashArrayMappedTries.Leaf{Symbol, Int64}, :a, 1)::Base.HashArrayMappedTries.Leaf{Symbol, Int64}
│   %4  = Base.getfield(%2, :data)::Vector{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}}
│   %5  = $(Expr(:boundscheck, true))::Bool
└──       goto #5 if not %5
2 ─ %7  = Base.sub_int(1, 1)::Int64
│   %8  = Base.bitcast(UInt64, %7)::UInt64
│   %9  = Base.getfield(%4, :size)::Tuple{Int64}
│   %10 = $(Expr(:boundscheck, true))::Bool
│   %11 = Base.getfield(%9, 1, %10)::Int64
│   %12 = Base.bitcast(UInt64, %11)::UInt64
│   %13 = Base.ult_int(%8, %12)::Bool
└──       goto #4 if not %13
3 ─       goto #5
4 ─ %16 = Core.tuple(1)::Tuple{Int64}
│         invoke Base.throw_boundserror(%4::Vector{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}}, %16::Tuple{Int64})::Union{}
└──       unreachable
5 ┄ %19 = Base.getfield(%4, :ref)::MemoryRef{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}}
│   %20 = Base.memoryref(%19, 1, false)::MemoryRef{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}}
│         Base.memoryrefset!(%20, %3, :not_atomic, false)::MemoryRef{Union{Base.HashArrayMappedTries.HAMT{Symbol, Int64}, Base.HashArrayMappedTries.Leaf{Symbol, Int64}}}
└──       goto #6
6 ─ %23 = Base.getfield(%2, :bitmap)::UInt32
│   %24 = Base.or_int(%23, 0x00010000)::UInt32
│         Base.setfield!(%2, :bitmap, %24)::UInt32
└──       goto #7
7 ─ %27 = %new(Base.PersistentDict{Symbol, Int64}, %2)::Base.PersistentDict{Symbol, Int64}
└──       goto #8
8 ─ %29 = invoke Base.getindex(%27::Base.PersistentDict{Symbol, Int64}, 🅰️:Symbol)::Int64
└──       return %29
```

After:
```
julia> using BenchmarkTools

julia> function foo()
           a = Base.PersistentDict(:a => 1)
           return a[:a]
       end
foo (generic function with 1 method)

julia> @benchmark foo()
BenchmarkTools.Trial: 10000 samples with 1000 evaluations.
 Range (min … max):  2.459 ns … 11.320 ns  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     2.460 ns              ┊ GC (median):    0.00%
 Time  (mean ± σ):   2.469 ns ±  0.183 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▂    █                                              ▁    █ ▂
  █▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁█ █
  2.46 ns      Histogram: log(frequency) by time     2.47 ns <

 Memory estimate: 0 bytes, allocs estimate: 0.

julia> @code_typed foo()
CodeInfo(
1 ─     return 1
```
  • Loading branch information
Keno committed Nov 2, 2023
1 parent 3a6c418 commit f5e821f
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 61 deletions.
2 changes: 2 additions & 0 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -959,4 +959,6 @@ arraysize(a::Array) = a.size
arraysize(a::Array, i::Int) = sle_int(i, nfields(a.size)) ? getfield(a.size, i) : 1
export arrayref, arrayset, arraysize, const_arrayref

include(Core, "optimized_generics.jl")

ccall(:jl_set_istopmod, Cvoid, (Any, Bool), Core, true)
87 changes: 86 additions & 1 deletion base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@ function is_known_call(@nospecialize(x), @nospecialize(func), ir::Union{IRCode,I
return singleton_type(ft) === func
end

function is_known_invoke_or_call(@nospecialize(x), @nospecialize(func), ir::Union{IRCode,IncrementalCompact})
isinvoke = isexpr(x, :invoke)
(isinvoke || isexpr(x, :call)) || return false
ft = argextype(x.args[isinvoke ? 2 : 1], ir)
return singleton_type(ft) === func
end

struct SSAUse
kind::Symbol
idx::Int
Expand Down Expand Up @@ -817,6 +824,81 @@ function lift_svec_ref!(compact::IncrementalCompact, idx::Int, stmt::Expr)
return
end

function lift_leaves_keyvalue(compact::IncrementalCompact, @nospecialize(key),
leaves::Vector{Any}, 𝕃ₒ::AbstractLattice)
# For every leaf, the lifted value
lifted_leaves = LiftedLeaves()
maybe_undef = false
for i = 1:length(leaves)
leaf = leaves[i]
cache_key = leaf
if isa(leaf, AnySSAValue)
(def, leaf) = walk_to_def(compact, leaf)
if is_known_invoke_or_call(def, Core.OptimizedGenerics.KeyValue.set, compact)
@assert isexpr(def, :invoke)
if length(def.args) in (5, 6)
collection = def.args[end-2]
set_key = def.args[end-1]
set_val_idx = length(def.args)
elseif length(def.args) == 4
collection = def.args[end-1]
# Key is deleted
# TODO: Model this
return nothing
elseif length(def.args) == 3
collection = def.args[end]
# The whole collection is deleted
# TODO: Model this
return nothing
else
return nothing
end
if set_key === key || (egal_tfunc(𝕃ₒ, argextype(key, compact), argextype(set_key, compact)) == Const(true))
lift_arg!(compact, leaf, cache_key, def, set_val_idx, lifted_leaves)
continue
end
# TODO: Continue walking the chain
return nothing
end
end
return nothing
end
return lifted_leaves, maybe_undef
end

function lift_keyvalue_get!(compact::IncrementalCompact, idx::Int, stmt::Expr, 𝕃ₒ::AbstractLattice)
# TODO: Support variants with callbacks
#length(stmt.args) == 4 || return

collection = stmt.args[end-1]
key = stmt.args[end]

leaves, visited_philikes = collect_leaves(compact, collection, Any, 𝕃ₒ, phi_or_ifelse_predecessors)
isempty(leaves) && return

lifted_result = lift_leaves_keyvalue(compact, key, leaves, 𝕃ₒ)
lifted_result === nothing && return
lifted_leaves, any_undef = lifted_result

result_t = Union{}
for v in values(lifted_leaves)
v === nothing && return
result_t = tmerge(𝕃ₒ, result_t, argextype(v.val, compact))
end

lifted_val = perform_lifting!(compact,
visited_philikes, key, result_t, lifted_leaves, collection, nothing)

compact[idx] = lifted_val === nothing ? nothing : lifted_val.val
if lifted_val !== nothing
if !(𝕃ₒ, compact[SSAValue(idx)][:type], result_t)
compact[SSAValue(idx)][:flag] |= IR_FLAG_REFINED
end
end

return
end

# TODO: We could do the whole lifing machinery here, but really all
# we want to do is clean this up when it got inserted by inlining,
# which always targets simple `svec` call or `_compute_sparams`,
Expand Down Expand Up @@ -1002,7 +1084,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
for ((_, idx), stmt) in compact
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
isa(stmt, Expr) || continue
is_setfield = is_isdefined = is_finalizer = false
is_setfield = is_isdefined = is_finalizer = is_keyvalue_get = false
field_ordering = :unspecified
if is_known_call(stmt, setfield!, compact)
4 <= length(stmt.args) <= 5 || continue
Expand Down Expand Up @@ -1092,6 +1174,9 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
lift_comparison!(isa, compact, idx, stmt, 𝕃ₒ)
elseif is_known_call(stmt, Core.ifelse, compact)
fold_ifelse!(compact, idx, stmt)
elseif is_known_invoke_or_call(stmt, Core.OptimizedGenerics.KeyValue.get, compact)
4 <= length(stmt.args) <= 6 || continue
lift_keyvalue_get!(compact, idx, stmt, 𝕃ₒ)
elseif isexpr(stmt, :new)
refine_new_effects!(𝕃ₒ, compact, idx, stmt)
end
Expand Down
115 changes: 59 additions & 56 deletions base/dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -889,10 +889,35 @@ _similar_for(c::AbstractDict, ::Type{T}, itr, isz, len) where {T} =

include("hamt.jl")
using .HashArrayMappedTries
using Core.OptimizedGenerics: KeyValue
const HAMT = HashArrayMappedTries

struct PersistentDict{K,V} <: AbstractDict{K,V}
trie::HAMT.HAMT{K,V}
# Serves as a marker for an empty initialization
@noinline function KeyValue.set(::Type{PersistentDict{K, V}}) where {K, V}
new{K, V}(HAMT.HAMT{K,V}())
end
@noinline function KeyValue.set(::Type{PersistentDict{K, V}}, ::Nothing, key, val) where {K, V}
new{K, V}(HAMT.HAMT{K, V}(key, val))
end
@noinline function KeyValue.set(dict::PersistentDict{K, V}, key, val) where {K, V}
trie = dict.trie
h = hash(key)
found, present, trie, i, bi, top, hs = HAMT.path(trie, key, h, #=persistent=# true)
HAMT.insert!(found, present, trie, i, bi, hs, val)
return new{K, V}(top)
end
@noinline function KeyValue.set(dict::PersistentDict{K, V}, key) where {K, V}
trie = dict.trie
h = hash(key)
found, present, trie, i, bi, top, _ = HAMT.path(trie, key, h, #=persistent=# true)
if found && present
deleteat!(trie.data, i)
HAMT.unset!(trie, bi)
end
return new{K, V}(top)
end
end

"""
Expand Down Expand Up @@ -923,19 +948,27 @@ Base.PersistentDict{Symbol, Int64} with 1 entry:
"""
PersistentDict

PersistentDict{K,V}() where {K,V} = PersistentDict(HAMT.HAMT{K,V}())
PersistentDict{K,V}(KV::Pair) where {K,V} = PersistentDict(HAMT.HAMT{K,V}(KV...))
PersistentDict(KV::Pair{K,V}) where {K,V} = PersistentDict(HAMT.HAMT{K,V}(KV...))
PersistentDict{K,V}() where {K, V} = KeyValue.set(PersistentDict{K,V})
function PersistentDict{K,V}(KV::Pair) where {K,V}
KeyValue.set(
PersistentDict{K, V},
nothing,
KV...)
end
function PersistentDict(KV::Pair{K,V}) where {K,V}
KeyValue.set(
PersistentDict{K, V},
nothing,
KV...)
end
PersistentDict(dict::PersistentDict, pair::Pair) = PersistentDict(dict, pair...)
PersistentDict{K,V}(dict::PersistentDict{K,V}, pair::Pair) where {K,V} = PersistentDict(dict, pair...)


function PersistentDict(dict::PersistentDict{K,V}, key, val) where {K,V}
key = convert(K, key)
val = convert(V, val)
trie = dict.trie
h = hash(key)
found, present, trie, i, bi, top, hs = HAMT.path(trie, key, h, #=persistent=# true)
HAMT.insert!(found, present, trie, i, bi, hs, val)
return PersistentDict(top)
return KeyValue.set(dict, key, val)
end

function PersistentDict(kv::Pair, rest::Pair...)
Expand All @@ -950,84 +983,54 @@ end
eltype(::PersistentDict{K,V}) where {K,V} = Pair{K,V}

function in(key_val::Pair{K,V}, dict::PersistentDict{K,V}, valcmp=(==)) where {K,V}
trie = dict.trie
if HAMT.islevel_empty(trie)
return false
end

key, val = key_val

h = hash(key)
found, present, trie, i, _, _, _ = HAMT.path(trie, key, h)
if found && present
leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V}
return valcmp(val, leaf.val) && return true
end
return false
return KeyValue.get(found->valcmp(val, found), ()->false, dict, key)
end

function haskey(dict::PersistentDict{K}, key::K) where K
trie = dict.trie
h = hash(key)
found, present, _, _, _, _, _ = HAMT.path(trie, key, h)
return found && present
return KeyValue.get(_->true, ()->false, dict, key)
end

function getindex(dict::PersistentDict{K,V}, key::K) where {K,V}
trie = dict.trie
if HAMT.islevel_empty(trie)
throw(KeyError(key))
return KeyValue.get(dict, key) do
return throw(KeyError(key))
end
h = hash(key)
found, present, trie, i, _, _, _ = HAMT.path(trie, key, h)
if found && present
leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V}
return leaf.val
end
throw(KeyError(key))
end

function get(dict::PersistentDict{K,V}, key::K, default) where {K,V}
trie = dict.trie
if HAMT.islevel_empty(trie)
return KeyValue.get(dict, key) do
return default
end
h = hash(key)
found, present, trie, i, _, _, _ = HAMT.path(trie, key, h)
if found && present
leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V}
return leaf.val
end
return default
end

function get(default::Callable, dict::PersistentDict{K,V}, key::K) where {K,V}
@noinline function KeyValue.get(transform, default, dict::PersistentDict{K, V}, key) where {K, V}
trie = dict.trie
if HAMT.islevel_empty(trie)
return default
return default()
end
h = hash(key)
found, present, trie, i, _, _, _ = HAMT.path(trie, key, h)
if found && present
leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V}
return leaf.val
return transform(leaf.val)
end
return default()
end

iterate(dict::PersistentDict, state=nothing) = HAMT.iterate(dict.trie, state)
@noinline function KeyValue.get(default, dict::PersistentDict, key)
KeyValue.get(identity, default, dict, key)
end

function get(default::Callable, dict::PersistentDict{K,V}, key::K) where {K,V}
return KeyValue.get(default, dict, key)
end

function delete(dict::PersistentDict{K}, key::K) where K
trie = dict.trie
h = hash(key)
found, present, trie, i, bi, top, _ = HAMT.path(trie, key, h, #=persistent=# true)
if found && present
deleteat!(trie.data, i)
HAMT.unset!(trie, bi)
end
return PersistentDict(top)
return KeyValue.set(dict, key)
end

iterate(dict::PersistentDict, state=nothing) = HAMT.iterate(dict.trie, state)

length(dict::PersistentDict) = HAMT.length(dict.trie)
isempty(dict::PersistentDict) = HAMT.isempty(dict.trie)
empty(::PersistentDict, ::Type{K}, ::Type{V}) where {K, V} = PersistentDict{K, V}()
12 changes: 9 additions & 3 deletions base/hamt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,18 @@ mutable struct HAMT{K, V}
bitmap::BITMAP
end
HAMT{K, V}() where {K, V} = HAMT(Vector{Union{Leaf{K, V}, HAMT{K, V}}}(undef, 0), zero(BITMAP))
function HAMT{K,V}(k::K, v) where {K, V}
v = convert(V, v)

@Base.assume_effects :nothrow function init_hamt(K, V, k, v)
# For a single element we can't have a hash-collision
trie = HAMT(Vector{Union{Leaf{K, V}, HAMT{K, V}}}(undef, 1), zero(BITMAP))
trie.data[1] = Leaf{K,V}(k,v)
@inbounds trie.data[1] = Leaf{K,V}(k,v)
return trie
end

function HAMT{K,V}(k::K, v) where {K, V}
v = convert(V, v)
bi = BitmapIndex(HashState(k))
trie = init_hamt(K, V, k, v)
set!(trie, bi)
return trie
end
Expand Down
58 changes: 58 additions & 0 deletions base/optimized_generics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

module OptimizedGenerics

# This file defines interfaces that are recognized and optimized by the compiler
# They are intended to be used by data structure implementations that wish to
# opt into some level of compiler optimizations. These interfaces are
# EXPERIMENTAL and currently intended for use by Base only. They are subject
# to change or removal without notice. It is undefined behavior to add methods
# to these generics that do not conform to the specified interface.
#
# The intended way to use these generics is that data structures will provide
# appropriate implementations for a generic. In the absence of compiler
# optimizations, these behave like regular methods. However, the compiler is
# semantically allowed to perform certain structural optimizations on
# appropriate combinations of these intrinsics without proving correctness.

# Compiler-recognized generics for immutable key-value stores (dicts, etc.)
"""
module KeyValue
Implements a key-value like interface where the compiler has liberty to perform
the following transformations. The core optimization semantically allowed for
the compiler is:
get(set(x, key, val), key) -> val
where the compiler will recursively look through `x`. Keys are compared by
egality.
Implementations must observe the following constraints:
1. It is undefined behavior for `get` not to return the exact (by egality) val
stored for a given `key`.
"""
module KeyValue
"""
set(collection, [key [, val]])
set(T, collection, key, val)
Set the `key` in `collection` to `val`. If `val` is omitted, deletes the
value from the collection. If `key` is omitted as well, deletes all elements
of the collection.
"""
function set end

"""
get([[transform,] default,] collection, key, val)
Retrieve the value corresponding to `key` in `collection`. Optionally takes
a `default` callback that is executed if `key` is not found and a `transform`
callback that is executed only if the value is found (but not on the return)
value of `default`.
"""
function get end
end

end
2 changes: 1 addition & 1 deletion base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ function objectid(x)
return _objectid(x)
end
function _foldable_objectid(@nospecialize(x))
@_foldable_meta
@_total_meta
_objectid(x)
end
_objectid(@nospecialize(x)) = ccall(:jl_object_id, UInt, (Any,), x)
Expand Down
7 changes: 7 additions & 0 deletions test/compiler/irpasses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1568,3 +1568,10 @@ let m = Meta.@lower 1 + 1

Core.Compiler.verify_ir(ir)
end

# Test support for Core.OptimizedGenerics.KeyValue protocol
function persistent_dict_elim()
a = Base.PersistentDict(:a => 1)
return a[:a]
end
@test fully_eliminated(persistent_dict_elim)

0 comments on commit f5e821f

Please sign in to comment.