Skip to content

Commit

Permalink
feat: lagrangian multiplier based projection algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 26, 2024
1 parent d44a4b2 commit 7b52768
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 26 deletions.
125 changes: 105 additions & 20 deletions src/manifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ properties.
`nlsolve` is not `missing`.
- `autodiff`: The autodifferentiation algorithm to use to compute the Jacobian if
`manifold_jacobian` is not specified. This must be specified if `manifold_jacobian` is
not specified and `nlsolve` is `missing`. If `nlsolve` is not `missing`, then
`autodiff` is ignored.
not specified.
- `manifold_jacobian`: The Jacobian of the manifold (wrt the state). This has the same
signature as `manifold` and the first argument is the Jacobian if inplace.
Expand Down Expand Up @@ -118,13 +117,7 @@ function (proj::ManifoldProjection)(integrator)
proj.manifold_jacobian !== nothing && (proj.manifold_jacobian.t = integrator.t)

SciMLBase.reinit!(proj.nlcache, integrator.u; integrator.p)

if proj.nlsolve === missing
_, u, retcode = SciMLBase.solve!(proj.nlcache)
else
sol = SciMLBase.solve!(proj.nlcache)
(; u, retcode) = sol
end
_, u, retcode = SciMLBase.solve!(proj.nlcache)

if !SciMLBase.successful_retcode(retcode)
SciMLBase.terminate!(integrator, retcode)
Expand All @@ -146,17 +139,17 @@ function initialize_manifold_projection(affect!::ManifoldProjection, u, t, integ
(affect!.manifold_jacobian.autonomous = autonomous)
end

affect!.manifold.t = t
affect!.manifold_jacobian !== nothing && (affect!.manifold_jacobian.t = t)

if affect!.nlsolve === missing
affect!.manifold.t = t
affect!.manifold_jacobian !== nothing && (affect!.manifold_jacobian.t = t)
cache = init_manifold_projection(
Val(SciMLBase.isinplace(integrator.f)), affect!.manifold, affect!.autodiff,
affect!.manifold_jacobian, u, integrator.p; affect!.kwargs...)
else
# nlfunc = NonlinearFunction{iip}(affect!.g; affect!.resid_prototype)
# nlprob = NonlinearProblem(nlfunc, u, integrator.p)
# affect!.nlcache = init(nlprob, affect!.nlsolve; affect!.kwargs...)
error("Not Implemented")
cache = init_manifold_projection_nonlinear_problem(
Val(SciMLBase.isinplace(integrator.f)), affect!.manifold, affect!.autodiff,
affect!.manifold_jacobian, u, integrator.p, affect!.nlsolve; affect!.kwargs...)
end
affect!.nlcache = cache
u_modified!(integrator, false)
Expand Down Expand Up @@ -187,6 +180,97 @@ function (f::UntypedNonAutonomousFunction)(res, u, p)
end
(f::UntypedNonAutonomousFunction)(u, p) = f.autonomous ? f.f(u, p) : f.f(u, p, f.t)

# This is solving the langrange multiplier formulation. This is more accurate but at the
# same time significantly more expensive.
@concrete mutable struct NonlinearSolveManifoldProjectionCache{iip}
manifold
p
λ
z
gu_cache
nlcache

first_call::Bool
J
manifold_jacobian
autodiff
di_extras
end

function SciMLBase.reinit!(
cache::NonlinearSolveManifoldProjectionCache{iip}, u; p = cache.p) where {iip}
if !cache.first_call || (cache.!== u || cache.p !== p)
compute_manifold_jacobian!(cache.J, cache.manifold_jacobian, cache.autodiff,
Val(iip), cache.manifold, cache.gu_cache, u, p, cache.di_extras)
end
cache.first_call = false
cache.= u
cache.p = p

cache.z[1:length(cache.λ)] .= false
cache.z[(length(cache.λ) + 1):end] .= vec(u)
SciMLBase.reinit!(cache.nlcache, cache.z; p = (u, cache.J, p))
end

function init_manifold_projection_nonlinear_problem(
IIP::Val{iip}, manifold, autodiff, manifold_jacobian, ũ, p, alg;
resid_prototype = nothing, kwargs...) where {iip}
if iip
if resid_prototype !== nothing
gu = similar(resid_prototype)
λ = similar(resid_prototype)
else
@warn "`resid_prototype` not provided for in-place problem. Assuming size of \
output is the same as input. This might be incorrect." maxlog=1
gu = similar(ũ)
λ = similar(ũ)
end
else
gu = nothing
λ = manifold(ũ, p)
end

J, di_extras = setup_manifold_jacobian(manifold_jacobian, autodiff, IIP, manifold,
gu, ũ, p)
z = vcat(vec(λ), vec(ũ))

nlfunc = if iip
let λlen = length(λ), λsz = size(λ), zsz = size(ũ)
@views (resid, u, ps) -> begin
ũ2, J2, p2 = ps
λ2, z2 = u[1:λlen], u[(λlen + 1):end]
manifold(reshape(resid[1:λlen], λsz), reshape(z2, zsz), p2)
resid[(λlen + 1):end] .= z2 .- vec(ũ2) .+ vec(vec(J2' * λ2))
end
end
else
let λlen = length(λ), zsz = size(ũ)
@views (u, ps) -> begin
ũ2, J2, p2 = ps
λ2, z2 = u[1:λlen], u[(λlen + 1):end]
gz = vec(manifold(reshape(z2, zsz), p2))
resid = z2 .- vec(ũ2) .+ vec(J2' * λ2)
return vcat(gz, resid)
end
end
end

nlprob = NonlinearProblem(NonlinearFunction{iip}(nlfunc), z, (ũ, J, p))
nlcache = SciMLBase.init(nlprob, alg; kwargs...)

return NonlinearSolveManifoldProjectionCache{iip}(
manifold, p, λ, z, ũ, gu, nlcache, true, J, manifold_jacobian, autodiff, di_extras)
end

@views function SciMLBase.solve!(cache::NonlinearSolveManifoldProjectionCache{iip}) where {iip}
sol = SciMLBase.solve!(cache.nlcache)
(; u, retcode) = sol
λ = reshape(u[1:length(cache.λ)], size(cache.λ))
= reshape(u[(length(cache.λ) + 1):end], size(cache.ũ))
return λ, ũ, retcode
end

# This is the algorithm described in Hairer III.
@concrete mutable struct SingleFactorizeManifoldProjectionCache{iip}
manifold
Expand Down Expand Up @@ -225,7 +309,7 @@ default_abstol(::Type{T}) where {T} = real(oneunit(T)) * (eps(real(one(T))))^(4

function init_manifold_projection(IIP::Val{iip}, manifold, autodiff, manifold_jacobian, ũ,
p; abstol = default_abstol(eltype(ũ)), maxiters = 1000,
resid_prototype = nothing) where {iip}
resid_prototype = nothing, kwargs...) where {iip}
if iip
if resid_prototype !== nothing
gu = similar(resid_prototype)
Expand Down Expand Up @@ -309,6 +393,11 @@ function setup_manifold_jacobian(
return J, di_extras
end

function setup_manifold_jacobian(
::Nothing, ::Nothing, ::Val{iip}, manifold, gu, ũ, p) where {iip}
error("`autodiff` is set to `nothing` and analytic manifold jacobian is not provided.")
end

function compute_manifold_jacobian!(J, manifold_jacobian, autodiff, ::Val{iip},
manifold, gu, ũ, p, di_extras) where {iip}
if iip
Expand All @@ -329,10 +418,6 @@ function compute_manifold_jacobian!(J, ::Nothing, autodiff, ::Val{iip}, manifold
return J
end

function setup_manifold_jacobian(::Nothing, ::Nothing, args...)
error("`autodiff` is set to `nothing` and analytic manifold jacobian is not provided.")
end

function safe_factorize!(A::AbstractMatrix)
if issquare(A)
fact = LinearAlgebra.cholesky(A; check = false)
Expand Down
27 changes: 21 additions & 6 deletions test/manifold_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,27 @@ solve(prob, Vern7(), callback = cb_t)

# autodiff=false
cb_false = ManifoldProjection(
g; nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff()), resid_prototype = zeros(2))
g; nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff()), resid_prototype = zeros(2),
autodiff = AutoFiniteDiff())
solve(prob, Vern7(), callback = cb_false)
sol = solve(prob, Vern7(), callback = cb_false)
@test sol.u[end][1]^2 + sol.u[end][2]^2 2

cb_t_false = ManifoldProjection(g_t,
nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff()), resid_prototype = zeros(2))
nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff()), resid_prototype = zeros(2),
autodiff = AutoFiniteDiff())
solve(prob, Vern7(), callback = cb_t_false)
sol_t = solve(prob, Vern7(), callback = cb_t_false)
@test sol_t.u == sol.u && sol_t.t == sol.t

