diff --git a/AUTHORS.md b/AUTHORS.md index 54d6321633..5ab164c0ed 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -38,6 +38,7 @@ are listed in alphabetical order: * Julia Odenthal * Sigrun Ortleb * Hendrik Ranocha +* Warisa Roongaraya * Andrés M. Rueda-Ramírez * Felipe Santillan * Michael Schlottke-Lakemper diff --git a/src/time_integration/methods_2N.jl b/src/time_integration/methods_2N.jl index f3b09b01e9..e5b970c6bd 100644 --- a/src/time_integration/methods_2N.jl +++ b/src/time_integration/methods_2N.jl @@ -104,9 +104,8 @@ function Base.getproperty(integrator::SimpleIntegrator2N, field::Symbol) return getfield(integrator, field) end -# Fakes `solve`: https://diffeq.sciml.ai/v6.8/basics/overview/#Solving-the-Problems-1 -function solve(ode::ODEProblem, alg::T; - dt, callback = nothing, kwargs...) where {T <: SimpleAlgorithm2N} +function init(ode::ODEProblem, alg::SimpleAlgorithm2N; + dt, callback = nothing, kwargs...) u = copy(ode.u0) du = similar(u) u_tmp = similar(u) @@ -129,67 +128,84 @@ function solve(ode::ODEProblem, alg::T; error("unsupported") end + return integrator +end + +# Fakes `solve`: https://diffeq.sciml.ai/v6.8/basics/overview/#Solving-the-Problems-1 +function solve(ode::ODEProblem, alg::SimpleAlgorithm2N; + dt, callback = nothing, kwargs...) + integrator = init(ode, alg, dt = dt, callback = callback; kwargs...) + + # Start actual solve solve!(integrator) end function solve!(integrator::SimpleIntegrator2N) @unpack prob = integrator.sol + + integrator.finalstep = false + + @trixi_timeit timer() "main loop" while !integrator.finalstep + step!(integrator) + end # "main loop" timer + + return TimeIntegratorSolution((first(prob.tspan), integrator.t), + (prob.u0, integrator.u), + integrator.sol.prob) +end + +function step!(integrator::SimpleIntegrator2N) + @unpack prob = integrator.sol @unpack alg = integrator t_end = last(prob.tspan) callbacks = integrator.opts.callback - integrator.finalstep = false - @trixi_timeit timer() "main loop" while !integrator.finalstep - if isnan(integrator.dt) - error("time step size `dt` is NaN") - end + @assert !integrator.finalstep + if isnan(integrator.dt) + error("time step size `dt` is NaN") + end - # if the next iteration would push the simulation beyond the end time, set dt accordingly - if integrator.t + integrator.dt > t_end || - isapprox(integrator.t + integrator.dt, t_end) - integrator.dt = t_end - integrator.t - terminate!(integrator) - end + # if the next iteration would push the simulation beyond the end time, set dt accordingly + if integrator.t + integrator.dt > t_end || + isapprox(integrator.t + integrator.dt, t_end) + integrator.dt = t_end - integrator.t + terminate!(integrator) + end - # one time step - integrator.u_tmp .= 0 - for stage in eachindex(alg.c) - t_stage = integrator.t + integrator.dt * alg.c[stage] - integrator.f(integrator.du, integrator.u, prob.p, t_stage) - - a_stage = alg.a[stage] - b_stage_dt = alg.b[stage] * integrator.dt - @trixi_timeit timer() "Runge-Kutta step" begin - @threaded for i in eachindex(integrator.u) - integrator.u_tmp[i] = integrator.du[i] - - integrator.u_tmp[i] * a_stage - integrator.u[i] += integrator.u_tmp[i] * b_stage_dt - end + # one time step + integrator.u_tmp .= 0 + for stage in eachindex(alg.c) + t_stage = integrator.t + integrator.dt * alg.c[stage] + integrator.f(integrator.du, integrator.u, prob.p, t_stage) + + a_stage = alg.a[stage] + b_stage_dt = alg.b[stage] * integrator.dt + @trixi_timeit timer() "Runge-Kutta step" begin + @threaded for i in eachindex(integrator.u) + integrator.u_tmp[i] = integrator.du[i] - + integrator.u_tmp[i] * a_stage + integrator.u[i] += integrator.u_tmp[i] * b_stage_dt end end - integrator.iter += 1 - integrator.t += integrator.dt - - # handle callbacks - if callbacks isa CallbackSet - foreach(callbacks.discrete_callbacks) do cb - if cb.condition(integrator.u, integrator.t, integrator) - cb.affect!(integrator) - end - return nothing + end + integrator.iter += 1 + integrator.t += integrator.dt + + # handle callbacks + if callbacks isa CallbackSet + foreach(callbacks.discrete_callbacks) do cb + if cb.condition(integrator.u, integrator.t, integrator) + cb.affect!(integrator) end - end - - # respect maximum number of iterations - if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep - @warn "Interrupted. Larger maxiters is needed." - terminate!(integrator) + return nothing end end - return TimeIntegratorSolution((first(prob.tspan), integrator.t), - (prob.u0, integrator.u), - integrator.sol.prob) + # respect maximum number of iterations + if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep + @warn "Interrupted. Larger maxiters is needed." + terminate!(integrator) + end end # get a cache where the RHS can be stored diff --git a/src/time_integration/methods_3Sstar.jl b/src/time_integration/methods_3Sstar.jl index 7b70466606..6128d1551d 100644 --- a/src/time_integration/methods_3Sstar.jl +++ b/src/time_integration/methods_3Sstar.jl @@ -171,9 +171,8 @@ function Base.getproperty(integrator::SimpleIntegrator3Sstar, field::Symbol) return getfield(integrator, field) end -# Fakes `solve`: https://diffeq.sciml.ai/v6.8/basics/overview/#Solving-the-Problems-1 -function solve(ode::ODEProblem, alg::T; - dt, callback = nothing, kwargs...) where {T <: SimpleAlgorithm3Sstar} +function init(ode::ODEProblem, alg::SimpleAlgorithm3Sstar; + dt, callback = nothing, kwargs...) u = copy(ode.u0) du = similar(u) u_tmp1 = similar(u) @@ -199,73 +198,90 @@ function solve(ode::ODEProblem, alg::T; error("unsupported") end + return integrator +end + +# Fakes `solve`: https://diffeq.sciml.ai/v6.8/basics/overview/#Solving-the-Problems-1 +function solve(ode::ODEProblem, alg::SimpleAlgorithm3Sstar; + dt, callback = nothing, kwargs...) + integrator = init(ode, alg, dt = dt, callback = callback; kwargs...) + + # Start actual solve solve!(integrator) end function solve!(integrator::SimpleIntegrator3Sstar) @unpack prob = integrator.sol + + integrator.finalstep = false + + @trixi_timeit timer() "main loop" while !integrator.finalstep + step!(integrator) + end # "main loop" timer + + return TimeIntegratorSolution((first(prob.tspan), integrator.t), + (prob.u0, integrator.u), + integrator.sol.prob) +end + +function step!(integrator::SimpleIntegrator3Sstar) + @unpack prob = integrator.sol @unpack alg = integrator t_end = last(prob.tspan) callbacks = integrator.opts.callback - integrator.finalstep = false - @trixi_timeit timer() "main loop" while !integrator.finalstep - if isnan(integrator.dt) - error("time step size `dt` is NaN") - end + @assert !integrator.finalstep + if isnan(integrator.dt) + error("time step size `dt` is NaN") + end - # if the next iteration would push the simulation beyond the end time, set dt accordingly - if integrator.t + integrator.dt > t_end || - isapprox(integrator.t + integrator.dt, t_end) - integrator.dt = t_end - integrator.t - terminate!(integrator) - end + # if the next iteration would push the simulation beyond the end time, set dt accordingly + if integrator.t + integrator.dt > t_end || + isapprox(integrator.t + integrator.dt, t_end) + integrator.dt = t_end - integrator.t + terminate!(integrator) + end - # one time step - integrator.u_tmp1 .= zero(eltype(integrator.u_tmp1)) - integrator.u_tmp2 .= integrator.u - for stage in eachindex(alg.c) - t_stage = integrator.t + integrator.dt * alg.c[stage] - prob.f(integrator.du, integrator.u, prob.p, t_stage) - - delta_stage = alg.delta[stage] - gamma1_stage = alg.gamma1[stage] - gamma2_stage = alg.gamma2[stage] - gamma3_stage = alg.gamma3[stage] - beta_stage_dt = alg.beta[stage] * integrator.dt - @trixi_timeit timer() "Runge-Kutta step" begin - @threaded for i in eachindex(integrator.u) - integrator.u_tmp1[i] += delta_stage * integrator.u[i] - integrator.u[i] = (gamma1_stage * integrator.u[i] + - gamma2_stage * integrator.u_tmp1[i] + - gamma3_stage * integrator.u_tmp2[i] + - beta_stage_dt * integrator.du[i]) - end + # one time step + integrator.u_tmp1 .= zero(eltype(integrator.u_tmp1)) + integrator.u_tmp2 .= integrator.u + for stage in eachindex(alg.c) + t_stage = integrator.t + integrator.dt * alg.c[stage] + prob.f(integrator.du, integrator.u, prob.p, t_stage) + + delta_stage = alg.delta[stage] + gamma1_stage = alg.gamma1[stage] + gamma2_stage = alg.gamma2[stage] + gamma3_stage = alg.gamma3[stage] + beta_stage_dt = alg.beta[stage] * integrator.dt + @trixi_timeit timer() "Runge-Kutta step" begin + @threaded for i in eachindex(integrator.u) + integrator.u_tmp1[i] += delta_stage * integrator.u[i] + integrator.u[i] = (gamma1_stage * integrator.u[i] + + gamma2_stage * integrator.u_tmp1[i] + + gamma3_stage * integrator.u_tmp2[i] + + beta_stage_dt * integrator.du[i]) end end - integrator.iter += 1 - integrator.t += integrator.dt - - # handle callbacks - if callbacks isa CallbackSet - foreach(callbacks.discrete_callbacks) do cb - if cb.condition(integrator.u, integrator.t, integrator) - cb.affect!(integrator) - end - return nothing + end + integrator.iter += 1 + integrator.t += integrator.dt + + # handle callbacks + if callbacks isa CallbackSet + foreach(callbacks.discrete_callbacks) do cb + if cb.condition(integrator.u, integrator.t, integrator) + cb.affect!(integrator) end - end - - # respect maximum number of iterations - if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep - @warn "Interrupted. Larger maxiters is needed." - terminate!(integrator) + return nothing end end - return TimeIntegratorSolution((first(prob.tspan), integrator.t), - (prob.u0, integrator.u), - integrator.sol.prob) + # respect maximum number of iterations + if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep + @warn "Interrupted. Larger maxiters is needed." + terminate!(integrator) + end end # get a cache where the RHS can be stored