Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Out of place QuadratureAdjoint for Working with StaticArrays #680

Merged
merged 33 commits into from
Aug 13, 2022
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d8ea45b
changes for handling sarrays out of place
ba2tripleO Jul 4, 2022
966f3b7
OOP adjoint for QuadratureAdjoint sensealg
ba2tripleO Jul 5, 2022
2df07ee
clean up implementation
ba2tripleO Jul 6, 2022
bdde708
Update adjoint_common.jl
ba2tripleO Jul 6, 2022
2c067ab
oop dispatch for dgdu function, some tests, some corrections
ba2tripleO Jul 11, 2022
a5c4574
Merge branch 'master' into sarray
ba2tripleO Jul 11, 2022
4a7722f
Merge remote-tracking branch 'Abhishek-1Bhatt/sarray' into sarray
ba2tripleO Jul 11, 2022
971cea0
Update src/concrete_solve.jl
ba2tripleO Jul 11, 2022
10843ad
revert df_iip plus correction in dgdu
ba2tripleO Jul 11, 2022
1dba742
Merge remote-tracking branch 'Abhishek-1Bhatt/sarray' into sarray
ba2tripleO Jul 11, 2022
6158ba9
Update adjoint_common.jl
ba2tripleO Jul 11, 2022
a2d90fd
Update adjoint_common.jl
ba2tripleO Jul 11, 2022
7f54125
Merge branch 'SciML:master' into sarray
ba2tripleO Jul 13, 2022
b18f9d3
tests, returns
ba2tripleO Jul 14, 2022
174812a
Updated adjoint_oop.jl
ba2tripleO Jul 14, 2022
04ade0d
formatter
ba2tripleO Jul 14, 2022
9140817
fix oop adjoint tests
ba2tripleO Jul 19, 2022
a3747b4
OOP Adjoint on Numerical solve
ba2tripleO Jul 20, 2022
15fadaf
bug fix, formatter
ba2tripleO Jul 20, 2022
336becf
Merge branch 'SciML:master' into sarray
ba2tripleO Jul 20, 2022
4cafd99
Merge remote-tracking branch 'Abhishek-1Bhatt/sarray' into sarray
ba2tripleO Jul 20, 2022
40b7f2a
fix
ba2tripleO Jul 20, 2022
6517591
fix error
ba2tripleO Jul 20, 2022
e6cf65b
Added adjoint_oop.jl to runtests.jl
ba2tripleO Jul 20, 2022
8c59e9a
Merge branch 'SciML:master' into sarray
ba2tripleO Aug 12, 2022
d354ba7
StaticArrays AD, tests
ba2tripleO Aug 12, 2022
eb850e9
Update Project.toml
ba2tripleO Aug 12, 2022
5971d27
Update Project.toml
ba2tripleO Aug 12, 2022
d9f7ed0
formatter
ba2tripleO Aug 12, 2022
e7c8905
ODEAdjointProblem fix
ba2tripleO Aug 12, 2022
cf07c99
Merge branch 'master' into sarray
ChrisRackauckas Aug 13, 2022
02f1ae2
format
ChrisRackauckas Aug 13, 2022
7340ca4
no Enzyme v1.6
ChrisRackauckas Aug 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,11 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AlgebraicMultigrid", "ComponentArrays", "Calculus", "Distributed", "DelayDiffEq", "Optimization", "OptimizationFlux", "OptimizationOptimJL", "Flux", "ReverseDiff", "SafeTestsets", "Test", "Random", "Pkg", "SteadyStateDiffEq", "NLsolve", "NonlinearSolve", "SparseArrays"]
test = ["AlgebraicMultigrid", "ComponentArrays", "Calculus", "Distributed", "DelayDiffEq", "Optimization", "OptimizationFlux", "OptimizationOptimJL", "Flux", "ReverseDiff", "SafeTestsets", "Test", "Random", "Pkg", "SteadyStateDiffEq", "NLsolve", "NonlinearSolve", "SparseArrays", "SimpleChains", "StaticArrays"]
47 changes: 32 additions & 15 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ inplace_sensitivity(S::SensitivityFunction) = isinplace(getprob(S))

