Skip to content

Commit

Permalink
wip: make Cthulhu cache CodeInstance-based
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Dec 19, 2024
1 parent 2a935da commit c1c6407
Show file tree
Hide file tree
Showing 11 changed files with 275 additions and 302 deletions.
172 changes: 83 additions & 89 deletions src/Cthulhu.jl

Large diffs are not rendered by default.

102 changes: 48 additions & 54 deletions src/callsite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,30 @@ using Unicode
abstract type CallInfo end

# Call could be resolved to a singular MI
struct MICallInfo <: CallInfo
mi::MethodInstance
struct EdgeCallInfo <: CallInfo
ci::CodeInstance
rt
effects::Effects
exct
function MICallInfo(mi::MethodInstance, @nospecialize(rt), effects, @nospecialize(exct=nothing))
function EdgeCallInfo(ci::CodeInstance, @nospecialize(rt), effects::Effects, @nospecialize(exct=nothing))
if isa(rt, LimitedAccuracy)
return LimitedCallInfo(new(mi, ignorelimited(rt), effects, exct))
return LimitedCallInfo(new(ci, ignorelimited(rt), effects, exct))
else
return new(mi, rt, effects, exct)
return new(ci, rt, effects, exct)
end
end
end
get_mi(ci::MICallInfo) = ci.mi
get_rt(ci::MICallInfo) = ci.rt
get_effects(ci::MICallInfo) = ci.effects
get_exct(ci::MICallInfo) = ci.exct
get_ci(ci::EdgeCallInfo) = ci.ci
get_rt(ci::EdgeCallInfo) = ci.rt
get_effects(ci::EdgeCallInfo) = ci.effects
get_exct(ci::EdgeCallInfo) = ci.exct

abstract type WrappedCallInfo <: CallInfo end

get_wrapped(ci::WrappedCallInfo) = ci.wrapped
ignorewrappers(ci::CallInfo) = ci
ignorewrappers(ci::WrappedCallInfo) = ignorewrappers(get_wrapped(ci))
get_mi(ci::WrappedCallInfo) = get_mi(ignorewrappers(ci))
get_ci(ci::WrappedCallInfo) = get_ci(ignorewrappers(ci))
get_rt(ci::WrappedCallInfo) = get_rt(ignorewrappers(ci))
get_effects(ci::WrappedCallInfo) = get_effects(ignorewrappers(ci))
get_exct(ci::WrappedCallInfo) = get_exct(ignorewrappers(ci))
Expand All @@ -44,22 +44,17 @@ struct RTCallInfo <: CallInfo
exct
end
get_rt(ci::RTCallInfo) = ci.rt
get_mi(ci::RTCallInfo) = nothing
get_ci(ci::RTCallInfo) = nothing
get_effects(ci::RTCallInfo) = Effects()
get_exct(ci::RTCallInfo) = ci.exct

# uncached callsite, we can't recurse into this call
struct UncachedCallInfo <: WrappedCallInfo
wrapped::CallInfo
end

struct PureCallInfo <: CallInfo
argtypes::Vector{Any}
rt
PureCallInfo(argtypes::Vector{Any}, @nospecialize(rt)) =
new(argtypes, rt)
end
get_mi(::PureCallInfo) = nothing
get_ci(::PureCallInfo) = nothing
get_rt(pci::PureCallInfo) = pci.rt
get_effects(::PureCallInfo) = EFFECTS_TOTAL
get_exct(::PureCallInfo) = Union{}
Expand All @@ -69,7 +64,7 @@ struct FailedCallInfo <: CallInfo
sig
rt
end
get_mi(ci::FailedCallInfo) = fail(ci)
get_ci(ci::FailedCallInfo) = fail(ci)
get_rt(ci::FailedCallInfo) = fail(ci)
get_effects(ci::FailedCallInfo) = fail(ci)
get_exct(ci::FailedCallInfo) = fail(ci)
Expand All @@ -83,7 +78,7 @@ struct GeneratedCallInfo <: CallInfo
sig
rt
end
get_mi(genci::GeneratedCallInfo) = fail(genci)
get_ci(genci::GeneratedCallInfo) = fail(genci)
get_rt(genci::GeneratedCallInfo) = fail(genci)
get_effects(genci::GeneratedCallInfo) = fail(genci)
get_exct(genci::GeneratedCallInfo) = fail(genci)
Expand All @@ -101,15 +96,15 @@ struct MultiCallInfo <: CallInfo
@nospecialize(exct=nothing)) =
new(sig, rt, exct, callinfos)
end
get_mi(ci::MultiCallInfo) = error("Can't extract MethodInstance from multiple call informations")
get_ci(ci::MultiCallInfo) = error("Can't extract MethodInstance from multiple call informations")
get_rt(ci::MultiCallInfo) = ci.rt
get_effects(mci::MultiCallInfo) = mapreduce(get_effects, CC.merge_effects, mci.callinfos)
get_exct(ci::MultiCallInfo) = ci.exct

