diff --git a/perf/jet.jl b/perf/jet.jl index 17b6ae19..ad230cb4 100644 --- a/perf/jet.jl +++ b/perf/jet.jl @@ -1,4 +1,5 @@ -using ArgParse, JET, Test, BenchmarkTools, DiffEqBase, ClimaTimeSteppers +# using Revise; include("perf/jet.jl") +using ArgParse, JET, Test, BenchmarkTools, SciMLBase, ClimaTimeSteppers import ClimaTimeSteppers as CTS function parse_commandline() s = ArgParse.ArgParseSettings() @@ -15,10 +16,29 @@ end cts = joinpath(dirname(@__DIR__)); include(joinpath(cts, "test", "problems.jl")) config_integrators(itc::IntegratorTestCase) = config_integrators(itc.prob) + +struct Foo end +foo!(integrator) = nothing +(::Foo)(integrator) = foo!(integrator) +struct Bar end +bar!(integrator) = nothing +(::Bar)(integrator) = bar!(integrator) + +function discrete_cb(cb!, n) + cond = if n == 1 + (u, t, integrator) -> isnothing(cb!(integrator)) + else + (u, t, integrator) -> isnothing(cb!(integrator)) || rand() ≤ 0.5 + end + SciMLBase.DiscreteCallback(cond, cb!;) +end function config_integrators(problem) algorithm = CTS.IMEXAlgorithm(ARS343(), NewtonsMethod(; max_iters = 2)) dt = 0.01 - integrator = DiffEqBase.init(problem, algorithm; dt) + discrete_callbacks = (discrete_cb(Foo(), 0), discrete_cb(Bar(), 0), discrete_cb(Foo(), 1), discrete_cb(Bar(), 1)) + callback = SciMLBase.CallbackSet((), discrete_callbacks) + + integrator = SciMLBase.init(problem, algorithm; dt, callback) integrator.cache = CTS.init_cache(problem, algorithm) return (; integrator) end @@ -33,7 +53,12 @@ else end (; integrator) = config_integrators(prob) -CTS.step_u!(integrator, integrator.cache) # compile first, and make sure it runs -step_allocs = @allocated CTS.step_u!(integrator, integrator.cache) -@show step_allocs -JET.@test_opt CTS.step_u!(integrator, integrator.cache) +@testset "JET / allocations" begin + CTS.step_u!(integrator, integrator.cache) # compile first, and make sure it runs + step_allocs = @allocated CTS.step_u!(integrator, integrator.cache) + @show step_allocs + JET.@test_opt CTS.step_u!(integrator, integrator.cache) + + CTS.__step!(integrator) # compile first, and make sure it runs + JET.@test_opt broken = true CTS.__step!(integrator) +end