Skip to content

Commit

Permalink
Merge pull request #142 from SciML/wps
Browse files Browse the repository at this point in the history
Make work-precision benchmarks more robust to failures
  • Loading branch information
ChrisRackauckas authored Jul 16, 2024
2 parents c085b60 + 026bd4b commit d01dd9b
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 59 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiffEqDevTools"
uuid = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
version = "2.44.3"
version = "2.44.4"

[deps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand All @@ -28,7 +28,7 @@ Distributed = "1.9"
LinearAlgebra = "1.9"
Logging = "1.9"
NLsolve = "4.2"
NonlinearSolve = "1, 2"
NonlinearSolve = "3.13"
ODEProblemLibrary = "0.1"
OrdinaryDiffEq = "6"
ParameterizedFunctions = "5"
Expand Down
114 changes: 60 additions & 54 deletions src/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,69 +217,75 @@ function WorkPrecision(prob, alg, abstols, reltols, dts = nothing;

stats[i] = sol.stats

if haskey(kwargs, :prob_choice)
cur_appxsol = appxsol[kwargs[:prob_choice]]
elseif prob isa AbstractArray
cur_appxsol = appxsol[1]
else
cur_appxsol = appxsol
end

if cur_appxsol !== nothing
errsol = appxtrue(sol, cur_appxsol)
errors[i] = Dict{Symbol, Float64}()
for err in keys(errsol.errors)
errors[i][err] = mean(errsol.errors[err])
end
else
errors[i] = Dict{Symbol, Float64}()
for err in keys(sol.errors)
errors[i][err] = mean(sol.errors[err])
if SciMLBase.successful_retcode(sol)
if haskey(kwargs, :prob_choice)
cur_appxsol = appxsol[kwargs[:prob_choice]]
elseif prob isa AbstractArray
cur_appxsol = appxsol[1]
else
cur_appxsol = appxsol
end
end

benchmark_f = let dts = dts, _prob = _prob, alg = alg, sol = sol,
abstols = abstols, reltols = reltols, kwargs = kwargs

if dts === nothing
if _prob isa DAEProblem
() -> @elapsed solve(_prob, alg, sol.u, sol.t;
abstol = abstols[i],
reltol = reltols[i],
timeseries_errors = false,
dense_errors = false, kwargs...)
else
() -> @elapsed solve(_prob, alg, sol.u, sol.t, sol.k;
abstol = abstols[i],
reltol = reltols[i],
timeseries_errors = false,
dense_errors = false, kwargs...)
if cur_appxsol !== nothing
errsol = appxtrue(sol, cur_appxsol)
errors[i] = Dict{Symbol, Float64}()
for err in keys(errsol.errors)
errors[i][err] = mean(errsol.errors[err])
end
else
if _prob isa DAEProblem
() -> @elapsed solve(_prob, alg, sol.u, sol.t;
abstol = abstols[i],
reltol = reltols[i],
dt = dts[i],
timeseries_errors = false,
dense_errors = false, kwargs...)
errors[i] = Dict{Symbol, Float64}()
for err in keys(sol.errors)
errors[i][err] = mean(sol.errors[err])
end
end

benchmark_f = let dts = dts, _prob = _prob, alg = alg, sol = sol,
abstols = abstols, reltols = reltols, kwargs = kwargs

if dts === nothing
if _prob isa DAEProblem
() -> @elapsed solve(_prob, alg, sol.u, sol.t;
abstol = abstols[i],
reltol = reltols[i],
timeseries_errors = false,
dense_errors = false, kwargs...)
else
() -> @elapsed solve(_prob, alg, sol.u, sol.t, sol.k;
abstol = abstols[i],
reltol = reltols[i],
timeseries_errors = false,
dense_errors = false, kwargs...)
end
else
() -> @elapsed solve(_prob, alg, sol.u, sol.t, sol.k;
abstol = abstols[i],
reltol = reltols[i],
dt = dts[i],
timeseries_errors = false,
dense_errors = false, kwargs...)
if _prob isa DAEProblem
() -> @elapsed solve(_prob, alg, sol.u, sol.t;
abstol = abstols[i],
reltol = reltols[i],
dt = dts[i],
timeseries_errors = false,
dense_errors = false, kwargs...)
else
() -> @elapsed solve(_prob, alg, sol.u, sol.t, sol.k;
abstol = abstols[i],
reltol = reltols[i],
dt = dts[i],
timeseries_errors = false,
dense_errors = false, kwargs...)
end
end
end
end
benchmark_f() # pre-compile
benchmark_f() # pre-compile

b_t = benchmark_f()
if b_t > seconds
times[i] = b_t
b_t = benchmark_f()
if b_t > seconds
times[i] = b_t
else
times[i] = mapreduce(i -> benchmark_f(), min, 2:numruns; init = b_t)
end
else
times[i] = mapreduce(i -> benchmark_f(), min, 2:numruns; init = b_t)
# Unsuccessful retcode, give NaN time
errors[i] = Dict(:l∞ => NaN, :L2 => NaN, :final => NaN, :l2 => NaN, :L∞ => NaN)
times[i] = NaN
end
end
end
Expand Down
1 change: 1 addition & 0 deletions src/plotrecipes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ end
ys = [get_val_from_wp(wp, y) for wp in wp_set.wps]
xguide --> key_to_label(x)
yguide --> key_to_label(y)
legend --> :outerright
label --> reshape(wp_set.names, 1, length(wp_set))
return xs, ys
elseif view == :dt_convergence
Expand Down
4 changes: 2 additions & 2 deletions src/test_solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ function appxtrue(sim::EnsembleSolution, appx_setup; kwargs...)
for i in eachindex(sim)
prob = sim[i].prob
prob2 = SDEProblem(prob.f, prob.g, prob.u0, prob.tspan,
noise = NoiseWrapper(sim[i].W))
noise = NoiseWrapper(sim.u[i].W))
true_sol = solve(prob2, appx_setup[:alg]; appx_setup...)
_new_sols[i] = appxtrue(sim[i], true_sol)
_new_sols[i] = appxtrue(sim.u[i], true_sol)
end
new_sols = convert(Vector{typeof(_new_sols[1])}, _new_sols)
calculate_ensemble_errors(new_sols; converged = sim.converged,
Expand Down
2 changes: 1 addition & 1 deletion test/benchmark_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ setups = [Dict(:alg => RK4()); Dict(:alg => Euler()); Dict(:alg => BS3());
t1 = @elapsed sol = solve(prob, RK4(), dt = 1 / 2^(4))
t2 = @elapsed sol2 = solve(prob, setups[1][:alg], dt = 1 / 2^(4))

@test (sol2[end] == sol[end])
@test (sol2.u[end] == sol.u[end])

test_sol_2Dlinear = TestSolution(
solve(prob_ode_2Dlinear, Vern7(), abstol = 1 / 10^14, reltol = 1 / 10^14))
Expand Down

0 comments on commit d01dd9b

Please sign in to comment.