Skip to content

Commit

Permalink
fix: update GeneralDomain to the new ManifoldProjection API
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 27, 2024
1 parent 7b52768 commit 8087d11
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 81 deletions.
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"

[compat]
ADTypes = "1.9.0"
DiffEqCallbacks = "3"
Documenter = "1"
OrdinaryDiffEq = "6.88"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/projection.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ManifoldProjection
Here we solve the harmonic oscillator:

```@example manifold
using OrdinaryDiffEq, DiffEqCallbacks, Plots
using OrdinaryDiffEq, DiffEqCallbacks, Plots, ADTypes
u0 = ones(2)
function f(du, u, p, t)
Expand All @@ -35,7 +35,7 @@ end
To build the callback, we just call

```@example manifold
cb = ManifoldProjection(g)
cb = ManifoldProjection(g; autodiff = AutoForwardDiff())
```

Using this callback, the Runge-Kutta method `Vern7` conserves energy. Note that the
Expand Down
132 changes: 78 additions & 54 deletions src/domain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,35 @@

abstract type AbstractDomainAffect{T, S, uType} end

(f::AbstractDomainAffect)(integrator) = affect!(integrator, f)

struct PositiveDomainAffect{T, S, uType} <: AbstractDomainAffect{T, S, uType}
abstol::T
scalefactor::S
u::uType
end

struct GeneralDomainAffect{autonomous, F, T, S, uType} <: AbstractDomainAffect{T, S, uType}
struct GeneralDomainAffect{F <: AbstractNonAutonomousFunction, T, S, uType, A} <:
AbstractDomainAffect{T, S, uType}
g::F
abstol::T
scalefactor::S
u::uType
resid::uType
autonomous::A
end

function GeneralDomainAffect{autonomous}(g::F, abstol::T, scalefactor::S, u::uType,
resid::uType) where {autonomous, F, T, S, uType
}
new{autonomous, F, T, S, uType}(g, abstol, scalefactor, u, resid)
function initialize_general_domain_affect(cb, u, t, integrator)
return initialize_general_domain_affect(cb.affect!, u, t, integrator)
end
function initialize_general_domain_affect(affect!::GeneralDomainAffect, u, t, integrator)
if affect!.autonomous === nothing
autonomous = maximum(SciMLBase.numargs(affect!.g.f)) ==
2 + SciMLBase.isinplace(integrator.f)
affect!.g.autonomous = autonomous
end
end

# definitions of callback functions

# Workaround since it is not possible to add methods to an abstract type:
# https://github.com/JuliaLang/julia/issues/14919
(f::PositiveDomainAffect)(integrator) = affect!(integrator, f)
(f::GeneralDomainAffect)(integrator) = affect!(integrator, f)

# general method definitions for domain callbacks

"""
Expand All @@ -41,6 +43,8 @@ function affect!(integrator, f::AbstractDomainAffect{T, S, uType}) where {T, S,
throw(ArgumentError("domain callback can only be applied to adaptive algorithms"))
end

iip = Val(SciMLBase.isinplace(integrator.f))

# define array of next time step, absolute tolerance, and scale factor
if uType <: Nothing
if integrator.u isa Union{Number, StaticArraysCore.SArray}
Expand All @@ -55,7 +59,7 @@ function affect!(integrator, f::AbstractDomainAffect{T, S, uType}) where {T, S,
scalefactor = S <: Nothing ? 1 // 2 : f.scalefactor

# setup callback and save additional arguments for checking next time step
args = setup(f, integrator)
args = setup(f, integrator, iip)

# obtain proposed next time step
dt = get_proposed_dt(integrator)
Expand All @@ -80,7 +84,7 @@ function affect!(integrator, f::AbstractDomainAffect{T, S, uType}) where {T, S,
end

# check whether time step is accepted
isaccepted(u, p, t, abstol, f, args...) && break
isaccepted(u, p, t, abstol, f, iip, args...) && break

# reduce time step
dtcache = dt
Expand Down Expand Up @@ -120,20 +124,20 @@ was modified.
modify_u!(integrator, ::AbstractDomainAffect) = false

"""
setup(f::AbstractDomainAffect, integrator)
setup(f::AbstractDomainAffect, integrator, ::Val{iip}) where {iip}
Setup callback `f` and return an arbitrary tuple whose elements are used as additional
arguments in checking whether time step is accepted.
"""
setup(::AbstractDomainAffect, integrator) = ()
setup(::AbstractDomainAffect, integrator, ::Val{iip}) where {iip} = ()

"""
isaccepted(u, abstol, f::AbstractDomainAffect, args...)
Return whether `u` is an acceptable state vector at the next time point given absolute
tolerance `abstol`, callback `f`, and other optional arguments.
"""
isaccepted(u, p, t, tolerance, ::AbstractDomainAffect, args...) = true
isaccepted(u, p, t, tolerance, ::AbstractDomainAffect, ::Val{iip}, args...) where {iip} = true

# specific method definitions for positive domain callback

Expand Down Expand Up @@ -175,27 +179,30 @@ function _set_neg_zero!(integrator, u::StaticArraysCore.SArray)
end

# state vector is accepted if its entries are greater than -abstol
isaccepted(u, p, t, abstol::Number, ::PositiveDomainAffect) = all(ui -> ui > -abstol, u)
function isaccepted(u, p, t, abstol, ::PositiveDomainAffect)
function isaccepted(u, p, t, abstol::Number, ::PositiveDomainAffect, ::Val{iip}) where {iip}
return all(ui -> ui > -abstol, u)
end
function isaccepted(u, p, t, abstol, ::PositiveDomainAffect, ::Val{iip}) where {iip}
length(u) == length(abstol) ||
throw(DimensionMismatch("numbers of states and tolerances do not match"))
all(ui > -tol for (ui, tol) in zip(u, abstol))
return all(ui > -tol for (ui, tol) in zip(u, abstol))
end

# specific method definitions for general domain callback

# create array of residuals
function setup(f::GeneralDomainAffect, integrator)
f.resid isa Nothing ? (similar(integrator.u),) : (f.resid,)
setup(f::GeneralDomainAffect, integrator, ::Val{false}) = (nothing,)
function setup(f::GeneralDomainAffect, integrator, ::Val{true})
return f.resid === nothing ? (similar(integrator.u),) : (f.resid,)
end

function isaccepted(u, p, t, abstol, f::GeneralDomainAffect{autonomous, F, T, S, uType},
resid) where {autonomous, F, T, S, uType}
function isaccepted(u, p, t, abstol, f::GeneralDomainAffect, ::Val{iip}, resid) where {iip}
# calculate residuals
if autonomous
f.g.t = t
if iip
f.g(resid, u, p)
else
f.g(resid, u, p, t)
resid = f.g(u, p)
end

# accept time step if residuals are smaller than the tolerance
Expand All @@ -214,26 +221,32 @@ end
"""
GeneralDomain(
g, u = nothing; save = true, abstol = nothing, scalefactor = nothing,
autonomous = maximum(SciMLBase.numargs(g)) == 3, nlsolve_kwargs = (;
abstol = 10 * eps()), kwargs...)
autonomous = nothing, domain_jacobian = nothing,
nlsolve_kwargs = (; abstol = 10 * eps()), kwargs...)
A `GeneralDomain` callback in DiffEqCallbacks.jl generalizes the concept of
a `PositiveDomain` callback to arbitrary domains. Domains are specified by
in-place functions `g(resid, u, p)` or `g(resid, u, p, t)` that calculate residuals of a
state vector `u` at time `t` relative to that domain, with `p` the parameters of the
corresponding integrator. As for `PositiveDomain`, steps are accepted if residuals
of the extrapolated values at the next time step are below
a certain tolerance. Moreover, this callback is automatically coupled with a
`ManifoldProjection` that keeps all calculated state vectors close to the desired
domain, but in contrast to a `PositiveDomain` callback the nonlinear solver in a
`ManifoldProjection` cannot guarantee that all state vectors of the solution are
actually inside the domain. Thus, a `PositiveDomain` callback should generally be
preferred.
a `PositiveDomain` callback to arbitrary domains.
Domains are specified by
- in-place functions `g(resid, u, p)` or `g(resid, u, p, t)` if the corresponding
ODEProblem is an inplace problem, or
- out-of-place functions `g(u, p)` or `g(u, p, t)` if the corresponding ODEProblem is
an out-of-place problem.
The function calculates residuals of a state vector `u` at time `t` relative to that domain,
with `p` the parameters of the corresponding integrator.
As for `PositiveDomain`, steps are accepted if residuals of the extrapolated values at the
next time step are below a certain tolerance. Moreover, this callback is automatically
coupled with a `ManifoldProjection` that keeps all calculated state vectors close to the
desired domain, but in contrast to a `PositiveDomain` callback the nonlinear solver in a
`ManifoldProjection` cannot guarantee that all state vectors of the solution are actually
inside the domain. Thus, a `PositiveDomain` callback should generally be preferred.
## Arguments
- `g`: the implicit definition of the domain as a function `g(resid, u, p)` or
`g(resid, u, p, t)` which is zero when the value is in the domain.
- `g`: the implicit definition of the domain as a function as described above which is
zero when the value is in the domain.
- `u`: A prototype of the state vector of the integrator. A copy of it is saved and
extrapolated values are written to it. If it is not specified,
every application of the callback allocates a new copy of the state vector.
Expand All @@ -248,9 +261,13 @@ preferred.
specified, time steps are halved.
- `autonomous`: Whether `g` is an autonomous function of the form `g(resid, u, p)`.
If it is not specified, it is determined automatically.
- `kwargs`: All other keyword arguments are passed to `ManifoldProjection`.
- `kwargs`: All other keyword arguments are passed to [`ManifoldProjection`](@ref).
- `nlsolve_kwargs`: All keyword arguments are passed to the nonlinear solver in
`ManifoldProjection`. The default is `(; abstol = 10 * eps())`.
- `domain_jacobian`: The Jacobian of the domain (wrt the state). This has the same
signature as `g` and the first argument is the Jacobian if inplace. This corresponds to
the `manifold_jacobian` argument of [`ManifoldProjection`](@ref). Note that passing
a `manifold_jacobian` is not supported for `GeneralDomain` and results in an error.
## References
Expand All @@ -260,20 +277,27 @@ Non-negative solutions of ODEs. Applied Mathematics and Computation 170
"""
function GeneralDomain(
g, u = nothing; save = true, abstol = nothing, scalefactor = nothing,
autonomous = maximum(SciMLBase.numargs(g)) == 3, nlsolve_kwargs = (;
abstol = 10 * eps()), kwargs...)
_autonomous = SciMLBase._unwrap_val(autonomous)
if u isa Nothing
affect! = GeneralDomainAffect{_autonomous}(g, abstol, scalefactor, nothing, nothing)
autonomous = nothing, domain_jacobian = nothing, manifold_jacobian = missing,
nlsolve_kwargs = (; abstol = 10 * eps()), kwargs...)
if manifold_jacobian !== missing
throw(ArgumentError("`manifold_jacobian` is not supported for `GeneralDomain`. \
Use `domain_jacobian` instead."))
end
manifold_projection = ManifoldProjection(
g; save = false, autonomous, manifold_jacobian = domain_jacobian,
kwargs..., nlsolve_kwargs...)
domain = wrap_autonomous_function(autonomous, g)
domain_jacobian = wrap_autonomous_function(autonomous, domain_jacobian)
affect! = if u === nothing
GeneralDomainAffect(domain, abstol, scalefactor, nothing, nothing, autonomous)
else
affect! = GeneralDomainAffect{_autonomous}(g, abstol, scalefactor, deepcopy(u),
deepcopy(u))
GeneralDomainAffect(
domain, abstol, scalefactor, deepcopy(u), deepcopy(u), autonomous)
end
condition = (u, t, integrator) -> true
CallbackSet(
ManifoldProjection(
g; save = false, autonomous, isinplace = Val(true), kwargs..., nlsolve_kwargs...),
DiscreteCallback(condition, affect!; save_positions = (false, save)))
domain_cb = DiscreteCallback(
Returns(true), affect!; initialize = initialize_general_domain_affect,
save_positions = (false, save))
return CallbackSet(manifold_projection, domain_cb)
end

@doc doc"""
Expand Down
41 changes: 19 additions & 22 deletions src/manifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ properties.
would work in most cases (See [1] for details). Alternatively, a nonlinear solver as
defined in the
[NonlinearSolve.jl format](https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/)
can be specified.
can be specified. Additionally if NonlinearSolve.jl is loaded and `nothing` is specified
a polyalgorithm is used.
- `save`: Whether to do the standard saving (applied after the callback)
- `autonomous`: Whether `g` is an autonomous function of the form `g(resid, u, p)` or
`g(u, p)`. Specify it as `Val(::Bool)` to disable runtime branching. If `nothing`,
Expand Down Expand Up @@ -88,25 +89,8 @@ end

function ManifoldProjection(
manifold, autodiff, manifold_jacobian, nlsolve, kwargs, autonomous)
if autonomous isa Val{true} || autonomous isa Val{false}
wrapped_manifold = TypedNonAutonomousFunction{SciMLBase._unwrap_val(autonomous)}(
manifold, nothing)
wrapped_manifold_jacobian = if manifold_jacobian === nothing
nothing
else
TypedNonAutonomousFunction{SciMLBase._unwrap_val(autonomous)}(
manifold_jacobian, nothing)
end
autonomous = SciMLBase._unwrap_val(autonomous)
else
_autonomous = autonomous === nothing ? false : autonomous
wrapped_manifold = UntypedNonAutonomousFunction(_autonomous, manifold, nothing)
wrapped_manifold_jacobian = if manifold_jacobian === nothing
nothing
else
UntypedNonAutonomousFunction(_autonomous, manifold_jacobian, nothing)
end
end
wrapped_manifold = wrap_autonomous_function(autonomous, manifold)
wrapped_manifold_jacobian = wrap_autonomous_function(autonomous, manifold_jacobian)
return ManifoldProjection(wrapped_manifold, wrapped_manifold_jacobian,
autodiff, nothing, nlsolve, kwargs, autonomous)
end
Expand Down Expand Up @@ -158,7 +142,20 @@ end
export ManifoldProjection

# wrapper for non-autonomous functions
@concrete mutable struct TypedNonAutonomousFunction{autonomous}
function wrap_autonomous_function(autonomous::Union{Val{true}, Val{false}}, g)
g === nothing && return nothing
return TypedNonAutonomousFunction{SciMLBase._unwrap_val(autonomous)}(g, nothing)
end
function wrap_autonomous_function(autonomous::Union{Bool, Nothing}, g)
g === nothing && return nothing
autonomous = autonomous === nothing ? false : autonomous
return UntypedNonAutonomousFunction(autonomous, g, nothing)
end

abstract type AbstractNonAutonomousFunction end

@concrete mutable struct TypedNonAutonomousFunction{autonomous} <:
AbstractNonAutonomousFunction
f
t::Any
end
Expand All @@ -169,7 +166,7 @@ end
(f::TypedNonAutonomousFunction{false})(u, p) = f.f(u, p, f.t)
(f::TypedNonAutonomousFunction{true})(u, p) = f.f(u, p)

@concrete mutable struct UntypedNonAutonomousFunction
@concrete mutable struct UntypedNonAutonomousFunction <: AbstractNonAutonomousFunction
autonomous::Bool
f
t::Any
Expand Down
14 changes: 11 additions & 3 deletions test/domain_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using DiffEqCallbacks, OrdinaryDiffEq, Test
using DiffEqCallbacks, OrdinaryDiffEq, Test, ADTypes, NonlinearSolve

# Non-negative ODE examples
#
Expand Down Expand Up @@ -39,7 +39,11 @@ naive_sol_absval = solve(prob_absval, BS3())
function g(resid, u, p)
resid[1] = u[1] < 0 ? -u[1] : 0
end
general_sol_absval = solve(prob_absval, BS3(); callback = GeneralDomain(g, [1.0]),
general_sol_absval = solve(
prob_absval, BS3();
callback = GeneralDomain(g, [1.0];
autodiff = AutoForwardDiff(),
nlsolve=NewtonRaphson(; autodiff = AutoForwardDiff())),
save_everystep = false)
@test all(x -> x[1] 0, general_sol_absval.u)
@test general_sol_absval.errors[:l∞] < 9.9e-5
Expand All @@ -49,7 +53,11 @@ general_sol_absval = solve(prob_absval, BS3(); callback = GeneralDomain(g, [1.0]
# test "non-autonomous" function
g_t(resid, u, p, t) = g(resid, u, p)

general_t_sol_absval = solve(prob_absval, BS3(); callback = GeneralDomain(g_t, [1.0]),
general_t_sol_absval = solve(
prob_absval, BS3();
callback = GeneralDomain(g_t, [1.0];
autodiff = AutoForwardDiff(),
nlsolve=NewtonRaphson(; autodiff = AutoForwardDiff())),
save_everystep = false)
@test general_sol_absval.t general_t_sol_absval.t
@test general_sol_absval.u general_t_sol_absval.u
Expand Down

0 comments on commit 8087d11

Please sign in to comment.