# test array partitions
function f_ap!(du, u, p, t)
du[1:2] .= u[3:4]
du[3:4] .= u[1:2]
end

u₀ = ArrayPartition(ones(2), ones(2))
prob = ODEProblem(f, u₀, (0.0, 100.0))
prob = ODEProblem(f_ap!, u₀, (0.0, 100.0))

sol = solve(prob, Vern7(), callback = cb)
@test sol.u[end][1]^2 + sol.u[end][2]^2 2
Expand All @@ -71,6 +78,12 @@ sol = solve(prob, Vern7(), callback = cb_unsat)
@test !SciMLBase.successful_retcode(sol)
@test last(sol.t) != 100.0

cb_unsat = ManifoldProjection(
g_unsat; resid_prototype = zeros(2), autodiff = AutoForwardDiff(), nlsolve = NewtonRaphson())
sol = solve(prob, Vern7(), callback = cb_unsat)
@test !SciMLBase.successful_retcode(sol)
@test last(sol.t) != 100.0

# Tests for OOP Manifold Projection
function g_oop(u, p)
return [u[2]^2 + u[1]^2 - 2
Expand Down Expand Up @@ -98,20 +111,22 @@ solve(prob, Vern7(), callback = cb_t)

# autodiff=false
cb_false = ManifoldProjection(
g_oop; nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff()), isinplace = Val(false))
g_oop; nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff()), autodiff = AutoFiniteDiff())
solve(prob, Vern7(), callback = cb_false)
sol = solve(prob, Vern7(), callback = cb_false)
@test sol.u[end][1]^2 + sol.u[end][2]^2 2

cb_t_false = ManifoldProjection(g_oop_t,
nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff()), isinplace = Val(false))
nlsolve = NewtonRaphson(; autodiff = AutoFiniteDiff()), autodiff = AutoFiniteDiff())
solve(prob, Vern7(), callback = cb_t_false)
sol_t = solve(prob, Vern7(), callback = cb_t_false)
@test sol_t.u == sol.u && sol_t.t == sol.t

# test array partitions
f_ap(u, p, t) = ArrayPartition(u[3:4], u[1:2])

u₀ = ArrayPartition(ones(2), ones(2))
prob = ODEProblem(f, u₀, (0.0, 100.0))
prob = ODEProblem(f_ap, u₀, (0.0, 100.0))

sol = solve(prob, Vern7(), callback = cb)
@test sol.u[end][1]^2 + sol.u[end][2]^2 2
Expand Down

0 comments on commit 7b52768

Please sign in to comment.