Skip to content

Commit

Permalink
add native caching
Browse files Browse the repository at this point in the history
  • Loading branch information
collinwarner committed May 18, 2023
1 parent 456d4cb commit 1951087
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 51 deletions.
12 changes: 8 additions & 4 deletions src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ function cached_compilation(cache::AbstractDict{UInt,V},
key = hash(tt, key)
key = hash(world, key)
key = hash(cfg, key)

# NOTE: no use of lock(::Function)/@lock/get! to avoid try/catch and closure overhead
lock(cache_lock)
obj = get(cache, key, nothing)
Expand All @@ -36,6 +35,7 @@ function cached_compilation(cache::AbstractDict{UInt,V},
if obj === nothing || compile_hook[] !== nothing
obj = actual_compilation(cache, key, cfg, ft, tt, compiler, linker)::V
end

return obj::V
end

Expand All @@ -45,10 +45,14 @@ end
src = methodinstance(ft, tt)
job = CompilerJob(src, cfg)

global_cache = ci_cache(job)
asm = nothing
# TODO: consider loading the assembly from an on-disk cache here

# compile
# read asm from persistent offline cache
if haskey(global_cache.asm, src)
asm = global_cache.asm[src]
end

if asm === nothing
asm = compiler(job)
end
Expand All @@ -57,7 +61,7 @@ end
# in which case the cache will already be populated)
lock(cache_lock) do
haskey(cache, key) && return cache[key]

global_cache.asm[src] = asm
obj = linker(job, asm)
cache[key] = obj
obj
Expand Down
11 changes: 6 additions & 5 deletions src/jlgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,16 +255,19 @@ using Core.Compiler: CodeInstance, MethodInstance, InferenceParams, Optimization

struct CodeCache
dict::IdDict{MethodInstance,Vector{CodeInstance}}
asm::IdDict{MethodInstance, NamedTuple{(:image, :entry, :external_gvars), Tuple{Vector{UInt8}, String, Vector{String}}}}

CodeCache() = new(Dict{MethodInstance,Vector{CodeInstance}}())
CodeCache(cache::CodeCache) = new(GPUCompiler.copyAndFilter(cache.dict))
CodeCache() = new(Dict{MethodInstance,Vector{CodeInstance}}(),
Dict{MethodInstance, NamedTuple{(:image, :entry, :external_gvars), Tuple{Vector{UInt8}, String, Vector{String}}}}())

CodeCache(cache::CodeCache) = new(GPUCompiler.copyAndFilter(cache.dict), cache.asm)
end

function copyAndFilter(dict::IdDict)
out= IdDict()
for key in keys(dict)
useKey = true
# why is it an array of code instances, can there be more than 1?

for ci in dict[key]
if ci.max_world < typemax(typeof(ci.max_world))
useKey = false
Expand Down Expand Up @@ -590,7 +593,6 @@ end

function ci_cache_populate(interp, cache, mt, mi, min_world, max_world)
src = Core.Compiler.typeinf_ext_toplevel(interp, mi)

# inference populates the cache, so we don't need to jl_get_method_inferred
wvc = WorldView(cache, min_world, max_world)
@assert Core.Compiler.haskey(wvc, mi)
Expand Down Expand Up @@ -622,7 +624,6 @@ function ci_cache_lookup(cache, mi, min_world, max_world)
return ci
end


## interface

# for platforms without @cfunction-with-closure support
Expand Down
88 changes: 46 additions & 42 deletions src/precompilation_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,33 @@ function ci_cache_snapshot()
cleaned_cache_to_save = IdDict()
for key in keys(GPUCompiler.GLOBAL_CI_CACHES)
# Will only keep those elements with infinite ranges
# copy constructor
cleaned_cache_to_save[key] = GPUCompiler.CodeCache(GPUCompiler.GLOBAL_CI_CACHES[key])
end

return cleaned_cache_to_save
end

function ci_cache_delta(previous_snapshot)
current_snapshot = ci_cache_snapshot()
delta_snapshot = IdDict{Tuple{DataType, Core.Compiler.InferenceParams, Core.Compiler.OptimizationParams}, GPUCompiler.CodeCache}()
for (cachekey, codecache) in current_snapshot
for (cachekey, codecache) in current_snapshot # iterate through all caches
if cachekey in keys(previous_snapshot)
for (mi, civ) in codecache.dict
for (mi, civ) in codecache.dict # iterate through all mi
if mi in keys(previous_snapshot[cachekey].dict)
for ci in civ
if !(ci in previous_snapshot[cachekey].dict[mi])
if !(cachekey in keys(delta_snapshot))
delta_snapshot[cachekey] = GPUCompiler.CodeCache()
delta_snapshot[cachekey].dict[mi] = Vector{CodeInstance}()
if haskey(codecache.asm, mi)
delta_snapshot[cachekey].asm[mi] = codecache.asm[mi]
end
elseif !(mi in keys(delta_snapshot[cachekey].dict))
delta_snapshot[cachekey].dict[mi] = Vector{CodeInstance}()
if haskey(codecache.asm, mi)
delta_snapshot[cachekey].asm[mi] = codecache.asm[mi]
end
end

push!(delta_snapshot[cachekey].dict[mi], ci)
Expand All @@ -36,66 +44,45 @@ function ci_cache_delta(previous_snapshot)
if !(cachekey in keys(delta_snapshot))
delta_snapshot[cachekey] = GPUCompiler.CodeCache()
end

if haskey(codecache.asm, mi)
delta_snapshot[cachekey].asm[mi] = codecache.asm[mi]
end
delta_snapshot[cachekey].dict[mi] = civ
end
end
else
delta_snapshot[cachekey] = current_snapshot[cachekey]
end
end

return delta_snapshot
end

function print_keys(caches)
println("************")
for (key, cache) in caches
for (mi, civ) in cache.dict
println("$mi -> $(length(civ))")
end
end
println("************")
end
function ci_cache_insert(cache)
if !is_precompiling()
#first clean the cache
cleaned_cache = IdDict()
for (key, c) in cache
usedCache = false
newCodeCache = GPUCompiler.CodeCache()
for (mi, civ) in c.dict
new_civ = Vector()
for ci in civ
if ci.min_world <= ci.max_world
push!(new_civ, ci)
end
end
if length(new_civ) > 0
usedCache = true
newCodeCache.dict[mi] = new_civ
end
end
if usedCache
cleaned_cache[key] = newCodeCache
end
end

# need to merge caches at the code instance level
for (key, local_cache) in cleaned_cache
for (key, local_cache) in cache
if haskey(GPUCompiler.GLOBAL_CI_CACHES, key)
global_cache = GPUCompiler.GLOBAL_CI_CACHES[key]
#local_cache = cache[key]
for (mi, civ) in (local_cache.dict)
# this should be one since there is only one range that is infinite
@assert length(civ) == 1
# add all code instances to global cache
# could move truncating code to set index
ci = civ[1]
if haskey(global_cache.dict, mi)
gciv = global_cache.dict[mi]
# truncation cod3
# sort by min world age, then make sure no age ranges overlap // this part is uneeded
sort(gciv, by=x->x.min_world)
if ci.min_world > gciv[length(gciv)].min_world
invalidate_code_cache(global_cache, mi, ci.min_world - 1)
Core.Compiler.setindex!(global_cache, ci, mi)
else
println("Should not get here?")
@assert false
end
else
# occurs if we kill everything in the parent and then need to store in child
Core.Compiler.setindex!(global_cache, ci, mi)
Core.Compiler.setindex!(global_cache, civ[1], mi)
#@assert haskey(local_cache.asm, mi)
if haskey(local_cache.asm, mi)
global_cache.asm[mi] = local_cache.asm[mi]
end
end
else
Expand All @@ -118,3 +105,20 @@ function precompile_gpucompiler(job)
GPUCompiler.ci_cache_populate(interp, cache, mt, job.source, job.world, typemax(Cint))
end
end

"""
Generate a precompile file for the current state of the cache
"""
function generate_precompilation_file(snapshot, filename, precompilation_function)
method_instances = []
for (cachekey, cache) in snapshot
for (mi, civ) in cache.dict
push!(method_instances, mi)
end
end

precompile_statements = join(["$precompilation_function($(mi.specTypes.parameters[1]), Core.$(mi.specTypes.parameters[2:length(mi.specTypes.parameters)]))" for mi in method_instances], '\n')
open(filename, "w") do file
write(file, precompile_statements)
end
end

0 comments on commit 1951087

Please sign in to comment.