struct ReverseLossCallback{λType, timeType, yType, RefType, FMType, AlgType, dg1Type,
dg2Type,
cacheType}
cacheType, solType}
isq::Bool
λ::λType
t::timeType
Expand All @@ -383,6 +383,7 @@ struct ReverseLossCallback{λType, timeType, yType, RefType, FMType, AlgType, dg
dgdu::dg1Type
dgdp::dg2Type
diffcache::cacheType
sol::solType
end

function ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time)
Expand All @@ -392,13 +393,17 @@ function ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time)
@unpack factorized_mass_matrix = sensefun.diffcache
prob = getprob(sensefun)
idx = length(prob.u0)

return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix,
sensealg, dgdu, dgdp, sensefun.diffcache)
if ArrayInterfaceCore.ismutable(y)
return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix,
sensealg, dgdu, dgdp, sensefun.diffcache, nothing)
else
return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix,
sensealg, dgdu, dgdp, sensefun.diffcache, sensefun.sol)
end
end

function (f::ReverseLossCallback)(integrator)
@unpack isq, λ, t, y, cur_time, idx, F, sensealg, dgdu, dgdp = f
@unpack isq, λ, t, y, cur_time, idx, F, sensealg, dgdu, dgdp, sol = f
@unpack diffvar_idxs, algevar_idxs, issemiexplicitdae, J, uf, f_cache, jac_config = f.diffcache

p, u = integrator.p, integrator.u
Expand All @@ -407,16 +412,23 @@ function (f::ReverseLossCallback)(integrator)
copyto!(y, integrator.u[(end - idx + 1):end])
end

# Warning: alias here! Be careful with λ
gᵤ = isq ? λ : @view(λ[1:idx])
if dgdu !== nothing
dgdu(gᵤ, y, p, t[cur_time[]], cur_time[])
# add discrete dgdp contribution
if dgdp !== nothing && !isq
gp = @view(λ[(idx + 1):end])
dgdp(gp, y, p, t[cur_time[]], cur_time[])
u[(idx + 1):length(λ)] .+= gp
if ArrayInterfaceCore.ismutable(u)
# Warning: alias here! Be careful with λ
gᵤ = isq ? λ : @view(λ[1:idx])
if dgdu !== nothing
dgdu(gᵤ, y, p, t[cur_time[]], cur_time[])
# add discrete dgdp contribution
if dgdp !== nothing && !isq
gp = @view(λ[(idx + 1):end])
dgdp(gp, y, p, t[cur_time[]], cur_time[])
u[(idx + 1):length(λ)] .+= gp
end
end
else
@assert sensealg isa QuadratureAdjoint
outtype = DiffEqBase.parameterless_type(λ)
y = sol(t[cur_time[]])
gᵤ = dgdu(y, p, t[cur_time[]], cur_time[]; outtype = outtype)
end

if issemiexplicitdae
Expand All @@ -434,7 +446,12 @@ function (f::ReverseLossCallback)(integrator)
F !== I && F !== (I, I) && ldiv!(F, Δλd)
end

u[diffvar_idxs] .+= Δλd
if ArrayInterfaceCore.ismutable(u)
u[diffvar_idxs] .+= Δλd
else
@assert sensealg isa QuadratureAdjoint
integrator.u += Δλd
end
u_modified!(integrator, true)
cur_time[] -= 1
return nothing
Expand Down
78 changes: 72 additions & 6 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg,
_save_idxs = save_idxs === nothing ? Colon() : save_idxs

function adjoint_sensitivity_backpass(Δ)
function df(_out, u, p, t, i)
function df_iip(_out, u, p, t, i)
outtype = typeof(_out) <: SubArray ?
DiffEqBase.parameterless_type(_out.parent) :
DiffEqBase.parameterless_type(_out)
Expand Down Expand Up @@ -307,16 +307,82 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg,
end
end

