-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #277 from CliMA/ck/benchmark
Add benchmark utility extension
- Loading branch information
Showing
8 changed files
with
265 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
module ClimaTimeSteppersBenchmarkToolsExt | ||
|
||
|
||
import StatsBase | ||
import SciMLBase | ||
import PrettyTables | ||
import OrderedCollections | ||
import LinearAlgebra as LA | ||
import ClimaTimeSteppers as CTS | ||
using CUDA | ||
import ClimaComms | ||
|
||
import BenchmarkTools | ||
Base.:*(t::BenchmarkTools.Trial, n::Int) = n * t | ||
Base.:*(n::Int, t::BenchmarkTools.Trial) = | ||
BenchmarkTools.Trial(t.params, t.times .* n, t.gctimes .* n, t.memory * n, t.allocs * n) | ||
|
||
include("benchmark_utils.jl") | ||
|
||
function get_n_calls_per_step(integrator::CTS.DistributedODEIntegrator) | ||
(; alg) = integrator | ||
(; newtons_method, name) = alg | ||
(; max_iters) = newtons_method | ||
n_calls_per_step(name, max_iters) | ||
end | ||
|
||
# TODO: generalize | ||
n_calls_per_step(::CTS.ARS343, max_newton_iters) = Dict( | ||
"Wfact" => 3 * max_newton_iters, | ||
"ldiv!" => 3 * max_newton_iters, | ||
"T_imp!" => 3 * max_newton_iters, | ||
"T_exp_T_lim!" => 4, | ||
"lim!" => 4, | ||
"dss!" => 4, | ||
"post_explicit!" => 3, | ||
"post_implicit!" => 4, | ||
"step!" => 1, | ||
) | ||
|
||
|
||
""" | ||
benchmark_step( | ||
integrator::DistributedODEIntegrator, | ||
device::ClimaComms.AbstractDevice; | ||
with_cu_prof = :bfrofile, # [:bprofile, :profile] | ||
trace = false | ||
) | ||
Benchmark a DistributedODEIntegrator | ||
""" | ||
function CTS.benchmark_step( | ||
integrator::CTS.DistributedODEIntegrator, | ||
device::ClimaComms.AbstractDevice; | ||
with_cu_prof = :bprofile, | ||
trace = false, | ||
) | ||
(; u, p, t, dt, sol, alg) = integrator | ||
(; f) = sol.prob | ||
if f isa CTS.ClimaODEFunction | ||
|
||
W = get_W(integrator) | ||
X = similar(u) | ||
trials₀ = OrderedCollections.OrderedDict() | ||
|
||
#! format: off | ||
trials₀["Wfact"] = get_trial(wfact_fun(integrator), (W, u, p, dt, t), "Wfact", device; with_cu_prof, trace); | ||
trials₀["ldiv!"] = get_trial(LA.ldiv!, (X, W, u), "ldiv!", device; with_cu_prof, trace); | ||
trials₀["T_imp!"] = get_trial(implicit_fun(integrator), implicit_args(integrator), "T_imp!", device; with_cu_prof, trace); | ||
trials₀["T_exp_T_lim!"] = get_trial(remaining_fun(integrator), remaining_args(integrator), "T_exp_T_lim!", device; with_cu_prof, trace); | ||
trials₀["lim!"] = get_trial(f.lim!, (X, p, t, u), "lim!", device; with_cu_prof, trace); | ||
trials₀["dss!"] = get_trial(f.dss!, (u, p, t), "dss!", device; with_cu_prof, trace); | ||
trials₀["post_explicit!"] = get_trial(f.post_explicit!, (u, p, t), "post_explicit!", device; with_cu_prof, trace); | ||
trials₀["post_implicit!"] = get_trial(f.post_implicit!, (u, p, t), "post_implicit!", device; with_cu_prof, trace); | ||
trials₀["step!"] = get_trial(SciMLBase.step!, (integrator, ), "step!", device; with_cu_prof, trace); | ||
#! format: on | ||
|
||
trials = OrderedCollections.OrderedDict() | ||
|
||
n_calls_per_step = get_n_calls_per_step(integrator) | ||
for k in keys(trials₀) | ||
isnothing(trials₀[k]) && continue | ||
trials[k] = trials₀[k] * n_calls_per_step[k] | ||
end | ||
|
||
table_summary = OrderedCollections.OrderedDict() | ||
for k in keys(trials) | ||
isnothing(trials[k]) && continue | ||
table_summary[k] = get_summary(trials[k], trials["step!"]) | ||
end | ||
|
||
tabulate_summary(table_summary; n_calls_per_step) | ||
|
||
return (; table_summary, trials) | ||
else | ||
@warn "`ClimaTimeSteppers.benchmark` is not yet supported for $f." | ||
return (; table_summary = nothing, trials = nothing) | ||
end | ||
end | ||
|
||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
##### | ||
##### BenchmarkTools's trial utils | ||
##### | ||
|
||
get_summary(trial, trial_step) = (; | ||
# Using some BenchmarkTools internals :/ | ||
mem = BenchmarkTools.prettymemory(trial.memory), | ||
mem_val = trial.memory, | ||
nalloc = trial.allocs, | ||
t_min = BenchmarkTools.prettytime(minimum(trial.times)), | ||
t_max = BenchmarkTools.prettytime(maximum(trial.times)), | ||
t_mean = BenchmarkTools.prettytime(StatsBase.mean(trial.times)), | ||
t_mean_val = StatsBase.mean(trial.times), | ||
t_med = BenchmarkTools.prettytime(StatsBase.median(trial.times)), | ||
n_samples = length(trial), | ||
percentage = minimum(trial.times) / minimum(trial_step.times) * 100, | ||
) | ||
|
||
function tabulate_summary(summary; n_calls_per_step) | ||
summary_keys = collect(keys(summary)) | ||
mem = map(k -> summary[k].mem, summary_keys) | ||
nalloc = map(k -> summary[k].nalloc, summary_keys) | ||
t_mean = map(k -> summary[k].t_mean, summary_keys) | ||
t_min = map(k -> summary[k].t_min, summary_keys) | ||
t_max = map(k -> summary[k].t_max, summary_keys) | ||
t_med = map(k -> summary[k].t_med, summary_keys) | ||
n_samples = map(k -> summary[k].n_samples, summary_keys) | ||
percentage = map(k -> summary[k].percentage, summary_keys) | ||
|
||
func_names = if isnothing(n_calls_per_step) | ||
map(k -> string(k), collect(keys(summary))) | ||
else | ||
@info "(#)x entries have been multiplied by corresponding factors in order to compute percentages" | ||
map(k -> string(k, " ($(n_calls_per_step[k])x)"), collect(keys(summary))) | ||
end | ||
table_data = hcat(func_names, mem, nalloc, t_min, t_max, t_mean, t_med, n_samples, percentage) | ||
|
||
header = ( | ||
["Function", "Memory", "allocs", "Time", "Time", "Time", "Time", "N-samples", "step! percentage"], | ||
[" ", "estimate", "estimate", "min", "max", "mean", "median", "", ""], | ||
) | ||
|
||
PrettyTables.pretty_table( | ||
table_data; | ||
header, | ||
crop = :none, | ||
alignment = vcat(:l, repeat([:r], length(header[1]) - 1)), | ||
) | ||
end | ||
|
||
get_trial(f::Nothing, args, name, device; with_cu_prof = :bprofile, trace = false) = nothing | ||
function get_trial(f, args, name, device; with_cu_prof = :bprofile, trace = false) | ||
f(args...) # compile first | ||
b = if device isa ClimaComms.CUDADevice | ||
BenchmarkTools.@benchmarkable CUDA.@sync $f($(args)...) | ||
else | ||
BenchmarkTools.@benchmarkable $f($(args)...) | ||
end | ||
sample_limit = 10 | ||
println("--------------- Benchmarking/profiling $name...") | ||
trial = BenchmarkTools.run(b, samples = sample_limit) | ||
if device isa ClimaComms.CUDADevice | ||
p = if with_cu_prof == :bprofile | ||
CUDA.@bprofile trace = trace f(args...) | ||
else | ||
CUDA.@profile trace = trace f(args...) | ||
end | ||
println(p) | ||
end | ||
println() | ||
return trial | ||
end | ||
|
||
get_W(i::CTS.DistributedODEIntegrator) = i.cache.newtons_method_cache.j | ||
get_W(i) = i.cache.W | ||
f_args(i, f::CTS.ForwardEulerODEFunction) = (copy(i.u), i.u, i.p, i.t, i.dt) | ||
f_args(i, f) = (similar(i.u), i.u, i.p, i.t) | ||
|
||
r_args(i, f::CTS.ForwardEulerODEFunction) = (copy(i.u), copy(i.u), i.u, i.p, i.t, i.dt) | ||
r_args(i, f) = (similar(i.u), similar(i.u), i.u, i.p, i.t) | ||
|
||
implicit_args(i::CTS.DistributedODEIntegrator) = f_args(i, i.sol.prob.f.T_imp!) | ||
implicit_args(i) = f_args(i, i.f.f1) | ||
remaining_args(i::CTS.DistributedODEIntegrator) = r_args(i, i.sol.prob.f.T_exp_T_lim!) | ||
remaining_args(i) = r_args(i, i.f.f2) | ||
wfact_fun(i) = implicit_fun(i).Wfact | ||
implicit_fun(i::CTS.DistributedODEIntegrator) = i.sol.prob.f.T_imp! | ||
implicit_fun(i) = i.sol.prob.f.f1 | ||
remaining_fun(i::CTS.DistributedODEIntegrator) = i.sol.prob.f.T_exp_T_lim! | ||
remaining_fun(i) = i.sol.prob.f.f2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,39 @@ | ||
using ArgParse, JET, Test, DiffEqBase, ClimaTimeSteppers | ||
using ClimaComms | ||
using CUDA, BenchmarkTools, OrderedCollections, StatsBase, PrettyTables # needed for CTS.benchmark_step | ||
import ClimaTimeSteppers as CTS | ||
import DiffEqBase | ||
using BenchmarkTools, DiffEqBase | ||
|
||
include(joinpath(pkgdir(CTS), "test", "problems.jl")) | ||
|
||
function main() | ||
algorithm = CTS.IMEXAlgorithm(CTS.ARS343(), CTS.NewtonsMethod(; max_iters = 2)) | ||
dt = 0.01 | ||
for problem in (split_linear_prob_wfact_split(), split_linear_prob_wfact_split_fe()) | ||
integrator = DiffEqBase.init(problem, algorithm; dt) | ||
|
||
cache = CTS.init_cache(problem, algorithm) | ||
|
||
CTS.step_u!(integrator, cache) | ||
|
||
trial = @benchmark CTS.step_u!($integrator, $cache) | ||
show(stdout, MIME("text/plain"), trial) | ||
function parse_commandline() | ||
s = ArgParse.ArgParseSettings() | ||
ArgParse.@add_arg_table s begin | ||
"--problem" | ||
help = "Problem type [`ode_fun`, `fe`]" | ||
arg_type = String | ||
default = "diffusion2d" | ||
end | ||
parsed_args = ArgParse.parse_args(ARGS, s) | ||
return (s, parsed_args) | ||
end | ||
(s, parsed_args) = parse_commandline() | ||
cts = joinpath(dirname(@__DIR__)); | ||
include(joinpath(cts, "test", "problems.jl")) | ||
config_integrators(itc::IntegratorTestCase) = config_integrators(itc.prob) | ||
function config_integrators(problem) | ||
algorithm = CTS.IMEXAlgorithm(ARS343(), NewtonsMethod(; max_iters = 2)) | ||
dt = 0.01 | ||
integrator = DiffEqBase.init(problem, algorithm; dt) | ||
integrator.cache = CTS.init_cache(problem, algorithm) | ||
return (; integrator) | ||
end | ||
main() | ||
prob = if parsed_args["problem"] == "diffusion2d" | ||
climacore_2Dheat_test_cts(Float64) | ||
elseif parsed_args["problem"] == "ode_fun" | ||
split_linear_prob_wfact_split() | ||
elseif parsed_args["problem"] == "fe" | ||
split_linear_prob_wfact_split_fe() | ||
else | ||
error("Bad option") | ||
end | ||
(; integrator) = config_integrators(prob) | ||
|
||
device = ClimaComms.device() | ||
CTS.benchmark_step(integrator, device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters