diff --git a/Project.toml b/Project.toml index d5c7a6d4b..f87d29aa5 100644 --- a/Project.toml +++ b/Project.toml @@ -27,8 +27,9 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" MethodAnalysis = "85b6ec6f-f7df-4429-9514-a64bcd9ee824" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AbstractTrees", "ColorTypes", "Documenter", "FixedPointNumbers", "InteractiveUtils", "MethodAnalysis", "Pkg", "Test"] +test = ["AbstractTrees", "ColorTypes", "Documenter", "FixedPointNumbers", "InteractiveUtils", "MethodAnalysis", "Random", "Pkg", "Test"] diff --git a/src/parcel_snoopi.jl b/src/parcel_snoopi.jl index ddd038ab2..30a3bb121 100644 --- a/src/parcel_snoopi.jl +++ b/src/parcel_snoopi.jl @@ -157,6 +157,9 @@ tuplestring(params) = isempty(params) ? "()" : '(' * join(params, ',') * ",)" wrap_precompile(ttstr::AbstractString) = "Base.precompile(" * ttstr * ')' # use `Base.` to avoid conflict with Core and Pkg +append_time(str, ::Nothing) = str +append_time(str, t::AbstractFloat) = str * " # time: " * string(Float32(t)) + """ add_if_evals!(pclist, mod::Module, fstr, params, tt; prefix = "", check_eval::Bool=true) @@ -164,15 +167,16 @@ Adds the precompilation statements only if they can be evaled. It uses [`can_eva In some cases, you may want to bypass this function by passing `check_eval=true` to increase the snooping performance. """ -function add_if_evals!(pclist, mod::Module, fstr, params, tt; prefix = "", check_eval::Bool=true) +function add_if_evals!(pclist, mod::Module, fstr, params, tt; prefix = "", check_eval::Bool=true, time=nothing) ttstr = tupletypestring(fstr, params) can, exc = can_eval(mod, ttstr, check_eval) if can - push!(pclist, prefix*wrap_precompile(ttstr)) + push!(pclist, append_time(prefix*wrap_precompile(ttstr), time)) + return true else @debug "Module $mod: skipping $tt due to eval failure" exception=exc _module=mod _file="precompile_$mod.jl" end - return pclist + return false end function reprcontext(mod::Module, @nospecialize(T::Type)) @@ -188,6 +192,24 @@ function reprcontext(mod::Module, @nospecialize(T::Type)) end end +function known_type(mod::Module, @nospecialize(T::Type)) + # First check whether supplying module context allows evaluation + rplain = repr(T; context=:module=>mod) + try + ex = Meta.parse(rplain) + Core.eval(mod, ex) + return true + catch + # Add full module context + try + Core.eval(mod, Meta.parse(repr(T; context=:module=>nothing))) + return true + catch + end + end + return false +end + function handle_kwbody(topmod::Module, m::Method, paramrepr, tt, fstr="fbody"; check_eval = true) nameparent = Symbol(match(r"^#([^#]*)#", String(m.name)).captures[1]) if !isdefined(m.module, nameparent) # TODO: replace debugging with error-handling @@ -238,7 +260,7 @@ function parcel(tinf::AbstractVector{Tuple{Float64, Core.MethodInstance}}; modgens = Dict{Module, Vector{Method}}() # methods with generators in a module mods = OrderedSet{Module}() # module of each parameter for a given method sym_module = Dict{Symbol, Module}() # 1-1 association between modules and module name - for (t, mi) in reverse(tinf) + for (_, mi) in reverse(tinf) isdefined(mi, :specTypes) || continue tt = mi.specTypes m = mi.def @@ -268,78 +290,7 @@ function parcel(tinf::AbstractVector{Tuple{Float64, Core.MethodInstance}}; Core.eval(topmod, lookup_kwbody_ex) end end - # Create the string representation of the signature - # Use special care with keyword functions, anonymous functions - p = tt.parameters[1] # the portion of the signature related to the function itself - paramrepr = map(T->reprcontext(topmod, T), Iterators.drop(tt.parameters, 1)) # all the rest of the args - - if any(str->occursin('#', str), paramrepr) - @debug "Skipping $tt due to argument types having anonymous bindings" - continue - end - mname, mmod = String(p.name.name), m.module # m.name strips the kw identifier - mkw = match(kwrex, mname) - mkwbody = match(kwbodyrex, mname) - isgen = match(genrex, mname) !== nothing - isanon = match(anonrex, mname) !== nothing || match(innerrex, mname) !== nothing - isgen && (mkwbody = nothing) - if VERSION < v"1.4.0-DEV.215" # before this version, we can't robustly look up kwbody callers (missing `nkw`) - isanon |= mkwbody !== nothing # treat kwbody methods the same way we treat anonymous functions - mkwbody = nothing - end - if mkw !== nothing - # Keyword function - fname = mkw.captures[1] === nothing ? mkw.captures[2] : mkw.captures[1] - fkw = "Core.kwftype(typeof($mmod.$fname))" - add_if_evals!(pc[topmodname], topmod, fkw, paramrepr, tt; check_eval=check_eval) - elseif mkwbody !== nothing - ret = handle_kwbody(topmod, m, paramrepr, tt; check_eval = check_eval) - if ret !== nothing - push!(pc[topmodname], ret) - end - elseif isgen - # Generator for a @generated function - if !haskey(modgens, m.module) - callers = modgens[m.module] = methods_with_generators(m.module) - else - callers = modgens[m.module] - end - for caller in callers - if nameof(caller.generator.gen) == m.name - # determine whether the generator is being called from a kwbody method - sig = Base.unwrap_unionall(caller.sig) - cname, cmod = String(sig.parameters[1].name.name), caller.module - cparamrepr = map(repr, Iterators.drop(sig.parameters, 1)) - csigstr = tuplestring(cparamrepr) - mkwc = match(kwbodyrex, cname) - if mkwc === nothing - getgen = "typeof(which($cmod.$(caller.name),$csigstr).generator.gen)" - add_if_evals!(pc[topmodname], topmod, getgen, paramrepr, tt; check_eval=check_eval) - else - if VERSION >= v"1.4.0-DEV.215" - getgen = "which(Core.kwfunc($cmod.$(mkwc.captures[1])),$csigstr).generator.gen" - ret = handle_kwbody(topmod, caller, cparamrepr, tt; check_eval = check_eval) #, getgen) - if ret !== nothing - push!(pc[topmodname], ret) - end - else - # Bail and treat as if anonymous - prefix = "isdefined($mmod, Symbol(\"$mname\")) && " - fstr = "getfield($mmod, Symbol(\"$mname\"))" # this is universal, var is Julia 1.3+ - add_if_evals!(pc[topmodname], topmod, fstr, paramrepr, tt; prefix=prefix, check_eval=check_eval) - end - end - break - end - end - elseif isanon - # Anonymous function, wrap in an `isdefined` - prefix = "isdefined($mmod, Symbol(\"$mname\")) && " - fstr = "getfield($mmod, Symbol(\"$mname\"))" # this is universal, var is Julia 1.3+ - add_if_evals!(pc[topmodname], topmod, fstr, paramrepr, tt; prefix=prefix, check_eval = check_eval) - else - add_if_evals!(pc[topmodname], topmod, reprcontext(topmod, p), paramrepr, tt, check_eval = check_eval) - end + add_repr!(pc[topmodname], modgens, mi, topmod; check_eval=check_eval) end # loop over the output @@ -352,6 +303,84 @@ function parcel(tinf::AbstractVector{Tuple{Float64, Core.MethodInstance}}; return Dict(mod=>collect(lines) for (mod, lines) in pc) # convert Set to Array before return end +function add_repr!(list, modgens::Dict{Module, Vector{Method}}, mi::MethodInstance, topmod::Module=mi.def.module; check_eval::Bool, time=nothing) + # Create the string representation of the signature + # Use special care with keyword functions, anonymous functions + tt = mi.specTypes + m = mi.def + p = tt.parameters[1] # the portion of the signature related to the function itself + paramrepr = map(T->reprcontext(topmod, T), Iterators.drop(tt.parameters, 1)) # all the rest of the args + + if any(str->occursin('#', str), paramrepr) + @debug "Skipping $tt due to argument types having anonymous bindings" + return false + end + mname, mmod = String(p.name.name), m.module # m.name strips the kw identifier + mkw = match(kwrex, mname) + mkwbody = match(kwbodyrex, mname) + isgen = match(genrex, mname) !== nothing + isanon = match(anonrex, mname) !== nothing || match(innerrex, mname) !== nothing + isgen && (mkwbody = nothing) + if VERSION < v"1.4.0-DEV.215" # before this version, we can't robustly look up kwbody callers (missing `nkw`) + isanon |= mkwbody !== nothing # treat kwbody methods the same way we treat anonymous functions + mkwbody = nothing + end + if mkw !== nothing + # Keyword function + fname = mkw.captures[1] === nothing ? mkw.captures[2] : mkw.captures[1] + fkw = "Core.kwftype(typeof($mmod.$fname))" + return add_if_evals!(list, topmod, fkw, paramrepr, tt; check_eval=check_eval, time=time) + elseif mkwbody !== nothing + ret = handle_kwbody(topmod, m, paramrepr, tt; check_eval = check_eval) + if ret !== nothing + push!(list, append_time(ret, time)) + return true + end + elseif isgen + # Generator for a @generated function + if !haskey(modgens, m.module) + callers = modgens[m.module] = methods_with_generators(m.module) + else + callers = modgens[m.module] + end + for caller in callers + if nameof(caller.generator.gen) == m.name + # determine whether the generator is being called from a kwbody method + sig = Base.unwrap_unionall(caller.sig) + cname, cmod = String(sig.parameters[1].name.name), caller.module + cparamrepr = map(repr, Iterators.drop(sig.parameters, 1)) + csigstr = tuplestring(cparamrepr) + mkwc = match(kwbodyrex, cname) + if mkwc === nothing + getgen = "typeof(which($cmod.$(caller.name),$csigstr).generator.gen)" + return add_if_evals!(list, topmod, getgen, paramrepr, tt; check_eval=check_eval, time=time) + else + if VERSION >= v"1.4.0-DEV.215" + getgen = "which(Core.kwfunc($cmod.$(mkwc.captures[1])),$csigstr).generator.gen" + ret = handle_kwbody(topmod, caller, cparamrepr, tt; check_eval = check_eval) #, getgen) + if ret !== nothing + push!(list, append_time(ret, time)) + return true + end + else + # Bail and treat as if anonymous + prefix = "isdefined($mmod, Symbol(\"$mname\")) && " + fstr = "getfield($mmod, Symbol(\"$mname\"))" # this is universal, var is Julia 1.3+ + return add_if_evals!(list, topmod, fstr, paramrepr, tt; prefix=prefix, check_eval=check_eval, time=time) + end + end + break + end + end + elseif isanon + # Anonymous function, wrap in an `isdefined` + prefix = "isdefined($mmod, Symbol(\"$mname\")) && " + fstr = "getfield($mmod, Symbol(\"$mname\"))" # this is universal, var is Julia 1.3+ + return add_if_evals!(list, topmod, fstr, paramrepr, tt; prefix=prefix, check_eval = check_eval, time=time) + end + return add_if_evals!(list, topmod, reprcontext(topmod, p), paramrepr, tt, check_eval = check_eval, time=time) +end + """ exclusions_remover!(pcI, exclusions) diff --git a/src/parcel_snoopi_deep.jl b/src/parcel_snoopi_deep.jl index 27f59606e..397970ec7 100644 --- a/src/parcel_snoopi_deep.jl +++ b/src/parcel_snoopi_deep.jl @@ -97,6 +97,131 @@ function build_inclusive_times(t::Timing) return InclusiveTiming(t.mi_info, incl_time, t.start_time, child_times) end +struct Precompiles + mi_info::Core.Compiler.Timings.InferenceFrameInfo + total_time::UInt64 + precompiles::Vector{Tuple{UInt64,MethodInstance}} +end +Precompiles(it::InclusiveTiming) = Precompiles(it.mi_info, it.inclusive_time, Tuple{UInt64,MethodInstance}[]) + +inclusive_time(t::Precompiles) = t.total_time +precompilable_time(precompiles::Vector{Tuple{UInt64,MethodInstance}}) = sum(first, precompiles; init=zero(UInt64)) +precompilable_time(pc::Precompiles) = precompilable_time(pc.precompiles) + +function Base.show(io::IO, pc::Precompiles) + tpc = precompilable_time(pc) + print(io, "Precompiles: ", pc.total_time/10^9, " for ", pc.mi_info.mi, + " had ", length(pc.precompiles), " precompilable roots reclaiming ", tpc/10^9, + " ($(round(Int, 100*tpc/pc.total_time))%)") +end + +precompilable_roots(t::Timing) = precompilable_roots(build_inclusive_times(t)) +function precompilable_roots(t::InclusiveTiming) + pcs = [precompilable_roots!(Precompiles(it), it) for it in t.children] + tpc = precompilable_time.(pcs) + p = sortperm(tpc) + return pcs[p] +end + +function precompilable_roots!(pc, t::InclusiveTiming) + mi = t.mi_info.mi + m = mi.def + if isa(m, Method) + mod = m.module + params = Base.unwrap_unionall(mi.specTypes)::DataType + can_eval = true + for p in params.parameters + if !known_type(mod, p) + can_eval = false + break + end + end + if can_eval + push!(pc.precompiles, (t.inclusive_time, mi)) + return pc + end + end + foreach(t.children) do c + precompilable_roots!(pc, c) + end + return pc +end + +function parcel(pcs::Vector{Precompiles}) + tosecs((t, mi)::Tuple{UInt64,MethodInstance}) = (t/10^9, mi) + pcdict = Dict{Module,Vector{Tuple{UInt64,MethodInstance}}}() + t_grand_total = sum(inclusive_time, pcs; init=zero(UInt64)) + for pc in pcs + for (t, mi) in pc.precompiles + m = mi.def + mod = isa(m, Method) ? m.module : m + list = get!(Vector{Tuple{UInt64,MethodInstance}}, pcdict, mod) + push!(list, (t, mi)) + end + end + pclist = [mod => (precompilable_time(list)/10^9, sort!(tosecs.(list); by=first)) for (mod, list) in pcdict] + sort!(pclist; by = pr -> pr.second[1]) + return t_grand_total/10^9, pclist +end + +parcel(t::InclusiveTiming) = parcel(precompilable_roots(t)) +parcel(t::Timing) = parcel(build_inclusive_times(t)) + +function get_reprs(tmi::Vector{Tuple{Float64,MethodInstance}}; tmin=0.001) + strs = OrderedSet{String}() + modgens = Dict{Module, Vector{Method}}() + tmp = String[] + twritten = 0.0 + for (t, mi) in reverse(tmi) + if t >= tmin + if add_repr!(tmp, modgens, mi; check_eval=false, time=t) + str = pop!(tmp) + if !any(rex -> occursin(rex, str), default_exclusions) + push!(strs, str) + twritten += t + end + end + end + end + return strs, twritten +end + +function write(io::IO, tmi::Vector{Tuple{Float64,MethodInstance}}; indent::AbstractString=" ", kwargs...) + strs, twritten = get_reprs(tmi; kwargs...) + for str in strs + println(io, indent, str) + end + return twritten, length(strs) +end + +function write(prefix::AbstractString, pc::Vector{Pair{Module,Tuple{Float64,Vector{Tuple{Float64,MethodInstance}}}}}; ioreport::IO=stdout, header::Bool=true, always::Bool=false, kwargs...) + if !isdir(prefix) + mkpath(prefix) + end + for (mod, ttmi) in pc + tmod, tmi = ttmi + v, twritten = get_reprs(tmi; kwargs...) + if isempty(v) + println(ioreport, "$mod: no precompile statements out of $tmod") + continue + end + open(joinpath(prefix, "precompile_$(mod).jl"), "w") do io + if header + if any(str->occursin("__lookup", str), v) + println(io, lookup_kwbody_str) + end + println(io, "function _precompile_()") + !always && println(io, " ccall(:jl_generating_output, Cint, ()) == 1 || return nothing") + end + for ln in v + println(io, " ", ln) + end + header && println(io, "end") + end + println(ioreport, "$mod: precompiled $twritten out of $tmod") + end +end + """ flamegraph(t::Core.Compiler.Timings.Timing; tmin_secs=0.0) flamegraph(t::SnoopCompile.InclusiveTiming; tmin_secs=0.0) diff --git a/src/write.jl b/src/write.jl index ff3e956cb..ee72d0101 100644 --- a/src/write.jl +++ b/src/write.jl @@ -1,5 +1,5 @@ # Write precompiles for userimg.jl -function write(io::IO, pc::Vector) +function write(io::IO, pc::Vector{<:AbstractString}) for ln in pc println(io, ln) end @@ -10,10 +10,11 @@ function write(filename::AbstractString, pc::Vector) if !isdir(path) mkpath(path) end + ret = nothing open(filename, "w") do io - write(io, pc) + ret = write(io, pc) end - nothing + return ret end """ diff --git a/test/snoopi_deep.jl b/test/snoopi_deep.jl index 52373e03e..3a4072ed7 100644 --- a/test/snoopi_deep.jl +++ b/test/snoopi_deep.jl @@ -1,6 +1,7 @@ using SnoopCompile using SnoopCompile.SnoopCompileCore using Test +using Random using AbstractTrees # For FlameGraphs tests @@ -21,7 +22,7 @@ using AbstractTrees # For FlameGraphs tests g(y::Integer) = h(Any[y]) end - timing = SnoopCompileCore.@snoopi_deep begin + timing = @snoopi_deep begin M.g(2) M.g(true) end @@ -66,7 +67,7 @@ end g(y::Integer) = h(Any[y]) end - timing = SnoopCompileCore.@snoopi_deep begin + timing = @snoopi_deep begin M.g(2) end times = flatten_times(timing) @@ -87,3 +88,60 @@ end @test length(collect(AbstractTrees.PreOrderDFS(fg2))) == (length(collect(AbstractTrees.PreOrderDFS(fg))) - 1) end end + +include("testmodules/SnoopBench.jl") +@testset "parcel" begin + a = SnoopBench.A() + tinf = @snoopi_deep SnoopBench.f1(a) + ttot, prs = SnoopCompile.parcel(tinf) + mod, (tmod, tmis) = only(prs) + @test mod === SnoopBench + t, mi = only(tmis) + @test ttot == tmod == t # since there is only one + @test mi.def.name === :f1 + + A = [a] + tinf = @snoopi_deep SnoopBench.mappushes(identity, A) + ttot, prs = SnoopCompile.parcel(tinf) + mod, (tmod, tmis) = only(prs) + @test mod === SnoopBench + @test ttot == tmod # since there is only one + @test length(tmis) == 2 + io = IOBuffer() + SnoopCompile.write(io, tmis; tmin=0.0) + str = String(take!(io)) + @test occursin(r"typeof\(mappushes\),Any,Vector\{A\}", str) + @test occursin(r"typeof\(mappushes!\),typeof\(identity\),Vector\{Any\},Vector\{A\}", str) + + list = Any[1, 1.0, Float16(1.0), a] + tinf = @snoopi_deep SnoopBench.mappushes(isequal(Int8(1)), list) + ttot, prs = SnoopCompile.parcel(tinf) + @test length(prs) == 2 + _, (tmodBase, tmis) = prs[findfirst(pr->pr.first === Base, prs)] + tw, nw = SnoopCompile.write(io, tmis; tmin=0.0) + @test 0.0 <= tw <= tmodBase && 0 <= nw <= length(tmis)-1 + str = String(take!(io)) + @test !occursin(r"Base.Fix2\{typeof\(isequal\).*SnoopBench.A\}", str) + @test length(split(chomp(str), '\n')) == nw + _, (tmodBench, tmis) = prs[findfirst(pr->pr.first === SnoopBench, prs)] + @test tmodBench + tmodBase ≈ ttot + tw, nw = SnoopCompile.write(io, tmis; tmin=0.0) + @test nw == 2 + str = String(take!(io)) + @test occursin(r"typeof\(mappushes\),Any,Vector\{Any\}", str) + @test occursin(r"typeof\(mappushes!\),Base.Fix2\{typeof\(isequal\).*\},Vector\{Any\},Vector\{Any\}", str) + + td = joinpath(tempdir(), randstring(8)) + SnoopCompile.write(td, prs; ioreport=io) + str = String(take!(io)) + @test occursin(r"Base: precompiled [\d\.]+ out of [\d\.]+", str) + @test occursin(r"SnoopBench: precompiled [\d\.]+ out of [\d\.]+", str) + file_base = joinpath(td, "precompile_Base.jl") + @test isfile(file_base) + @test occursin("ccall(:jl_generating_output", read(file_base, String)) + rm(td, recursive=true, force=true) + SnoopCompile.write(td, prs; ioreport=io, header=false) + str = String(take!(io)) # just to clear it in case we use it again + @test !occursin("ccall(:jl_generating_output", read(file_base, String)) + rm(td, recursive=true, force=true) +end diff --git a/test/testmodules/SnoopBench.jl b/test/testmodules/SnoopBench.jl new file mode 100644 index 000000000..5d4b6194d --- /dev/null +++ b/test/testmodules/SnoopBench.jl @@ -0,0 +1,25 @@ +module SnoopBench + +# Assignment of parcel to modules +struct A end +f3(::A) = 1 +f2(a::A) = f3(a) +f1(a::A) = f2(a) + +# Like map! except it uses push! +# With a single call site +mappushes!(f, dest, src) = (for item in src push!(dest, f(item)) end; return dest) +mappushes(@nospecialize(f), src) = mappushes!(f, [], src) +function mappushes3!(f, dest, src) + # A version with multiple call sites + item1 = src[1] + push!(dest, item1) + item2 = src[2] + push!(dest, item2) + item3 = src[3] + push!(dest, item3) + return dest +end +mappushes3(@nospecialize(f), src) = mappushes3!(f, [], src) + +end