function df_oop(u, p, t, i; outtype = nothing)
if only_end
eltype(Δ) <: NoTangent && return
if typeof(Δ) <: AbstractArray{<:AbstractArray} && length(Δ) == 1 && i == 1
# user did sol[end] on only_end
if typeof(_save_idxs) <: Number
x = vec(Δ[1])
_out = adapt(outtype, @view(x[_save_idxs]))
elseif _save_idxs isa Colon
_out = adapt(outtype, vec(Δ[1]))
else
_out = adapt(outtype,
vec(Δ[1])[_save_idxs])
end
else
Δ isa NoTangent && return
if typeof(_save_idxs) <: Number
x = vec(Δ)
_out = adapt(outtype, @view(x[_save_idxs]))
elseif _save_idxs isa Colon
_out = adapt(outtype, vec(Δ))
else
x = vec(Δ)
_out = adapt(outtype, @view(x[_save_idxs]))
end
end
else
!Base.isconcretetype(eltype(Δ)) &&
(Δ[i] isa NoTangent || eltype(Δ) <: NoTangent) && return
if typeof(Δ) <: AbstractArray{<:AbstractArray} || typeof(Δ) <: DESolution
x = Δ[i]
if typeof(_save_idxs) <: Number
_out = @view(x[_save_idxs])
elseif _save_idxs isa Colon
_out = vec(x)
else
_out = vec(@view(x[_save_idxs]))
end
else
if typeof(_save_idxs) <: Number
_out = adapt(outtype,
reshape(Δ, prod(size(Δ)[1:(end - 1)]),
size(Δ)[end])[_save_idxs, i])
elseif _save_idxs isa Colon
_out = vec(adapt(outtype,
reshape(Δ, prod(size(Δ)[1:(end - 1)]),
size(Δ)[end])[:, i]))
else
ba2tripleO marked this conversation as resolved.
Show resolved Hide resolved
_out = vec(adapt(outtype,
reshape(Δ,
prod(size(Δ)[1:(end - 1)]),
size(Δ)[end])[:, i]))
end
end
end
return _out
end

if haskey(kwargs_adj, :callback_adj)
cb2 = CallbackSet(cb, kwargs[:callback_adj])
else
cb2 = cb
end

du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts, dgdu_discrete = df,
sensealg = sensealg,
callback = cb2,
kwargs_adj...)
if ArrayInterfaceCore.ismutable(eltype(sol.u))
du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts,
dgdu_discrete = df_iip,
sensealg = sensealg,
callback = cb2,
kwargs_adj...)
else
du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts,
dgdu_discrete = df_oop,
sensealg = sensealg,
callback = cb2,
kwargs_adj...)
end

du0 = reshape(du0, size(u0))
dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing :
Expand Down
56 changes: 56 additions & 0 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,12 @@ function vecjacobian!(dλ, y, λ, p, t, S::TS;
return
end

function vecjacobian(y, λ, p, t, S::TS;
dgrad = nothing, dy = nothing,
W = nothing) where {TS <: SensitivityFunction}
return _vecjacobian(y, λ, p, t, S, S.sensealg.autojacvec, dgrad, dy, W)
end

function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::Bool, dgrad, dy,
W) where {TS <: SensitivityFunction}
@unpack sensealg, f = S
Expand Down Expand Up @@ -573,6 +579,43 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad,
return
end

function _vecjacobian(y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy,
W) where {TS <: SensitivityFunction}
@unpack sensealg, f = S
prob = getprob(S)

isautojacvec = get_jacvec(sensealg)

if W === nothing
_dy, back = Zygote.pullback(y, p) do u, p
vec(f(u, p, t))
end
else
_dy, back = Zygote.pullback(y, p) do u, p
vec(f(u, p, t, W))
end
end

# Grab values from `_dy` before `back` in case mutated
dy !== nothing && (dy[:] .= vec(_dy))

tmp1, tmp2 = back(λ)
if tmp1 === nothing && !sensealg.autojacvec.allow_nothing
throw(ZygoteVJPNothingError())
elseif tmp1 !== nothing
(dλ = vec(tmp1))
end

if dgrad !== nothing
if tmp2 === nothing && !sensealg.autojacvec.allow_nothing
throw(ZygoteVJPNothingError())
elseif tmp2 !== nothing
(dgrad[:] .= vec(tmp2))
end
end
return dy, dλ, dgrad
end

