Skip to content

Commit

Permalink
wip: adjustments to the latest master
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Mar 19, 2024
1 parent 0f74c78 commit 9883872
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 13 deletions.
56 changes: 44 additions & 12 deletions src/stage2/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ CC.get_inference_cache(ei::ADInterpreter) = get_inference_cache(ei.native_interp
CC.lock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing
CC.unlock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing

@static if VERSION v"1.11.0-DEV.1552"
CC.cache_owner(ei::ADInterpreter) = ei.opt
end

function CC.code_cache(ei::ADInterpreter)
while ei.current_level > lastindex(ei.opt)
push!(ei.opt, Dict{MethodInstance, Any}())
Expand Down Expand Up @@ -291,21 +295,17 @@ function CC.finish(state::InferenceState, interp::ADInterpreter)
return res
end

function CC.transform_result_for_cache(interp::ADInterpreter,
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult)
return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects)
end

function CC.inlining_policy(interp::ADInterpreter,
@nospecialize(src), @nospecialize(info::CC.CallInfo),
stmt_flag::(@static VERSION ≥ v"1.11.0-DEV.377" ? UInt32 : UInt8),
mi::MethodInstance, argtypes::Vector{Any})
const StmtFlag = @static VERSION v"1.11.0-DEV.377" ? UInt32 : UInt8
function diffractor_inlining_policy(@nospecialize(src), @nospecialize(info::CC.CallInfo),
stmt_flag::StmtFlag)
# Disallow inlining things away that have an frule
if isa(info, FRuleCallInfo)
return nothing
end
if isa(src, CC.SemiConcreteResult)
return src
@static if VERSION < v"1.11.0-DEV.879"
if isa(src, CC.SemiConcreteResult)
return src
end
end
@assert isa(src, Cthulhu.OptimizedSource) || isnothing(src)
if isa(src, Cthulhu.OptimizedSource)
Expand All @@ -314,12 +314,44 @@ function CC.inlining_policy(interp::ADInterpreter,
end
return nothing
end
return missing
end

@static if VERSION v"1.12.0-DEV.45"
function CC.transform_result_for_cache(interp::ADInterpreter,
::MethodInstance, ::WorldRange, result::InferenceResult, ::Bool)
return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects)
end
function CC.src_inlining_policy(interp::ADInterpreter,
@nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::StmtFlag)
ret = diffractor_inlining_policy(src, info, stmt_flag)
ret === nothing && return false
ret !== missing && return true
return CC.src_inlining_policy(interp::AbstractInterpreter,
src::Any, info::CC.CallInfo, stmt_flag::StmtFlag)
end
CC.retrieve_ir_for_inlining(cached_result::CodeInstance, src::Cthulhu.OptimizedSource) =
CC.retrieve_ir_for_inlining(cached_result.def, src.ir, true)
CC.retrieve_ir_for_inlining(mi::MethodInstance, src::Cthulhu.OptimizedSource, preserve_local_sources::Bool) =
CC.retrieve_ir_for_inlining(mi, src.ir, preserve_local_sources)
else
function CC.transform_result_for_cache(interp::ADInterpreter,
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult)
return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects)
end
function CC.inlining_policy(interp::ADInterpreter,
@nospecialize(src), @nospecialize(info::CC.CallInfo), stmt_flag::StmtFlag,
mi::MethodInstance, argtypes::Vector{Any})
ret = diffractor_inlining_policy(src, info, stmt_flag)
ret === nothing && return nothing
ret !== missing && return ret
# the default inlining policy may try additional effor to find the source in a local cache
return @invoke CC.inlining_policy(interp::AbstractInterpreter,
nothing, info::CC.CallInfo,
stmt_flag::(@static VERSION ≥ v"1.11.0-DEV.377" ? UInt32 : UInt8),
stmt_flag::StmtFlag,
mi::MethodInstance, argtypes::Vector{Any})
end
end

#=
function CC.optimize(interp::ADInterpreter, opt::OptimizationState,
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const bwd = Diffractor.PrimeDerivativeBack
@testset verbose=true "Diffractor.jl" begin # overall testset, ensures all tests run

@testset "$file" for file in (
"extra_rules.jl"
"extra_rules.jl",
"stage2_fwd.jl",
"tangent.jl",
"forward_diff_no_inf.jl",
Expand Down

0 comments on commit 9883872

Please sign in to comment.