Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add benchmark utility extension #277

Merged
merged 4 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,20 @@ NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[weakdeps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"

[extensions]
ClimaTimeSteppersBenchmarkToolsExt = ["CUDA", "BenchmarkTools", "OrderedCollections", "StatsBase", "PrettyTables"]

[compat]
BenchmarkTools = "1"
ClimaComms = "0.4, 0.5, 0.6"
CUDA = "3, 4, 5"
Colors = "0.12"
DataStructures = "0.18"
DiffEqBase = "6"
Expand All @@ -27,7 +39,10 @@ KernelAbstractions = "0.7, 0.8, 0.9"
Krylov = "0.8, 0.9"
LinearAlgebra = "1"
LinearOperators = "2"
OrderedCollections = "1"
PrettyTables = "2"
NVTX = "0.3"
SciMLBase = "1, 2"
StaticArrays = "1"
StatsBase = "0.33, 0.34"
julia = "1.8"
101 changes: 101 additions & 0 deletions ext/ClimaTimeSteppersBenchmarkToolsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
module ClimaTimeSteppersBenchmarkToolsExt


import StatsBase
import SciMLBase
import PrettyTables
import OrderedCollections
import LinearAlgebra as LA
import ClimaTimeSteppers as CTS
using CUDA
import ClimaComms

import BenchmarkTools
Base.:*(t::BenchmarkTools.Trial, n::Int) = n * t
Base.:*(n::Int, t::BenchmarkTools.Trial) =
BenchmarkTools.Trial(t.params, t.times .* n, t.gctimes .* n, t.memory * n, t.allocs * n)

include("benchmark_utils.jl")

function get_n_calls_per_step(integrator::CTS.DistributedODEIntegrator)
(; alg) = integrator
(; newtons_method, name) = alg
(; max_iters) = newtons_method
n_calls_per_step(name, max_iters)
end

# TODO: generalize
n_calls_per_step(::CTS.ARS343, max_newton_iters) = Dict(
"Wfact" => 3 * max_newton_iters,
"ldiv!" => 3 * max_newton_iters,
"T_imp!" => 3 * max_newton_iters,
"T_exp_T_lim!" => 4,
"lim!" => 4,
"dss!" => 4,
"post_explicit!" => 3,
"post_implicit!" => 4,
"step!" => 1,
)


"""
benchmark_step(
integrator::DistributedODEIntegrator,
device::ClimaComms.AbstractDevice;
with_cu_prof = :bfrofile, # [:bprofile, :profile]
trace = false
)

Benchmark a DistributedODEIntegrator
"""
function CTS.benchmark_step(
integrator::CTS.DistributedODEIntegrator,
device::ClimaComms.AbstractDevice;
with_cu_prof = :bprofile,
trace = false,
)
(; u, p, t, dt, sol, alg) = integrator
(; f) = sol.prob
if f isa CTS.ClimaODEFunction

W = get_W(integrator)
X = similar(u)
trials₀ = OrderedCollections.OrderedDict()

#! format: off
trials₀["Wfact"] = get_trial(wfact_fun(integrator), (W, u, p, dt, t), "Wfact", device; with_cu_prof, trace);
trials₀["ldiv!"] = get_trial(LA.ldiv!, (X, W, u), "ldiv!", device; with_cu_prof, trace);
trials₀["T_imp!"] = get_trial(implicit_fun(integrator), implicit_args(integrator), "T_imp!", device; with_cu_prof, trace);
trials₀["T_exp_T_lim!"] = get_trial(remaining_fun(integrator), remaining_args(integrator), "T_exp_T_lim!", device; with_cu_prof, trace);
trials₀["lim!"] = get_trial(f.lim!, (X, p, t, u), "lim!", device; with_cu_prof, trace);
trials₀["dss!"] = get_trial(f.dss!, (u, p, t), "dss!", device; with_cu_prof, trace);
trials₀["post_explicit!"] = get_trial(f.post_explicit!, (u, p, t), "post_explicit!", device; with_cu_prof, trace);
trials₀["post_implicit!"] = get_trial(f.post_implicit!, (u, p, t), "post_implicit!", device; with_cu_prof, trace);
trials₀["step!"] = get_trial(SciMLBase.step!, (integrator, ), "step!", device; with_cu_prof, trace);
#! format: on

trials = OrderedCollections.OrderedDict()

n_calls_per_step = get_n_calls_per_step(integrator)
for k in keys(trials₀)
isnothing(trials₀[k]) && continue
trials[k] = trials₀[k] * n_calls_per_step[k]
end

table_summary = OrderedCollections.OrderedDict()
for k in keys(trials)
isnothing(trials[k]) && continue
table_summary[k] = get_summary(trials[k], trials["step!"])
end

tabulate_summary(table_summary; n_calls_per_step)

return (; table_summary, trials)
else
@warn "`ClimaTimeSteppers.benchmark` is not yet supported for $f."
return (; table_summary = nothing, trials = nothing)
end
end


end
90 changes: 90 additions & 0 deletions ext/benchmark_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#####
##### BenchmarkTools's trial utils
#####

get_summary(trial, trial_step) = (;
# Using some BenchmarkTools internals :/
mem = BenchmarkTools.prettymemory(trial.memory),
mem_val = trial.memory,
nalloc = trial.allocs,
t_min = BenchmarkTools.prettytime(minimum(trial.times)),
t_max = BenchmarkTools.prettytime(maximum(trial.times)),
t_mean = BenchmarkTools.prettytime(StatsBase.mean(trial.times)),
t_mean_val = StatsBase.mean(trial.times),
t_med = BenchmarkTools.prettytime(StatsBase.median(trial.times)),
n_samples = length(trial),
percentage = minimum(trial.times) / minimum(trial_step.times) * 100,
)

function tabulate_summary(summary; n_calls_per_step)
summary_keys = collect(keys(summary))
mem = map(k -> summary[k].mem, summary_keys)
nalloc = map(k -> summary[k].nalloc, summary_keys)
t_mean = map(k -> summary[k].t_mean, summary_keys)
t_min = map(k -> summary[k].t_min, summary_keys)
t_max = map(k -> summary[k].t_max, summary_keys)
t_med = map(k -> summary[k].t_med, summary_keys)
n_samples = map(k -> summary[k].n_samples, summary_keys)
percentage = map(k -> summary[k].percentage, summary_keys)

func_names = if isnothing(n_calls_per_step)
map(k -> string(k), collect(keys(summary)))
else
@info "(#)x entries have been multiplied by corresponding factors in order to compute percentages"
map(k -> string(k, " ($(n_calls_per_step[k])x)"), collect(keys(summary)))
end
table_data = hcat(func_names, mem, nalloc, t_min, t_max, t_mean, t_med, n_samples, percentage)

header = (
["Function", "Memory", "allocs", "Time", "Time", "Time", "Time", "N-samples", "step! percentage"],
[" ", "estimate", "estimate", "min", "max", "mean", "median", "", ""],
)

PrettyTables.pretty_table(
table_data;
header,
crop = :none,
alignment = vcat(:l, repeat([:r], length(header[1]) - 1)),
)
end

get_trial(f::Nothing, args, name, device; with_cu_prof = :bprofile, trace = false) = nothing
function get_trial(f, args, name, device; with_cu_prof = :bprofile, trace = false)
f(args...) # compile first
b = if device isa ClimaComms.CUDADevice
BenchmarkTools.@benchmarkable CUDA.@sync $f($(args)...)
else
BenchmarkTools.@benchmarkable $f($(args)...)
end
sample_limit = 10
println("--------------- Benchmarking/profiling $name...")
trial = BenchmarkTools.run(b, samples = sample_limit)
if device isa ClimaComms.CUDADevice
p = if with_cu_prof == :bprofile
CUDA.@bprofile trace = trace f(args...)
else
CUDA.@profile trace = trace f(args...)
end
println(p)
end
println()
return trial
end

get_W(i::CTS.DistributedODEIntegrator) = i.cache.newtons_method_cache.j
get_W(i) = i.cache.W
f_args(i, f::CTS.ForwardEulerODEFunction) = (copy(i.u), i.u, i.p, i.t, i.dt)
f_args(i, f) = (similar(i.u), i.u, i.p, i.t)

r_args(i, f::CTS.ForwardEulerODEFunction) = (copy(i.u), copy(i.u), i.u, i.p, i.t, i.dt)
r_args(i, f) = (similar(i.u), similar(i.u), i.u, i.p, i.t)

implicit_args(i::CTS.DistributedODEIntegrator) = f_args(i, i.sol.prob.f.T_imp!)
implicit_args(i) = f_args(i, i.f.f1)
remaining_args(i::CTS.DistributedODEIntegrator) = r_args(i, i.sol.prob.f.T_exp_T_lim!)
remaining_args(i) = r_args(i, i.f.f2)
wfact_fun(i) = implicit_fun(i).Wfact
implicit_fun(i::CTS.DistributedODEIntegrator) = i.sol.prob.f.T_imp!
implicit_fun(i) = i.sol.prob.f.f1
remaining_fun(i::CTS.DistributedODEIntegrator) = i.sol.prob.f.T_exp_T_lim!
remaining_fun(i) = i.sol.prob.f.f2
18 changes: 14 additions & 4 deletions perf/Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.10.2"
julia_version = "1.10.3"
manifest_format = "2.0"
project_hash = "86292bcc5e784cb4f6a9709b4395280ec762f782"
project_hash = "192050c2153dcc826c1eca5d5b7a774796048dec"

[[deps.ADTypes]]
git-tree-sha1 = "fcdb00b4d412b80ab08e39978e3bdef579e5e224"
Expand Down Expand Up @@ -248,7 +248,11 @@ weakdeps = ["Krylov"]
deps = ["ClimaComms", "Colors", "DataStructures", "DiffEqBase", "DiffEqCallbacks", "KernelAbstractions", "Krylov", "LinearAlgebra", "LinearOperators", "NVTX", "SciMLBase", "StaticArrays"]
path = ".."
uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
version = "0.7.20"
version = "0.7.21"
weakdeps = ["BenchmarkTools", "CUDA", "OrderedCollections", "PrettyTables", "StatsBase"]

[deps.ClimaTimeSteppers.extensions]
ClimaTimeSteppersBenchmarkToolsExt = ["CUDA", "BenchmarkTools", "OrderedCollections", "StatsBase", "PrettyTables"]

[[deps.CloseOpenIntervals]]
deps = ["Static", "StaticArrayInterface"]
Expand Down Expand Up @@ -304,7 +308,7 @@ weakdeps = ["Dates", "LinearAlgebra"]
[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.1.0+0"
version = "1.1.1+0"

[[deps.CompositionsBase]]
git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad"
Expand Down Expand Up @@ -1461,6 +1465,12 @@ git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed"
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
version = "1.7.0"

[[deps.StatsBase]]
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.34.3"

[[deps.StrideArraysCore]]
deps = ["ArrayInterface", "CloseOpenIntervals", "IfElse", "LayoutPointers", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface", "ThreadingUtilities"]
git-tree-sha1 = "25349bf8f63aa36acbff5e3550a86e9f5b0ef682"
Expand Down
4 changes: 4 additions & 0 deletions perf/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884"
ClimaTimeSteppers = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
Expand All @@ -15,12 +16,15 @@ Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125"
NVTX = "5da4648a-3479-48b8-97b9-01cb529c0a1f"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PProf = "e4faabce-9ead-11e9-39d9-4379958e3056"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
ProfileCanvas = "efd6af41-a80b-495e-886c-e51b0c7d77a3"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
URIs = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4"

Expand Down
54 changes: 36 additions & 18 deletions perf/benchmark.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,39 @@
using ArgParse, JET, Test, DiffEqBase, ClimaTimeSteppers
using ClimaComms
using CUDA, BenchmarkTools, OrderedCollections, StatsBase, PrettyTables # needed for CTS.benchmark_step
import ClimaTimeSteppers as CTS
import DiffEqBase
using BenchmarkTools, DiffEqBase

include(joinpath(pkgdir(CTS), "test", "problems.jl"))

function main()
algorithm = CTS.IMEXAlgorithm(CTS.ARS343(), CTS.NewtonsMethod(; max_iters = 2))
dt = 0.01
for problem in (split_linear_prob_wfact_split(), split_linear_prob_wfact_split_fe())
integrator = DiffEqBase.init(problem, algorithm; dt)

cache = CTS.init_cache(problem, algorithm)

CTS.step_u!(integrator, cache)

trial = @benchmark CTS.step_u!($integrator, $cache)
show(stdout, MIME("text/plain"), trial)
function parse_commandline()
s = ArgParse.ArgParseSettings()
ArgParse.@add_arg_table s begin
"--problem"
help = "Problem type [`ode_fun`, `fe`]"
arg_type = String
default = "diffusion2d"
end
parsed_args = ArgParse.parse_args(ARGS, s)
return (s, parsed_args)
end
(s, parsed_args) = parse_commandline()
cts = joinpath(dirname(@__DIR__));
include(joinpath(cts, "test", "problems.jl"))
config_integrators(itc::IntegratorTestCase) = config_integrators(itc.prob)
function config_integrators(problem)
algorithm = CTS.IMEXAlgorithm(ARS343(), NewtonsMethod(; max_iters = 2))
dt = 0.01
integrator = DiffEqBase.init(problem, algorithm; dt)
integrator.cache = CTS.init_cache(problem, algorithm)
return (; integrator)
end
main()
prob = if parsed_args["problem"] == "diffusion2d"
climacore_2Dheat_test_cts(Float64)
elseif parsed_args["problem"] == "ode_fun"
split_linear_prob_wfact_split()
elseif parsed_args["problem"] == "fe"
split_linear_prob_wfact_split_fe()
else
error("Bad option")
end
(; integrator) = config_integrators(prob)

device = ClimaComms.device()
CTS.benchmark_step(integrator, device)
4 changes: 4 additions & 0 deletions src/ClimaTimeSteppers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,8 @@ include("solvers/rosenbrock.jl")

include("Callbacks.jl")


benchmark_step(integrator, device) =
@warn "Must load CUDA, BenchmarkTools, OrderedCollections, StatsBase, PrettyTables to trigger the ClimaTimeSteppersBenchmarkToolsExt extension"

end
2 changes: 1 addition & 1 deletion test/aqua.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using Aqua
end

@testset "Aqua tests (remaining)" begin
Aqua.test_all(ClimaTimeSteppers; ambiguities=false)
Aqua.test_all(ClimaTimeSteppers; ambiguities = false)
end

nothing
Loading