struct TaskCallInfo <: CallInfo
ci::CallInfo
end
get_mi(tci::TaskCallInfo) = get_mi(tci.ci)
get_ci(tci::TaskCallInfo) = get_ci(tci.ci)
get_rt(tci::TaskCallInfo) = get_rt(tci.ci)
get_effects(tci::TaskCallInfo) = get_effects(tci.ci)
get_exct(tci::TaskCallInfo) = get_exct(tci.ci)
Expand All @@ -118,7 +113,7 @@ struct InvokeCallInfo <: CallInfo
ci::CallInfo
InvokeCallInfo(@nospecialize ci::CallInfo) = new(ci)
end
get_mi(ici::InvokeCallInfo) = get_mi(ici.ci)
get_ci(ici::InvokeCallInfo) = get_ci(ici.ci)
get_rt(ici::InvokeCallInfo) = get_rt(ici.ci)
get_effects(ici::InvokeCallInfo) = get_effects(ici.ci)
get_exct(ici::InvokeCallInfo) = get_exct(ici.ci)
Expand All @@ -128,7 +123,7 @@ struct OCCallInfo <: CallInfo
ci::CallInfo
OCCallInfo(@nospecialize ci::CallInfo) = new(ci)
end
get_mi(occi::OCCallInfo) = get_mi(occi.ci)
get_ci(occi::OCCallInfo) = get_ci(occi.ci)
get_rt(occi::OCCallInfo) = get_rt(occi.ci)
get_effects(occi::OCCallInfo) = get_effects(occi.ci)
get_exct(occi::OCCallInfo) = get_exct(occi.ci)
Expand All @@ -137,52 +132,52 @@ get_exct(occi::OCCallInfo) = get_exct(occi.ci)
struct ReturnTypeCallInfo <: CallInfo
vmi::CallInfo # virtualized method call
end
get_mi((; vmi)::ReturnTypeCallInfo) = isa(vmi, FailedCallInfo) ? nothing : get_mi(vmi)
get_ci((; vmi)::ReturnTypeCallInfo) = isa(vmi, FailedCallInfo) ? nothing : get_ci(vmi)
get_rt((; vmi)::ReturnTypeCallInfo) = Type{isa(vmi, FailedCallInfo) ? Union{} : widenconst(get_rt(vmi))}
get_effects(::ReturnTypeCallInfo) = EFFECTS_TOTAL
get_exct(::ReturnTypeCallInfo) = Union{} # FIXME

struct ConstPropCallInfo <: CallInfo
mi::CallInfo
ci::CallInfo
result::InferenceResult
end
get_mi(cpci::ConstPropCallInfo) = cpci.result.linfo
get_rt(cpci::ConstPropCallInfo) = get_rt(cpci.mi)
get_ci(cpci::ConstPropCallInfo) = get_ci(cpci.ci)
get_rt(cpci::ConstPropCallInfo) = get_rt(cpci.ci)
get_effects(cpci::ConstPropCallInfo) = get_effects(cpci.result)
get_exct(cpci::ConstPropCallInfo) = get_exct(cpci.mi)
get_exct(cpci::ConstPropCallInfo) = get_exct(cpci.ci)

struct ConcreteCallInfo <: CallInfo
mi::CallInfo
ci::CallInfo
argtypes::ArgTypes
end
get_mi(ceci::ConcreteCallInfo) = get_mi(ceci.mi)
get_rt(ceci::ConcreteCallInfo) = get_rt(ceci.mi)
get_effects(ceci::ConcreteCallInfo) = get_effects(ceci.mi)
get_exct(cici::ConcreteCallInfo) = get_exct(ceci.mi)
get_ci(ceci::ConcreteCallInfo) = get_ci(ceci.ci)
get_rt(ceci::ConcreteCallInfo) = get_rt(ceci.ci)
get_effects(ceci::ConcreteCallInfo) = get_effects(ceci.ci)
get_exct(cici::ConcreteCallInfo) = get_exct(ceci.ci)

