Skip to content

Commit

Permalink
Revive SSPKnoth
Browse files Browse the repository at this point in the history
  • Loading branch information
Sbozzolo committed May 24, 2024
1 parent 759efc8 commit 2ec7589
Show file tree
Hide file tree
Showing 8 changed files with 704 additions and 168 deletions.
274 changes: 142 additions & 132 deletions docs/Manifest.toml

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pages = [
"Algorithm Formulations" => [
"ODE Solvers" => "algorithm_formulations/ode_solvers.md",
"Newtons Method" => "algorithm_formulations/newtons_method.md",
"Rosenbrock Method" => "algorithm_formulations/rosenbrock.md",
"Old LSRK Formulations" => "algorithm_formulations/lsrk.md",
"Old MRRK Formulations" => "algorithm_formulations/mrrk.md",
],
Expand Down
1 change: 1 addition & 0 deletions docs/src/plotting_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ imex_convergence_orders(::ARK548L2SA2) = (5, 5, 5)
imex_convergence_orders(::SSP22Heuns) = (2, 2, 2)
imex_convergence_orders(::SSP33ShuOsher) = (3, 3, 3)
imex_convergence_orders(::RK4) = (4, 4, 4)
imex_convergence_orders(::SSPKnoth) = (2, 2, 2)

# Compute a confidence interval for the convergence order, returning the
# estimated convergence order and its uncertainty.
Expand Down
1 change: 1 addition & 0 deletions ext/benchmark_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ function get_trial(f, args, name; device, with_cu_prof = :bprofile, trace = fals
end

get_W(i::CTS.DistributedODEIntegrator) = i.cache.newtons_method_cache.j
get_W(i::CTS.RosenbrockAlgorithm) = i.cache.newtons_method_cache.j
get_W(i) = i.cache.W
f_args(i, f::CTS.ForwardEulerODEFunction) = (copy(i.u), i.u, i.p, i.t, i.dt)
f_args(i, f) = (similar(i.u), i.u, i.p, i.t)
Expand Down
183 changes: 148 additions & 35 deletions src/solvers/rosenbrock.jl
Original file line number Diff line number Diff line change
@@ -1,81 +1,198 @@
export SSPKnoth
using StaticArrays
import DiffEqBase
import LinearAlgebra: ldiv!

abstract type RosenbrockAlgorithm <: DistributedODEAlgorithm end
abstract type RosenbrockAlgorithmName <: AbstractAlgorithmName end

"""
RosenbrockTableau{N, RT, N²}
Contains everything that defines a Rosenbrock-type method.
- N: number of stages,
- N²: number of stages squared,
- RT: real type (Float32, Float64, ...)
Refer to the documentation for the precise meaning of the symbols below.
"""
struct RosenbrockTableau{N, RT, N²}
"""A = α Γ⁻¹"""
A::SMatrix{N, N, RT, N²}
"""Tableau used for the time-dependent part"""
α::SMatrix{N, N, RT, N²}
"""Stepping matrix"""
C::SMatrix{N, N, RT, N²}
"""Substage contribution matrix"""
Γ::SMatrix{N, N, RT, N²}
"""m = b Γ⁻¹, used to compute the increments k"""
m::SMatrix{N, 1, RT, N}
end

struct RosenbrockCache{Nstages, RT, N², A}
tableau::RosenbrockTableau{Nstages, RT, N²}
"""
RosenbrockAlgorithm(tableau)
Constructs a Rosenbrock algorithm for solving ODEs.
"""
struct RosenbrockAlgorithm{T <: RosenbrockTableau} <: ClimaTimeSteppers.DistributedODEAlgorithm
tableau::T
end

"""
RosenbrockCache{N, A, WT}
Contains everything that is needed to run a Rosenbrock-type method.
- Nstages: number of stages,
- A: type of the evolved state (e.g., a ClimaCore.FieldVector),
- WT: type of the Jacobian (Wfact)
"""
struct RosenbrockCache{Nstages, A, WT}
"""Preallocated space for the state"""
U::A

"""Preallocated space for the tendency"""
fU::A

"""Preallocated space for the explicit contribution to the tendency"""
fU_exp::A

"""Preallocated space for the limited contribution to the tendency"""
fU_lim::A

"""Contributions to the state for each stage"""
k::NTuple{Nstages, A}
W::Any
linsolve!::Any

"""Preallocated space for the Wfact, dtγJ - 𝕀, or Wfact_t, 𝕀/dtγ - J, with J the Jacobian of the implicit tendency"""
W::WT

"""Preallocated space for the explicit time derivative of the tendency"""
∂Y∂t::A
end

function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::RosenbrockAlgorithm; kwargs...)

tab = tableau(alg, eltype(prob.u0))
Nstages = length(tab.m)
Nstages = length(alg.tableau.m)
U = zero(prob.u0)
fU = zero(prob.u0)
fU_exp = zero(prob.u0)
fU_lim = zero(prob.u0)
∂Y∂t = zero(prob.u0)
k = ntuple(n -> similar(prob.u0), Nstages)
W = prob.f.jac_prototype
linsolve! = alg.linsolve(Val{:init}, W, prob.u0; kwargs...)

return RosenbrockCache(tab, U, fU, k, W, linsolve!)
W = prob.f.T_imp!.jac_prototype
return RosenbrockCache{Nstages, typeof(U), typeof(W)}(U, fU, fU_exp, fU_lim, k, W, ∂Y∂t)
end

"""
step_u!(int, cache::RosenbrockCache{Nstages})
Take one step with the Rosenbrock-method with the given `cache`.
function step_u!(int, cache::RosenbrockCache{Nstages, RT}) where {Nstages, RT}
(; m, Γ, A, C) = cache.tableau
Some choices are being made here. Most of these are empirically motivated and should be
revisited on different problems.
- We do not update dtγ across stages
- We do not update Wfact across stages
- We apply DSS to the sum of the explicit and implicit tendency at all the stages but the last
- We apply DSS to incremented state (ie, after the final stage is applied)
"""
function step_u!(int, cache::RosenbrockCache{Nstages}) where {Nstages}
(; m, Γ, A, α, C) = int.alg.tableau
(; u, p, t, dt) = int
(; W, U, fU, k, linsolve!) = cache
f! = int.sol.prob.f
Wfact_t! = int.sol.prob.f.Wfact_t
(; W, U, fU, fU_exp, fU_lim, k, ∂Y∂t) = cache
T_imp! = int.sol.prob.f.T_imp!
T_exp_lim! = int.sol.prob.f.T_exp_T_lim!
tgrad! = int.sol.prob.f.T_imp!.tgrad

Wfact! = int.sol.prob.f.T_imp!.Wfact
Wfact_t! = int.sol.prob.f.T_imp!.Wfact_t
if !isnothing(Wfact!) && !isnothing(Wfact_t!)
error("Only one between Wfact and Wfact_t can be non-nothing")
end

(; post_explicit!, post_implicit!, dss!) = int.sol.prob.f

dtγ = dt * Γ[1, 1]

if isnothing(Wfact_t!)
# We have Wfact
Wfact!(W, u, p, dtγ, t)
else
# We have Wfact_t
Wfact_t!(W, u, p, dtγ, t)
end


# 1) compute jacobian factorization
γ = dt * Γ[1, 1]
Wfact_t!(W, u, p, γ, t)
for i in 1:Nstages
αi = sum(α[i, 1:(i - 1)])

U .= u
for j in 1:(i - 1)
U .+= A[i, j] .* k[j]
end
# TODO: there should be a time modification here (t + c * dt)
# if f does depend on time, would need to add tgrad term as well
f!(fU, U, p, t)

# NOTE: post_implicit! is a misnomer
post_implicit!(U, p, t)

if !isnothing(T_imp!)
T_imp!(fU, U, p, t)
end

if !isnothing(T_exp_lim!)
T_exp_lim!(fU_exp, fU_lim, U, p, t)
fU .+= fU_exp
fU .+= fU_lim
end

# We dss the tendency at every stage but the last. At the last stage, we
# dss the incremented state
(i != Nstages) && dss!(fU, p, t)

for j in 1:(i - 1)
fU .+= (C[i, j] / dt) .* k[j]
end
linsolve!(k[i], W, fU)

if !isnothing(tgrad!)
tgrad!(∂Y∂t, u, p, t)
fU .+= αi .* dt .* ∂Y∂t
end

if isnothing(Wfact_t!)
fU .*= -dtγ
end

if W isa Matrix
ldiv!(k[i], lu(W), fU)
else
ldiv!(k[i], W, fU)
end
end

for i in 1:Nstages
u .+= m[i] .* k[i]
end
end

struct SSPKnoth{L} <: RosenbrockAlgorithm
linsolve::L
dss!(u, p, t)
return nothing
end
SSPKnoth(; linsolve) = SSPKnoth(linsolve)

"""
SSPKnoth
`SSPKnoth` is a second-order Rosenbrock method.
We do not know where the coefficients come from. They are the same as in `CGDycore.jl`.
"""
struct SSPKnoth end

struct SSPKnoth <: RosenbrockAlgorithmName end

function tableau(::SSPKnoth, RT)
# ROS.transformed=true;
N = 3
= N * N
α = @SMatrix RT[
0 0 0
1 0 0
1/4 1/4 0
]
# ROS.d=ROS.alpha*ones(ROS.nStage,1);
b = @SMatrix RT[1 / 6 1 / 6 2 / 3]
Γ = @SMatrix RT[
1 0 0
Expand All @@ -85,9 +202,5 @@ function tableau(::SSPKnoth, RT)
A = α / Γ
C = -inv(Γ)
m = b / Γ
return RosenbrockTableau{N, RT, N²}(A, C, Γ, m)
# ROS.SSP.alpha=[1 0 0
# 3/4 1/4 0
# 1/3 0 2/3];

return RosenbrockTableau{N, RT, N²}(A, α, C, Γ, m)
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
ClimaCore = "d414da3d-4745-48bb-8d80-42e94e092884"
ClimaCorePlots = "cf7c7e5a-b407-4c48-9047-11a94a308626"
ClimaTimeSteppers = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand Down
15 changes: 14 additions & 1 deletion test/problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using ClimaCore
using ClimaComms
import ClimaCore: Domains, Geometry, Meshes, Topologies, Spaces, Fields, Operators, Limiters

@static isdefined(ClimaComms, :device_type) && ClimaComms.@import_required_backends
@static pkgversion(ClimaComms) >= v"0.6" && ClimaComms.@import_required_backends

import Krylov # Trigger ClimaCore/ext/KrylovExt

Expand Down Expand Up @@ -380,6 +380,19 @@ function ark_analytic_nonlin_test_cts(::Type{FT}) where {FT}
)
end

function sspknoth_analytic_nonlin_test_cts(::Type{FT}) where {FT}
ClimaIntegratorTestCase(;
test_name = "sspknoth_analytic_nonlin",
linear_implicit = false,
t_end = FT(10),
Y₀ = FT[0],
analytic_sol = (t) -> [log(t^2 / 2 + t + 1)],
tendency! = (Yₜ, Y, _, t) -> Yₜ .= (t + 1) .* exp.(.-Y),
Wfact! = (W, Y, _, Δt, t) -> W .= (-Δt * (t + 1) .* exp.(.-Y) .- 1),
tgrad! = (∂Y∂t, Y, _, t) -> ∂Y∂t .= exp.(.-Y),
)
end

# From Section 5.1 of "Example Programs for ARKode v4.4.0" by D. R. Reynolds
function ark_analytic_sys_test_cts(::Type{FT}) where {FT}
λ = FT(-100) # increase magnitude for more stiffness
Expand Down
Loading

0 comments on commit 2ec7589

Please sign in to comment.