diff --git a/Project.toml b/Project.toml index 565797d0..2d57291e 100644 --- a/Project.toml +++ b/Project.toml @@ -17,8 +17,20 @@ NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +[weakdeps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" + +[extensions] +ClimaTimeSteppersBenchmarkToolsExt = ["CUDA", "BenchmarkTools", "OrderedCollections", "StatsBase", "PrettyTables"] + [compat] +BenchmarkTools = "1" ClimaComms = "0.4, 0.5, 0.6" +CUDA = "3, 4, 5" Colors = "0.12" DataStructures = "0.18" DiffEqBase = "6" @@ -27,7 +39,10 @@ KernelAbstractions = "0.7, 0.8, 0.9" Krylov = "0.8, 0.9" LinearAlgebra = "1" LinearOperators = "2" +OrderedCollections = "1" +PrettyTables = "2" NVTX = "0.3" SciMLBase = "1, 2" StaticArrays = "1" +StatsBase = "0.33, 0.34" julia = "1.8" diff --git a/ext/ClimaTimeSteppersBenchmarkToolsExt.jl b/ext/ClimaTimeSteppersBenchmarkToolsExt.jl new file mode 100644 index 00000000..b2e5b2e2 --- /dev/null +++ b/ext/ClimaTimeSteppersBenchmarkToolsExt.jl @@ -0,0 +1,101 @@ +module ClimaTimeSteppersBenchmarkToolsExt + + +import StatsBase +import SciMLBase +import PrettyTables +import OrderedCollections +import LinearAlgebra as LA +import ClimaTimeSteppers as CTS +using CUDA +import ClimaComms + +import BenchmarkTools +Base.:*(t::BenchmarkTools.Trial, n::Int) = n * t +Base.:*(n::Int, t::BenchmarkTools.Trial) = + BenchmarkTools.Trial(t.params, t.times .* n, t.gctimes .* n, t.memory * n, t.allocs * n) + +include("benchmark_utils.jl") + +function get_n_calls_per_step(integrator::CTS.DistributedODEIntegrator) + (; alg) = integrator + (; newtons_method, name) = alg + (; max_iters) = newtons_method + n_calls_per_step(name, max_iters) +end + +# TODO: generalize +n_calls_per_step(::CTS.ARS343, max_newton_iters) = Dict( + "Wfact" => 3 * max_newton_iters, + "ldiv!" => 3 * max_newton_iters, + "T_imp!" => 3 * max_newton_iters, + "T_exp_T_lim!" => 4, + "lim!" => 4, + "dss!" => 4, + "post_explicit!" => 3, + "post_implicit!" => 4, + "step!" => 1, +) + + +""" + benchmark_step( + integrator::DistributedODEIntegrator, + device::ClimaComms.AbstractDevice; + with_cu_prof = :bfrofile, # [:bprofile, :profile] + trace = false + ) + +Benchmark a DistributedODEIntegrator +""" +function CTS.benchmark_step( + integrator::CTS.DistributedODEIntegrator, + device::ClimaComms.AbstractDevice; + with_cu_prof = :bprofile, + trace = false, +) + (; u, p, t, dt, sol, alg) = integrator + (; f) = sol.prob + if f isa CTS.ClimaODEFunction + + W = get_W(integrator) + X = similar(u) + trials₀ = OrderedCollections.OrderedDict() + +#! format: off + trials₀["Wfact"] = get_trial(wfact_fun(integrator), (W, u, p, dt, t), "Wfact", device; with_cu_prof, trace); + trials₀["ldiv!"] = get_trial(LA.ldiv!, (X, W, u), "ldiv!", device; with_cu_prof, trace); + trials₀["T_imp!"] = get_trial(implicit_fun(integrator), implicit_args(integrator), "T_imp!", device; with_cu_prof, trace); + trials₀["T_exp_T_lim!"] = get_trial(remaining_fun(integrator), remaining_args(integrator), "T_exp_T_lim!", device; with_cu_prof, trace); + trials₀["lim!"] = get_trial(f.lim!, (X, p, t, u), "lim!", device; with_cu_prof, trace); + trials₀["dss!"] = get_trial(f.dss!, (u, p, t), "dss!", device; with_cu_prof, trace); + trials₀["post_explicit!"] = get_trial(f.post_explicit!, (u, p, t), "post_explicit!", device; with_cu_prof, trace); + trials₀["post_implicit!"] = get_trial(f.post_implicit!, (u, p, t), "post_implicit!", device; with_cu_prof, trace); + trials₀["step!"] = get_trial(SciMLBase.step!, (integrator, ), "step!", device; with_cu_prof, trace); +#! format: on + + trials = OrderedCollections.OrderedDict() + + n_calls_per_step = get_n_calls_per_step(integrator) + for k in keys(trials₀) + isnothing(trials₀[k]) && continue + trials[k] = trials₀[k] * n_calls_per_step[k] + end + + table_summary = OrderedCollections.OrderedDict() + for k in keys(trials) + isnothing(trials[k]) && continue + table_summary[k] = get_summary(trials[k], trials["step!"]) + end + + tabulate_summary(table_summary; n_calls_per_step) + + return (; table_summary, trials) + else + @warn "`ClimaTimeSteppers.benchmark` is not yet supported for $f." + return (; table_summary = nothing, trials = nothing) + end +end + + +end diff --git a/ext/benchmark_utils.jl b/ext/benchmark_utils.jl new file mode 100644 index 00000000..dd8384ce --- /dev/null +++ b/ext/benchmark_utils.jl @@ -0,0 +1,90 @@ +##### +##### BenchmarkTools's trial utils +##### + +get_summary(trial, trial_step) = (; + # Using some BenchmarkTools internals :/ + mem = BenchmarkTools.prettymemory(trial.memory), + mem_val = trial.memory, + nalloc = trial.allocs, + t_min = BenchmarkTools.prettytime(minimum(trial.times)), + t_max = BenchmarkTools.prettytime(maximum(trial.times)), + t_mean = BenchmarkTools.prettytime(StatsBase.mean(trial.times)), + t_mean_val = StatsBase.mean(trial.times), + t_med = BenchmarkTools.prettytime(StatsBase.median(trial.times)), + n_samples = length(trial), + percentage = minimum(trial.times) / minimum(trial_step.times) * 100, +) + +function tabulate_summary(summary; n_calls_per_step) + summary_keys = collect(keys(summary)) + mem = map(k -> summary[k].mem, summary_keys) + nalloc = map(k -> summary[k].nalloc, summary_keys) + t_mean = map(k -> summary[k].t_mean, summary_keys) + t_min = map(k -> summary[k].t_min, summary_keys) + t_max = map(k -> summary[k].t_max, summary_keys) + t_med = map(k -> summary[k].t_med, summary_keys) + n_samples = map(k -> summary[k].n_samples, summary_keys) + percentage = map(k -> summary[k].percentage, summary_keys) + + func_names = if isnothing(n_calls_per_step) + map(k -> string(k), collect(keys(summary))) + else + @info "(#)x entries have been multiplied by corresponding factors in order to compute percentages" + map(k -> string(k, " ($(n_calls_per_step[k])x)"), collect(keys(summary))) + end + table_data = hcat(func_names, mem, nalloc, t_min, t_max, t_mean, t_med, n_samples, percentage) + + header = ( + ["Function", "Memory", "allocs", "Time", "Time", "Time", "Time", "N-samples", "step! percentage"], + [" ", "estimate", "estimate", "min", "max", "mean", "median", "", ""], + ) + + PrettyTables.pretty_table( + table_data; + header, + crop = :none, + alignment = vcat(:l, repeat([:r], length(header[1]) - 1)), + ) +end + +get_trial(f::Nothing, args, name, device; with_cu_prof = :bprofile, trace = false) = nothing +function get_trial(f, args, name, device; with_cu_prof = :bprofile, trace = false) + f(args...) # compile first + b = if device isa ClimaComms.CUDADevice + BenchmarkTools.@benchmarkable CUDA.@sync $f($(args)...) + else + BenchmarkTools.@benchmarkable $f($(args)...) + end + sample_limit = 10 + println("--------------- Benchmarking/profiling $name...") + trial = BenchmarkTools.run(b, samples = sample_limit) + if device isa ClimaComms.CUDADevice + p = if with_cu_prof == :bprofile + CUDA.@bprofile trace = trace f(args...) + else + CUDA.@profile trace = trace f(args...) + end + println(p) + end + println() + return trial +end + +get_W(i::CTS.DistributedODEIntegrator) = i.cache.newtons_method_cache.j +get_W(i) = i.cache.W +f_args(i, f::CTS.ForwardEulerODEFunction) = (copy(i.u), i.u, i.p, i.t, i.dt) +f_args(i, f) = (similar(i.u), i.u, i.p, i.t) + +r_args(i, f::CTS.ForwardEulerODEFunction) = (copy(i.u), copy(i.u), i.u, i.p, i.t, i.dt) +r_args(i, f) = (similar(i.u), similar(i.u), i.u, i.p, i.t) + +implicit_args(i::CTS.DistributedODEIntegrator) = f_args(i, i.sol.prob.f.T_imp!) +implicit_args(i) = f_args(i, i.f.f1) +remaining_args(i::CTS.DistributedODEIntegrator) = r_args(i, i.sol.prob.f.T_exp_T_lim!) +remaining_args(i) = r_args(i, i.f.f2) +wfact_fun(i) = implicit_fun(i).Wfact +implicit_fun(i::CTS.DistributedODEIntegrator) = i.sol.prob.f.T_imp! +implicit_fun(i) = i.sol.prob.f.f1 +remaining_fun(i::CTS.DistributedODEIntegrator) = i.sol.prob.f.T_exp_T_lim! +remaining_fun(i) = i.sol.prob.f.f2 diff --git a/perf/Manifest.toml b/perf/Manifest.toml index cc60868a..12c8fa19 100644 --- a/perf/Manifest.toml +++ b/perf/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.2" +julia_version = "1.10.3" manifest_format = "2.0" -project_hash = "86292bcc5e784cb4f6a9709b4395280ec762f782" +project_hash = "192050c2153dcc826c1eca5d5b7a774796048dec" [[deps.ADTypes]] git-tree-sha1 = "fcdb00b4d412b80ab08e39978e3bdef579e5e224" @@ -248,7 +248,11 @@ weakdeps = ["Krylov"] deps = ["ClimaComms", "Colors", "DataStructures", "DiffEqBase", "DiffEqCallbacks", "KernelAbstractions", "Krylov", "LinearAlgebra", "LinearOperators", "NVTX", "SciMLBase", "StaticArrays"] path = ".." uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79" -version = "0.7.20" +version = "0.7.21" +weakdeps = ["BenchmarkTools", "CUDA", "OrderedCollections", "PrettyTables", "StatsBase"] + + [deps.ClimaTimeSteppers.extensions] + ClimaTimeSteppersBenchmarkToolsExt = ["CUDA", "BenchmarkTools", "OrderedCollections", "StatsBase", "PrettyTables"] [[deps.CloseOpenIntervals]] deps = ["Static", "StaticArrayInterface"] @@ -304,7 +308,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.0+0" +version = "1.1.1+0" [[deps.CompositionsBase]] git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" @@ -1461,6 +1465,12 @@ git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" version = "1.7.0" +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.34.3" + [[deps.StrideArraysCore]] deps = ["ArrayInterface", "CloseOpenIntervals", "IfElse", "LayoutPointers", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface", "ThreadingUtilities"] git-tree-sha1 = "25349bf8f63aa36acbff5e3550a86e9f5b0ef682" diff --git a/perf/Project.toml b/perf/Project.toml index c50c64f7..43b43955 100644 --- a/perf/Project.toml +++ b/perf/Project.toml @@ -1,6 +1,7 @@ [deps] ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884" ClimaTimeSteppers = "595c0a79-7f3d-439a-bc5a-b232dc3bde79" @@ -15,12 +16,15 @@ Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" PProf = "e4faabce-9ead-11e9-39d9-4379958e3056" +PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" ProfileCanvas = "efd6af41-a80b-495e-886c-e51b0c7d77a3" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" diff --git a/perf/benchmark.jl b/perf/benchmark.jl index fc64f49c..1d8200cb 100644 --- a/perf/benchmark.jl +++ b/perf/benchmark.jl @@ -1,21 +1,39 @@ +using ArgParse, JET, Test, DiffEqBase, ClimaTimeSteppers +using ClimaComms +using CUDA, BenchmarkTools, OrderedCollections, StatsBase, PrettyTables # needed for CTS.benchmark_step import ClimaTimeSteppers as CTS -import DiffEqBase -using BenchmarkTools, DiffEqBase - -include(joinpath(pkgdir(CTS), "test", "problems.jl")) - -function main() - algorithm = CTS.IMEXAlgorithm(CTS.ARS343(), CTS.NewtonsMethod(; max_iters = 2)) - dt = 0.01 - for problem in (split_linear_prob_wfact_split(), split_linear_prob_wfact_split_fe()) - integrator = DiffEqBase.init(problem, algorithm; dt) - - cache = CTS.init_cache(problem, algorithm) - - CTS.step_u!(integrator, cache) - - trial = @benchmark CTS.step_u!($integrator, $cache) - show(stdout, MIME("text/plain"), trial) +function parse_commandline() + s = ArgParse.ArgParseSettings() + ArgParse.@add_arg_table s begin + "--problem" + help = "Problem type [`ode_fun`, `fe`]" + arg_type = String + default = "diffusion2d" end + parsed_args = ArgParse.parse_args(ARGS, s) + return (s, parsed_args) +end +(s, parsed_args) = parse_commandline() +cts = joinpath(dirname(@__DIR__)); +include(joinpath(cts, "test", "problems.jl")) +config_integrators(itc::IntegratorTestCase) = config_integrators(itc.prob) +function config_integrators(problem) + algorithm = CTS.IMEXAlgorithm(ARS343(), NewtonsMethod(; max_iters = 2)) + dt = 0.01 + integrator = DiffEqBase.init(problem, algorithm; dt) + integrator.cache = CTS.init_cache(problem, algorithm) + return (; integrator) end -main() +prob = if parsed_args["problem"] == "diffusion2d" + climacore_2Dheat_test_cts(Float64) +elseif parsed_args["problem"] == "ode_fun" + split_linear_prob_wfact_split() +elseif parsed_args["problem"] == "fe" + split_linear_prob_wfact_split_fe() +else + error("Bad option") +end +(; integrator) = config_integrators(prob) + +device = ClimaComms.device() +CTS.benchmark_step(integrator, device) diff --git a/src/ClimaTimeSteppers.jl b/src/ClimaTimeSteppers.jl index 78a14a90..1fba8092 100644 --- a/src/ClimaTimeSteppers.jl +++ b/src/ClimaTimeSteppers.jl @@ -128,4 +128,8 @@ include("solvers/rosenbrock.jl") include("Callbacks.jl") + +benchmark_step(integrator, device) = + @warn "Must load CUDA, BenchmarkTools, OrderedCollections, StatsBase, PrettyTables to trigger the ClimaTimeSteppersBenchmarkToolsExt extension" + end diff --git a/test/aqua.jl b/test/aqua.jl index 4453e1f1..d8603c92 100644 --- a/test/aqua.jl +++ b/test/aqua.jl @@ -18,7 +18,7 @@ using Aqua end @testset "Aqua tests (remaining)" begin - Aqua.test_all(ClimaTimeSteppers; ambiguities=false) + Aqua.test_all(ClimaTimeSteppers; ambiguities = false) end nothing