function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, dy,
W) where {TS <: SensitivityFunction}
@unpack sensealg = S
Expand Down Expand Up @@ -867,6 +910,19 @@ function accumulate_cost!(dλ, y, p, t, S::TS,
return nothing
end

function accumulate_cost(dλ, y, p, t, S::TS,
dgrad = nothing) where {TS <: SensitivityFunction}
@unpack dgdu, dgdp = S.diffcache

dλ -= dgdu(y, p, t)
if dgdp !== nothing
if dgrad !== nothing
dgrad -= dgdp(y, p, t)
end
end
return dλ, dgrad
end

function build_jac_config(alg, uf, u)
if alg_autodiff(alg)
jac_config = ForwardDiff.JacobianConfig(uf, u, u,
Expand Down
59 changes: 50 additions & 9 deletions src/quadrature_adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@ function (S::ODEQuadratureAdjointSensitivityFunction)(du, u, p, t)
return nothing
end

function (S::ODEQuadratureAdjointSensitivityFunction)(u, p, t)
@unpack sol, discrete = S
f = sol.prob.f

λ, grad, y, dgrad, dy = split_states(u, t, S)

dy, dλ, dgrad = vecjacobian(y, λ, p, t, S; dgrad = dgrad, dy = dy)
dλ *= (-one(eltype(λ)))

if !discrete
dλ, dgrad = accumulate_cost(dλ, y, p, t, S, dgrad)
end
return dλ
end

function split_states(du, u, t, S::ODEQuadratureAdjointSensitivityFunction; update = true)
@unpack y, sol = S

Expand All @@ -49,6 +64,18 @@ function split_states(du, u, t, S::ODEQuadratureAdjointSensitivityFunction; upda
λ, nothing, y, dλ, nothing, nothing
end

function split_states(u, t, S::ODEQuadratureAdjointSensitivityFunction; update = true)
@unpack y, sol = S

if update
y = sol(t, continuity = :right)
end

λ = u

λ, nothing, y, nothing, nothing
ba2tripleO marked this conversation as resolved.
Show resolved Hide resolved
end

# g is either g(t,u,p) or discrete g(t,u,i)
@noinline function ODEAdjointProblem(sol, sensealg::QuadratureAdjoint,
t = nothing,
Expand Down Expand Up @@ -79,9 +106,13 @@ end
(dgdu_continuous === nothing && dgdp_continuous === nothing ||
g !== nothing))

len = length(u0)
λ = similar(u0, len)
λ .= false
if ArrayInterfaceCore.ismutable(u0)
len = length(u0)
λ = similar(u0, len)
λ .= false
else
λ = zero(u0)
end
sense = ODEQuadratureAdjointSensitivityFunction(g, sensealg, discrete, sol,
dgdu_continuous, dgdp_continuous)

Expand All @@ -102,7 +133,7 @@ end
odefun = ODEFunction(sense, mass_matrix = sol.prob.f.mass_matrix',
jac_prototype = adjoint_jac_prototype)
end
return ODEProblem(odefun, z0, tspan, p, callback = cb)
return ODEProblem{ArrayInterfaceCore.ismutable(z0)}(odefun, z0, tspan, p, callback = cb)
end

struct AdjointSensitivityIntegrand{pType, uType, lType, rateType, S, AS, PF, PJC, PJT, DGP,
Expand All @@ -129,7 +160,6 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing)
λ = zero(adj_sol.prob.u0)
# we need to alias `y`
f_cache = zero(y)
f_cache .= false
isautojacvec = get_jacvec(sensealg)

dgdp_cache = dgdp === nothing ? nothing : zero(p)
Expand Down Expand Up @@ -197,8 +227,14 @@ end
function (S::AdjointSensitivityIntegrand)(out, t)
@unpack y, λ, pJ, pf, p, f_cache, dgdp_cache, paramjac_config, sensealg, sol, adj_sol = S
f = sol.prob.f
sol(y, t)
adj_sol(λ, t)
# if eltype(sol.u) <: StaticArrays.SArray
if ArrayInterfaceCore.ismutable(eltype(sol.u))
sol(y, t)
adj_sol(λ, t)
else
y = sol(t)
λ = adj_sol(t)
end
isautojacvec = get_jacvec(sensealg)
# y is aliased

Expand Down Expand Up @@ -309,8 +345,13 @@ function _adjoint_sensitivities(sol, sensealg::QuadratureAdjoint, alg; t = nothi
end

for i in (length(t) - 1):-1:1
res .+= quadgk(integrand, t[i], t[i + 1],
atol = abstol, rtol = reltol)[1]
if ArrayInterfaceCore.ismutable(res)
res .+= quadgk(integrand, t[i], t[i + 1],
atol = abstol, rtol = reltol)[1]
else
res += quadgk(integrand, t[i], t[i + 1],
atol = abstol, rtol = reltol)[1]
end
if t[i] == t[i + 1]
integrand = update_integrand_and_dgrad(res, sensealg, callback,
integrand,
Expand Down
Loading