struct SemiConcreteCallInfo <: CallInfo
mi::CallInfo
ci::CallInfo
ir::IRCode
end
get_mi(scci::SemiConcreteCallInfo) = get_mi(scci.mi)
get_rt(scci::SemiConcreteCallInfo) = get_rt(scci.mi)
get_effects(scci::SemiConcreteCallInfo) = get_effects(scci.mi)
get_exct(scci::SemiConcreteCallInfo) = get_exct(scci.mi)
get_ci(scci::SemiConcreteCallInfo) = get_ci(scci.ci)
get_rt(scci::SemiConcreteCallInfo) = get_rt(scci.ci)
get_effects(scci::SemiConcreteCallInfo) = get_effects(scci.ci)
get_exct(scci::SemiConcreteCallInfo) = get_exct(scci.ci)

# CUDA callsite
struct CuCallInfo <: CallInfo
cumi::MICallInfo
ci::EdgeCallInfo
end
get_mi(gci::CuCallInfo) = get_mi(gci.cumi)
get_rt(gci::CuCallInfo) = get_rt(gci.cumi)
get_effects(gci::CuCallInfo) = get_effects(gci.cumi)
get_ci(gci::CuCallInfo) = get_ci(gci.ci)
get_rt(gci::CuCallInfo) = get_rt(gci.ci)
get_effects(gci::CuCallInfo) = get_effects(gci.ci)

struct Callsite
id::Int # ssa-id
info::CallInfo
head::Symbol
end
get_mi(c::Callsite) = get_mi(c.info)
get_ci(c::Callsite) = get_ci(c.info)
get_effects(c::Callsite) = get_effects(c.info)

# Callsite printing
Expand Down Expand Up @@ -277,17 +272,17 @@ function Base.show(io::IO, (;exct)::ExctWrapper)
printstyled(io, "(↑::", exct, ")"; color)
end

function show_callinfo(limiter, mici::MICallInfo)
mi = mici.mi
function show_callinfo(limiter, ci::EdgeCallInfo)
mi = ci.ci.def
tt = (Base.unwrap_unionall(mi.specTypes)::DataType).parameters[2:end]
if !isa(mi.def, Method)
name = ":toplevel"
else
name = mi.def.name
end
rt = get_rt(mici)
exct = get_exct(mici)
__show_limited(limiter, name, tt, rt, get_effects(mici), exct)
rt = get_rt(ci)
exct = get_exct(ci)
__show_limited(limiter, name, tt, rt, get_effects(ci), exct)
end

function show_callinfo(limiter, ci::Union{MultiCallInfo, FailedCallInfo, GeneratedCallInfo})
Expand Down Expand Up @@ -317,20 +312,20 @@ function show_callinfo(limiter, ci::ConstPropCallInfo)
# XXX: The first argument could be const-overriden too
name = ci.result.linfo.def.name
tt = ci.result.argtypes[2:end]
ci = ignorewrappers(ci.mi)::MICallInfo
ci = ignorewrappers(ci.ci)::EdgeCallInfo
__show_limited(limiter, name, tt, get_rt(ci), get_effects(ci))
end

function show_callinfo(limiter, ci::SemiConcreteCallInfo)
# XXX: The first argument could be const-overriden too
name = get_mi(ci).def.name
name = get_ci(ci).def.def.name
tt = ci.ir.argtypes[2:end]
__show_limited(limiter, name, tt, get_rt(ci), get_effects(ci))
end

function show_callinfo(limiter, ci::ConcreteCallInfo)
# XXX: The first argument could be const-overriden too
name = get_mi(ci).def.name
name = get_ci(ci).def.def.name
tt = ci.argtypes[2:end]
__show_limited(limiter, name, tt, get_rt(ci), get_effects(ci))
end
Expand Down Expand Up @@ -435,7 +430,7 @@ function Base.show(io::IO, c::Callsite)
limiter = TextWidthLimiter(io, cols)
limiter.width += 1 # for the '%' character
print(limiter, string(c.id))
if isa(info, MICallInfo)
if isa(info, EdgeCallInfo)
print(limiter, optimize ? string(" = ", c.head, ' ') : " = ")
show_callinfo(limiter, info)
else
Expand All @@ -457,7 +452,6 @@ function wrapped_callinfo(limiter, ci::WrappedCallInfo)
print(limiter, " > ")
end
_wrapped_callinfo(limiter, ::LimitedCallInfo) = print(limiter, "limited")
_wrapped_callinfo(limiter, ::UncachedCallInfo) = print(limiter, "uncached")

# is_callsite returns true if `call` dispatches to `callee`
# See also `maybe_callsite` below
Expand Down Expand Up @@ -527,7 +521,7 @@ function maybe_callsite(info::RTCallInfo, @nospecialize(tt::Type))
end
return true
end
function maybe_callsite(info::MICallInfo, @nospecialize(tt::Type))
function maybe_callsite(info::EdgeCallInfo, @nospecialize(tt::Type))
return tt <: info.mi.specTypes
end

