From 8087d11b112b32fc8a904f949d57e006c5e61df3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 27 Sep 2024 13:10:45 -0400 Subject: [PATCH] fix: update GeneralDomain to the new ManifoldProjection API --- docs/Project.toml | 2 + docs/src/projection.md | 4 +- src/domain.jl | 132 ++++++++++++++++++++++++----------------- src/manifold.jl | 41 ++++++------- test/domain_tests.jl | 14 ++++- 5 files changed, 112 insertions(+), 81 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 759ff0db..70dbd1e3 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -6,6 +7,7 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" [compat] +ADTypes = "1.9.0" DiffEqCallbacks = "3" Documenter = "1" OrdinaryDiffEq = "6.88" diff --git a/docs/src/projection.md b/docs/src/projection.md index 31b08bc2..7c0d2f12 100644 --- a/docs/src/projection.md +++ b/docs/src/projection.md @@ -12,7 +12,7 @@ ManifoldProjection Here we solve the harmonic oscillator: ```@example manifold -using OrdinaryDiffEq, DiffEqCallbacks, Plots +using OrdinaryDiffEq, DiffEqCallbacks, Plots, ADTypes u0 = ones(2) function f(du, u, p, t) @@ -35,7 +35,7 @@ end To build the callback, we just call ```@example manifold -cb = ManifoldProjection(g) +cb = ManifoldProjection(g; autodiff = AutoForwardDiff()) ``` Using this callback, the Runge-Kutta method `Vern7` conserves energy. Note that the diff --git a/src/domain.jl b/src/domain.jl index d9879886..581a18f8 100644 --- a/src/domain.jl +++ b/src/domain.jl @@ -2,33 +2,35 @@ abstract type AbstractDomainAffect{T, S, uType} end +(f::AbstractDomainAffect)(integrator) = affect!(integrator, f) + struct PositiveDomainAffect{T, S, uType} <: AbstractDomainAffect{T, S, uType} abstol::T scalefactor::S u::uType end -struct GeneralDomainAffect{autonomous, F, T, S, uType} <: AbstractDomainAffect{T, S, uType} +struct GeneralDomainAffect{F <: AbstractNonAutonomousFunction, T, S, uType, A} <: + AbstractDomainAffect{T, S, uType} g::F abstol::T scalefactor::S u::uType resid::uType + autonomous::A +end - function GeneralDomainAffect{autonomous}(g::F, abstol::T, scalefactor::S, u::uType, - resid::uType) where {autonomous, F, T, S, uType - } - new{autonomous, F, T, S, uType}(g, abstol, scalefactor, u, resid) +function initialize_general_domain_affect(cb, u, t, integrator) + return initialize_general_domain_affect(cb.affect!, u, t, integrator) +end +function initialize_general_domain_affect(affect!::GeneralDomainAffect, u, t, integrator) + if affect!.autonomous === nothing + autonomous = maximum(SciMLBase.numargs(affect!.g.f)) == + 2 + SciMLBase.isinplace(integrator.f) + affect!.g.autonomous = autonomous end end -# definitions of callback functions - -# Workaround since it is not possible to add methods to an abstract type: -# https://github.com/JuliaLang/julia/issues/14919 -(f::PositiveDomainAffect)(integrator) = affect!(integrator, f) -(f::GeneralDomainAffect)(integrator) = affect!(integrator, f) - # general method definitions for domain callbacks """ @@ -41,6 +43,8 @@ function affect!(integrator, f::AbstractDomainAffect{T, S, uType}) where {T, S, throw(ArgumentError("domain callback can only be applied to adaptive algorithms")) end + iip = Val(SciMLBase.isinplace(integrator.f)) + # define array of next time step, absolute tolerance, and scale factor if uType <: Nothing if integrator.u isa Union{Number, StaticArraysCore.SArray} @@ -55,7 +59,7 @@ function affect!(integrator, f::AbstractDomainAffect{T, S, uType}) where {T, S, scalefactor = S <: Nothing ? 1 // 2 : f.scalefactor # setup callback and save additional arguments for checking next time step - args = setup(f, integrator) + args = setup(f, integrator, iip) # obtain proposed next time step dt = get_proposed_dt(integrator) @@ -80,7 +84,7 @@ function affect!(integrator, f::AbstractDomainAffect{T, S, uType}) where {T, S, end # check whether time step is accepted - isaccepted(u, p, t, abstol, f, args...) && break + isaccepted(u, p, t, abstol, f, iip, args...) && break # reduce time step dtcache = dt @@ -120,12 +124,12 @@ was modified. modify_u!(integrator, ::AbstractDomainAffect) = false """ - setup(f::AbstractDomainAffect, integrator) + setup(f::AbstractDomainAffect, integrator, ::Val{iip}) where {iip} Setup callback `f` and return an arbitrary tuple whose elements are used as additional arguments in checking whether time step is accepted. """ -setup(::AbstractDomainAffect, integrator) = () +setup(::AbstractDomainAffect, integrator, ::Val{iip}) where {iip} = () """ isaccepted(u, abstol, f::AbstractDomainAffect, args...) @@ -133,7 +137,7 @@ setup(::AbstractDomainAffect, integrator) = () Return whether `u` is an acceptable state vector at the next time point given absolute tolerance `abstol`, callback `f`, and other optional arguments. """ -isaccepted(u, p, t, tolerance, ::AbstractDomainAffect, args...) = true +isaccepted(u, p, t, tolerance, ::AbstractDomainAffect, ::Val{iip}, args...) where {iip} = true # specific method definitions for positive domain callback @@ -175,27 +179,30 @@ function _set_neg_zero!(integrator, u::StaticArraysCore.SArray) end # state vector is accepted if its entries are greater than -abstol -isaccepted(u, p, t, abstol::Number, ::PositiveDomainAffect) = all(ui -> ui > -abstol, u) -function isaccepted(u, p, t, abstol, ::PositiveDomainAffect) +function isaccepted(u, p, t, abstol::Number, ::PositiveDomainAffect, ::Val{iip}) where {iip} + return all(ui -> ui > -abstol, u) +end +function isaccepted(u, p, t, abstol, ::PositiveDomainAffect, ::Val{iip}) where {iip} length(u) == length(abstol) || throw(DimensionMismatch("numbers of states and tolerances do not match")) - all(ui > -tol for (ui, tol) in zip(u, abstol)) + return all(ui > -tol for (ui, tol) in zip(u, abstol)) end # specific method definitions for general domain callback # create array of residuals -function setup(f::GeneralDomainAffect, integrator) - f.resid isa Nothing ? (similar(integrator.u),) : (f.resid,) +setup(f::GeneralDomainAffect, integrator, ::Val{false}) = (nothing,) +function setup(f::GeneralDomainAffect, integrator, ::Val{true}) + return f.resid === nothing ? (similar(integrator.u),) : (f.resid,) end -function isaccepted(u, p, t, abstol, f::GeneralDomainAffect{autonomous, F, T, S, uType}, - resid) where {autonomous, F, T, S, uType} +function isaccepted(u, p, t, abstol, f::GeneralDomainAffect, ::Val{iip}, resid) where {iip} # calculate residuals - if autonomous + f.g.t = t + if iip f.g(resid, u, p) else - f.g(resid, u, p, t) + resid = f.g(u, p) end # accept time step if residuals are smaller than the tolerance @@ -214,26 +221,32 @@ end """ GeneralDomain( g, u = nothing; save = true, abstol = nothing, scalefactor = nothing, - autonomous = maximum(SciMLBase.numargs(g)) == 3, nlsolve_kwargs = (; - abstol = 10 * eps()), kwargs...) + autonomous = nothing, domain_jacobian = nothing, + nlsolve_kwargs = (; abstol = 10 * eps()), kwargs...) A `GeneralDomain` callback in DiffEqCallbacks.jl generalizes the concept of -a `PositiveDomain` callback to arbitrary domains. Domains are specified by -in-place functions `g(resid, u, p)` or `g(resid, u, p, t)` that calculate residuals of a -state vector `u` at time `t` relative to that domain, with `p` the parameters of the -corresponding integrator. As for `PositiveDomain`, steps are accepted if residuals -of the extrapolated values at the next time step are below -a certain tolerance. Moreover, this callback is automatically coupled with a -`ManifoldProjection` that keeps all calculated state vectors close to the desired -domain, but in contrast to a `PositiveDomain` callback the nonlinear solver in a -`ManifoldProjection` cannot guarantee that all state vectors of the solution are -actually inside the domain. Thus, a `PositiveDomain` callback should generally be -preferred. +a `PositiveDomain` callback to arbitrary domains. + +Domains are specified by + - in-place functions `g(resid, u, p)` or `g(resid, u, p, t)` if the corresponding + ODEProblem is an inplace problem, or + - out-of-place functions `g(u, p)` or `g(u, p, t)` if the corresponding ODEProblem is + an out-of-place problem. + +The function calculates residuals of a state vector `u` at time `t` relative to that domain, +with `p` the parameters of the corresponding integrator. + +As for `PositiveDomain`, steps are accepted if residuals of the extrapolated values at the +next time step are below a certain tolerance. Moreover, this callback is automatically +coupled with a `ManifoldProjection` that keeps all calculated state vectors close to the +desired domain, but in contrast to a `PositiveDomain` callback the nonlinear solver in a +`ManifoldProjection` cannot guarantee that all state vectors of the solution are actually +inside the domain. Thus, a `PositiveDomain` callback should generally be preferred. ## Arguments - - `g`: the implicit definition of the domain as a function `g(resid, u, p)` or - `g(resid, u, p, t)` which is zero when the value is in the domain. + - `g`: the implicit definition of the domain as a function as described above which is + zero when the value is in the domain. - `u`: A prototype of the state vector of the integrator. A copy of it is saved and extrapolated values are written to it. If it is not specified, every application of the callback allocates a new copy of the state vector. @@ -248,9 +261,13 @@ preferred. specified, time steps are halved. - `autonomous`: Whether `g` is an autonomous function of the form `g(resid, u, p)`. If it is not specified, it is determined automatically. - - `kwargs`: All other keyword arguments are passed to `ManifoldProjection`. + - `kwargs`: All other keyword arguments are passed to [`ManifoldProjection`](@ref). - `nlsolve_kwargs`: All keyword arguments are passed to the nonlinear solver in `ManifoldProjection`. The default is `(; abstol = 10 * eps())`. + - `domain_jacobian`: The Jacobian of the domain (wrt the state). This has the same + signature as `g` and the first argument is the Jacobian if inplace. This corresponds to + the `manifold_jacobian` argument of [`ManifoldProjection`](@ref). Note that passing + a `manifold_jacobian` is not supported for `GeneralDomain` and results in an error. ## References @@ -260,20 +277,27 @@ Non-negative solutions of ODEs. Applied Mathematics and Computation 170 """ function GeneralDomain( g, u = nothing; save = true, abstol = nothing, scalefactor = nothing, - autonomous = maximum(SciMLBase.numargs(g)) == 3, nlsolve_kwargs = (; - abstol = 10 * eps()), kwargs...) - _autonomous = SciMLBase._unwrap_val(autonomous) - if u isa Nothing - affect! = GeneralDomainAffect{_autonomous}(g, abstol, scalefactor, nothing, nothing) + autonomous = nothing, domain_jacobian = nothing, manifold_jacobian = missing, + nlsolve_kwargs = (; abstol = 10 * eps()), kwargs...) + if manifold_jacobian !== missing + throw(ArgumentError("`manifold_jacobian` is not supported for `GeneralDomain`. \ + Use `domain_jacobian` instead.")) + end + manifold_projection = ManifoldProjection( + g; save = false, autonomous, manifold_jacobian = domain_jacobian, + kwargs..., nlsolve_kwargs...) + domain = wrap_autonomous_function(autonomous, g) + domain_jacobian = wrap_autonomous_function(autonomous, domain_jacobian) + affect! = if u === nothing + GeneralDomainAffect(domain, abstol, scalefactor, nothing, nothing, autonomous) else - affect! = GeneralDomainAffect{_autonomous}(g, abstol, scalefactor, deepcopy(u), - deepcopy(u)) + GeneralDomainAffect( + domain, abstol, scalefactor, deepcopy(u), deepcopy(u), autonomous) end - condition = (u, t, integrator) -> true - CallbackSet( - ManifoldProjection( - g; save = false, autonomous, isinplace = Val(true), kwargs..., nlsolve_kwargs...), - DiscreteCallback(condition, affect!; save_positions = (false, save))) + domain_cb = DiscreteCallback( + Returns(true), affect!; initialize = initialize_general_domain_affect, + save_positions = (false, save)) + return CallbackSet(manifold_projection, domain_cb) end @doc doc""" diff --git a/src/manifold.jl b/src/manifold.jl index 74eef9db..9d427d40 100644 --- a/src/manifold.jl +++ b/src/manifold.jl @@ -31,7 +31,8 @@ properties. would work in most cases (See [1] for details). Alternatively, a nonlinear solver as defined in the [NonlinearSolve.jl format](https://docs.sciml.ai/NonlinearSolve/stable/basics/solve/) - can be specified. + can be specified. Additionally if NonlinearSolve.jl is loaded and `nothing` is specified + a polyalgorithm is used. - `save`: Whether to do the standard saving (applied after the callback) - `autonomous`: Whether `g` is an autonomous function of the form `g(resid, u, p)` or `g(u, p)`. Specify it as `Val(::Bool)` to disable runtime branching. If `nothing`, @@ -88,25 +89,8 @@ end function ManifoldProjection( manifold, autodiff, manifold_jacobian, nlsolve, kwargs, autonomous) - if autonomous isa Val{true} || autonomous isa Val{false} - wrapped_manifold = TypedNonAutonomousFunction{SciMLBase._unwrap_val(autonomous)}( - manifold, nothing) - wrapped_manifold_jacobian = if manifold_jacobian === nothing - nothing - else - TypedNonAutonomousFunction{SciMLBase._unwrap_val(autonomous)}( - manifold_jacobian, nothing) - end - autonomous = SciMLBase._unwrap_val(autonomous) - else - _autonomous = autonomous === nothing ? false : autonomous - wrapped_manifold = UntypedNonAutonomousFunction(_autonomous, manifold, nothing) - wrapped_manifold_jacobian = if manifold_jacobian === nothing - nothing - else - UntypedNonAutonomousFunction(_autonomous, manifold_jacobian, nothing) - end - end + wrapped_manifold = wrap_autonomous_function(autonomous, manifold) + wrapped_manifold_jacobian = wrap_autonomous_function(autonomous, manifold_jacobian) return ManifoldProjection(wrapped_manifold, wrapped_manifold_jacobian, autodiff, nothing, nlsolve, kwargs, autonomous) end @@ -158,7 +142,20 @@ end export ManifoldProjection # wrapper for non-autonomous functions -@concrete mutable struct TypedNonAutonomousFunction{autonomous} +function wrap_autonomous_function(autonomous::Union{Val{true}, Val{false}}, g) + g === nothing && return nothing + return TypedNonAutonomousFunction{SciMLBase._unwrap_val(autonomous)}(g, nothing) +end +function wrap_autonomous_function(autonomous::Union{Bool, Nothing}, g) + g === nothing && return nothing + autonomous = autonomous === nothing ? false : autonomous + return UntypedNonAutonomousFunction(autonomous, g, nothing) +end + +abstract type AbstractNonAutonomousFunction end + +@concrete mutable struct TypedNonAutonomousFunction{autonomous} <: + AbstractNonAutonomousFunction f t::Any end @@ -169,7 +166,7 @@ end (f::TypedNonAutonomousFunction{false})(u, p) = f.f(u, p, f.t) (f::TypedNonAutonomousFunction{true})(u, p) = f.f(u, p) -@concrete mutable struct UntypedNonAutonomousFunction +@concrete mutable struct UntypedNonAutonomousFunction <: AbstractNonAutonomousFunction autonomous::Bool f t::Any diff --git a/test/domain_tests.jl b/test/domain_tests.jl index 9aaee626..0130b246 100644 --- a/test/domain_tests.jl +++ b/test/domain_tests.jl @@ -1,4 +1,4 @@ -using DiffEqCallbacks, OrdinaryDiffEq, Test +using DiffEqCallbacks, OrdinaryDiffEq, Test, ADTypes, NonlinearSolve # Non-negative ODE examples # @@ -39,7 +39,11 @@ naive_sol_absval = solve(prob_absval, BS3()) function g(resid, u, p) resid[1] = u[1] < 0 ? -u[1] : 0 end -general_sol_absval = solve(prob_absval, BS3(); callback = GeneralDomain(g, [1.0]), +general_sol_absval = solve( + prob_absval, BS3(); + callback = GeneralDomain(g, [1.0]; + autodiff = AutoForwardDiff(), + nlsolve=NewtonRaphson(; autodiff = AutoForwardDiff())), save_everystep = false) @test all(x -> x[1] ā‰„ 0, general_sol_absval.u) @test general_sol_absval.errors[:lāˆž] < 9.9e-5 @@ -49,7 +53,11 @@ general_sol_absval = solve(prob_absval, BS3(); callback = GeneralDomain(g, [1.0] # test "non-autonomous" function g_t(resid, u, p, t) = g(resid, u, p) -general_t_sol_absval = solve(prob_absval, BS3(); callback = GeneralDomain(g_t, [1.0]), +general_t_sol_absval = solve( + prob_absval, BS3(); + callback = GeneralDomain(g_t, [1.0]; + autodiff = AutoForwardDiff(), + nlsolve=NewtonRaphson(; autodiff = AutoForwardDiff())), save_everystep = false) @test general_sol_absval.t ā‰ˆ general_t_sol_absval.t @test general_sol_absval.u ā‰ˆ general_t_sol_absval.u