Skip to content

Commit

Permalink
Merge pull request #277 from CliMA/ck/benchmark
Browse files Browse the repository at this point in the history
Add benchmark utility extension
  • Loading branch information
charleskawczynski authored May 3, 2024
2 parents 783370a + de5d88b commit 83cf6b6
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 23 deletions.
15 changes: 15 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,20 @@ NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[weakdeps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"

[extensions]
ClimaTimeSteppersBenchmarkToolsExt = ["CUDA", "BenchmarkTools", "OrderedCollections", "StatsBase", "PrettyTables"]

[compat]
BenchmarkTools = "1"
ClimaComms = "0.4, 0.5, 0.6"
CUDA = "3, 4, 5"
Colors = "0.12"
DataStructures = "0.18"
DiffEqBase = "6"
Expand All @@ -27,7 +39,10 @@ KernelAbstractions = "0.7, 0.8, 0.9"
Krylov = "0.8, 0.9"
LinearAlgebra = "1"
LinearOperators = "2"
OrderedCollections = "1"
PrettyTables = "2"
NVTX = "0.3"
SciMLBase = "1, 2"
StaticArrays = "1"
StatsBase = "0.33, 0.34"
julia = "1.8"
101 changes: 101 additions & 0 deletions ext/ClimaTimeSteppersBenchmarkToolsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
module ClimaTimeSteppersBenchmarkToolsExt


import StatsBase
import SciMLBase
import PrettyTables
import OrderedCollections
import LinearAlgebra as LA
import ClimaTimeSteppers as CTS
using CUDA
import ClimaComms

import BenchmarkTools
Base.:*(t::BenchmarkTools.Trial, n::Int) = n * t
Base.:*(n::Int, t::BenchmarkTools.Trial) =
BenchmarkTools.Trial(t.params, t.times .* n, t.gctimes .* n, t.memory * n, t.allocs * n)

include("benchmark_utils.jl")

function get_n_calls_per_step(integrator::CTS.DistributedODEIntegrator)
(; alg) = integrator
(; newtons_method, name) = alg
(; max_iters) = newtons_method
n_calls_per_step(name, max_iters)
end

# TODO: generalize
n_calls_per_step(::CTS.ARS343, max_newton_iters) = Dict(
"Wfact" => 3 * max_newton_iters,
"ldiv!" => 3 * max_newton_iters,
"T_imp!" => 3 * max_newton_iters,
"T_exp_T_lim!" => 4,
"lim!" => 4,
"dss!" => 4,
"post_explicit!" => 3,
"post_implicit!" => 4,
"step!" => 1,
)


"""
benchmark_step(
integrator::DistributedODEIntegrator,
device::ClimaComms.AbstractDevice;
with_cu_prof = :bfrofile, # [:bprofile, :profile]
trace = false
)
Benchmark a DistributedODEIntegrator
"""
function CTS.benchmark_step(
integrator::CTS.DistributedODEIntegrator,
device::ClimaComms.AbstractDevice;
with_cu_prof = :bprofile,
trace = false,
)
(; u, p, t, dt, sol, alg) = integrator
(; f) = sol.prob
if f isa CTS.ClimaODEFunction

W = get_W(integrator)
X = similar(u)
trials₀ = OrderedCollections.OrderedDict()

#! format: off
trials₀["Wfact"] = get_trial(wfact_fun(integrator), (W, u, p, dt, t), "Wfact", device; with_cu_prof, trace);
trials₀["ldiv!"] = get_trial(LA.ldiv!, (X, W, u), "ldiv!", device; with_cu_prof, trace);
trials₀["T_imp!"] = get_trial(implicit_fun(integrator), implicit_args(integrator), "T_imp!", device; with_cu_prof, trace);
trials₀["T_exp_T_lim!"] = get_trial(remaining_fun(integrator), remaining_args(integrator), "T_exp_T_lim!", device; with_cu_prof, trace);
trials₀["lim!"] = get_trial(f.lim!, (X, p, t, u), "lim!", device; with_cu_prof, trace);
trials₀["dss!"] = get_trial(f.dss!, (u, p, t), "dss!", device; with_cu_prof, trace);
trials₀["post_explicit!"] = get_trial(f.post_explicit!, (u, p, t), "post_explicit!", device; with_cu_prof, trace);
trials₀["post_implicit!"] = get_trial(f.post_implicit!, (u, p, t), "post_implicit!", device; with_cu_prof, trace);
trials₀["step!"] = get_trial(SciMLBase.step!, (integrator, ), "step!", device; with_cu_prof, trace);
#! format: on

trials = OrderedCollections.OrderedDict()

n_calls_per_step = get_n_calls_per_step(integrator)
for k in keys(trials₀)
isnothing(trials₀[k]) && continue
trials[k] = trials₀[k] * n_calls_per_step[k]
end

table_summary = OrderedCollections.OrderedDict()
for k in keys(trials)
isnothing(trials[k]) && continue
table_summary[k] = get_summary(trials[k], trials["step!"])
end

tabulate_summary(table_summary; n_calls_per_step)

return (; table_summary, trials)
else
@warn "`ClimaTimeSteppers.benchmark` is not yet supported for $f."
return (; table_summary = nothing, trials = nothing)
end
end


end
90 changes: 90 additions & 0 deletions ext/benchmark_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#####
##### BenchmarkTools's trial utils
#####

get_summary(trial, trial_step) = (;
# Using some BenchmarkTools internals :/
mem = BenchmarkTools.prettymemory(trial.memory),
mem_val = trial.memory,
nalloc = trial.allocs,
t_min = BenchmarkTools.prettytime(minimum(trial.times)),
t_max = BenchmarkTools.prettytime(maximum(trial.times)),
t_mean = BenchmarkTools.prettytime(StatsBase.mean(trial.times)),
t_mean_val = StatsBase.mean(trial.times),
t_med = BenchmarkTools.prettytime(StatsBase.median(trial.times)),
n_samples = length(trial),
percentage = minimum(trial.times) / minimum(trial_step.times) * 100,
)

function tabulate_summary(summary; n_calls_per_step)
summary_keys = collect(keys(summary))
mem = map(k -> summary[k].mem, summary_keys)
nalloc = map(k -> summary[k].nalloc, summary_keys)
t_mean = map(k -> summary[k].t_mean, summary_keys)
t_min = map(k -> summary[k].t_min, summary_keys)
t_max = map(k -> summary[k].t_max, summary_keys)
t_med = map(k -> summary[k].t_med, summary_keys)
n_samples = map(k -> summary[k].n_samples, summary_keys)
percentage = map(k -> summary[k].percentage, summary_keys)

func_names = if isnothing(n_calls_per_step)
map(k -> string(k), collect(keys(summary)))
else
@info "(#)x entries have been multiplied by corresponding factors in order to compute percentages"
map(k -> string(k, " ($(n_calls_per_step[k])x)"), collect(keys(summary)))
end
table_data = hcat(func_names, mem, nalloc, t_min, t_max, t_mean, t_med, n_samples, percentage)

header = (
["Function", "Memory", "allocs", "Time", "Time", "Time", "Time", "N-samples", "step! percentage"],
[" ", "estimate", "estimate", "min", "max", "mean", "median", "", ""],
)

PrettyTables.pretty_table(
table_data;
header,
crop = :none,
alignment = vcat(:l, repeat([:r], length(header[1]) - 1)),
)
end

get_trial(f::Nothing, args, name, device; with_cu_prof = :bprofile, trace = false) = nothing
function get_trial(f, args, name, device; with_cu_prof = :bprofile, trace = false)
f(args...) # compile first
b = if device isa ClimaComms.CUDADevice
BenchmarkTools.@benchmarkable CUDA.@sync $f($(args)...)
else
BenchmarkTools.@benchmarkable $f($(args)...)
end
sample_limit = 10
println("--------------- Benchmarking/profiling $name...")
trial = BenchmarkTools.run(b, samples = sample_limit)
if device isa ClimaComms.CUDADevice
p = if with_cu_prof == :bprofile
CUDA.@bprofile trace = trace f(args...)
else
CUDA.@profile trace = trace f(args...)
end
println(p)
end
println()
return trial
end

get_W(i::CTS.DistributedODEIntegrator) = i.cache.newtons_method_cache.j
get_W(i) = i.cache.W
f_args(i, f::CTS.ForwardEulerODEFunction) = (copy(i.u), i.u, i.p, i.t, i.dt)
f_args(i, f) = (similar(i.u), i.u, i.p, i.t)

r_args(i, f::CTS.ForwardEulerODEFunction) = (copy(i.u), copy(i.u), i.u, i.p, i.t, i.dt)
r_args(i, f) = (similar(i.u), similar(i.u), i.u, i.p, i.t)

implicit_args(i::CTS.DistributedODEIntegrator) = f_args(i, i.sol.prob.f.T_imp!)
implicit_args(i) = f_args(i, i.f.f1)
remaining_args(i::CTS.DistributedODEIntegrator) = r_args(i, i.sol.prob.f.T_exp_T_lim!)
remaining_args(i) = r_args(i, i.f.f2)
wfact_fun(i) = implicit_fun(i).Wfact
implicit_fun(i::CTS.DistributedODEIntegrator) = i.sol.prob.f.T_imp!
implicit_fun(i) = i.sol.prob.f.f1
remaining_fun(i::CTS.DistributedODEIntegrator) = i.sol.prob.f.T_exp_T_lim!
remaining_fun(i) = i.sol.prob.f.f2
18 changes: 14 additions & 4 deletions perf/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.10.2"
julia_version = "1.10.3"
manifest_format = "2.0"
project_hash = "86292bcc5e784cb4f6a9709b4395280ec762f782"
project_hash = "192050c2153dcc826c1eca5d5b7a774796048dec"

[[deps.ADTypes]]
git-tree-sha1 = "fcdb00b4d412b80ab08e39978e3bdef579e5e224"
Expand Down Expand Up @@ -248,7 +248,11 @@ weakdeps = ["Krylov"]
deps = ["ClimaComms", "Colors", "DataStructures", "DiffEqBase", "DiffEqCallbacks", "KernelAbstractions", "Krylov", "LinearAlgebra", "LinearOperators", "NVTX", "SciMLBase", "StaticArrays"]
path = ".."
uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
version = "0.7.20"
version = "0.7.21"
weakdeps = ["BenchmarkTools", "CUDA", "OrderedCollections", "PrettyTables", "StatsBase"]

[deps.ClimaTimeSteppers.extensions]
ClimaTimeSteppersBenchmarkToolsExt = ["CUDA", "BenchmarkTools", "OrderedCollections", "StatsBase", "PrettyTables"]

[[deps.CloseOpenIntervals]]
deps = ["Static", "StaticArrayInterface"]
Expand Down Expand Up @@ -304,7 +308,7 @@ weakdeps = ["Dates", "LinearAlgebra"]
[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.1.0+0"
version = "1.1.1+0"

[[deps.CompositionsBase]]
git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad"
Expand Down Expand Up @@ -1461,6 +1465,12 @@ git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed"
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
version = "1.7.0"

[[deps.StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.34.3"

[[deps.StrideArraysCore]]
deps = ["ArrayInterface", "CloseOpenIntervals", "IfElse", "LayoutPointers", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface", "ThreadingUtilities"]
git-tree-sha1 = "25349bf8f63aa36acbff5e3550a86e9f5b0ef682"
Expand Down
4 changes: 4 additions & 0 deletions perf/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884"
ClimaTimeSteppers = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
Expand All @@ -15,12 +16,15 @@ Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PProf = "e4faabce-9ead-11e9-39d9-4379958e3056"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
ProfileCanvas = "efd6af41-a80b-495e-886c-e51b0c7d77a3"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4"

Expand Down
54 changes: 36 additions & 18 deletions perf/benchmark.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,39 @@
using ArgParse, JET, Test, DiffEqBase, ClimaTimeSteppers
using ClimaComms
using CUDA, BenchmarkTools, OrderedCollections, StatsBase, PrettyTables # needed for CTS.benchmark_step
import ClimaTimeSteppers as CTS
import DiffEqBase
using BenchmarkTools, DiffEqBase

include(joinpath(pkgdir(CTS), "test", "problems.jl"))

function main()
algorithm = CTS.IMEXAlgorithm(CTS.ARS343(), CTS.NewtonsMethod(; max_iters = 2))
dt = 0.01
for problem in (split_linear_prob_wfact_split(), split_linear_prob_wfact_split_fe())
integrator = DiffEqBase.init(problem, algorithm; dt)

cache = CTS.init_cache(problem, algorithm)

CTS.step_u!(integrator, cache)

trial = @benchmark CTS.step_u!($integrator, $cache)
show(stdout, MIME("text/plain"), trial)
function parse_commandline()
s = ArgParse.ArgParseSettings()
ArgParse.@add_arg_table s begin
"--problem"
help = "Problem type [`ode_fun`, `fe`]"
arg_type = String
default = "diffusion2d"
end
parsed_args = ArgParse.parse_args(ARGS, s)
return (s, parsed_args)
end
(s, parsed_args) = parse_commandline()
cts = joinpath(dirname(@__DIR__));
include(joinpath(cts, "test", "problems.jl"))
config_integrators(itc::IntegratorTestCase) = config_integrators(itc.prob)
function config_integrators(problem)
algorithm = CTS.IMEXAlgorithm(ARS343(), NewtonsMethod(; max_iters = 2))
dt = 0.01
integrator = DiffEqBase.init(problem, algorithm; dt)
integrator.cache = CTS.init_cache(problem, algorithm)
return (; integrator)
end
main()
prob = if parsed_args["problem"] == "diffusion2d"
climacore_2Dheat_test_cts(Float64)
elseif parsed_args["problem"] == "ode_fun"
split_linear_prob_wfact_split()
elseif parsed_args["problem"] == "fe"
split_linear_prob_wfact_split_fe()
else
error("Bad option")
end
(; integrator) = config_integrators(prob)

device = ClimaComms.device()
CTS.benchmark_step(integrator, device)
4 changes: 4 additions & 0 deletions src/ClimaTimeSteppers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,8 @@ include("solvers/rosenbrock.jl")

include("Callbacks.jl")


benchmark_step(integrator, device) =
@warn "Must load CUDA, BenchmarkTools, OrderedCollections, StatsBase, PrettyTables to trigger the ClimaTimeSteppersBenchmarkToolsExt extension"

end
2 changes: 1 addition & 1 deletion test/aqua.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using Aqua
end

@testset "Aqua tests (remaining)" begin
Aqua.test_all(ClimaTimeSteppers; ambiguities=false)
Aqua.test_all(ClimaTimeSteppers; ambiguities = false)
end

nothing

0 comments on commit 83cf6b6

Please sign in to comment.