Expand Down
22 changes: 8 additions & 14 deletions src/codeview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,29 +117,21 @@ is_type_unstable(@nospecialize(type)) = type isa Type && (!Base.isdispatchelem(t

cthulhu_warntype(args...; kwargs...) = cthulhu_warntype(stdout::IO, args...; kwargs...)
function cthulhu_warntype(io::IO, debuginfo::AnyDebugInfo,
src::Union{CodeInfo,IRCode}, @nospecialize(rt), effects::Effects, mi::Union{Nothing,MethodInstance}=nothing;
src::Union{CodeInfo,IRCode}, @nospecialize(rt), effects::Effects, codeinst::Union{Nothing,CodeInstance}=nothing;
hide_type_stable::Bool=false, inline_cost::Bool=false, optimize::Bool=false,
interp::CthulhuInterpreter=CthulhuInterpreter())
if inline_cost
isa(mi, MethodInstance) || error("Need a MethodInstance to show inlining costs. Call `cthulhu_typed` directly instead.")
end
cthulhu_typed(io, debuginfo, src, rt, nothing, effects, mi; iswarn=true, optimize, hide_type_stable, inline_cost, interp)
cthulhu_typed(io, debuginfo, src, rt, nothing, effects, codeinst; iswarn=true, optimize, hide_type_stable, inline_cost, interp)
return nothing
end

# # for API consistency with the others
# function cthulhu_typed(io::IO, mi::MethodInstance, optimize, debuginfo, params, config::CthulhuConfig)
# interp = mkinterp(mi)
# (; src, rt, infos, slottypes) = lookup(interp, mi, optimize)
# ci = Cthulhu.preprocess_ci!(src, mi, optimize, config)
# cthulhu_typed(io, debuginfo, src, rt, mi)
# end

cthulhu_typed(io::IO, debuginfo::DebugInfo, args...; kwargs...) =
cthulhu_typed(io, Symbol(debuginfo), args...; kwargs...)
function cthulhu_typed(io::IO, debuginfo::Symbol,
src::Union{CodeInfo,IRCode}, @nospecialize(rt), @nospecialize(exct),
effects::Effects, mi::Union{Nothing,MethodInstance};
effects::Effects, codeinst::Union{Nothing,CodeInstance};
iswarn::Bool=false, hide_type_stable::Bool=false, optimize::Bool=true,
pc2remarks::Union{Nothing,PC2Remarks}=nothing,
pc2effects::Union{Nothing,PC2Effects}=nothing,
Expand All @@ -148,6 +140,8 @@ function cthulhu_typed(io::IO, debuginfo::Symbol,
inlay_types_vscode::Bool=false, diagnostics_vscode::Bool=false, jump_always::Bool=false,
interp::AbstractInterpreter=CthulhuInterpreter())

mi = codeinst === nothing ? nothing : codeinst.def

debuginfo = IRShow.debuginfo(debuginfo)
lineprinter = __debuginfo[debuginfo]
rettype = ignorelimited(rt)
Expand Down Expand Up @@ -316,11 +310,11 @@ function cthulhu_typed(io::IO, debuginfo::Symbol,
end
println(lambda_io)
else
isa(mi, MethodInstance) || throw("`mi::MethodInstance` is required")
isa(codeinst, CodeInstance) || throw("`codeinst::CodeInstance` is required")
cfg = src isa IRCode ? src.cfg : CC.compute_basic_blocks(src.code)
max_bb_idx_size = length(string(length(cfg.blocks)))
str = irshow_config.line_info_preprinter(lambda_io, " "^(max_bb_idx_size + 2), -1)
callsite = Callsite(0, MICallInfo(mi, rettype, effects, exct), :invoke)
callsite = Callsite(0, EdgeCallInfo(codeinst, rettype, effects, exct), :invoke)
println(lambda_io, "", ""^(max_bb_idx_size), str, " ", callsite)
end

Expand Down Expand Up @@ -459,7 +453,7 @@ function Base.show(
(; interp, mi) = b
(; effects) = lookup(interp, mi, optimize)
if get(io, :typeinfo, Any) === Bookmark # a hack to check if in Vector etc.
print(io, Callsite(-1, MICallInfo(b.mi, rt, Effects()), :invoke))
print(io, Callsite(-1, EdgeCallInfo(b.mi, rt, Effects()), :invoke))
print(io, " (world: ", world, ")")
return
end
Expand Down
Loading

0 comments on commit c1c6407

Please sign in to comment.