Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed May 3, 2024
1 parent 5105226 commit de5d88b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 40 deletions.
43 changes: 12 additions & 31 deletions ext/ClimaTimeSteppersBenchmarkToolsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ Base.:*(n::Int, t::BenchmarkTools.Trial) =

include("benchmark_utils.jl")

function get_n_calls_per_step(integrator::CTS.DistributedODEIntegrator)
(; alg) = integrator
(; newtons_method, name) = alg
(; max_iters) = newtons_method
n_calls_per_step(name, max_iters)
end

# TODO: generalize
n_calls_per_step(::CTS.ARS343, max_newton_iters) = Dict(
"Wfact" => 3 * max_newton_iters,
"ldiv!" => 3 * max_newton_iters,
Expand Down Expand Up @@ -50,7 +58,6 @@ function CTS.benchmark_step(
(; f) = sol.prob
if f isa CTS.ClimaODEFunction

(; sol, u, p, dt, t) = integrator
W = get_W(integrator)
X = similar(u)
trials₀ = OrderedCollections.OrderedDict()
Expand All @@ -68,46 +75,20 @@ function CTS.benchmark_step(
#! format: on

trials = OrderedCollections.OrderedDict()
local n_calls
(; alg) = integrator
(; newtons_method, name) = alg
(; max_iters) = newtons_method


keep_percentage = true
n_calls_per_step = get_n_calls_per_step(integrator)
for k in keys(trials₀)
isnothing(trials₀[k]) && continue
trials[k] = trials₀[k] * n_calls_per_step(name, max_iters)[k]
trials[k] = trials₀[k] * n_calls_per_step[k]
end
n_calls = Dict(map(collect(keys(trials₀))) do k
k => n_calls_per_step(name, max_iters)[k]
end)

# keep_percentage = try
# for k in keys(trials₀)
# trials[k] = trials₀[k] * n_calls_per_step(name, max_iters)[k]
# end
# n_calls = map(collect(keys(trials₀))) do k
# n_calls_per_step(name, max_iters)[k]
# end
# true
# catch
# for k in keys(trials₀)
# trials[k] = trials₀[k]
# end
# n_calls = nothing
# false
# end

table_summary = OrderedCollections.OrderedDict()
for k in keys(trials)
isnothing(trials[k]) && continue
table_summary[k] = get_summary(trials[k], trials["step!"]; keep_percentage)
table_summary[k] = get_summary(trials[k], trials["step!"])
end
keep_percentage ||
@warn "The percentage column was computed incorrectly, please open an issue in ClimaTimeSteppers with a reproducer."

tabulate_summary(table_summary; n_calls)
tabulate_summary(table_summary; n_calls_per_step)

return (; table_summary, trials)
else
Expand Down
18 changes: 9 additions & 9 deletions ext/benchmark_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
##### BenchmarkTools's trial utils
#####

get_summary(trial, trial_step; keep_percentage) = (;
get_summary(trial, trial_step) = (;
# Using some BenchmarkTools internals :/
mem = BenchmarkTools.prettymemory(trial.memory),
mem_val = trial.memory,
Expand All @@ -13,10 +13,10 @@ get_summary(trial, trial_step; keep_percentage) = (;
t_mean_val = StatsBase.mean(trial.times),
t_med = BenchmarkTools.prettytime(StatsBase.median(trial.times)),
n_samples = length(trial),
percentage = keep_percentage ? minimum(trial.times) / minimum(trial_step.times) * 100 : -1,
percentage = minimum(trial.times) / minimum(trial_step.times) * 100,
)

function tabulate_summary(summary; n_calls)
function tabulate_summary(summary; n_calls_per_step)
summary_keys = collect(keys(summary))
mem = map(k -> summary[k].mem, summary_keys)
nalloc = map(k -> summary[k].nalloc, summary_keys)
Expand All @@ -27,11 +27,11 @@ function tabulate_summary(summary; n_calls)
n_samples = map(k -> summary[k].n_samples, summary_keys)
percentage = map(k -> summary[k].percentage, summary_keys)

func_names = if isnothing(n_calls)
func_names = if isnothing(n_calls_per_step)
map(k -> string(k), collect(keys(summary)))
else
@info "(#)x entries have been multiplied by corresponding factors in order to compute percentages"
map(k -> string(k, " ($(n_calls[k])x)"), collect(keys(summary)))
map(k -> string(k, " ($(n_calls_per_step[k])x)"), collect(keys(summary)))
end
table_data = hcat(func_names, mem, nalloc, t_min, t_max, t_mean, t_med, n_samples, percentage)

Expand Down Expand Up @@ -60,10 +60,10 @@ function get_trial(f, args, name, device; with_cu_prof = :bprofile, trace = fals
println("--------------- Benchmarking/profiling $name...")
trial = BenchmarkTools.run(b, samples = sample_limit)
if device isa ClimaComms.CUDADevice
if with_cu_prof == :bprofile
p = CUDA.@bprofile trace = trace f(args...)
elseif with_cu_prof == :profile
p = CUDA.@profile trace = trace f(args...)
p = if with_cu_prof == :bprofile
CUDA.@bprofile trace = trace f(args...)
else
CUDA.@profile trace = trace f(args...)
end
println(p)
end
Expand Down

0 comments on commit de5d88b

Please sign in to comment.