Skip to content

Commit

Permalink
Refactor time integrator 2N and 3Star more similar to OrdinaryDiffEq.…
Browse files Browse the repository at this point in the history
…jl integrators (#1975)

* refactor 2N and 3*

* Apply suggestions from code review

Co-authored-by: Hendrik Ranocha <ranocha@users.noreply.github.com>

* fix the logic of step!

* Apply suggestions from code review

Co-authored-by: Hendrik Ranocha <ranocha@users.noreply.github.com>

* fmt

* Revert "fmt"

This reverts commit edb92b0.

* fmt

* add name to AUTHORS.md

---------

Co-authored-by: Daniel Doehring <doehringd2@gmail.com>
Co-authored-by: Hendrik Ranocha <ranocha@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 19, 2024
1 parent 75d8c67 commit 5a642f2
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 100 deletions.
1 change: 1 addition & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ are listed in alphabetical order:
* Julia Odenthal
* Sigrun Ortleb
* Hendrik Ranocha
* Warisa Roongaraya
* Andrés M. Rueda-Ramírez
* Felipe Santillan
* Michael Schlottke-Lakemper
Expand Down
110 changes: 63 additions & 47 deletions src/time_integration/methods_2N.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,8 @@ function Base.getproperty(integrator::SimpleIntegrator2N, field::Symbol)
return getfield(integrator, field)
end

# Fakes `solve`: https://diffeq.sciml.ai/v6.8/basics/overview/#Solving-the-Problems-1
function solve(ode::ODEProblem, alg::T;
dt, callback = nothing, kwargs...) where {T <: SimpleAlgorithm2N}
function init(ode::ODEProblem, alg::SimpleAlgorithm2N;
dt, callback = nothing, kwargs...)
u = copy(ode.u0)
du = similar(u)
u_tmp = similar(u)
Expand All @@ -129,67 +128,84 @@ function solve(ode::ODEProblem, alg::T;
error("unsupported")
end

return integrator
end

# Fakes `solve`: https://diffeq.sciml.ai/v6.8/basics/overview/#Solving-the-Problems-1
function solve(ode::ODEProblem, alg::SimpleAlgorithm2N;
dt, callback = nothing, kwargs...)
integrator = init(ode, alg, dt = dt, callback = callback; kwargs...)

# Start actual solve
solve!(integrator)
end

function solve!(integrator::SimpleIntegrator2N)
@unpack prob = integrator.sol

integrator.finalstep = false

@trixi_timeit timer() "main loop" while !integrator.finalstep
step!(integrator)
end # "main loop" timer

return TimeIntegratorSolution((first(prob.tspan), integrator.t),
(prob.u0, integrator.u),
integrator.sol.prob)
end

function step!(integrator::SimpleIntegrator2N)
@unpack prob = integrator.sol
@unpack alg = integrator
t_end = last(prob.tspan)
callbacks = integrator.opts.callback

integrator.finalstep = false
@trixi_timeit timer() "main loop" while !integrator.finalstep
if isnan(integrator.dt)
error("time step size `dt` is NaN")
end
@assert !integrator.finalstep
if isnan(integrator.dt)
error("time step size `dt` is NaN")
end

# if the next iteration would push the simulation beyond the end time, set dt accordingly
if integrator.t + integrator.dt > t_end ||
isapprox(integrator.t + integrator.dt, t_end)
integrator.dt = t_end - integrator.t
terminate!(integrator)
end
# if the next iteration would push the simulation beyond the end time, set dt accordingly
if integrator.t + integrator.dt > t_end ||
isapprox(integrator.t + integrator.dt, t_end)
integrator.dt = t_end - integrator.t
terminate!(integrator)
end

# one time step
integrator.u_tmp .= 0
for stage in eachindex(alg.c)
t_stage = integrator.t + integrator.dt * alg.c[stage]
integrator.f(integrator.du, integrator.u, prob.p, t_stage)

a_stage = alg.a[stage]
b_stage_dt = alg.b[stage] * integrator.dt
@trixi_timeit timer() "Runge-Kutta step" begin
@threaded for i in eachindex(integrator.u)
integrator.u_tmp[i] = integrator.du[i] -
integrator.u_tmp[i] * a_stage
integrator.u[i] += integrator.u_tmp[i] * b_stage_dt
end
# one time step
integrator.u_tmp .= 0
for stage in eachindex(alg.c)
t_stage = integrator.t + integrator.dt * alg.c[stage]
integrator.f(integrator.du, integrator.u, prob.p, t_stage)

a_stage = alg.a[stage]
b_stage_dt = alg.b[stage] * integrator.dt
@trixi_timeit timer() "Runge-Kutta step" begin
@threaded for i in eachindex(integrator.u)
integrator.u_tmp[i] = integrator.du[i] -
integrator.u_tmp[i] * a_stage
integrator.u[i] += integrator.u_tmp[i] * b_stage_dt
end
end
integrator.iter += 1
integrator.t += integrator.dt

# handle callbacks
if callbacks isa CallbackSet
foreach(callbacks.discrete_callbacks) do cb
if cb.condition(integrator.u, integrator.t, integrator)
cb.affect!(integrator)
end
return nothing
end
integrator.iter += 1
integrator.t += integrator.dt

# handle callbacks
if callbacks isa CallbackSet
foreach(callbacks.discrete_callbacks) do cb
if cb.condition(integrator.u, integrator.t, integrator)
cb.affect!(integrator)
end
end

# respect maximum number of iterations
if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep
@warn "Interrupted. Larger maxiters is needed."
terminate!(integrator)
return nothing
end
end

return TimeIntegratorSolution((first(prob.tspan), integrator.t),
(prob.u0, integrator.u),
integrator.sol.prob)
# respect maximum number of iterations
if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep
@warn "Interrupted. Larger maxiters is needed."
terminate!(integrator)
end
end

# get a cache where the RHS can be stored
Expand Down
122 changes: 69 additions & 53 deletions src/time_integration/methods_3Sstar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,8 @@ function Base.getproperty(integrator::SimpleIntegrator3Sstar, field::Symbol)
return getfield(integrator, field)
end

# Fakes `solve`: https://diffeq.sciml.ai/v6.8/basics/overview/#Solving-the-Problems-1
function solve(ode::ODEProblem, alg::T;
dt, callback = nothing, kwargs...) where {T <: SimpleAlgorithm3Sstar}
function init(ode::ODEProblem, alg::SimpleAlgorithm3Sstar;
dt, callback = nothing, kwargs...)
u = copy(ode.u0)
du = similar(u)
u_tmp1 = similar(u)
Expand All @@ -199,73 +198,90 @@ function solve(ode::ODEProblem, alg::T;
error("unsupported")
end

return integrator
end

# Fakes `solve`: https://diffeq.sciml.ai/v6.8/basics/overview/#Solving-the-Problems-1
function solve(ode::ODEProblem, alg::SimpleAlgorithm3Sstar;
dt, callback = nothing, kwargs...)
integrator = init(ode, alg, dt = dt, callback = callback; kwargs...)

# Start actual solve
solve!(integrator)
end

function solve!(integrator::SimpleIntegrator3Sstar)
@unpack prob = integrator.sol

integrator.finalstep = false

@trixi_timeit timer() "main loop" while !integrator.finalstep
step!(integrator)
end # "main loop" timer

return TimeIntegratorSolution((first(prob.tspan), integrator.t),
(prob.u0, integrator.u),
integrator.sol.prob)
end

function step!(integrator::SimpleIntegrator3Sstar)
@unpack prob = integrator.sol
@unpack alg = integrator
t_end = last(prob.tspan)
callbacks = integrator.opts.callback

integrator.finalstep = false
@trixi_timeit timer() "main loop" while !integrator.finalstep
if isnan(integrator.dt)
error("time step size `dt` is NaN")
end
@assert !integrator.finalstep
if isnan(integrator.dt)
error("time step size `dt` is NaN")
end

# if the next iteration would push the simulation beyond the end time, set dt accordingly
if integrator.t + integrator.dt > t_end ||
isapprox(integrator.t + integrator.dt, t_end)
integrator.dt = t_end - integrator.t
terminate!(integrator)
end
# if the next iteration would push the simulation beyond the end time, set dt accordingly
if integrator.t + integrator.dt > t_end ||
isapprox(integrator.t + integrator.dt, t_end)
integrator.dt = t_end - integrator.t
terminate!(integrator)
end

# one time step
integrator.u_tmp1 .= zero(eltype(integrator.u_tmp1))
integrator.u_tmp2 .= integrator.u
for stage in eachindex(alg.c)
t_stage = integrator.t + integrator.dt * alg.c[stage]
prob.f(integrator.du, integrator.u, prob.p, t_stage)

delta_stage = alg.delta[stage]
gamma1_stage = alg.gamma1[stage]
gamma2_stage = alg.gamma2[stage]
gamma3_stage = alg.gamma3[stage]
beta_stage_dt = alg.beta[stage] * integrator.dt
@trixi_timeit timer() "Runge-Kutta step" begin
@threaded for i in eachindex(integrator.u)
integrator.u_tmp1[i] += delta_stage * integrator.u[i]
integrator.u[i] = (gamma1_stage * integrator.u[i] +
gamma2_stage * integrator.u_tmp1[i] +
gamma3_stage * integrator.u_tmp2[i] +
beta_stage_dt * integrator.du[i])
end
# one time step
integrator.u_tmp1 .= zero(eltype(integrator.u_tmp1))
integrator.u_tmp2 .= integrator.u
for stage in eachindex(alg.c)
t_stage = integrator.t + integrator.dt * alg.c[stage]
prob.f(integrator.du, integrator.u, prob.p, t_stage)

delta_stage = alg.delta[stage]
gamma1_stage = alg.gamma1[stage]
gamma2_stage = alg.gamma2[stage]
gamma3_stage = alg.gamma3[stage]
beta_stage_dt = alg.beta[stage] * integrator.dt
@trixi_timeit timer() "Runge-Kutta step" begin
@threaded for i in eachindex(integrator.u)
integrator.u_tmp1[i] += delta_stage * integrator.u[i]
integrator.u[i] = (gamma1_stage * integrator.u[i] +
gamma2_stage * integrator.u_tmp1[i] +
gamma3_stage * integrator.u_tmp2[i] +
beta_stage_dt * integrator.du[i])
end
end
integrator.iter += 1
integrator.t += integrator.dt

# handle callbacks
if callbacks isa CallbackSet
foreach(callbacks.discrete_callbacks) do cb
if cb.condition(integrator.u, integrator.t, integrator)
cb.affect!(integrator)
end
return nothing
end
integrator.iter += 1
integrator.t += integrator.dt

# handle callbacks
if callbacks isa CallbackSet
foreach(callbacks.discrete_callbacks) do cb
if cb.condition(integrator.u, integrator.t, integrator)
cb.affect!(integrator)
end
end

# respect maximum number of iterations
if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep
@warn "Interrupted. Larger maxiters is needed."
terminate!(integrator)
return nothing
end
end

return TimeIntegratorSolution((first(prob.tspan), integrator.t),
(prob.u0, integrator.u),
integrator.sol.prob)
# respect maximum number of iterations
if integrator.iter >= integrator.opts.maxiters && !integrator.finalstep
@warn "Interrupted. Larger maxiters is needed."
terminate!(integrator)
end
end

# get a cache where the RHS can be stored
Expand Down

0 comments on commit 5a642f2

Please sign in to comment.