diff --git a/Project.toml b/Project.toml index 565797d06..2d57291e6 100644 --- a/Project.toml +++ b/Project.toml @@ -17,8 +17,20 @@ NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +[weakdeps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" + +[extensions] +ClimaTimeSteppersBenchmarkToolsExt = ["CUDA", "BenchmarkTools", "OrderedCollections", "StatsBase", "PrettyTables"] + [compat] +BenchmarkTools = "1" ClimaComms = "0.4, 0.5, 0.6" +CUDA = "3, 4, 5" Colors = "0.12" DataStructures = "0.18" DiffEqBase = "6" @@ -27,7 +39,10 @@ KernelAbstractions = "0.7, 0.8, 0.9" Krylov = "0.8, 0.9" LinearAlgebra = "1" LinearOperators = "2" +OrderedCollections = "1" +PrettyTables = "2" NVTX = "0.3" SciMLBase = "1, 2" StaticArrays = "1" +StatsBase = "0.33, 0.34" julia = "1.8" diff --git a/ext/ClimaTimeSteppersBenchmarkToolsExt.jl b/ext/ClimaTimeSteppersBenchmarkToolsExt.jl new file mode 100644 index 000000000..5f1713c83 --- /dev/null +++ b/ext/ClimaTimeSteppersBenchmarkToolsExt.jl @@ -0,0 +1,119 @@ +module ClimaTimeSteppersBenchmarkToolsExt + + +import StatsBase +import SciMLBase +import PrettyTables +import OrderedCollections +import LinearAlgebra as LA +import ClimaTimeSteppers as CTS +using CUDA +import ClimaComms + +import BenchmarkTools +Base.:*(t::BenchmarkTools.Trial, n::Int) = n * t +Base.:*(n::Int, t::BenchmarkTools.Trial) = + BenchmarkTools.Trial(t.params, t.times .* n, t.gctimes .* n, t.memory * n, t.allocs * n) + +include("benchmark_utils.jl") + +n_calls_per_step(::CTS.ARS343, max_newton_iters) = Dict( + "Wfact" => 3 * max_newton_iters, + "ldiv!" => 3 * max_newton_iters, + "T_imp!" => 3 * max_newton_iters, + "T_exp_T_lim!" => 4, + "lim!" => 4, + "dss!" => 4, + "post_explicit!" => 3, + "post_implicit!" => 4, + "step!" => 1, +) + + +""" + benchmark_step( + integrator::DistributedODEIntegrator, + device::ClimaComms.AbstractDevice; + with_cu_prof = :bfrofile, # [:bprofile, :profile] + trace = false + ) + +Benchmark a DistributedODEIntegrator +""" +function CTS.benchmark_step( + integrator::CTS.DistributedODEIntegrator, + device::ClimaComms.AbstractDevice; + with_cu_prof = :bprofile, + trace = false, +) + (; u, p, t, dt, sol, alg) = integrator + (; f) = sol.prob + if f isa CTS.ClimaODEFunction + + (; sol, u, p, dt, t) = integrator + W = get_W(integrator) + X = similar(u) + trials₀ = OrderedCollections.OrderedDict() + +#! format: off + trials₀["Wfact"] = get_trial(wfact_fun(integrator), (W, u, p, dt, t), "Wfact", device; with_cu_prof, trace); + trials₀["ldiv!"] = get_trial(LA.ldiv!, (X, W, u), "ldiv!", device; with_cu_prof, trace); + trials₀["T_imp!"] = get_trial(implicit_fun(integrator), implicit_args(integrator), "T_imp!", device; with_cu_prof, trace); + trials₀["T_exp_T_lim!"] = get_trial(remaining_fun(integrator), remaining_args(integrator), "T_exp_T_lim!", device; with_cu_prof, trace); + trials₀["lim!"] = get_trial(f.lim!, (X, p, t, u), "lim!", device; with_cu_prof, trace); + trials₀["dss!"] = get_trial(f.dss!, (u, p, t), "dss!", device; with_cu_prof, trace); + trials₀["post_explicit!"] = get_trial(f.post_explicit!, (u, p, t), "post_explicit!", device; with_cu_prof, trace); + trials₀["post_implicit!"] = get_trial(f.post_implicit!, (u, p, t), "post_implicit!", device; with_cu_prof, trace); + trials₀["step!"] = get_trial(SciMLBase.step!, (integrator, ), "step!", device; with_cu_prof, trace); +#! format: on + + trials = OrderedCollections.OrderedDict() + local n_calls + (; alg) = integrator + (; newtons_method, name) = alg + (; max_iters) = newtons_method + + + keep_percentage = true + for k in keys(trials₀) + trials[k] = trials₀[k] * n_calls_per_step(name, max_iters)[k] + end + n_calls = Dict(map(collect(keys(trials₀))) do k + k => n_calls_per_step(name, max_iters)[k] + end) + + # keep_percentage = try + # for k in keys(trials₀) + # trials[k] = trials₀[k] * n_calls_per_step(name, max_iters)[k] + # end + # n_calls = map(collect(keys(trials₀))) do k + # n_calls_per_step(name, max_iters)[k] + # end + # true + # catch + # for k in keys(trials₀) + # trials[k] = trials₀[k] + # end + # n_calls = nothing + # false + # end + + table_summary = OrderedCollections.OrderedDict() + for k in keys(trials) + isnothing(trials[k]) && continue + table_summary[k] = get_summary(trials[k], trials["step!"]; keep_percentage) + end + keep_percentage || + @warn "The percentage column was computed incorrectly, please open an issue in ClimaTimeSteppers with a reproducer." + + tabulate_summary(table_summary; n_calls) + + return (; table_summary, trials) + else + @warn "`ClimaTimeSteppers.benchmark` is not yet supported for $f." + return (; table_summary = nothing, trials = nothing) + end +end + + +end diff --git a/ext/benchmark_utils.jl b/ext/benchmark_utils.jl new file mode 100644 index 000000000..f69bdd0dd --- /dev/null +++ b/ext/benchmark_utils.jl @@ -0,0 +1,90 @@ +##### +##### BenchmarkTools's trial utils +##### + +get_summary(trial, trial_step; keep_percentage) = (; + # Using some BenchmarkTools internals :/ + mem = BenchmarkTools.prettymemory(trial.memory), + mem_val = trial.memory, + nalloc = trial.allocs, + t_min = BenchmarkTools.prettytime(minimum(trial.times)), + t_max = BenchmarkTools.prettytime(maximum(trial.times)), + t_mean = BenchmarkTools.prettytime(StatsBase.mean(trial.times)), + t_mean_val = StatsBase.mean(trial.times), + t_med = BenchmarkTools.prettytime(StatsBase.median(trial.times)), + n_samples = length(trial), + percentage = keep_percentage ? minimum(trial.times) / minimum(trial_step.times) * 100 : -1, +) + +function tabulate_summary(summary; n_calls) + summary_keys = collect(keys(summary)) + mem = map(k -> summary[k].mem, summary_keys) + nalloc = map(k -> summary[k].nalloc, summary_keys) + t_mean = map(k -> summary[k].t_mean, summary_keys) + t_min = map(k -> summary[k].t_min, summary_keys) + t_max = map(k -> summary[k].t_max, summary_keys) + t_med = map(k -> summary[k].t_med, summary_keys) + n_samples = map(k -> summary[k].n_samples, summary_keys) + percentage = map(k -> summary[k].percentage, summary_keys) + + func_names = if isnothing(n_calls) + map(k -> string(k), collect(keys(summary))) + else + @info "(#)x entries have been multiplied by corresponding factors in order to compute percentages" + map(k -> string(k, " ($(n_calls[k])x)"), collect(keys(summary))) + end + table_data = hcat(func_names, mem, nalloc, t_min, t_max, t_mean, t_med, n_samples, percentage) + + header = ( + ["Function", "Memory", "allocs", "Time", "Time", "Time", "Time", "N-samples", "step! percentage"], + [" ", "estimate", "estimate", "min", "max", "mean", "median", "", ""], + ) + + PrettyTables.pretty_table( + table_data; + header, + crop = :none, + alignment = vcat(:l, repeat([:r], length(header[1]) - 1)), + ) +end + +get_trial(f::Nothing, args, name, device; with_cu_prof = :bprofile, trace = false) = nothing +function get_trial(f, args, name, device; with_cu_prof = :bprofile, trace = false) + f(args...) # compile first + b = if device isa ClimaComms.CUDADevice + BenchmarkTools.@benchmarkable CUDA.@sync $f($(args)...) + else + BenchmarkTools.@benchmarkable $f($(args)...) + end + sample_limit = 10 + println("--------------- Benchmarking/profiling $name...") + trial = BenchmarkTools.run(b, samples = sample_limit) + if device isa ClimaComms.CUDADevice + if with_cu_prof == :bprofile + p = CUDA.@bprofile trace = trace f(args...) + elseif with_cu_prof == :profile + p = CUDA.@profile trace = trace f(args...) + end + println(p) + end + println() + return trial +end + +get_W(i::CTS.DistributedODEIntegrator) = i.cache.newtons_method_cache.j +get_W(i) = i.cache.W +f_args(i, f::CTS.ForwardEulerODEFunction) = (copy(i.u), i.u, i.p, i.t, i.dt) +f_args(i, f) = (similar(i.u), i.u, i.p, i.t) + +r_args(i, f::CTS.ForwardEulerODEFunction) = (copy(i.u), copy(i.u), i.u, i.p, i.t, i.dt) +r_args(i, f) = (similar(i.u), similar(i.u), i.u, i.p, i.t) + +implicit_args(i::CTS.DistributedODEIntegrator) = f_args(i, i.sol.prob.f.T_imp!) +implicit_args(i) = f_args(i, i.f.f1) +remaining_args(i::CTS.DistributedODEIntegrator) = r_args(i, i.sol.prob.f.T_exp_T_lim!) +remaining_args(i) = r_args(i, i.f.f2) +wfact_fun(i) = implicit_fun(i).Wfact +implicit_fun(i::CTS.DistributedODEIntegrator) = i.sol.prob.f.T_imp! +implicit_fun(i) = i.sol.prob.f.f1 +remaining_fun(i::CTS.DistributedODEIntegrator) = i.sol.prob.f.T_exp_T_lim! +remaining_fun(i) = i.sol.prob.f.f2 diff --git a/perf/benchmark.jl b/perf/benchmark.jl index fc64f49ca..6bf4f310f 100644 --- a/perf/benchmark.jl +++ b/perf/benchmark.jl @@ -2,20 +2,20 @@ import ClimaTimeSteppers as CTS import DiffEqBase using BenchmarkTools, DiffEqBase +using CUDA, BenchmarkTools, OrderedCollections, StatsBase, PrettyTables # needed for CTS.benchmark_step + include(joinpath(pkgdir(CTS), "test", "problems.jl")) function main() algorithm = CTS.IMEXAlgorithm(CTS.ARS343(), CTS.NewtonsMethod(; max_iters = 2)) dt = 0.01 + device = ClimaComms.device() for problem in (split_linear_prob_wfact_split(), split_linear_prob_wfact_split_fe()) integrator = DiffEqBase.init(problem, algorithm; dt) cache = CTS.init_cache(problem, algorithm) - CTS.step_u!(integrator, cache) - - trial = @benchmark CTS.step_u!($integrator, $cache) - show(stdout, MIME("text/plain"), trial) + CTS.benchmark_step(integrator, device) end end main() diff --git a/src/ClimaTimeSteppers.jl b/src/ClimaTimeSteppers.jl index 78a14a90c..1fba80920 100644 --- a/src/ClimaTimeSteppers.jl +++ b/src/ClimaTimeSteppers.jl @@ -128,4 +128,8 @@ include("solvers/rosenbrock.jl") include("Callbacks.jl") + +benchmark_step(integrator, device) = + @warn "Must load CUDA, BenchmarkTools, OrderedCollections, StatsBase, PrettyTables to trigger the ClimaTimeSteppersBenchmarkToolsExt extension" + end diff --git a/test/aqua.jl b/test/aqua.jl index 4453e1f19..d8603c929 100644 --- a/test/aqua.jl +++ b/test/aqua.jl @@ -18,7 +18,7 @@ using Aqua end @testset "Aqua tests (remaining)" begin - Aqua.test_all(ClimaTimeSteppers; ambiguities=false) + Aqua.test_all(ClimaTimeSteppers; ambiguities = false) end nothing