From 1090936efdda5d95411e5c6323c258f6b70a8d86 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 16 Jul 2024 01:32:50 -0400 Subject: [PATCH 1/5] Make work-precision benchmarks more robust to failures --- Project.toml | 2 +- src/benchmark.jl | 114 +++++++++++++++++++++++++---------------------- 2 files changed, 61 insertions(+), 55 deletions(-) diff --git a/Project.toml b/Project.toml index 65c6a0d..47a0f58 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiffEqDevTools" uuid = "f3b72e0c-5b89-59e1-b016-84e28bfd966d" authors = ["Chris Rackauckas "] -version = "2.44.3" +version = "2.44.4" [deps] DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" diff --git a/src/benchmark.jl b/src/benchmark.jl index 0f39699..27a9841 100644 --- a/src/benchmark.jl +++ b/src/benchmark.jl @@ -217,69 +217,75 @@ function WorkPrecision(prob, alg, abstols, reltols, dts = nothing; stats[i] = sol.stats - if haskey(kwargs, :prob_choice) - cur_appxsol = appxsol[kwargs[:prob_choice]] - elseif prob isa AbstractArray - cur_appxsol = appxsol[1] - else - cur_appxsol = appxsol - end - - if cur_appxsol !== nothing - errsol = appxtrue(sol, cur_appxsol) - errors[i] = Dict{Symbol, Float64}() - for err in keys(errsol.errors) - errors[i][err] = mean(errsol.errors[err]) - end - else - errors[i] = Dict{Symbol, Float64}() - for err in keys(sol.errors) - errors[i][err] = mean(sol.errors[err]) + if SciMLBase.successful_retcode(sol) + if haskey(kwargs, :prob_choice) + cur_appxsol = appxsol[kwargs[:prob_choice]] + elseif prob isa AbstractArray + cur_appxsol = appxsol[1] + else + cur_appxsol = appxsol end - end - - benchmark_f = let dts = dts, _prob = _prob, alg = alg, sol = sol, - abstols = abstols, reltols = reltols, kwargs = kwargs - if dts === nothing - if _prob isa DAEProblem - () -> @elapsed solve(_prob, alg, sol.u, sol.t; - abstol = abstols[i], - reltol = reltols[i], - timeseries_errors = false, - dense_errors = false, kwargs...) - else - () -> @elapsed solve(_prob, alg, sol.u, sol.t, sol.k; - abstol = abstols[i], - reltol = reltols[i], - timeseries_errors = false, - dense_errors = false, kwargs...) + if cur_appxsol !== nothing + errsol = appxtrue(sol, cur_appxsol) + errors[i] = Dict{Symbol, Float64}() + for err in keys(errsol.errors) + errors[i][err] = mean(errsol.errors[err]) end else - if _prob isa DAEProblem - () -> @elapsed solve(_prob, alg, sol.u, sol.t; - abstol = abstols[i], - reltol = reltols[i], - dt = dts[i], - timeseries_errors = false, - dense_errors = false, kwargs...) + errors[i] = Dict{Symbol, Float64}() + for err in keys(sol.errors) + errors[i][err] = mean(sol.errors[err]) + end + end + + benchmark_f = let dts = dts, _prob = _prob, alg = alg, sol = sol, + abstols = abstols, reltols = reltols, kwargs = kwargs + + if dts === nothing + if _prob isa DAEProblem + () -> @elapsed solve(_prob, alg, sol.u, sol.t; + abstol = abstols[i], + reltol = reltols[i], + timeseries_errors = false, + dense_errors = false, kwargs...) + else + () -> @elapsed solve(_prob, alg, sol.u, sol.t, sol.k; + abstol = abstols[i], + reltol = reltols[i], + timeseries_errors = false, + dense_errors = false, kwargs...) + end else - () -> @elapsed solve(_prob, alg, sol.u, sol.t, sol.k; - abstol = abstols[i], - reltol = reltols[i], - dt = dts[i], - timeseries_errors = false, - dense_errors = false, kwargs...) + if _prob isa DAEProblem + () -> @elapsed solve(_prob, alg, sol.u, sol.t; + abstol = abstols[i], + reltol = reltols[i], + dt = dts[i], + timeseries_errors = false, + dense_errors = false, kwargs...) + else + () -> @elapsed solve(_prob, alg, sol.u, sol.t, sol.k; + abstol = abstols[i], + reltol = reltols[i], + dt = dts[i], + timeseries_errors = false, + dense_errors = false, kwargs...) + end end end - end - benchmark_f() # pre-compile + benchmark_f() # pre-compile - b_t = benchmark_f() - if b_t > seconds - times[i] = b_t + b_t = benchmark_f() + if b_t > seconds + times[i] = b_t + else + times[i] = mapreduce(i -> benchmark_f(), min, 2:numruns; init = b_t) + end else - times[i] = mapreduce(i -> benchmark_f(), min, 2:numruns; init = b_t) + # Unsuccessful retcode, give NaN time + errors[i] = Dict(:l∞ => NaN, :L2 => NaN, :final => NaN, :l2 => NaN, :L∞ => NaN) + times[i] = NaN end end end From c0c094fae2d485abd45fe6a79391e677686542db Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 16 Jul 2024 03:00:50 -0400 Subject: [PATCH 2/5] bump nonlinearsolve --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 47a0f58..09a8a9f 100644 --- a/Project.toml +++ b/Project.toml @@ -28,7 +28,7 @@ Distributed = "1.9" LinearAlgebra = "1.9" Logging = "1.9" NLsolve = "4.2" -NonlinearSolve = "1, 2" +NonlinearSolve = "3" ODEProblemLibrary = "0.1" OrdinaryDiffEq = "6" ParameterizedFunctions = "5" From 567424119694e0636461e8582e3ac9d3068b41e0 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 16 Jul 2024 03:51:51 -0400 Subject: [PATCH 3/5] latest nonlinearsolve --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 09a8a9f..5fbed55 100644 --- a/Project.toml +++ b/Project.toml @@ -28,7 +28,7 @@ Distributed = "1.9" LinearAlgebra = "1.9" Logging = "1.9" NLsolve = "4.2" -NonlinearSolve = "3" +NonlinearSolve = "3.13" ODEProblemLibrary = "0.1" OrdinaryDiffEq = "6" ParameterizedFunctions = "5" From d09311d28f77a3b29a682db63f049d7b68e85ed2 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 16 Jul 2024 04:03:21 -0400 Subject: [PATCH 4/5] fix deprecations --- src/test_solution.jl | 4 ++-- test/benchmark_tests.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/test_solution.jl b/src/test_solution.jl index 279b54d..a8e4114 100644 --- a/src/test_solution.jl +++ b/src/test_solution.jl @@ -117,9 +117,9 @@ function appxtrue(sim::EnsembleSolution, appx_setup; kwargs...) for i in eachindex(sim) prob = sim[i].prob prob2 = SDEProblem(prob.f, prob.g, prob.u0, prob.tspan, - noise = NoiseWrapper(sim[i].W)) + noise = NoiseWrapper(sim.u[i].W)) true_sol = solve(prob2, appx_setup[:alg]; appx_setup...) - _new_sols[i] = appxtrue(sim[i], true_sol) + _new_sols[i] = appxtrue(sim.u[i], true_sol) end new_sols = convert(Vector{typeof(_new_sols[1])}, _new_sols) calculate_ensemble_errors(new_sols; converged = sim.converged, diff --git a/test/benchmark_tests.jl b/test/benchmark_tests.jl index f2e5967..66e96cd 100644 --- a/test/benchmark_tests.jl +++ b/test/benchmark_tests.jl @@ -23,7 +23,7 @@ setups = [Dict(:alg => RK4()); Dict(:alg => Euler()); Dict(:alg => BS3()); t1 = @elapsed sol = solve(prob, RK4(), dt = 1 / 2^(4)) t2 = @elapsed sol2 = solve(prob, setups[1][:alg], dt = 1 / 2^(4)) -@test (sol2[end] == sol[end]) +@test (sol2.u[end] == sol.u[end]) test_sol_2Dlinear = TestSolution( solve(prob_ode_2Dlinear, Vern7(), abstol = 1 / 10^14, reltol = 1 / 10^14)) From 026bd4b8f9aff423a6672a9c007a775db1be6030 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Tue, 16 Jul 2024 04:35:07 -0400 Subject: [PATCH 5/5] move legend --- src/plotrecipes.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/plotrecipes.jl b/src/plotrecipes.jl index e171685..e001880 100644 --- a/src/plotrecipes.jl +++ b/src/plotrecipes.jl @@ -110,6 +110,7 @@ end ys = [get_val_from_wp(wp, y) for wp in wp_set.wps] xguide --> key_to_label(x) yguide --> key_to_label(y) + legend --> :outerright label --> reshape(wp_set.names, 1, length(wp_set)) return xs, ys elseif view == :dt_convergence