From d8ea45bd2368d8605575a4da3105dc0dd3c667d5 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Mon, 4 Jul 2022 14:15:10 +0530 Subject: [PATCH 01/25] changes for handling sarrays out of place --- Project.toml | 1 + src/SciMLSensitivity.jl | 1 + src/adjoint_common.jl | 8 +++-- src/concrete_solve.jl | 30 +++++++++++++---- src/derivative_wrappers.jl | 67 ++++++++++++++++++++++++++++++++++++++ src/quadrature_adjoint.jl | 34 ++++++++++++++++--- 6 files changed, 128 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index 48e386174..25da29452 100644 --- a/Project.toml +++ b/Project.toml @@ -32,6 +32,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 6c35e9a6a..d83e60481 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -14,6 +14,7 @@ import ZygoteRules, Zygote, ReverseDiff import ArrayInterfaceCore, ArrayInterfaceTracker import Enzyme import GPUArraysCore +import StaticArrays using Cassette, DiffRules using Core: CodeInfo, SlotNumber, SSAValue, ReturnNode, GotoIfNot diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 57225ba53..4bdc6e671 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -419,8 +419,12 @@ function (f::ReverseLossCallback)(integrator) if F !== nothing F !== I && F !== (I, I) && ldiv!(F, Δλd) end - - u[diffvar_idxs] .+= Δλd + + if u isa StaticArrays.SArray + u += Δλd + else + u[diffvar_idxs] .+= Δλd + end u_modified!(integrator, true) cur_time[] -= 1 return nothing diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 5574d807f..9d7b506fc 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -258,7 +258,11 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, x = vec(Δ[1]) _out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs])) elseif _save_idxs isa Colon - vec(_out) .= adapt(outtype, vec(Δ[1])) + if _out isa StaticArrays.SArray + _out = adapt(outtype, vec(Δ[1])) + else + vec(_out) .= adapt(outtype, vec(Δ[1])) + end else vec(@view(_out[_save_idxs])) .= adapt(outtype, vec(Δ[1])[_save_idxs]) @@ -269,7 +273,11 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, x = vec(Δ) _out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs])) elseif _save_idxs isa Colon - vec(_out) .= adapt(outtype, vec(Δ)) + if _out isa StaticArrays.SArray + _out = adapt(outtype, vec(Δ)) + else + vec(_out) .= adapt(outtype, vec(Δ)) + end else x = vec(Δ) vec(@view(_out[_save_idxs])) .= adapt(outtype, @view(x[_save_idxs])) @@ -283,7 +291,11 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, if typeof(_save_idxs) <: Number _out[_save_idxs] = @view(x[_save_idxs]) elseif _save_idxs isa Colon - vec(_out) .= vec(x) + if _out isa StaticArrays.SArray + _out = vec(x) + else + vec(_out) .= vec(x) + end else vec(@view(_out[_save_idxs])) .= vec(@view(x[_save_idxs])) end @@ -293,9 +305,15 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, reshape(Δ, prod(size(Δ)[1:(end - 1)]), size(Δ)[end])[_save_idxs, i]) elseif _save_idxs isa Colon - vec(_out) .= vec(adapt(outtype, - reshape(Δ, prod(size(Δ)[1:(end - 1)]), - size(Δ)[end])[:, i])) + if _out isa StaticArrays.SArray + _out = vec(adapt(outtype, + reshape(Δ, prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[:, i])) + else + vec(_out) .= vec(adapt(outtype, + reshape(Δ, prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[:, i])) + end else vec(@view(_out[_save_idxs])) .= vec(adapt(outtype, reshape(Δ, diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 4f069a8e9..5b1c4be02 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -231,6 +231,12 @@ function vecjacobian!(dλ, y, λ, p, t, S::TS; return end +function vecjacobian(y, λ, p, t, S::TS; + dgrad = nothing, dy = nothing, + W = nothing) where {TS <: SensitivityFunction} + return _vecjacobian(y, λ, p, t, S, S.sensealg.autojacvec, dgrad, dy, W) +end + function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::Bool, dgrad, dy, W) where {TS <: SensitivityFunction} @unpack sensealg, f = S @@ -582,6 +588,44 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, return end +function _vecjacobian(y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy, + W) where {TS <: SensitivityFunction} + @unpack sensealg, f = S + prob = getprob(S) + + isautojacvec = get_jacvec(sensealg) + + if W === nothing + _dy, back = Zygote.pullback(y, p) do u, p + vec(f(u, p, t)) + end + else + _dy, back = Zygote.pullback(y, p) do u, p + vec(f(u, p, t, W)) + end + end + + # Grab values from `_dy` before `back` in case mutated + dy !== nothing && (dy[:] .= vec(_dy)) + + tmp1, tmp2 = back(λ) + if tmp1 === nothing && !sensealg.autojacvec.allow_nothing + throw(ZygoteVJPNothingError()) + elseif tmp1 !== nothing + (dλ = vec(tmp1)) + end + + if dgrad !== nothing + if tmp2 === nothing && !sensealg.autojacvec.allow_nothing + throw(ZygoteVJPNothingError()) + elseif tmp2 !== nothing + (dgrad[:] .= vec(tmp2)) + end + end + @show typeof(dλ), dλ + return dλ +end + function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, dy, W) where {TS <: SensitivityFunction} @unpack sensealg = S @@ -868,6 +912,29 @@ function accumulate_cost!(dλ, y, p, t, S::TS, return nothing end +function accumulate_cost(dλ, y, p, t, S::TS, + dgrad = nothing) where {TS <: SensitivityFunction} + @unpack dg, dg_val, g, g_grad_config = S.diffcache + if dg !== nothing + if !(dg isa Tuple) + dg(dg_val, y, p, t) + dλ -= vec(dg_val) + else + dg[1](dg_val[1], y, p, t) + dλ -= vec(dg_val[1]) + if dgrad !== nothing + dg[2](dg_val[2], y, p, t) + dgrad .-= vec(dg_val[2]) + end + end + else + g.t = t + gradient!(dg_val, g, y, S.sensealg, g_grad_config) + dλ -= vec(dg_val) + end + return dλ +end + function build_jac_config(alg, uf, u) if alg_autodiff(alg) jac_config = ForwardDiff.JacobianConfig(uf, u, u, diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index ec2f6145f..a362f3c0e 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -31,6 +31,21 @@ function (S::ODEQuadratureAdjointSensitivityFunction)(du, u, p, t) return nothing end +function (S::ODEQuadratureAdjointSensitivityFunction)(u, p, t) + @unpack sol, discrete = S + f = sol.prob.f + + λ, grad, y, dgrad, dy = split_states(u, t, S) + + dλ = vecjacobian(y, λ, p, t, S) * (-one(eltype(λ))) + + if !discrete + return accumulate_cost(dλ, y, p, t, S) + end + @show typeof(λ), λ, typeof(dλ), dλ + return dλ +end + function split_states(du, u, t, S::ODEQuadratureAdjointSensitivityFunction; update = true) @unpack y, sol = S @@ -48,6 +63,18 @@ function split_states(du, u, t, S::ODEQuadratureAdjointSensitivityFunction; upda λ, nothing, y, dλ, nothing, nothing end +function split_states(u, t, S::ODEQuadratureAdjointSensitivityFunction; update = true) + @unpack y, sol = S + + if update + y = sol(t, continuity = :right) + end + + λ = u + + λ, nothing, y, nothing, nothing +end + # g is either g(t,u,p) or discrete g(t,u,i) @noinline function ODEAdjointProblem(sol, sensealg::QuadratureAdjoint, t = nothing, @@ -70,9 +97,7 @@ end discrete = (t !== nothing && dg_continuous === nothing) - len = length(u0) - λ = similar(u0, len) - λ .= false + λ = zero(u0) sense = ODEQuadratureAdjointSensitivityFunction(g, sensealg, discrete, sol, dg_continuous) @@ -92,7 +117,7 @@ end odefun = ODEFunction(sense, mass_matrix = sol.prob.f.mass_matrix', jac_prototype = adjoint_jac_prototype) end - return ODEProblem(odefun, z0, tspan, p, callback = cb) + return ODEProblem{!(z0 isa StaticArrays.SArray)}(odefun, z0, tspan, p, callback = cb) end struct AdjointSensitivityIntegrand{pType, uType, lType, rateType, S, AS, PF, PJC, PJT, DGP, @@ -119,7 +144,6 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing) λ = zero(adj_sol.prob.u0) # we need to alias `y` f_cache = zero(y) - f_cache .= false isautojacvec = get_jacvec(sensealg) dgdp_cache = dgdp === nothing ? nothing : zero(p) From 966f3b767a89754ac554dd86649631c2f0cb0165 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Wed, 6 Jul 2022 00:15:14 +0530 Subject: [PATCH 02/25] OOP adjoint for QuadratureAdjoint sensealg --- src/adjoint_common.jl | 13 +- src/concrete_solve.jl | 235 +++++++++++++++++++++++++++++++++---- src/derivative_wrappers.jl | 1 - src/quadrature_adjoint.jl | 10 +- 4 files changed, 227 insertions(+), 32 deletions(-) diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 4bdc6e671..2cdea5ae2 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -401,9 +401,14 @@ function (f::ReverseLossCallback)(integrator) copyto!(y, integrator.u[(end - idx + 1):end]) end - # Warning: alias here! Be careful with λ - gᵤ = isq ? λ : @view(λ[1:idx]) - g(gᵤ, y, p, t[cur_time[]], cur_time[]) + if u isa StaticArrays.SArray + gᵤ = isq ? λ : @view(λ[1:idx]) + gᵤ = g(gᵤ, y, p, t[cur_time[]], cur_time[]) + else + # Warning: alias here! Be careful with λ + gᵤ = isq ? λ : @view(λ[1:idx]) + g(gᵤ, y, p, t[cur_time[]], cur_time[]) + end if issemiexplicitdae jacobian!(J, uf, y, f_cache, sensealg, jac_config) @@ -421,7 +426,7 @@ function (f::ReverseLossCallback)(integrator) end if u isa StaticArrays.SArray - u += Δλd + integrator.u += Δλd else u[diffvar_idxs] .+= Δλd end diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 9d7b506fc..e5ce1faab 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -258,11 +258,7 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, x = vec(Δ[1]) _out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs])) elseif _save_idxs isa Colon - if _out isa StaticArrays.SArray - _out = adapt(outtype, vec(Δ[1])) - else - vec(_out) .= adapt(outtype, vec(Δ[1])) - end + vec(_out) .= adapt(outtype, vec(Δ[1])) else vec(@view(_out[_save_idxs])) .= adapt(outtype, vec(Δ[1])[_save_idxs]) @@ -273,11 +269,7 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, x = vec(Δ) _out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs])) elseif _save_idxs isa Colon - if _out isa StaticArrays.SArray - _out = adapt(outtype, vec(Δ)) - else - vec(_out) .= adapt(outtype, vec(Δ)) - end + vec(_out) .= adapt(outtype, vec(Δ)) else x = vec(Δ) vec(@view(_out[_save_idxs])) .= adapt(outtype, @view(x[_save_idxs])) @@ -291,11 +283,7 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, if typeof(_save_idxs) <: Number _out[_save_idxs] = @view(x[_save_idxs]) elseif _save_idxs isa Colon - if _out isa StaticArrays.SArray - _out = vec(x) - else - vec(_out) .= vec(x) - end + vec(_out) .= vec(x) else vec(@view(_out[_save_idxs])) .= vec(@view(x[_save_idxs])) end @@ -305,15 +293,213 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, reshape(Δ, prod(size(Δ)[1:(end - 1)]), size(Δ)[end])[_save_idxs, i]) elseif _save_idxs isa Colon - if _out isa StaticArrays.SArray - _out = vec(adapt(outtype, - reshape(Δ, prod(size(Δ)[1:(end - 1)]), - size(Δ)[end])[:, i])) - else - vec(_out) .= vec(adapt(outtype, - reshape(Δ, prod(size(Δ)[1:(end - 1)]), - size(Δ)[end])[:, i])) - end + vec(_out) .= vec(adapt(outtype, + reshape(Δ, prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[:, i])) + else + vec(@view(_out[_save_idxs])) .= vec(adapt(outtype, + reshape(Δ, + prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[:, + i])) + end + end + end + end + + if haskey(kwargs_adj, :callback_adj) + cb2 = CallbackSet(cb, kwargs[:callback_adj]) + else + cb2 = cb + end + + du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts, dg_discrete = df, + sensealg = sensealg, + callback = cb2, + kwargs_adj...) + + du0 = reshape(du0, size(u0)) + dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing : + reshape(dp', size(p)) + + if originator isa SciMLBase.TrackerOriginator || + originator isa SciMLBase.ReverseDiffOriginator + (NoTangent(), NoTangent(), du0, dp, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + else + (NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(), + ntuple(_ -> NoTangent(), length(args))...) + end + end + out, adjoint_sensitivity_backpass +end + +function DiffEqBase._concrete_solve_adjoint(prob, alg, + sensealg::AbstractAdjointSensitivityAlgorithm, + u0::StaticArrays.SVector, p, originator::SciMLBase.ADOriginator, + args...; save_start = true, save_end = true, + saveat = eltype(prob.tspan)[], + save_idxs = nothing, + kwargs...) + if !(typeof(p) <: Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || + (p isa AbstractArray && !Base.isconcretetype(eltype(p))) + throw(AdjointSensitivityParameterCompatibilityError()) + end + + # Remove saveat, etc. from kwargs since it's handled separately + # and letting it jump back in there can break the adjoint + kwargs_prob = NamedTuple(filter(x -> x[1] != :saveat && x[1] != :save_start && + x[1] != :save_end && x[1] != :save_idxs, + prob.kwargs)) + + if haskey(kwargs, :callback) + cb = track_callbacks(CallbackSet(kwargs[:callback]), prob.tspan[1], prob.u0, prob.p, + sensealg) + _prob = remake(prob; u0 = u0, p = p, kwargs = merge(kwargs_prob, (; callback = cb))) + else + cb = nothing + _prob = remake(prob; u0 = u0, p = p, kwargs = kwargs_prob) + end + + # Remove callbacks, saveat, etc. from kwargs since it's handled separately + kwargs_fwd = NamedTuple{Base.diff_names(Base._nt_names(values(kwargs)), (:callback,))}(values(kwargs)) + + # Capture the callback_adj for the reverse pass and remove both callbacks + kwargs_adj = NamedTuple{ + Base.diff_names(Base._nt_names(values(kwargs)), + (:callback_adj, :callback))}(values(kwargs)) + isq = sensealg isa QuadratureAdjoint + if typeof(sensealg) <: BacksolveAdjoint + sol = solve(_prob, alg, args...; save_noise = true, + save_start = save_start, save_end = save_end, + saveat = saveat, kwargs_fwd...) + elseif ischeckpointing(sensealg) + sol = solve(_prob, alg, args...; save_noise = true, + save_start = true, save_end = true, + saveat = saveat, kwargs_fwd...) + else + sol = solve(_prob, alg, args...; save_noise = true, save_start = true, + save_end = true, kwargs_fwd...) + end + + # Force `save_start` and `save_end` in the forward pass This forces the + # solver to do the backsolve all the way back to `u0` Since the start aliases + # `_prob.u0`, this doesn't actually use more memory But it cleans up the + # implementation and makes `save_start` and `save_end` arg safe. + if typeof(sensealg) <: BacksolveAdjoint + # Saving behavior unchanged + ts = sol.t + only_end = length(ts) == 1 && ts[1] == _prob.tspan[2] + out = DiffEqBase.sensitivity_solution(sol, sol.u, ts) + elseif saveat isa Number + if _prob.tspan[2] > _prob.tspan[1] + ts = _prob.tspan[1]:convert(typeof(_prob.tspan[2]), abs(saveat)):_prob.tspan[2] + else + ts = _prob.tspan[2]:convert(typeof(_prob.tspan[2]), abs(saveat)):_prob.tspan[1] + end + # if _prob.tspan[2]-_prob.tspan[1] is not a multiple of saveat, one looses the last ts value + sol.t[end] !== ts[end] && (ts = fix_endpoints(sensealg, sol, ts)) + if cb === nothing + _out = sol(ts) + else + _, duplicate_iterator_times = separate_nonunique(sol.t) + _out, ts = out_and_ts(ts, duplicate_iterator_times, sol) + end + + out = if save_idxs === nothing + out = DiffEqBase.sensitivity_solution(sol, _out.u, ts) + else + out = DiffEqBase.sensitivity_solution(sol, + [_out[i][save_idxs] + for i in 1:length(_out)], ts) + end + only_end = length(ts) == 1 && ts[1] == _prob.tspan[2] + elseif isempty(saveat) + no_start = !save_start + no_end = !save_end + sol_idxs = 1:length(sol) + no_start && (sol_idxs = sol_idxs[2:end]) + no_end && (sol_idxs = sol_idxs[1:(end - 1)]) + only_end = length(sol_idxs) <= 1 + _u = sol.u[sol_idxs] + u = save_idxs === nothing ? _u : [x[save_idxs] for x in _u] + ts = sol.t[sol_idxs] + out = DiffEqBase.sensitivity_solution(sol, u, ts) + else + _saveat = saveat isa Array ? sort(saveat) : saveat # for minibatching + if cb === nothing + _saveat = eltype(_saveat) <: typeof(prob.tspan[2]) ? + convert.(typeof(_prob.tspan[2]), _saveat) : _saveat + ts = _saveat + _out = sol(ts) + else + _ts, duplicate_iterator_times = separate_nonunique(sol.t) + _out, ts = out_and_ts(_saveat, duplicate_iterator_times, sol) + end + + out = if save_idxs === nothing + out = DiffEqBase.sensitivity_solution(sol, _out.u, ts) + else + out = DiffEqBase.sensitivity_solution(sol, + [_out[i][save_idxs] + for i in 1:length(_out)], ts) + end + only_end = length(ts) == 1 && ts[1] == _prob.tspan[2] + end + + _save_idxs = save_idxs === nothing ? Colon() : save_idxs + + function adjoint_sensitivity_backpass(Δ) + function df(_out, u, p, t, i) + outtype = typeof(_out) <: SubArray ? + DiffEqBase.parameterless_type(_out.parent) : + DiffEqBase.parameterless_type(_out) + if only_end + eltype(Δ) <: NoTangent && return + if typeof(Δ) <: AbstractArray{<:AbstractArray} && length(Δ) == 1 && i == 1 + # user did sol[end] on only_end + if typeof(_save_idxs) <: Number + x = vec(Δ[1]) + _out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs])) + elseif _save_idxs isa Colon + _out = adapt(outtype, vec(Δ[1])) + else + vec(@view(_out[_save_idxs])) .= adapt(outtype, + vec(Δ[1])[_save_idxs]) + end + else + Δ isa NoTangent && return + if typeof(_save_idxs) <: Number + x = vec(Δ) + _out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs])) + elseif _save_idxs isa Colon + _out = adapt(outtype, vec(Δ)) + else + x = vec(Δ) + vec(@view(_out[_save_idxs])) .= adapt(outtype, @view(x[_save_idxs])) + end + end + else + !Base.isconcretetype(eltype(Δ)) && + (Δ[i] isa NoTangent || eltype(Δ) <: NoTangent) && return + if typeof(Δ) <: AbstractArray{<:AbstractArray} || typeof(Δ) <: DESolution + x = Δ[i] + if typeof(_save_idxs) <: Number + _out[_save_idxs] = @view(x[_save_idxs]) + elseif _save_idxs isa Colon + _out = vec(x) + else + vec(@view(_out[_save_idxs])) .= vec(@view(x[_save_idxs])) + end + else + if typeof(_save_idxs) <: Number + _out[_save_idxs] = adapt(outtype, + reshape(Δ, prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[_save_idxs, i]) + elseif _save_idxs isa Colon + _out = vec(adapt(outtype, + reshape(Δ, prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[:, i]))######Required Assignment################# else vec(@view(_out[_save_idxs])) .= vec(adapt(outtype, reshape(Δ, @@ -323,6 +509,7 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, end end end + return _out end if haskey(kwargs_adj, :callback_adj) diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 5b1c4be02..1dac7a791 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -622,7 +622,6 @@ function _vecjacobian(y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy, (dgrad[:] .= vec(tmp2)) end end - @show typeof(dλ), dλ return dλ end diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index a362f3c0e..a7f87cc75 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -42,7 +42,6 @@ function (S::ODEQuadratureAdjointSensitivityFunction)(u, p, t) if !discrete return accumulate_cost(dλ, y, p, t, S) end - @show typeof(λ), λ, typeof(dλ), dλ return dλ end @@ -211,8 +210,13 @@ end function (S::AdjointSensitivityIntegrand)(out, t) @unpack y, λ, pJ, pf, p, f_cache, dgdp_cache, paramjac_config, sensealg, sol, adj_sol = S f = sol.prob.f - sol(y, t) - adj_sol(λ, t) + if eltype(sol.u) <: StaticArrays.SArray + y = sol(t) + λ = adj_sol(t) + else + sol(y, t) + adj_sol(λ, t) + end isautojacvec = get_jacvec(sensealg) # y is aliased From 2df07eecd622b4467317aa2fd000e72b5c37dfe3 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Wed, 6 Jul 2022 12:55:37 +0530 Subject: [PATCH 03/25] clean up implementation --- Project.toml | 1 - src/SciMLSensitivity.jl | 1 - src/adjoint_common.jl | 18 +-- src/concrete_solve.jl | 236 +++++--------------------------------- src/quadrature_adjoint.jl | 11 +- 5 files changed, 43 insertions(+), 224 deletions(-) diff --git a/Project.toml b/Project.toml index 25da29452..48e386174 100644 --- a/Project.toml +++ b/Project.toml @@ -32,7 +32,6 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index d83e60481..6c35e9a6a 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -14,7 +14,6 @@ import ZygoteRules, Zygote, ReverseDiff import ArrayInterfaceCore, ArrayInterfaceTracker import Enzyme import GPUArraysCore -import StaticArrays using Cassette, DiffRules using Core: CodeInfo, SlotNumber, SSAValue, ReturnNode, GotoIfNot diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 2cdea5ae2..d622a5a6f 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -401,13 +401,15 @@ function (f::ReverseLossCallback)(integrator) copyto!(y, integrator.u[(end - idx + 1):end]) end - if u isa StaticArrays.SArray - gᵤ = isq ? λ : @view(λ[1:idx]) - gᵤ = g(gᵤ, y, p, t[cur_time[]], cur_time[]) - else + # if u isa StaticArrays.SArray + if ArrayInterfaceCore.ismutable(u) # Warning: alias here! Be careful with λ gᵤ = isq ? λ : @view(λ[1:idx]) g(gᵤ, y, p, t[cur_time[]], cur_time[]) + else + @assert sensealg isa QuadratureAdjoint + gᵤ = isq ? λ : @view(λ[1:idx]) + gᵤ = g(gᵤ, y, p, t[cur_time[]], cur_time[]) end if issemiexplicitdae @@ -425,10 +427,12 @@ function (f::ReverseLossCallback)(integrator) F !== I && F !== (I, I) && ldiv!(F, Δλd) end - if u isa StaticArrays.SArray - integrator.u += Δλd - else + # if u isa StaticArrays.SArray + if ArrayInterfaceCore.ismutable(u) u[diffvar_idxs] .+= Δλd + else + @assert sensealg isa QuadratureAdjoint + integrator.u += Δλd end u_modified!(integrator, true) cur_time[] -= 1 diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index e5ce1faab..d2d62fb55 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -258,7 +258,11 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, x = vec(Δ[1]) _out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs])) elseif _save_idxs isa Colon - vec(_out) .= adapt(outtype, vec(Δ[1])) + if ArrayInterfaceCore.ismutable(u) + vec(_out) .= adapt(outtype, vec(Δ[1])) + else + _out = adapt(outtype, vec(Δ[1])) + end else vec(@view(_out[_save_idxs])) .= adapt(outtype, vec(Δ[1])[_save_idxs]) @@ -269,7 +273,11 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, x = vec(Δ) _out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs])) elseif _save_idxs isa Colon - vec(_out) .= adapt(outtype, vec(Δ)) + if ArrayInterfaceCore.ismutable(u) + vec(_out) .= adapt(outtype, vec(Δ)) + else + _out = adapt(outtype, vec(Δ)) + end else x = vec(Δ) vec(@view(_out[_save_idxs])) .= adapt(outtype, @view(x[_save_idxs])) @@ -283,7 +291,11 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, if typeof(_save_idxs) <: Number _out[_save_idxs] = @view(x[_save_idxs]) elseif _save_idxs isa Colon - vec(_out) .= vec(x) + if ArrayInterfaceCore.ismutable(u) + vec(_out) .= vec(x) + else + _out = vec(x) + end else vec(@view(_out[_save_idxs])) .= vec(@view(x[_save_idxs])) end @@ -293,9 +305,15 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, reshape(Δ, prod(size(Δ)[1:(end - 1)]), size(Δ)[end])[_save_idxs, i]) elseif _save_idxs isa Colon - vec(_out) .= vec(adapt(outtype, - reshape(Δ, prod(size(Δ)[1:(end - 1)]), - size(Δ)[end])[:, i])) + if ArrayInterfaceCore.ismutable(u) + vec(_out) .= vec(adapt(outtype, + reshape(Δ, prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[:, i])) + else + _out = vec(adapt(outtype, + reshape(Δ, prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[:, i])) + end else vec(@view(_out[_save_idxs])) .= vec(adapt(outtype, reshape(Δ, @@ -305,211 +323,9 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, end end end - end - - if haskey(kwargs_adj, :callback_adj) - cb2 = CallbackSet(cb, kwargs[:callback_adj]) - else - cb2 = cb - end - - du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts, dg_discrete = df, - sensealg = sensealg, - callback = cb2, - kwargs_adj...) - - du0 = reshape(du0, size(u0)) - dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing : - reshape(dp', size(p)) - - if originator isa SciMLBase.TrackerOriginator || - originator isa SciMLBase.ReverseDiffOriginator - (NoTangent(), NoTangent(), du0, dp, NoTangent(), - ntuple(_ -> NoTangent(), length(args))...) - else - (NoTangent(), NoTangent(), NoTangent(), du0, dp, NoTangent(), - ntuple(_ -> NoTangent(), length(args))...) - end - end - out, adjoint_sensitivity_backpass -end - -function DiffEqBase._concrete_solve_adjoint(prob, alg, - sensealg::AbstractAdjointSensitivityAlgorithm, - u0::StaticArrays.SVector, p, originator::SciMLBase.ADOriginator, - args...; save_start = true, save_end = true, - saveat = eltype(prob.tspan)[], - save_idxs = nothing, - kwargs...) - if !(typeof(p) <: Union{Nothing, SciMLBase.NullParameters, AbstractArray}) || - (p isa AbstractArray && !Base.isconcretetype(eltype(p))) - throw(AdjointSensitivityParameterCompatibilityError()) - end - - # Remove saveat, etc. from kwargs since it's handled separately - # and letting it jump back in there can break the adjoint - kwargs_prob = NamedTuple(filter(x -> x[1] != :saveat && x[1] != :save_start && - x[1] != :save_end && x[1] != :save_idxs, - prob.kwargs)) - - if haskey(kwargs, :callback) - cb = track_callbacks(CallbackSet(kwargs[:callback]), prob.tspan[1], prob.u0, prob.p, - sensealg) - _prob = remake(prob; u0 = u0, p = p, kwargs = merge(kwargs_prob, (; callback = cb))) - else - cb = nothing - _prob = remake(prob; u0 = u0, p = p, kwargs = kwargs_prob) - end - - # Remove callbacks, saveat, etc. from kwargs since it's handled separately - kwargs_fwd = NamedTuple{Base.diff_names(Base._nt_names(values(kwargs)), (:callback,))}(values(kwargs)) - - # Capture the callback_adj for the reverse pass and remove both callbacks - kwargs_adj = NamedTuple{ - Base.diff_names(Base._nt_names(values(kwargs)), - (:callback_adj, :callback))}(values(kwargs)) - isq = sensealg isa QuadratureAdjoint - if typeof(sensealg) <: BacksolveAdjoint - sol = solve(_prob, alg, args...; save_noise = true, - save_start = save_start, save_end = save_end, - saveat = saveat, kwargs_fwd...) - elseif ischeckpointing(sensealg) - sol = solve(_prob, alg, args...; save_noise = true, - save_start = true, save_end = true, - saveat = saveat, kwargs_fwd...) - else - sol = solve(_prob, alg, args...; save_noise = true, save_start = true, - save_end = true, kwargs_fwd...) - end - - # Force `save_start` and `save_end` in the forward pass This forces the - # solver to do the backsolve all the way back to `u0` Since the start aliases - # `_prob.u0`, this doesn't actually use more memory But it cleans up the - # implementation and makes `save_start` and `save_end` arg safe. - if typeof(sensealg) <: BacksolveAdjoint - # Saving behavior unchanged - ts = sol.t - only_end = length(ts) == 1 && ts[1] == _prob.tspan[2] - out = DiffEqBase.sensitivity_solution(sol, sol.u, ts) - elseif saveat isa Number - if _prob.tspan[2] > _prob.tspan[1] - ts = _prob.tspan[1]:convert(typeof(_prob.tspan[2]), abs(saveat)):_prob.tspan[2] - else - ts = _prob.tspan[2]:convert(typeof(_prob.tspan[2]), abs(saveat)):_prob.tspan[1] - end - # if _prob.tspan[2]-_prob.tspan[1] is not a multiple of saveat, one looses the last ts value - sol.t[end] !== ts[end] && (ts = fix_endpoints(sensealg, sol, ts)) - if cb === nothing - _out = sol(ts) - else - _, duplicate_iterator_times = separate_nonunique(sol.t) - _out, ts = out_and_ts(ts, duplicate_iterator_times, sol) - end - - out = if save_idxs === nothing - out = DiffEqBase.sensitivity_solution(sol, _out.u, ts) - else - out = DiffEqBase.sensitivity_solution(sol, - [_out[i][save_idxs] - for i in 1:length(_out)], ts) - end - only_end = length(ts) == 1 && ts[1] == _prob.tspan[2] - elseif isempty(saveat) - no_start = !save_start - no_end = !save_end - sol_idxs = 1:length(sol) - no_start && (sol_idxs = sol_idxs[2:end]) - no_end && (sol_idxs = sol_idxs[1:(end - 1)]) - only_end = length(sol_idxs) <= 1 - _u = sol.u[sol_idxs] - u = save_idxs === nothing ? _u : [x[save_idxs] for x in _u] - ts = sol.t[sol_idxs] - out = DiffEqBase.sensitivity_solution(sol, u, ts) - else - _saveat = saveat isa Array ? sort(saveat) : saveat # for minibatching - if cb === nothing - _saveat = eltype(_saveat) <: typeof(prob.tspan[2]) ? - convert.(typeof(_prob.tspan[2]), _saveat) : _saveat - ts = _saveat - _out = sol(ts) - else - _ts, duplicate_iterator_times = separate_nonunique(sol.t) - _out, ts = out_and_ts(_saveat, duplicate_iterator_times, sol) - end - - out = if save_idxs === nothing - out = DiffEqBase.sensitivity_solution(sol, _out.u, ts) - else - out = DiffEqBase.sensitivity_solution(sol, - [_out[i][save_idxs] - for i in 1:length(_out)], ts) - end - only_end = length(ts) == 1 && ts[1] == _prob.tspan[2] - end - - _save_idxs = save_idxs === nothing ? Colon() : save_idxs - - function adjoint_sensitivity_backpass(Δ) - function df(_out, u, p, t, i) - outtype = typeof(_out) <: SubArray ? - DiffEqBase.parameterless_type(_out.parent) : - DiffEqBase.parameterless_type(_out) - if only_end - eltype(Δ) <: NoTangent && return - if typeof(Δ) <: AbstractArray{<:AbstractArray} && length(Δ) == 1 && i == 1 - # user did sol[end] on only_end - if typeof(_save_idxs) <: Number - x = vec(Δ[1]) - _out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs])) - elseif _save_idxs isa Colon - _out = adapt(outtype, vec(Δ[1])) - else - vec(@view(_out[_save_idxs])) .= adapt(outtype, - vec(Δ[1])[_save_idxs]) - end - else - Δ isa NoTangent && return - if typeof(_save_idxs) <: Number - x = vec(Δ) - _out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs])) - elseif _save_idxs isa Colon - _out = adapt(outtype, vec(Δ)) - else - x = vec(Δ) - vec(@view(_out[_save_idxs])) .= adapt(outtype, @view(x[_save_idxs])) - end - end - else - !Base.isconcretetype(eltype(Δ)) && - (Δ[i] isa NoTangent || eltype(Δ) <: NoTangent) && return - if typeof(Δ) <: AbstractArray{<:AbstractArray} || typeof(Δ) <: DESolution - x = Δ[i] - if typeof(_save_idxs) <: Number - _out[_save_idxs] = @view(x[_save_idxs]) - elseif _save_idxs isa Colon - _out = vec(x) - else - vec(@view(_out[_save_idxs])) .= vec(@view(x[_save_idxs])) - end - else - if typeof(_save_idxs) <: Number - _out[_save_idxs] = adapt(outtype, - reshape(Δ, prod(size(Δ)[1:(end - 1)]), - size(Δ)[end])[_save_idxs, i]) - elseif _save_idxs isa Colon - _out = vec(adapt(outtype, - reshape(Δ, prod(size(Δ)[1:(end - 1)]), - size(Δ)[end])[:, i]))######Required Assignment################# - else - vec(@view(_out[_save_idxs])) .= vec(adapt(outtype, - reshape(Δ, - prod(size(Δ)[1:(end - 1)]), - size(Δ)[end])[:, - i])) - end - end + if !(ArrayInterfaceCore.ismutable(u0)) + return _out end - return _out end if haskey(kwargs_adj, :callback_adj) diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index a7f87cc75..057a5fcdc 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -116,7 +116,7 @@ end odefun = ODEFunction(sense, mass_matrix = sol.prob.f.mass_matrix', jac_prototype = adjoint_jac_prototype) end - return ODEProblem{!(z0 isa StaticArrays.SArray)}(odefun, z0, tspan, p, callback = cb) + return ODEProblem{ArrayInterfaceCore.ismutable(z0)}(odefun, z0, tspan, p, callback = cb) end struct AdjointSensitivityIntegrand{pType, uType, lType, rateType, S, AS, PF, PJC, PJT, DGP, @@ -210,12 +210,13 @@ end function (S::AdjointSensitivityIntegrand)(out, t) @unpack y, λ, pJ, pf, p, f_cache, dgdp_cache, paramjac_config, sensealg, sol, adj_sol = S f = sol.prob.f - if eltype(sol.u) <: StaticArrays.SArray - y = sol(t) - λ = adj_sol(t) - else + # if eltype(sol.u) <: StaticArrays.SArray + if ArrayInterfaceCore.ismutable(eltype(sol.u)) sol(y, t) adj_sol(λ, t) + else + y = sol(t) + λ = adj_sol(t) end isautojacvec = get_jacvec(sensealg) # y is aliased From bdde7083ca82fd52c925836681248c7a7b6bb7c3 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Wed, 6 Jul 2022 12:57:59 +0530 Subject: [PATCH 04/25] Update adjoint_common.jl --- src/adjoint_common.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index d622a5a6f..297e0c5a5 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -401,7 +401,6 @@ function (f::ReverseLossCallback)(integrator) copyto!(y, integrator.u[(end - idx + 1):end]) end - # if u isa StaticArrays.SArray if ArrayInterfaceCore.ismutable(u) # Warning: alias here! Be careful with λ gᵤ = isq ? λ : @view(λ[1:idx]) @@ -427,7 +426,6 @@ function (f::ReverseLossCallback)(integrator) F !== I && F !== (I, I) && ldiv!(F, Δλd) end - # if u isa StaticArrays.SArray if ArrayInterfaceCore.ismutable(u) u[diffvar_idxs] .+= Δλd else From 2c067abab0cad49aa9e8593f36976fc538a59263 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Mon, 11 Jul 2022 11:38:35 +0530 Subject: [PATCH 05/25] oop dispatch for dgdu function, some tests, some corrections --- Project.toml | 4 +- src/adjoint_common.jl | 4 +- src/concrete_solve.jl | 60 +++++++++++++++++++++++++++++- src/derivative_wrappers.jl | 12 +++--- test/adjoint.jl | 75 +++++++++++++++++++++++++++++++++++++- 5 files changed, 144 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index 48e386174..0721c1015 100644 --- a/Project.toml +++ b/Project.toml @@ -85,9 +85,11 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AlgebraicMultigrid", "ComponentArrays", "Calculus", "Distributed", "DelayDiffEq", "Optimization", "OptimizationFlux", "OptimizationOptimJL", "Flux", "ReverseDiff", "SafeTestsets", "Test", "Random", "Pkg", "SteadyStateDiffEq", "NLsolve", "SparseArrays"] +test = ["AlgebraicMultigrid", "ComponentArrays", "Calculus", "Distributed", "DelayDiffEq", "Optimization", "OptimizationFlux", "OptimizationOptimJL", "Flux", "ReverseDiff", "SafeTestsets", "SimpleChains", "StaticArrays", "Test", "Random", "Pkg", "SteadyStateDiffEq", "NLsolve", "SparseArrays"] diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 297e0c5a5..1258b0862 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -407,8 +407,8 @@ function (f::ReverseLossCallback)(integrator) g(gᵤ, y, p, t[cur_time[]], cur_time[]) else @assert sensealg isa QuadratureAdjoint - gᵤ = isq ? λ : @view(λ[1:idx]) - gᵤ = g(gᵤ, y, p, t[cur_time[]], cur_time[]) + outtype = DiffEqBase.parameterless_type(λ) + gᵤ = g(y, p, t[cur_time[]], cur_time[];outtype=outtype) end if issemiexplicitdae diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index d2d62fb55..f03c39664 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -328,6 +328,64 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, end end + function df(u, p, t, i;outtype=nothing) + if only_end + eltype(Δ) <: NoTangent && return + if typeof(Δ) <: AbstractArray{<:AbstractArray} && length(Δ) == 1 && i == 1 + # user did sol[end] on only_end + if typeof(_save_idxs) <: Number + x = vec(Δ[1]) + _out = adapt(outtype, @view(x[_save_idxs])) + elseif _save_idxs isa Colon + _out = adapt(outtype, vec(Δ[1])) + else + _out = adapt(outtype, + vec(Δ[1])[_save_idxs]) + end + else + Δ isa NoTangent && return + if typeof(_save_idxs) <: Number + x = vec(Δ) + _out = adapt(outtype, @view(x[_save_idxs])) + elseif _save_idxs isa Colon + _out = adapt(outtype, vec(Δ)) + else + x = vec(Δ) + _out = adapt(outtype, @view(x[_save_idxs])) + end + end + else + !Base.isconcretetype(eltype(Δ)) && + (Δ[i] isa NoTangent || eltype(Δ) <: NoTangent) && return + if typeof(Δ) <: AbstractArray{<:AbstractArray} || typeof(Δ) <: DESolution + x = Δ[i] + if typeof(_save_idxs) <: Number + _out = @view(x[_save_idxs]) + elseif _save_idxs isa Colon + _out = vec(x) + else + _out = vec(@view(x[_save_idxs])) + end + else + if typeof(_save_idxs) <: Number + _out = adapt(outtype, + reshape(Δ, prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[_save_idxs, i]) + elseif _save_idxs isa Colon + _out = vec(adapt(outtype, + reshape(Δ, prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[:, i])) + else + _out = vec(adapt(outtype, + reshape(Δ, + prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[:,i])) + end + end + end + return _out + end + if haskey(kwargs_adj, :callback_adj) cb2 = CallbackSet(cb, kwargs[:callback_adj]) else @@ -855,7 +913,7 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, sensealg::ReverseDiffAdjo function reversediff_adjoint_forwardpass(_u0, _p) if (convert_tspan(sensealg) === nothing && - ((haskey(kwargs, :callback) && has_continuous_callback(kwargs[:callback])))) || + ((haskey(kwargs, :callback) && has_a_callback(kwargs[:callback])))) || (convert_tspan(sensealg) !== nothing && convert_tspan(sensealg)) _tspan = convert.(eltype(_p), prob.tspan) else diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 1dac7a791..721abcd8e 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -916,14 +916,14 @@ function accumulate_cost(dλ, y, p, t, S::TS, @unpack dg, dg_val, g, g_grad_config = S.diffcache if dg !== nothing if !(dg isa Tuple) - dg(dg_val, y, p, t) - dλ -= vec(dg_val) + dg_val = dg(y, p, t) + dλ -= dg_val else - dg[1](dg_val[1], y, p, t) - dλ -= vec(dg_val[1]) + dg[1](y, p, t) + dλ -= dg_val if dgrad !== nothing - dg[2](dg_val[2], y, p, t) - dgrad .-= vec(dg_val[2]) + dg[2](y, p, t) + dgrad -= dg_val end end else diff --git a/test/adjoint.jl b/test/adjoint.jl index 03e596ee5..0f648ea27 100644 --- a/test/adjoint.jl +++ b/test/adjoint.jl @@ -1,5 +1,5 @@ using SciMLSensitivity, OrdinaryDiffEq, RecursiveArrayTools, DiffEqBase, - ForwardDiff, Calculus, QuadGK, LinearAlgebra, Zygote + ForwardDiff, Calculus, QuadGK, LinearAlgebra, Zygote, SimpleChains, StaticArrays, Optimization, OptimizationFlux using Test function fb(du, u, p, t) @@ -849,3 +849,76 @@ using LinearAlgebra, SciMLSensitivity, OrdinaryDiffEq, ForwardDiff, QuadGK end end end + +####Fully oop Adjoint + +u0 = @SArray Float32[2.0, 0.0] +datasize = 30 +tspan = (0.0f0, 1.5f0) +tsteps = range(tspan[1], tspan[2], length = datasize) + +function trueODE(u, p, t) + true_A = @SMatrix Float32[-0.1 2.0; -2.0 -0.1] + ((u.^3)'true_A)' +end + +prob = ODEProblem(trueODE, u0, tspan) +data = Array(solve(prob, Tsit5(), saveat = tsteps)) + +sc = SimpleChain( + static(2), + Activation(x -> x.^3), + TurboDense{true}(tanh, static(50)), + TurboDense{true}(identity, static(2)) + ) + +p_nn = SimpleChains.init_params(sc) + +df(u,p,t) = sc(u,p) + +prob_nn = ODEProblem(df, u0, tspan, p_nn) +sol = solve(prob_nn, Tsit5();saveat=tsteps) +dg_disc(u, p, t, i;outtype=nothing) = data[:, i] .- u + +res = adjoint_sensitivities(sol,Tsit5();t=tsteps[end],dg_discrete=dg_disc, + sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())) + +@test !iszero(res[1]) +@test !iszero(res[2]) + +G(u,p,t) = sum(abs2, ((data.-u)./2)) + +function dg(u,p,t) + @show u + return data[:, end] .- u +end + +res = adjoint_sensitivities(sol,Tsit5();dg_continuous=dg,g=G, + sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())) + +@test !iszero(res[1]) +@test !iszero(res[2]) + +prob_nn = ODEProblem(f, u0, tspan) + +function predict_neuralode(p) + Array(solve(prob_nn, Tsit5();p=p,saveat=tsteps,sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP()))) +end + +function loss_neuralode(p) + pred = predict_neuralode(p) + loss = sum(abs2, data .- pred) + return loss, pred +end + +callback = function (p, l, pred; doplot = true) + display(l) + return false +end + +optf = Optimization.OptimizationFunction((x,p)->loss_neuralode(x), Optimization.AutoZygote()) +optprob = Optimization.OptimizationProblem(optf, p_nn) + +res = Optimization.solve(optprob, ADAM(0.05),callback=callback,maxiters=300) + +@test loss_neuralode(res.u) < 0.8 \ No newline at end of file From 971cea06e1b1a00d4d699036d5cbe49e69a944eb Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt <46929125+Abhishek-1Bhatt@users.noreply.github.com> Date: Mon, 11 Jul 2022 12:10:52 +0530 Subject: [PATCH 06/25] Update src/concrete_solve.jl --- src/concrete_solve.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index ebdc987bb..a2054de75 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -913,7 +913,7 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, sensealg::ReverseDiffAdjo function reversediff_adjoint_forwardpass(_u0, _p) if (convert_tspan(sensealg) === nothing && - ((haskey(kwargs, :callback) && has_a_callback(kwargs[:callback])))) || + ((haskey(kwargs, :callback) && has_continuous_callback(kwargs[:callback])))) || (convert_tspan(sensealg) !== nothing && convert_tspan(sensealg)) _tspan = convert.(eltype(_p), prob.tspan) else From 10843add3c7cfb5261e90c1b8a5417aa29e1b07e Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Mon, 11 Jul 2022 19:38:47 +0530 Subject: [PATCH 07/25] revert df_iip plus correction in dgdu --- src/adjoint_common.jl | 4 ++-- src/concrete_solve.jl | 53 ++++++++++++++++--------------------------- 2 files changed, 21 insertions(+), 36 deletions(-) diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index bbbf943e2..d78db2bec 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -410,11 +410,11 @@ function (f::ReverseLossCallback)(integrator) if ArrayInterfaceCore.ismutable(u) # Warning: alias here! Be careful with λ gᵤ = isq ? λ : @view(λ[1:idx]) - g(gᵤ, y, p, t[cur_time[]], cur_time[]) + dgdu(gᵤ, y, p, t[cur_time[]], cur_time[]) else @assert sensealg isa QuadratureAdjoint outtype = DiffEqBase.parameterless_type(λ) - gᵤ = g(y, p, t[cur_time[]], cur_time[];outtype=outtype) + gᵤ = dgdu(y, p, t[cur_time[]], cur_time[];outtype=outtype) end if issemiexplicitdae diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index ebdc987bb..c31fe2a53 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -246,7 +246,7 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, _save_idxs = save_idxs === nothing ? Colon() : save_idxs function adjoint_sensitivity_backpass(Δ) - function df(_out, u, p, t, i) + function df_iip(_out, u, p, t, i) outtype = typeof(_out) <: SubArray ? DiffEqBase.parameterless_type(_out.parent) : DiffEqBase.parameterless_type(_out) @@ -258,11 +258,7 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, x = vec(Δ[1]) _out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs])) elseif _save_idxs isa Colon - if ArrayInterfaceCore.ismutable(u) - vec(_out) .= adapt(outtype, vec(Δ[1])) - else - _out = adapt(outtype, vec(Δ[1])) - end + vec(_out) .= adapt(outtype, vec(Δ[1])) else vec(@view(_out[_save_idxs])) .= adapt(outtype, vec(Δ[1])[_save_idxs]) @@ -273,11 +269,7 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, x = vec(Δ) _out[_save_idxs] .= adapt(outtype, @view(x[_save_idxs])) elseif _save_idxs isa Colon - if ArrayInterfaceCore.ismutable(u) - vec(_out) .= adapt(outtype, vec(Δ)) - else - _out = adapt(outtype, vec(Δ)) - end + vec(_out) .= adapt(outtype, vec(Δ)) else x = vec(Δ) vec(@view(_out[_save_idxs])) .= adapt(outtype, @view(x[_save_idxs])) @@ -291,11 +283,7 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, if typeof(_save_idxs) <: Number _out[_save_idxs] = @view(x[_save_idxs]) elseif _save_idxs isa Colon - if ArrayInterfaceCore.ismutable(u) - vec(_out) .= vec(x) - else - _out = vec(x) - end + vec(_out) .= vec(x) else vec(@view(_out[_save_idxs])) .= vec(@view(x[_save_idxs])) end @@ -305,15 +293,9 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, reshape(Δ, prod(size(Δ)[1:(end - 1)]), size(Δ)[end])[_save_idxs, i]) elseif _save_idxs isa Colon - if ArrayInterfaceCore.ismutable(u) - vec(_out) .= vec(adapt(outtype, - reshape(Δ, prod(size(Δ)[1:(end - 1)]), - size(Δ)[end])[:, i])) - else - _out = vec(adapt(outtype, - reshape(Δ, prod(size(Δ)[1:(end - 1)]), - size(Δ)[end])[:, i])) - end + vec(_out) .= vec(adapt(outtype, + reshape(Δ, prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[:, i])) else vec(@view(_out[_save_idxs])) .= vec(adapt(outtype, reshape(Δ, @@ -323,12 +305,9 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, end end end - if !(ArrayInterfaceCore.ismutable(u0)) - return _out - end end - function df(u, p, t, i;outtype=nothing) + function df_oop(u, p, t, i;outtype=nothing) if only_end eltype(Δ) <: NoTangent && return if typeof(Δ) <: AbstractArray{<:AbstractArray} && length(Δ) == 1 && i == 1 @@ -391,11 +370,17 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, else cb2 = cb end - - du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts, dgdu_discrete = df, - sensealg = sensealg, - callback = cb2, - kwargs_adj...) + if ArrayInterfaceCore.ismutable(eltype(sol.u)) + du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts, dgdu_discrete = df_iip, + sensealg = sensealg, + callback = cb2, + kwargs_adj...) + else + du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts, dgdu_discrete = df_oop, + sensealg = sensealg, + callback = cb2, + kwargs_adj...) + end du0 = reshape(du0, size(u0)) dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing : From 6158ba91cb44327dc9d1b9327b22382e92114a54 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt <46929125+Abhishek-1Bhatt@users.noreply.github.com> Date: Mon, 11 Jul 2022 19:51:36 +0530 Subject: [PATCH 08/25] Update adjoint_common.jl --- src/adjoint_common.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index d78db2bec..5005c4067 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -410,7 +410,15 @@ function (f::ReverseLossCallback)(integrator) if ArrayInterfaceCore.ismutable(u) # Warning: alias here! Be careful with λ gᵤ = isq ? λ : @view(λ[1:idx]) - dgdu(gᵤ, y, p, t[cur_time[]], cur_time[]) + if dgdu !== nothing + dgdu(gᵤ, y, p, t[cur_time[]], cur_time[]) + # add discrete dgdp contribution + if dgdp !== nothing && !isq + gp = @view(λ[(idx + 1):end]) + dgdp(gp, y, p, t[cur_time[]], cur_time[]) + u[(idx + 1):length(λ)] .+= gp + end + end else @assert sensealg isa QuadratureAdjoint outtype = DiffEqBase.parameterless_type(λ) From a2d90fd74d82896c55205b4638299bb404a39f49 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt <46929125+Abhishek-1Bhatt@users.noreply.github.com> Date: Mon, 11 Jul 2022 19:55:09 +0530 Subject: [PATCH 09/25] Update adjoint_common.jl --- src/adjoint_common.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 5005c4067..b844d6b6b 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -412,12 +412,12 @@ function (f::ReverseLossCallback)(integrator) gᵤ = isq ? λ : @view(λ[1:idx]) if dgdu !== nothing dgdu(gᵤ, y, p, t[cur_time[]], cur_time[]) - # add discrete dgdp contribution - if dgdp !== nothing && !isq - gp = @view(λ[(idx + 1):end]) - dgdp(gp, y, p, t[cur_time[]], cur_time[]) - u[(idx + 1):length(λ)] .+= gp - end + # add discrete dgdp contribution + if dgdp !== nothing && !isq + gp = @view(λ[(idx + 1):end]) + dgdp(gp, y, p, t[cur_time[]], cur_time[]) + u[(idx + 1):length(λ)] .+= gp + end end else @assert sensealg isa QuadratureAdjoint From b18f9d3bc051e5d2a78b3f47b5dbff3adaf9ef87 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Thu, 14 Jul 2022 19:31:58 +0530 Subject: [PATCH 10/25] tests, returns --- src/derivative_wrappers.jl | 26 +++------ src/quadrature_adjoint.jl | 5 +- test/adjoint.jl | 75 +------------------------- test/adjoint_oop.jl | 106 +++++++++++++++++++++++++++++++++++++ 4 files changed, 118 insertions(+), 94 deletions(-) create mode 100644 test/adjoint_oop.jl diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index b63e743f9..910111460 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -613,7 +613,7 @@ function _vecjacobian(y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy, (dgrad[:] .= vec(tmp2)) end end - return dλ + return dy, dλ, dgrad end function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, dy, @@ -912,25 +912,15 @@ end function accumulate_cost(dλ, y, p, t, S::TS, dgrad = nothing) where {TS <: SensitivityFunction} - @unpack dg, dg_val, g, g_grad_config = S.diffcache - if dg !== nothing - if !(dg isa Tuple) - dg_val = dg(y, p, t) - dλ -= dg_val - else - dg[1](y, p, t) - dλ -= dg_val - if dgrad !== nothing - dg[2](y, p, t) - dgrad -= dg_val - end + @unpack dgdu, dgdp = S.diffcache + + dλ -= dgdu(y, p, t) + if dgdp !== nothing + if dgrad !== nothing + dgrad -= dgdp(y, p, t) end - else - g.t = t - gradient!(dg_val, g, y, S.sensealg, g_grad_config) - dλ -= vec(dg_val) end - return dλ + return dλ, dgrad end function build_jac_config(alg, uf, u) diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index a93c5ad41..1f574dffd 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -38,10 +38,11 @@ function (S::ODEQuadratureAdjointSensitivityFunction)(u, p, t) λ, grad, y, dgrad, dy = split_states(u, t, S) - dλ = vecjacobian(y, λ, p, t, S) * (-one(eltype(λ))) + dy, dλ, dgrad = vecjacobian(y, λ, p, t, S;dgrad=dgrad, dy=dy) + dλ *= (-one(eltype(λ))) if !discrete - return accumulate_cost(dλ, y, p, t, S) + dλ, dgrad = accumulate_cost(dλ, y, p, t, S, dgrad) end return dλ end diff --git a/test/adjoint.jl b/test/adjoint.jl index b99b63252..fd9cc18d8 100644 --- a/test/adjoint.jl +++ b/test/adjoint.jl @@ -1,5 +1,5 @@ using SciMLSensitivity, OrdinaryDiffEq, RecursiveArrayTools, DiffEqBase, - ForwardDiff, Calculus, QuadGK, LinearAlgebra, Zygote, SimpleChains, StaticArrays, Optimization, OptimizationFlux + ForwardDiff, Calculus, QuadGK, LinearAlgebra, Zygote using Test function fb(du, u, p, t) @@ -866,76 +866,3 @@ using LinearAlgebra, SciMLSensitivity, OrdinaryDiffEq, ForwardDiff, QuadGK end end end - -####Fully oop Adjoint - -u0 = @SArray Float32[2.0, 0.0] -datasize = 30 -tspan = (0.0f0, 1.5f0) -tsteps = range(tspan[1], tspan[2], length = datasize) - -function trueODE(u, p, t) - true_A = @SMatrix Float32[-0.1 2.0; -2.0 -0.1] - ((u.^3)'true_A)' -end - -prob = ODEProblem(trueODE, u0, tspan) -data = Array(solve(prob, Tsit5(), saveat = tsteps)) - -sc = SimpleChain( - static(2), - Activation(x -> x.^3), - TurboDense{true}(tanh, static(50)), - TurboDense{true}(identity, static(2)) - ) - -p_nn = SimpleChains.init_params(sc) - -df(u,p,t) = sc(u,p) - -prob_nn = ODEProblem(df, u0, tspan, p_nn) -sol = solve(prob_nn, Tsit5();saveat=tsteps) -dg_disc(u, p, t, i;outtype=nothing) = data[:, i] .- u - -res = adjoint_sensitivities(sol,Tsit5();t=tsteps[end],dg_discrete=dg_disc, - sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())) - -@test !iszero(res[1]) -@test !iszero(res[2]) - -G(u,p,t) = sum(abs2, ((data.-u)./2)) - -function dg(u,p,t) - @show u - return data[:, end] .- u -end - -res = adjoint_sensitivities(sol,Tsit5();dg_continuous=dg,g=G, - sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())) - -@test !iszero(res[1]) -@test !iszero(res[2]) - -prob_nn = ODEProblem(f, u0, tspan) - -function predict_neuralode(p) - Array(solve(prob_nn, Tsit5();p=p,saveat=tsteps,sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP()))) -end - -function loss_neuralode(p) - pred = predict_neuralode(p) - loss = sum(abs2, data .- pred) - return loss, pred -end - -callback = function (p, l, pred; doplot = true) - display(l) - return false -end - -optf = Optimization.OptimizationFunction((x,p)->loss_neuralode(x), Optimization.AutoZygote()) -optprob = Optimization.OptimizationProblem(optf, p_nn) - -res = Optimization.solve(optprob, ADAM(0.05),callback=callback,maxiters=300) - -@test loss_neuralode(res.u) < 0.8 \ No newline at end of file diff --git a/test/adjoint_oop.jl b/test/adjoint_oop.jl new file mode 100644 index 000000000..f570417ff --- /dev/null +++ b/test/adjoint_oop.jl @@ -0,0 +1,106 @@ +using SciMLSensitivity,OrdinaryDiffEq, SimpleChains, StaticArrays, QuadGK, ForwardDiff, Zygote +using Test + +u0 = @SArray Float32[2.0, 0.0] +datasize = 30 +tspan = (0.0f0, 1.5f0) +tsteps = range(tspan[1], tspan[2], length = datasize) + +function trueODE(u, p, t) + true_A = @SMatrix Float32[-0.1 2.0; -2.0 -0.1] + ((u.^3)'true_A)' +end + +prob = ODEProblem(trueODE, u0, tspan) +sol_n = solve(prob, Tsit5(), saveat = tsteps) +data = Array(solve(prob, Tsit5(), saveat = tsteps)) + +sc = SimpleChain( + static(2), + Activation(x -> x.^3), + TurboDense{true}(tanh, static(50)), + TurboDense{true}(identity, static(2)) + ) + +p_nn = SimpleChains.init_params(sc) + +df(u,p,t) = sc(u,p) + +prob_nn = ODEProblem(df, u0, tspan, p_nn) +sol = solve(prob_nn, Tsit5();saveat=tsteps) +dg_disc(u, p, t, i;outtype=nothing) = data[:, i] .- u + +du0, dp = adjoint_sensitivities(sol,Tsit5();t=tsteps,dgdu_discrete=dg_disc, + sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())) + +@test !iszero(du0) +@test !iszero(dp) + +## numerical + +function G_p(p) + tmp_prob = remake(prob_nn,u0=prob_nn.u0,p=p) + A = Array(solve(tmp_prob,Tsit5(),saveat=tsteps, + sensealg=SensitivityADPassThrough())) + + return sum(((data .- A).^2)./2) +end +function G_u(u0) + tmp_prob = remake(prob_nn,u0=u0,p=p_nn) + A = Array(solve(tmp_prob,Tsit5(),saveat=tsteps, + sensealg=SensitivityADPassThrough())) + return sum(((data .- A).^2)./2) +end +G_p(p_nn) +G_u(u0) +n_du0 = ForwardDiff.gradient(G_p,p_nn) +n_dp = ForwardDiff.gradient(G_u,u0) + +@test_broken n_du0 ≈ du0 +@test_broken n_dp ≈ dp + +## Continuous case + +G(u,p,t) = sum(((data .- u).^2)./2) + +function dg(u,p,t) + return data[:, end] .- u +end + +du0, dp = adjoint_sensitivities(sol,Tsit5();dgdu_continuous=dg,g=G, + sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())) + +@test !iszero(du0) +@test !iszero(dp) +##numerical + +function G_p(p) + tmp_prob = remake(prob_nn,p=p) + sol = solve(tmp_prob,Tsit5(),abstol=1e-5,reltol=1e-5) + res,err = quadgk((t)-> (sum(sol_n(t) .- sol(t)).^2)./2,0.0,1.0,atol=1e-5,rtol=1e-5) # sol_n(t):numerical solution/data(above) + res +end + +function G_u(u0) + tmp_prob = remake(prob_nn,u0=u0) + sol = solve(tmp_prob,Tsit5(),abstol=1e-5,reltol=1e-5) + res,err = quadgk((t)-> (sum(sol_n(t) .- sol(t)).^2)./2,0.0,1.0,atol=1e-5,rtol=1e-5) # sol_n(t):numerical solution/data(above) + res +end + +n_du0 = ForwardDiff.gradient(G_u,u0) +n_dp = ForwardDiff.gradient(G_p,p_nn) + +@test_broken n_du0 ≈ du0 +@test_broken n_dp ≈ dp + +#concrete_solve + +du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob_nn, Tsit5(), u0, p, + abstol = 1e-5, reltol = 1e-5, + saveat = tsteps, + sensealg = QuadratureAdjoint(autojacvec=ZygoteVJP()))), + u0, p_nn) + +@test !iszero(du0) +@test !iszero(dp) \ No newline at end of file From 174812a4975c382f25637187155c2764513ae6f9 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Thu, 14 Jul 2022 19:47:16 +0530 Subject: [PATCH 11/25] Updated adjoint_oop.jl --- test/adjoint_oop.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/adjoint_oop.jl b/test/adjoint_oop.jl index f570417ff..6e6855f48 100644 --- a/test/adjoint_oop.jl +++ b/test/adjoint_oop.jl @@ -76,15 +76,15 @@ du0, dp = adjoint_sensitivities(sol,Tsit5();dgdu_continuous=dg,g=G, function G_p(p) tmp_prob = remake(prob_nn,p=p) - sol = solve(tmp_prob,Tsit5(),abstol=1e-5,reltol=1e-5) - res,err = quadgk((t)-> (sum(sol_n(t) .- sol(t)).^2)./2,0.0,1.0,atol=1e-5,rtol=1e-5) # sol_n(t):numerical solution/data(above) + sol = solve(tmp_prob,Tsit5(),abstol=1e-12,reltol=1e-12) + res,err = quadgk((t)-> (sum(sol_n(t) .- sol(t)).^2)./2,0.0,1.0,atol=1e-12,rtol=1e-12) # sol_n(t):numerical solution/data(above) res end function G_u(u0) tmp_prob = remake(prob_nn,u0=u0) - sol = solve(tmp_prob,Tsit5(),abstol=1e-5,reltol=1e-5) - res,err = quadgk((t)-> (sum(sol_n(t) .- sol(t)).^2)./2,0.0,1.0,atol=1e-5,rtol=1e-5) # sol_n(t):numerical solution/data(above) + sol = solve(tmp_prob,Tsit5(),abstol=1e-12,reltol=1e-12) + res,err = quadgk((t)-> (sum(sol_n(t) .- sol(t)).^2)./2,0.0,1.0,atol=1e-12,rtol=1e-12) # sol_n(t):numerical solution/data(above) res end @@ -97,7 +97,7 @@ n_dp = ForwardDiff.gradient(G_p,p_nn) #concrete_solve du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob_nn, Tsit5(), u0, p, - abstol = 1e-5, reltol = 1e-5, + abstol = 1e-12, reltol = 1e-12, saveat = tsteps, sensealg = QuadratureAdjoint(autojacvec=ZygoteVJP()))), u0, p_nn) From 04ade0d2f2cfc6ae4a6c60685cd0828f7f4042e7 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Thu, 14 Jul 2022 19:53:53 +0530 Subject: [PATCH 12/25] formatter --- src/adjoint_common.jl | 4 +- src/concrete_solve.jl | 16 +++---- src/derivative_wrappers.jl | 6 +-- src/quadrature_adjoint.jl | 8 ++-- test/adjoint_oop.jl | 85 +++++++++++++++++++------------------- 5 files changed, 61 insertions(+), 58 deletions(-) diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 07e313385..e74f2debc 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -422,7 +422,7 @@ function (f::ReverseLossCallback)(integrator) else @assert sensealg isa QuadratureAdjoint outtype = DiffEqBase.parameterless_type(λ) - gᵤ = dgdu(y, p, t[cur_time[]], cur_time[];outtype=outtype) + gᵤ = dgdu(y, p, t[cur_time[]], cur_time[]; outtype = outtype) end if issemiexplicitdae @@ -439,7 +439,7 @@ function (f::ReverseLossCallback)(integrator) if F !== nothing F !== I && F !== (I, I) && ldiv!(F, Δλd) end - + if ArrayInterfaceCore.ismutable(u) u[diffvar_idxs] .+= Δλd else diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index 9dffe8d68..298501807 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -307,7 +307,7 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, end end - function df_oop(u, p, t, i;outtype=nothing) + function df_oop(u, p, t, i; outtype = nothing) if only_end eltype(Δ) <: NoTangent && return if typeof(Δ) <: AbstractArray{<:AbstractArray} && length(Δ) == 1 && i == 1 @@ -353,12 +353,12 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, elseif _save_idxs isa Colon _out = vec(adapt(outtype, reshape(Δ, prod(size(Δ)[1:(end - 1)]), - size(Δ)[end])[:, i])) + size(Δ)[end])[:, i])) else _out = vec(adapt(outtype, - reshape(Δ, - prod(size(Δ)[1:(end - 1)]), - size(Δ)[end])[:,i])) + reshape(Δ, + prod(size(Δ)[1:(end - 1)]), + size(Δ)[end])[:, i])) end end end @@ -371,12 +371,14 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg, cb2 = cb end if ArrayInterfaceCore.ismutable(eltype(sol.u)) - du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts, dgdu_discrete = df_iip, + du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts, + dgdu_discrete = df_iip, sensealg = sensealg, callback = cb2, kwargs_adj...) else - du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts, dgdu_discrete = df_oop, + du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts, + dgdu_discrete = df_oop, sensealg = sensealg, callback = cb2, kwargs_adj...) diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 910111460..08546bc1e 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -580,15 +580,15 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, end function _vecjacobian(y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy, - W) where {TS <: SensitivityFunction} + W) where {TS <: SensitivityFunction} @unpack sensealg, f = S prob = getprob(S) isautojacvec = get_jacvec(sensealg) - + if W === nothing _dy, back = Zygote.pullback(y, p) do u, p - vec(f(u, p, t)) + vec(f(u, p, t)) end else _dy, back = Zygote.pullback(y, p) do u, p diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index 1f574dffd..54069f536 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -38,11 +38,11 @@ function (S::ODEQuadratureAdjointSensitivityFunction)(u, p, t) λ, grad, y, dgrad, dy = split_states(u, t, S) - dy, dλ, dgrad = vecjacobian(y, λ, p, t, S;dgrad=dgrad, dy=dy) - dλ *= (-one(eltype(λ))) - + dy, dλ, dgrad = vecjacobian(y, λ, p, t, S; dgrad = dgrad, dy = dy) + dλ *= (-one(eltype(λ))) + if !discrete - dλ, dgrad = accumulate_cost(dλ, y, p, t, S, dgrad) + dλ, dgrad = accumulate_cost(dλ, y, p, t, S, dgrad) end return dλ end diff --git a/test/adjoint_oop.jl b/test/adjoint_oop.jl index 6e6855f48..c6e5ddcf3 100644 --- a/test/adjoint_oop.jl +++ b/test/adjoint_oop.jl @@ -1,4 +1,5 @@ -using SciMLSensitivity,OrdinaryDiffEq, SimpleChains, StaticArrays, QuadGK, ForwardDiff, Zygote +using SciMLSensitivity, OrdinaryDiffEq, SimpleChains, StaticArrays, QuadGK, ForwardDiff, + Zygote using Test u0 = @SArray Float32[2.0, 0.0] @@ -8,30 +9,28 @@ tsteps = range(tspan[1], tspan[2], length = datasize) function trueODE(u, p, t) true_A = @SMatrix Float32[-0.1 2.0; -2.0 -0.1] - ((u.^3)'true_A)' + ((u .^ 3)'true_A)' end prob = ODEProblem(trueODE, u0, tspan) sol_n = solve(prob, Tsit5(), saveat = tsteps) data = Array(solve(prob, Tsit5(), saveat = tsteps)) -sc = SimpleChain( - static(2), - Activation(x -> x.^3), - TurboDense{true}(tanh, static(50)), - TurboDense{true}(identity, static(2)) - ) +sc = SimpleChain(static(2), + Activation(x -> x .^ 3), + TurboDense{true}(tanh, static(50)), + TurboDense{true}(identity, static(2))) p_nn = SimpleChains.init_params(sc) -df(u,p,t) = sc(u,p) +df(u, p, t) = sc(u, p) prob_nn = ODEProblem(df, u0, tspan, p_nn) -sol = solve(prob_nn, Tsit5();saveat=tsteps) -dg_disc(u, p, t, i;outtype=nothing) = data[:, i] .- u +sol = solve(prob_nn, Tsit5(); saveat = tsteps) +dg_disc(u, p, t, i; outtype = nothing) = data[:, i] .- u -du0, dp = adjoint_sensitivities(sol,Tsit5();t=tsteps,dgdu_discrete=dg_disc, - sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())) +du0, dp = adjoint_sensitivities(sol, Tsit5(); t = tsteps, dgdu_discrete = dg_disc, + sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP())) @test !iszero(du0) @test !iszero(dp) @@ -39,57 +38,59 @@ du0, dp = adjoint_sensitivities(sol,Tsit5();t=tsteps,dgdu_discrete=dg_disc, ## numerical function G_p(p) - tmp_prob = remake(prob_nn,u0=prob_nn.u0,p=p) - A = Array(solve(tmp_prob,Tsit5(),saveat=tsteps, - sensealg=SensitivityADPassThrough())) + tmp_prob = remake(prob_nn, u0 = prob_nn.u0, p = p) + A = Array(solve(tmp_prob, Tsit5(), saveat = tsteps, + sensealg = SensitivityADPassThrough())) - return sum(((data .- A).^2)./2) + return sum(((data .- A) .^ 2) ./ 2) end function G_u(u0) - tmp_prob = remake(prob_nn,u0=u0,p=p_nn) - A = Array(solve(tmp_prob,Tsit5(),saveat=tsteps, - sensealg=SensitivityADPassThrough())) - return sum(((data .- A).^2)./2) + tmp_prob = remake(prob_nn, u0 = u0, p = p_nn) + A = Array(solve(tmp_prob, Tsit5(), saveat = tsteps, + sensealg = SensitivityADPassThrough())) + return sum(((data .- A) .^ 2) ./ 2) end G_p(p_nn) G_u(u0) -n_du0 = ForwardDiff.gradient(G_p,p_nn) -n_dp = ForwardDiff.gradient(G_u,u0) +n_du0 = ForwardDiff.gradient(G_p, p_nn) +n_dp = ForwardDiff.gradient(G_u, u0) @test_broken n_du0 ≈ du0 @test_broken n_dp ≈ dp ## Continuous case -G(u,p,t) = sum(((data .- u).^2)./2) +G(u, p, t) = sum(((data .- u) .^ 2) ./ 2) -function dg(u,p,t) +function dg(u, p, t) return data[:, end] .- u end -du0, dp = adjoint_sensitivities(sol,Tsit5();dgdu_continuous=dg,g=G, - sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())) +du0, dp = adjoint_sensitivities(sol, Tsit5(); dgdu_continuous = dg, g = G, + sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP())) @test !iszero(du0) @test !iszero(dp) ##numerical function G_p(p) - tmp_prob = remake(prob_nn,p=p) - sol = solve(tmp_prob,Tsit5(),abstol=1e-12,reltol=1e-12) - res,err = quadgk((t)-> (sum(sol_n(t) .- sol(t)).^2)./2,0.0,1.0,atol=1e-12,rtol=1e-12) # sol_n(t):numerical solution/data(above) - res + tmp_prob = remake(prob_nn, p = p) + sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12) + res, err = quadgk((t) -> (sum(sol_n(t) .- sol(t)) .^ 2) ./ 2, 0.0, 1.0, atol = 1e-12, + rtol = 1e-12) # sol_n(t):numerical solution/data(above) + res end function G_u(u0) - tmp_prob = remake(prob_nn,u0=u0) - sol = solve(tmp_prob,Tsit5(),abstol=1e-12,reltol=1e-12) - res,err = quadgk((t)-> (sum(sol_n(t) .- sol(t)).^2)./2,0.0,1.0,atol=1e-12,rtol=1e-12) # sol_n(t):numerical solution/data(above) - res + tmp_prob = remake(prob_nn, u0 = u0) + sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12) + res, err = quadgk((t) -> (sum(sol_n(t) .- sol(t)) .^ 2) ./ 2, 0.0, 1.0, atol = 1e-12, + rtol = 1e-12) # sol_n(t):numerical solution/data(above) + res end -n_du0 = ForwardDiff.gradient(G_u,u0) -n_dp = ForwardDiff.gradient(G_p,p_nn) +n_du0 = ForwardDiff.gradient(G_u, u0) +n_dp = ForwardDiff.gradient(G_p, p_nn) @test_broken n_du0 ≈ du0 @test_broken n_dp ≈ dp @@ -97,10 +98,10 @@ n_dp = ForwardDiff.gradient(G_p,p_nn) #concrete_solve du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob_nn, Tsit5(), u0, p, - abstol = 1e-12, reltol = 1e-12, - saveat = tsteps, - sensealg = QuadratureAdjoint(autojacvec=ZygoteVJP()))), - u0, p_nn) + abstol = 1e-12, reltol = 1e-12, + saveat = tsteps, + sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP()))), + u0, p_nn) @test !iszero(du0) -@test !iszero(dp) \ No newline at end of file +@test !iszero(dp) From 9140817984d90592306a175a3cf72a569db75fbd Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Tue, 19 Jul 2022 18:07:29 +0530 Subject: [PATCH 13/25] fix oop adjoint tests --- src/adjoint_common.jl | 8 ++++--- test/adjoint_oop.jl | 54 +++++++++++++++++++++++++++++++++++-------- 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index e74f2debc..895c2d6e1 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -371,7 +371,7 @@ inplace_sensitivity(S::SensitivityFunction) = isinplace(getprob(S)) struct ReverseLossCallback{λType, timeType, yType, RefType, FMType, AlgType, dg1Type, dg2Type, - cacheType} + cacheType, solType} isq::Bool λ::λType t::timeType @@ -383,6 +383,7 @@ struct ReverseLossCallback{λType, timeType, yType, RefType, FMType, AlgType, dg dgdu::dg1Type dgdp::dg2Type diffcache::cacheType + sol::solType end function ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time) @@ -394,11 +395,11 @@ function ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time) idx = length(prob.u0) return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix, - sensealg, dgdu, dgdp, sensefun.diffcache) + sensealg, dgdu, dgdp, sensefun.diffcache, sensefun.sol) end function (f::ReverseLossCallback)(integrator) - @unpack isq, λ, t, y, cur_time, idx, F, sensealg, dgdu, dgdp = f + @unpack isq, λ, t, y, cur_time, idx, F, sensealg, dgdu, dgdp, sol = f @unpack diffvar_idxs, algevar_idxs, issemiexplicitdae, J, uf, f_cache, jac_config = f.diffcache p, u = integrator.p, integrator.u @@ -422,6 +423,7 @@ function (f::ReverseLossCallback)(integrator) else @assert sensealg isa QuadratureAdjoint outtype = DiffEqBase.parameterless_type(λ) + y = sol(t[cur_time[]]) gᵤ = dgdu(y, p, t[cur_time[]], cur_time[]; outtype = outtype) end diff --git a/test/adjoint_oop.jl b/test/adjoint_oop.jl index c6e5ddcf3..498a68afe 100644 --- a/test/adjoint_oop.jl +++ b/test/adjoint_oop.jl @@ -2,6 +2,9 @@ using SciMLSensitivity, OrdinaryDiffEq, SimpleChains, StaticArrays, QuadGK, Forw Zygote using Test + + +##### u0 = @SArray Float32[2.0, 0.0] datasize = 30 tspan = (0.0f0, 1.5f0) @@ -27,7 +30,8 @@ df(u, p, t) = sc(u, p) prob_nn = ODEProblem(df, u0, tspan, p_nn) sol = solve(prob_nn, Tsit5(); saveat = tsteps) -dg_disc(u, p, t, i; outtype = nothing) = data[:, i] .- u + +dg_disc(u, p, t, i; outtype = nothing) = u .- data[:, i] du0, dp = adjoint_sensitivities(sol, Tsit5(); t = tsteps, dgdu_discrete = dg_disc, sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP())) @@ -52,18 +56,18 @@ function G_u(u0) end G_p(p_nn) G_u(u0) -n_du0 = ForwardDiff.gradient(G_p, p_nn) -n_dp = ForwardDiff.gradient(G_u, u0) +n_dp = ForwardDiff.gradient(G_p, p_nn) +n_du0 = ForwardDiff.gradient(G_u, u0) -@test_broken n_du0 ≈ du0 -@test_broken n_dp ≈ dp +@test n_du0 ≈ du0 rtol = 1e-3 +@test n_dp ≈ dp' rtol = 1e-3 ## Continuous case G(u, p, t) = sum(((data .- u) .^ 2) ./ 2) function dg(u, p, t) - return data[:, end] .- u + return u .- Array(sol_n(t)) end du0, dp = adjoint_sensitivities(sol, Tsit5(); dgdu_continuous = dg, g = G, @@ -76,7 +80,7 @@ du0, dp = adjoint_sensitivities(sol, Tsit5(); dgdu_continuous = dg, g = G, function G_p(p) tmp_prob = remake(prob_nn, p = p) sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12) - res, err = quadgk((t) -> (sum(sol_n(t) .- sol(t)) .^ 2) ./ 2, 0.0, 1.0, atol = 1e-12, + res, err = quadgk((t) -> (sum(((sol_n(t) .- sol(t)).^2)./2)), 0.0, 1.5, atol = 1e-12, rtol = 1e-12) # sol_n(t):numerical solution/data(above) res end @@ -84,7 +88,7 @@ end function G_u(u0) tmp_prob = remake(prob_nn, u0 = u0) sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12) - res, err = quadgk((t) -> (sum(sol_n(t) .- sol(t)) .^ 2) ./ 2, 0.0, 1.0, atol = 1e-12, + res, err = quadgk((t) -> (sum(((sol_n(t) .- sol(t)).^2)./2)), 0.0, 1.5, atol = 1e-12, rtol = 1e-12) # sol_n(t):numerical solution/data(above) res end @@ -92,8 +96,8 @@ end n_du0 = ForwardDiff.gradient(G_u, u0) n_dp = ForwardDiff.gradient(G_p, p_nn) -@test_broken n_du0 ≈ du0 -@test_broken n_dp ≈ dp +@test n_du0 ≈ du0 rtol=1e-3 +@test n_dp ≈ dp' rtol=1e-3 #concrete_solve @@ -105,3 +109,33 @@ du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob_nn, Tsit5(), u0, p, @test !iszero(du0) @test !iszero(dp) + +#####Delete################################################################ +using Flux + +u0 = Float32[2.0; 0.0] +datasize = 30 +tspan = (0.0f0, 1.5f0) +tsteps = range(tspan[1], tspan[2], length = datasize) + +function trueODEfunc(u, p, t) + true_A = [-0.1f0 2.0f0; -2.0f0 -0.1f0] + return ((u.^3)'true_A)' +end + +prob_trueode = ODEProblem(trueODEfunc, u0, tspan) +ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) + +dudt2 = Chain((x) -> x.^3, + Dense(2, 50, tanh), + Dense(50, 2)) +p, re = Flux.destructure(dudt2) +f(u, p, t) = re(p)(u) + +prob_nn = ODEProblem(f, u0, tspan) + +du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob_nn, Tsit5(), u0, p, + abstol = 1e-12, reltol = 1e-12, + saveat = tsteps, + sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP()))), + u0, p) \ No newline at end of file From a3747b471ab6783a95447a9deb221473a450b6e7 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Wed, 20 Jul 2022 12:21:27 +0530 Subject: [PATCH 14/25] OOP Adjoint on Numerical solve --- src/quadrature_adjoint.jl | 7 +- test/adjoint_oop.jl | 133 +++++++++++++++++++++++++++++--------- 2 files changed, 107 insertions(+), 33 deletions(-) diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index 54069f536..f73315317 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -332,8 +332,13 @@ function _adjoint_sensitivities(sol, sensealg::QuadratureAdjoint, alg; t = nothi end for i in (length(t) - 1):-1:1 - res .+= quadgk(integrand, t[i], t[i + 1], + if ArrayInterfaceCore.ismutable(res) + res .+= quadgk(integrand, t[i], t[i + 1], + atol = abstol, rtol = reltol)[1] + else + res += quadgk(integrand, t[i], t[i + 1], atol = abstol, rtol = reltol)[1] + end if t[i] == t[i + 1] for cb in callback.discrete_callbacks if t[i] ∈ cb.affect!.event_times diff --git a/test/adjoint_oop.jl b/test/adjoint_oop.jl index 498a68afe..12da13a96 100644 --- a/test/adjoint_oop.jl +++ b/test/adjoint_oop.jl @@ -2,9 +2,108 @@ using SciMLSensitivity, OrdinaryDiffEq, SimpleChains, StaticArrays, QuadGK, Forw Zygote using Test +##Adjoints of numerical solve + +u0 = @SVector [1.0f0, 1.0f0] +p = @SMatrix [1.5f0 -1.0f0; 3.0f0 -1.0f0] +tspan = [0.0f0, 5.0f0] +datasize = 20 +tsteps = range(tspan[1], tspan[2], length = datasize) + +function f(u, p, t) + p*u +end + +prob = ODEProblem(f, u0, tspan, p) +sol = solve(prob, Tsit5(), saveat=tsteps, abstol = 1e-12, reltol = 1e-12) + +## Discrete Case +dg_disc(u, p, t, i; outtype = nothing) = u .- 1 + +du0, dp = adjoint_sensitivities(sol, Tsit5(); t = tsteps, dgdu_discrete = dg_disc, + sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP())) + +## with ForwardDiff +function G_p(p) + tmp_prob = remake(prob, p = p) + u = Array(solve(tmp_prob, Tsit5(), saveat = tsteps, + sensealg = SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12)) + + return sum(((1 .- u) .^ 2) ./ 2) +end + +function G_u(u0) + tmp_prob = remake(prob, u0 = u0) + u = Array(solve(tmp_prob, Tsit5(), saveat = tsteps, + sensealg = SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12)) + return sum(((1 .- u) .^ 2) ./ 2) +end + +G_p(p) +G_u(u0) +n_dp = ForwardDiff.gradient(G_p, p) +n_du0 = ForwardDiff.gradient(G_u, u0) + +@test n_du0 ≈ du0 rtol = 1e-3 +@test_broken n_dp ≈ dp' rtol = 1e-3 +@test sum(n_dp - dp') < 8.0 + +## Continuous Case + +g(u, p, t) = sum((u.^2)./2) + +function dg(u, p, t) + u +end + +du0, dp = adjoint_sensitivities(sol, Tsit5(); dgdu_continuous = dg, g = g, + sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP())) + +@test !iszero(du0) +@test !iszero(dp) + +##numerical + +function G_p(p) + tmp_prob = remake(prob, p = p) + sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12) + res, err = quadgk((t) -> (sum((sol(t).^2)./2)), 0.0, 5.0, atol = 1e-12, + rtol = 1e-12) + res +end + +function G_u(u0) + tmp_prob = remake(prob, u0 = u0) + sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12) + res, err = quadgk((t) -> (sum((sol(t).^2)./2)), 0.0, 5.0, atol = 1e-12, + rtol = 1e-12) + res +end + +n_du0 = ForwardDiff.gradient(G_u, u0) +n_dp = ForwardDiff.gradient(G_p, p) + +@test_broken n_du0 ≈ du0 rtol=1e-3 +@test_broken n_dp ≈ dp' rtol=1e-3 + +@test sum(n_du0 - du0) < 1.0 +@test sum(n_dp - dp) < 5.0 + +## concrete solve + +du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob, Tsit5(), u0, p, + abstol = 1e-6, reltol = 1e-6, + saveat = tsteps, + sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP()))), + u0, p) + +@test !iszero(du0) +@test !iszero(dp) -##### + + +##Neural ODE adjoint with SimpleChains u0 = @SArray Float32[2.0, 0.0] datasize = 30 tspan = (0.0f0, 1.5f0) @@ -108,34 +207,4 @@ du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob_nn, Tsit5(), u0, p, u0, p_nn) @test !iszero(du0) -@test !iszero(dp) - -#####Delete################################################################ -using Flux - -u0 = Float32[2.0; 0.0] -datasize = 30 -tspan = (0.0f0, 1.5f0) -tsteps = range(tspan[1], tspan[2], length = datasize) - -function trueODEfunc(u, p, t) - true_A = [-0.1f0 2.0f0; -2.0f0 -0.1f0] - return ((u.^3)'true_A)' -end - -prob_trueode = ODEProblem(trueODEfunc, u0, tspan) -ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps)) - -dudt2 = Chain((x) -> x.^3, - Dense(2, 50, tanh), - Dense(50, 2)) -p, re = Flux.destructure(dudt2) -f(u, p, t) = re(p)(u) - -prob_nn = ODEProblem(f, u0, tspan) - -du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob_nn, Tsit5(), u0, p, - abstol = 1e-12, reltol = 1e-12, - saveat = tsteps, - sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP()))), - u0, p) \ No newline at end of file +@test !iszero(dp) \ No newline at end of file From 15fadafe5370b0f6477163a6b44031a42ccc6563 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Wed, 20 Jul 2022 13:33:50 +0530 Subject: [PATCH 15/25] bug fix, formatter --- src/quadrature_adjoint.jl | 12 +++++++++--- test/adjoint_oop.jl | 41 +++++++++++++++++++-------------------- 2 files changed, 29 insertions(+), 24 deletions(-) diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index f73315317..be099f7d8 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -106,7 +106,13 @@ end (dgdu_continuous === nothing && dgdp_continuous === nothing || g !== nothing)) - λ = zero(u0) + if ArrayInterfaceCore.ismutable(u0) + len = length(u0) + λ = similar(u0, len) + λ .= false + else + λ = zero(u0) + end sense = ODEQuadratureAdjointSensitivityFunction(g, sensealg, discrete, sol, dgdu_continuous, dgdp_continuous) @@ -334,10 +340,10 @@ function _adjoint_sensitivities(sol, sensealg::QuadratureAdjoint, alg; t = nothi for i in (length(t) - 1):-1:1 if ArrayInterfaceCore.ismutable(res) res .+= quadgk(integrand, t[i], t[i + 1], - atol = abstol, rtol = reltol)[1] + atol = abstol, rtol = reltol)[1] else res += quadgk(integrand, t[i], t[i + 1], - atol = abstol, rtol = reltol)[1] + atol = abstol, rtol = reltol)[1] end if t[i] == t[i + 1] for cb in callback.discrete_callbacks diff --git a/test/adjoint_oop.jl b/test/adjoint_oop.jl index 12da13a96..2165b8eaa 100644 --- a/test/adjoint_oop.jl +++ b/test/adjoint_oop.jl @@ -11,11 +11,11 @@ datasize = 20 tsteps = range(tspan[1], tspan[2], length = datasize) function f(u, p, t) - p*u + p * u end prob = ODEProblem(f, u0, tspan, p) -sol = solve(prob, Tsit5(), saveat=tsteps, abstol = 1e-12, reltol = 1e-12) +sol = solve(prob, Tsit5(), saveat = tsteps, abstol = 1e-12, reltol = 1e-12) ## Discrete Case dg_disc(u, p, t, i; outtype = nothing) = u .- 1 @@ -27,7 +27,7 @@ du0, dp = adjoint_sensitivities(sol, Tsit5(); t = tsteps, dgdu_discrete = dg_dis function G_p(p) tmp_prob = remake(prob, p = p) u = Array(solve(tmp_prob, Tsit5(), saveat = tsteps, - sensealg = SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12)) + sensealg = SensitivityADPassThrough(), abstol = 1e-12, reltol = 1e-12)) return sum(((1 .- u) .^ 2) ./ 2) end @@ -35,7 +35,7 @@ end function G_u(u0) tmp_prob = remake(prob, u0 = u0) u = Array(solve(tmp_prob, Tsit5(), saveat = tsteps, - sensealg = SensitivityADPassThrough(), abstol=1e-12, reltol=1e-12)) + sensealg = SensitivityADPassThrough(), abstol = 1e-12, reltol = 1e-12)) return sum(((1 .- u) .^ 2) ./ 2) end @@ -44,13 +44,13 @@ G_u(u0) n_dp = ForwardDiff.gradient(G_p, p) n_du0 = ForwardDiff.gradient(G_u, u0) -@test n_du0 ≈ du0 rtol = 1e-3 -@test_broken n_dp ≈ dp' rtol = 1e-3 +@test n_du0≈du0 rtol=1e-3 +@test_broken n_dp≈dp' rtol=1e-3 @test sum(n_dp - dp') < 8.0 ## Continuous Case -g(u, p, t) = sum((u.^2)./2) +g(u, p, t) = sum((u .^ 2) ./ 2) function dg(u, p, t) u @@ -67,7 +67,7 @@ du0, dp = adjoint_sensitivities(sol, Tsit5(); dgdu_continuous = dg, g = g, function G_p(p) tmp_prob = remake(prob, p = p) sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12) - res, err = quadgk((t) -> (sum((sol(t).^2)./2)), 0.0, 5.0, atol = 1e-12, + res, err = quadgk((t) -> (sum((sol(t) .^ 2) ./ 2)), 0.0, 5.0, atol = 1e-12, rtol = 1e-12) res end @@ -75,7 +75,7 @@ end function G_u(u0) tmp_prob = remake(prob, u0 = u0) sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12) - res, err = quadgk((t) -> (sum((sol(t).^2)./2)), 0.0, 5.0, atol = 1e-12, + res, err = quadgk((t) -> (sum((sol(t) .^ 2) ./ 2)), 0.0, 5.0, atol = 1e-12, rtol = 1e-12) res end @@ -83,8 +83,8 @@ end n_du0 = ForwardDiff.gradient(G_u, u0) n_dp = ForwardDiff.gradient(G_p, p) -@test_broken n_du0 ≈ du0 rtol=1e-3 -@test_broken n_dp ≈ dp' rtol=1e-3 +@test_broken n_du0≈du0 rtol=1e-3 +@test_broken n_dp≈dp' rtol=1e-3 @test sum(n_du0 - du0) < 1.0 @test sum(n_dp - dp) < 5.0 @@ -100,9 +100,6 @@ du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob, Tsit5(), u0, p, @test !iszero(du0) @test !iszero(dp) - - - ##Neural ODE adjoint with SimpleChains u0 = @SArray Float32[2.0, 0.0] datasize = 30 @@ -158,8 +155,8 @@ G_u(u0) n_dp = ForwardDiff.gradient(G_p, p_nn) n_du0 = ForwardDiff.gradient(G_u, u0) -@test n_du0 ≈ du0 rtol = 1e-3 -@test n_dp ≈ dp' rtol = 1e-3 +@test n_du0≈du0 rtol=1e-3 +@test n_dp≈dp' rtol=1e-3 ## Continuous case @@ -179,7 +176,8 @@ du0, dp = adjoint_sensitivities(sol, Tsit5(); dgdu_continuous = dg, g = G, function G_p(p) tmp_prob = remake(prob_nn, p = p) sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12) - res, err = quadgk((t) -> (sum(((sol_n(t) .- sol(t)).^2)./2)), 0.0, 1.5, atol = 1e-12, + res, err = quadgk((t) -> (sum(((sol_n(t) .- sol(t)) .^ 2) ./ 2)), 0.0, 1.5, + atol = 1e-12, rtol = 1e-12) # sol_n(t):numerical solution/data(above) res end @@ -187,7 +185,8 @@ end function G_u(u0) tmp_prob = remake(prob_nn, u0 = u0) sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12) - res, err = quadgk((t) -> (sum(((sol_n(t) .- sol(t)).^2)./2)), 0.0, 1.5, atol = 1e-12, + res, err = quadgk((t) -> (sum(((sol_n(t) .- sol(t)) .^ 2) ./ 2)), 0.0, 1.5, + atol = 1e-12, rtol = 1e-12) # sol_n(t):numerical solution/data(above) res end @@ -195,8 +194,8 @@ end n_du0 = ForwardDiff.gradient(G_u, u0) n_dp = ForwardDiff.gradient(G_p, p_nn) -@test n_du0 ≈ du0 rtol=1e-3 -@test n_dp ≈ dp' rtol=1e-3 +@test n_du0≈du0 rtol=1e-3 +@test n_dp≈dp' rtol=1e-3 #concrete_solve @@ -207,4 +206,4 @@ du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob_nn, Tsit5(), u0, p, u0, p_nn) @test !iszero(du0) -@test !iszero(dp) \ No newline at end of file +@test !iszero(dp) From 40b7f2a8374baabe569384fa54ed2082f5dd7464 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Wed, 20 Jul 2022 14:45:47 +0530 Subject: [PATCH 16/25] fix --- src/adjoint_common.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 895c2d6e1..de0b92a3b 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -393,9 +393,12 @@ function ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time) @unpack factorized_mass_matrix = sensefun.diffcache prob = getprob(sensefun) idx = length(prob.u0) - - return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix, - sensealg, dgdu, dgdp, sensefun.diffcache, sensefun.sol) + if ArrayInterfaceCore.ismutable(y) + return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix, + sensealg, dgdu, dgdp, sensefun.diffcache, nothing) + else + return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix, + sensealg, dgdu, dgdp, sensefun.diffcache, sensefun.sol) end function (f::ReverseLossCallback)(integrator) From 651759104eb291801999c2e44b278c7204f68930 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Wed, 20 Jul 2022 14:52:11 +0530 Subject: [PATCH 17/25] fix error --- src/adjoint_common.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index de0b92a3b..582a7ec1b 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -395,10 +395,11 @@ function ReverseLossCallback(sensefun, λ, t, dgdu, dgdp, cur_time) idx = length(prob.u0) if ArrayInterfaceCore.ismutable(y) return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix, - sensealg, dgdu, dgdp, sensefun.diffcache, nothing) + sensealg, dgdu, dgdp, sensefun.diffcache, nothing) else return ReverseLossCallback(isq, λ, t, y, cur_time, idx, factorized_mass_matrix, - sensealg, dgdu, dgdp, sensefun.diffcache, sensefun.sol) + sensealg, dgdu, dgdp, sensefun.diffcache, sensefun.sol) + end end function (f::ReverseLossCallback)(integrator) From e6cf65baa0ca954629a1df965b4071abf7a56a9b Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Wed, 20 Jul 2022 15:05:11 +0530 Subject: [PATCH 18/25] Added adjoint_oop.jl to runtests.jl --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 71ef461d9..0ce1a99e4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -43,6 +43,7 @@ end @time @safetestset "Adjoint Sensitivity" begin include("adjoint.jl") end @time @safetestset "Continuous adjoint params" begin include("adjoint_param.jl") end @time @safetestset "Continuous and discrete costs" begin include("mixed_costs.jl") end + @time @safetestset "Fully Out of Place adjoint sensitivity with StaticArrays and SimpleChains" begin include("adjoint_oop.jl") end end if GROUP == "All" || GROUP == "Core4" From d354ba79c801babe34fad29d555eec0da02e74e7 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Fri, 12 Aug 2022 20:11:24 +0530 Subject: [PATCH 19/25] StaticArrays AD, tests --- Project.toml | 3 +- src/SciMLSensitivity.jl | 4 +- src/quadrature_adjoint.jl | 3 +- src/staticarrays.jl | 23 ++++ test/adjoint_oop.jl | 274 ++++++++++++++++++-------------------- test/runtests.jl | 2 +- 6 files changed, 159 insertions(+), 150 deletions(-) create mode 100644 src/staticarrays.jl diff --git a/Project.toml b/Project.toml index 1722f5160..49efce3e1 100644 --- a/Project.toml +++ b/Project.toml @@ -34,6 +34,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -88,9 +89,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" -SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 60f61416f..9c7a489ad 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -14,6 +14,7 @@ import ZygoteRules, Zygote, ReverseDiff import ArrayInterfaceCore, ArrayInterfaceTracker import Enzyme import GPUArraysCore +using StaticArrays import PreallocationTools: dualcache, get_tmp, DiffCache @@ -24,7 +25,7 @@ using EllipsisNotation using Markdown using Reexport -import ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented +import ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ProjectTo, project_type, _eltype_projectto, rrule abstract type SensitivityFunction end abstract type TransformedFunction end @@ -45,6 +46,7 @@ include("concrete_solve.jl") include("second_order.jl") include("steadystate_adjoint.jl") include("sde_tools.jl") +include("staticarrays.jl") # AD Extensions include("reversediff.jl") diff --git a/src/quadrature_adjoint.jl b/src/quadrature_adjoint.jl index 593a980ad..4391b8943 100644 --- a/src/quadrature_adjoint.jl +++ b/src/quadrature_adjoint.jl @@ -228,8 +228,7 @@ end function (S::AdjointSensitivityIntegrand)(out, t) @unpack y, λ, pJ, pf, p, f_cache, dgdp_cache, paramjac_config, sensealg, sol, adj_sol = S f = sol.prob.f - # if eltype(sol.u) <: StaticArrays.SArray - if ArrayInterfaceCore.ismutable(eltype(sol.u)) + if ArrayInterfaceCore.ismutable(y) sol(y, t) adj_sol(λ, t) else diff --git a/src/staticarrays.jl b/src/staticarrays.jl new file mode 100644 index 000000000..1cb924882 --- /dev/null +++ b/src/staticarrays.jl @@ -0,0 +1,23 @@ +### Projecting a tuple to SMatrix leads to ChainRulesCore._projection_mismatch by default, so overloaded here +function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::StaticArrays.SArray) + dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray + dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements)) + return project_type(project)(dz...) +end + +### Project SArray to SArray +function ProjectTo(x::StaticArrays.SArray{S,T}) where {S, T} + return ProjectTo{StaticArrays.SArray}(; element=_eltype_projectto(T), axes=S) +end + +function (project::ProjectTo{StaticArrays.SArray})(dx::AbstractArray{S,M}) where {S,M} + return StaticArrays.SArray{project.axes}(dx) +end + +### Adjoint for SArray constructor + +function rrule(::Type{T}, x::Tuple) where {T<:StaticArrays.SArray} + project_x = ProjectTo(x) + Array_pullback(ȳ) = (NoTangent(), project_x(ȳ)) + return T(x), Array_pullback +end \ No newline at end of file diff --git a/test/adjoint_oop.jl b/test/adjoint_oop.jl index 2165b8eaa..630c6f8ac 100644 --- a/test/adjoint_oop.jl +++ b/test/adjoint_oop.jl @@ -1,52 +1,134 @@ -using SciMLSensitivity, OrdinaryDiffEq, SimpleChains, StaticArrays, QuadGK, ForwardDiff, +using SciMLSensitivity, OrdinaryDiffEq, StaticArrays, QuadGK, ForwardDiff, Zygote using Test -##Adjoints of numerical solve +##StaticArrays rrule +u0 = @SVector rand(2) +p = @SVector rand(4) + +function lotka(u, p, svec=true) + du1 = p[1]*u[1] - p[2]*u[1]*u[2] + du2 = -p[3]*u[2] + p[4]*u[1]*u[2] + if svec + @SVector [du1, du2] + else + @SMatrix [du1 du2 du1; du2 du1 du1] + end +end + +#SVector constructor adjoint +function loss(p) + u = lotka(u0, p) + sum(1 .- u) +end + +grad = Zygote.gradient(loss, p) +@test typeof(grad[1]) <: SArray +grad2 = ForwardDiff.gradient(loss, p) +@test grad[1] ≈ grad2 rtol=1e-12 + +#SMatrix constructor adjoint +function loss_mat(p) + u = lotka(u0, p, false) + sum(1 .- u) +end + +grad = Zygote.gradient(loss_mat, p) +@test typeof(grad[1]) <: SArray +grad2 = ForwardDiff.gradient(loss_mat, p) +@test grad[1] ≈ grad2 rtol=1e-12 -u0 = @SVector [1.0f0, 1.0f0] -p = @SMatrix [1.5f0 -1.0f0; 3.0f0 -1.0f0] -tspan = [0.0f0, 5.0f0] -datasize = 20 +##Adjoints of StaticArrays ODE + +u0 = @SVector [1.0, 1.0] +p = @SVector [1.5,1.0,3.0,1.0] +tspan = (0.0, 5.0) +datasize = 15 tsteps = range(tspan[1], tspan[2], length = datasize) -function f(u, p, t) - p * u +function lotka(u, p, t) + du1 = p[1]*u[1] - p[2]*u[1]*u[2] + du2 = -p[3]*u[2] + p[4]*u[1]*u[2] + @SVector [du1, du2] end -prob = ODEProblem(f, u0, tspan, p) -sol = solve(prob, Tsit5(), saveat = tsteps, abstol = 1e-12, reltol = 1e-12) +prob = ODEProblem(lotka, u0, tspan, p) +sol = solve(prob, Tsit5(), saveat = tsteps, abstol = 1e-14, reltol = 1e-14) ## Discrete Case -dg_disc(u, p, t, i; outtype = nothing) = u .- 1 +dg_disc(u, p, t, i; outtype = nothing) = u du0, dp = adjoint_sensitivities(sol, Tsit5(); t = tsteps, dgdu_discrete = dg_disc, - sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP())) + sensealg = QuadratureAdjoint(abstol=1e-14, reltol=1e-14, autojacvec = ZygoteVJP())) -## with ForwardDiff -function G_p(p) - tmp_prob = remake(prob, p = p) - u = Array(solve(tmp_prob, Tsit5(), saveat = tsteps, - sensealg = SensitivityADPassThrough(), abstol = 1e-12, reltol = 1e-12)) +@test !iszero(du0) +@test !iszero(dp) +# +adj_prob = ODEAdjointProblem(sol, + QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = SciMLSensitivity.ZygoteVJP()), + tsteps, dg_disc) +adj_sol = solve(adj_prob, Tsit5(), abstol = 1e-14, reltol = 1e-14) +integrand = AdjointSensitivityIntegrand(sol, adj_sol, + QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = SciMLSensitivity.ZygoteVJP())) +res, err = quadgk(integrand, 0.0, 5.0, atol = 1e-14, rtol = 1e-14) + +@test adj_sol[end] ≈ du0 rtol=1e-12 +@test res ≈ dp rtol=1e-12 + +###Comparing with gradients of lotka volterra with normal arrays +u2 = [1.0, 1.0] +p2 = [1.5,1.0,3.0,1.0] + +function f(u, p, t) + du1 = p[1]*u[1] - p[2]*u[1]*u[2] + du2 = -p[3]*u[2] + p[4]*u[1]*u[2] + [du1, du2] +end + +prob2 = ODEProblem(f, u2, tspan, p2) +sol2 = solve(prob, Tsit5(), saveat = tsteps, abstol = 1e-14, reltol = 1e-14) + +function dg_disc(du, u, p, t, i) + du .= u +end + +du1, dp1 = adjoint_sensitivities(sol, Tsit5(); t = tsteps, dgdu_discrete = dg_disc, + sensealg = QuadratureAdjoint(abstol=1e-14, reltol=1e-14, autojacvec = ZygoteVJP())) + +@test du0 ≈ du1 rtol=1e-12 +@test dp ≈ dp1 rtol=1e-12 + +## with ForwardDiff and Zygote +function G_p(p) + tmp_prob = remake(prob, u0 = convert.(eltype(p), prob.u0), p = p) + sol = solve(tmp_prob, Tsit5(), abstol = 1e-14, reltol = 1e-14, + sensealg = QuadratureAdjoint(abstol=1e-14, reltol=1e-14, autojacvec=ZygoteVJP()), saveat = tsteps) + u = Array(sol) return sum(((1 .- u) .^ 2) ./ 2) end function G_u(u0) - tmp_prob = remake(prob, u0 = u0) - u = Array(solve(tmp_prob, Tsit5(), saveat = tsteps, - sensealg = SensitivityADPassThrough(), abstol = 1e-12, reltol = 1e-12)) + tmp_prob = remake(prob, u0 = u0, p = prob.p) + sol = solve(tmp_prob, Tsit5(), saveat = tsteps, + sensealg = QuadratureAdjoint(abstol=1e-14, reltol=1e-14, autojacvec=ZygoteVJP()), abstol = 1e-14, reltol = 1e-14) + u = Array(sol) + return sum(((1 .- u) .^ 2) ./ 2) end G_p(p) G_u(u0) -n_dp = ForwardDiff.gradient(G_p, p) -n_du0 = ForwardDiff.gradient(G_u, u0) +f_dp = ForwardDiff.gradient(G_p, p) +f_du0 = ForwardDiff.gradient(G_u, u0) -@test n_du0≈du0 rtol=1e-3 -@test_broken n_dp≈dp' rtol=1e-3 -@test sum(n_dp - dp') < 8.0 +z_dp = Zygote.gradient(G_p, p) +z_du0 = Zygote.gradient(G_u, u0) + +@test z_du0[1] ≈ f_du0 rtol=1e-12 +@test z_dp[1] ≈ f_dp rtol=1e-12 ## Continuous Case @@ -57,12 +139,25 @@ function dg(u, p, t) end du0, dp = adjoint_sensitivities(sol, Tsit5(); dgdu_continuous = dg, g = g, - sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP())) + sensealg = QuadratureAdjoint(abstol=1e-14, reltol=1e-14, autojacvec = ZygoteVJP())) @test !iszero(du0) @test !iszero(dp) -##numerical +adj_prob = ODEAdjointProblem(sol, + QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = SciMLSensitivity.ZygoteVJP()), + nothing, nothing, nothing, dg, nothing, g) +adj_sol = solve(adj_prob, Tsit5(), abstol = 1e-14, reltol = 1e-14) +integrand = AdjointSensitivityIntegrand(sol, adj_sol, + QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = SciMLSensitivity.ZygoteVJP())) +res, err = quadgk(integrand, 0.0, 5.0, atol = 1e-14, rtol = 1e-14) + +@test adj_sol[end] ≈ du0 rtol=1e-12 +@test res ≈ dp rtol=1e-12 + +##ForwardDiff function G_p(p) tmp_prob = remake(prob, p = p) @@ -80,130 +175,21 @@ function G_u(u0) res end -n_du0 = ForwardDiff.gradient(G_u, u0) -n_dp = ForwardDiff.gradient(G_p, p) +f_du0 = ForwardDiff.gradient(G_u, u0) +f_dp = ForwardDiff.gradient(G_p, p) + -@test_broken n_du0≈du0 rtol=1e-3 -@test_broken n_dp≈dp' rtol=1e-3 +@test !iszero(f_du0) +@test !iszero(f_dp) -@test sum(n_du0 - du0) < 1.0 -@test sum(n_dp - dp) < 5.0 ## concrete solve du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob, Tsit5(), u0, p, - abstol = 1e-6, reltol = 1e-6, + abstol = 1e-10, reltol = 1e-10, saveat = tsteps, - sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP()))), + sensealg = QuadratureAdjoint(abstol=1e-14, reltol=1e-14,autojacvec = ZygoteVJP()))), u0, p) @test !iszero(du0) -@test !iszero(dp) - -##Neural ODE adjoint with SimpleChains -u0 = @SArray Float32[2.0, 0.0] -datasize = 30 -tspan = (0.0f0, 1.5f0) -tsteps = range(tspan[1], tspan[2], length = datasize) - -function trueODE(u, p, t) - true_A = @SMatrix Float32[-0.1 2.0; -2.0 -0.1] - ((u .^ 3)'true_A)' -end - -prob = ODEProblem(trueODE, u0, tspan) -sol_n = solve(prob, Tsit5(), saveat = tsteps) -data = Array(solve(prob, Tsit5(), saveat = tsteps)) - -sc = SimpleChain(static(2), - Activation(x -> x .^ 3), - TurboDense{true}(tanh, static(50)), - TurboDense{true}(identity, static(2))) - -p_nn = SimpleChains.init_params(sc) - -df(u, p, t) = sc(u, p) - -prob_nn = ODEProblem(df, u0, tspan, p_nn) -sol = solve(prob_nn, Tsit5(); saveat = tsteps) - -dg_disc(u, p, t, i; outtype = nothing) = u .- data[:, i] - -du0, dp = adjoint_sensitivities(sol, Tsit5(); t = tsteps, dgdu_discrete = dg_disc, - sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP())) - -@test !iszero(du0) -@test !iszero(dp) - -## numerical - -function G_p(p) - tmp_prob = remake(prob_nn, u0 = prob_nn.u0, p = p) - A = Array(solve(tmp_prob, Tsit5(), saveat = tsteps, - sensealg = SensitivityADPassThrough())) - - return sum(((data .- A) .^ 2) ./ 2) -end -function G_u(u0) - tmp_prob = remake(prob_nn, u0 = u0, p = p_nn) - A = Array(solve(tmp_prob, Tsit5(), saveat = tsteps, - sensealg = SensitivityADPassThrough())) - return sum(((data .- A) .^ 2) ./ 2) -end -G_p(p_nn) -G_u(u0) -n_dp = ForwardDiff.gradient(G_p, p_nn) -n_du0 = ForwardDiff.gradient(G_u, u0) - -@test n_du0≈du0 rtol=1e-3 -@test n_dp≈dp' rtol=1e-3 - -## Continuous case - -G(u, p, t) = sum(((data .- u) .^ 2) ./ 2) - -function dg(u, p, t) - return u .- Array(sol_n(t)) -end - -du0, dp = adjoint_sensitivities(sol, Tsit5(); dgdu_continuous = dg, g = G, - sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP())) - -@test !iszero(du0) -@test !iszero(dp) -##numerical - -function G_p(p) - tmp_prob = remake(prob_nn, p = p) - sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12) - res, err = quadgk((t) -> (sum(((sol_n(t) .- sol(t)) .^ 2) ./ 2)), 0.0, 1.5, - atol = 1e-12, - rtol = 1e-12) # sol_n(t):numerical solution/data(above) - res -end - -function G_u(u0) - tmp_prob = remake(prob_nn, u0 = u0) - sol = solve(tmp_prob, Tsit5(), abstol = 1e-12, reltol = 1e-12) - res, err = quadgk((t) -> (sum(((sol_n(t) .- sol(t)) .^ 2) ./ 2)), 0.0, 1.5, - atol = 1e-12, - rtol = 1e-12) # sol_n(t):numerical solution/data(above) - res -end - -n_du0 = ForwardDiff.gradient(G_u, u0) -n_dp = ForwardDiff.gradient(G_p, p_nn) - -@test n_du0≈du0 rtol=1e-3 -@test n_dp≈dp' rtol=1e-3 - -#concrete_solve - -du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob_nn, Tsit5(), u0, p, - abstol = 1e-12, reltol = 1e-12, - saveat = tsteps, - sensealg = QuadratureAdjoint(autojacvec = ZygoteVJP()))), - u0, p_nn) - -@test !iszero(du0) -@test !iszero(dp) +@test !iszero(dp) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 0747a802f..ee77951da 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,7 +45,7 @@ end @time @safetestset "Adjoint Sensitivity" begin include("adjoint.jl") end @time @safetestset "Continuous adjoint params" begin include("adjoint_param.jl") end @time @safetestset "Continuous and discrete costs" begin include("mixed_costs.jl") end - @time @safetestset "Fully Out of Place adjoint sensitivity with StaticArrays and SimpleChains" begin include("adjoint_oop.jl") end + @time @safetestset "Fully Out of Place adjoint sensitivity" begin include("adjoint_oop.jl") end end if GROUP == "All" || GROUP == "Core4" From eb850e9236b0b8538fdfa73d2d9569b76c980bb8 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Fri, 12 Aug 2022 20:15:41 +0530 Subject: [PATCH 20/25] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 49efce3e1..194a7d38c 100644 --- a/Project.toml +++ b/Project.toml @@ -94,4 +94,4 @@ SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AlgebraicMultigrid", "ComponentArrays", "Calculus", "Distributed", "DelayDiffEq", "Optimization", "OptimizationFlux", "OptimizationOptimJL", "Flux", "ReverseDiff", "SafeTestsets", "Test", "Random", "Pkg", "SteadyStateDiffEq", "NLsolve", "NonlinearSolve", "SparseArrays", "SimpleChains", "StaticArrays"] +test = ["AlgebraicMultigrid", "ComponentArrays", "Calculus", "Distributed", "DelayDiffEq", "Optimization", "OptimizationFlux", "OptimizationOptimJL", "Flux", "ReverseDiff", "SafeTestsets", "Test", "Random", "Pkg", "SteadyStateDiffEq", "NLsolve", "NonlinearSolve", "SparseArrays", "StaticArrays"] From 5971d27441866e40c5f4b7cace97302f17fc8721 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Fri, 12 Aug 2022 20:36:19 +0530 Subject: [PATCH 21/25] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 194a7d38c..84376470d 100644 --- a/Project.toml +++ b/Project.toml @@ -94,4 +94,4 @@ SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AlgebraicMultigrid", "ComponentArrays", "Calculus", "Distributed", "DelayDiffEq", "Optimization", "OptimizationFlux", "OptimizationOptimJL", "Flux", "ReverseDiff", "SafeTestsets", "Test", "Random", "Pkg", "SteadyStateDiffEq", "NLsolve", "NonlinearSolve", "SparseArrays", "StaticArrays"] +test = ["AlgebraicMultigrid", "ComponentArrays", "Calculus", "Distributed", "DelayDiffEq", "Optimization", "OptimizationFlux", "OptimizationOptimJL", "Flux", "ReverseDiff", "SafeTestsets", "Test", "Random", "Pkg", "SteadyStateDiffEq", "NLsolve", "NonlinearSolve", "SparseArrays"] From d9f7ed001cc5743aab9878af1db996f6a3630106 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Fri, 12 Aug 2022 20:59:45 +0530 Subject: [PATCH 22/25] formatter --- src/SciMLSensitivity.jl | 3 +- src/staticarrays.jl | 10 +++--- test/adjoint_oop.jl | 67 +++++++++++++++++++++++------------------ 3 files changed, 44 insertions(+), 36 deletions(-) diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index 9c7a489ad..cf2769680 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -25,7 +25,8 @@ using EllipsisNotation using Markdown using Reexport -import ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ProjectTo, project_type, _eltype_projectto, rrule +import ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ProjectTo, + project_type, _eltype_projectto, rrule abstract type SensitivityFunction end abstract type TransformedFunction end diff --git a/src/staticarrays.jl b/src/staticarrays.jl index 1cb924882..11cc1aaad 100644 --- a/src/staticarrays.jl +++ b/src/staticarrays.jl @@ -6,18 +6,18 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::StaticArrays.SArray) end ### Project SArray to SArray -function ProjectTo(x::StaticArrays.SArray{S,T}) where {S, T} - return ProjectTo{StaticArrays.SArray}(; element=_eltype_projectto(T), axes=S) +function ProjectTo(x::StaticArrays.SArray{S, T}) where {S, T} + return ProjectTo{StaticArrays.SArray}(; element = _eltype_projectto(T), axes = S) end -function (project::ProjectTo{StaticArrays.SArray})(dx::AbstractArray{S,M}) where {S,M} +function (project::ProjectTo{StaticArrays.SArray})(dx::AbstractArray{S, M}) where {S, M} return StaticArrays.SArray{project.axes}(dx) end ### Adjoint for SArray constructor -function rrule(::Type{T}, x::Tuple) where {T<:StaticArrays.SArray} +function rrule(::Type{T}, x::Tuple) where {T <: StaticArrays.SArray} project_x = ProjectTo(x) Array_pullback(ȳ) = (NoTangent(), project_x(ȳ)) return T(x), Array_pullback -end \ No newline at end of file +end diff --git a/test/adjoint_oop.jl b/test/adjoint_oop.jl index 630c6f8ac..99f2f9c3c 100644 --- a/test/adjoint_oop.jl +++ b/test/adjoint_oop.jl @@ -6,9 +6,9 @@ using Test u0 = @SVector rand(2) p = @SVector rand(4) -function lotka(u, p, svec=true) - du1 = p[1]*u[1] - p[2]*u[1]*u[2] - du2 = -p[3]*u[2] + p[4]*u[1]*u[2] +function lotka(u, p, svec = true) + du1 = p[1] * u[1] - p[2] * u[1] * u[2] + du2 = -p[3] * u[2] + p[4] * u[1] * u[2] if svec @SVector [du1, du2] else @@ -21,12 +21,12 @@ function loss(p) u = lotka(u0, p) sum(1 .- u) end - + grad = Zygote.gradient(loss, p) @test typeof(grad[1]) <: SArray grad2 = ForwardDiff.gradient(loss, p) -@test grad[1] ≈ grad2 rtol=1e-12 - +@test grad[1]≈grad2 rtol=1e-12 + #SMatrix constructor adjoint function loss_mat(p) u = lotka(u0, p, false) @@ -36,19 +36,19 @@ end grad = Zygote.gradient(loss_mat, p) @test typeof(grad[1]) <: SArray grad2 = ForwardDiff.gradient(loss_mat, p) -@test grad[1] ≈ grad2 rtol=1e-12 +@test grad[1]≈grad2 rtol=1e-12 ##Adjoints of StaticArrays ODE u0 = @SVector [1.0, 1.0] -p = @SVector [1.5,1.0,3.0,1.0] +p = @SVector [1.5, 1.0, 3.0, 1.0] tspan = (0.0, 5.0) datasize = 15 tsteps = range(tspan[1], tspan[2], length = datasize) function lotka(u, p, t) - du1 = p[1]*u[1] - p[2]*u[1]*u[2] - du2 = -p[3]*u[2] + p[4]*u[1]*u[2] + du1 = p[1] * u[1] - p[2] * u[1] * u[2] + du2 = -p[3] * u[2] + p[4] * u[1] * u[2] @SVector [du1, du2] end @@ -59,7 +59,8 @@ sol = solve(prob, Tsit5(), saveat = tsteps, abstol = 1e-14, reltol = 1e-14) dg_disc(u, p, t, i; outtype = nothing) = u du0, dp = adjoint_sensitivities(sol, Tsit5(); t = tsteps, dgdu_discrete = dg_disc, - sensealg = QuadratureAdjoint(abstol=1e-14, reltol=1e-14, autojacvec = ZygoteVJP())) + sensealg = QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = ZygoteVJP())) @test !iszero(du0) @test !iszero(dp) @@ -74,16 +75,16 @@ integrand = AdjointSensitivityIntegrand(sol, adj_sol, autojacvec = SciMLSensitivity.ZygoteVJP())) res, err = quadgk(integrand, 0.0, 5.0, atol = 1e-14, rtol = 1e-14) -@test adj_sol[end] ≈ du0 rtol=1e-12 -@test res ≈ dp rtol=1e-12 +@test adj_sol[end]≈du0 rtol=1e-12 +@test res≈dp rtol=1e-12 ###Comparing with gradients of lotka volterra with normal arrays u2 = [1.0, 1.0] -p2 = [1.5,1.0,3.0,1.0] +p2 = [1.5, 1.0, 3.0, 1.0] function f(u, p, t) - du1 = p[1]*u[1] - p[2]*u[1]*u[2] - du2 = -p[3]*u[2] + p[4]*u[1]*u[2] + du1 = p[1] * u[1] - p[2] * u[1] * u[2] + du2 = -p[3] * u[2] + p[4] * u[1] * u[2] [du1, du2] end @@ -95,17 +96,20 @@ function dg_disc(du, u, p, t, i) end du1, dp1 = adjoint_sensitivities(sol, Tsit5(); t = tsteps, dgdu_discrete = dg_disc, - sensealg = QuadratureAdjoint(abstol=1e-14, reltol=1e-14, autojacvec = ZygoteVJP())) + sensealg = QuadratureAdjoint(abstol = 1e-14, + reltol = 1e-14, + autojacvec = ZygoteVJP())) -@test du0 ≈ du1 rtol=1e-12 -@test dp ≈ dp1 rtol=1e-12 +@test du0≈du1 rtol=1e-12 +@test dp≈dp1 rtol=1e-12 ## with ForwardDiff and Zygote function G_p(p) tmp_prob = remake(prob, u0 = convert.(eltype(p), prob.u0), p = p) sol = solve(tmp_prob, Tsit5(), abstol = 1e-14, reltol = 1e-14, - sensealg = QuadratureAdjoint(abstol=1e-14, reltol=1e-14, autojacvec=ZygoteVJP()), saveat = tsteps) + sensealg = QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = ZygoteVJP()), saveat = tsteps) u = Array(sol) return sum(((1 .- u) .^ 2) ./ 2) end @@ -113,7 +117,9 @@ end function G_u(u0) tmp_prob = remake(prob, u0 = u0, p = prob.p) sol = solve(tmp_prob, Tsit5(), saveat = tsteps, - sensealg = QuadratureAdjoint(abstol=1e-14, reltol=1e-14, autojacvec=ZygoteVJP()), abstol = 1e-14, reltol = 1e-14) + sensealg = QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = ZygoteVJP()), abstol = 1e-14, + reltol = 1e-14) u = Array(sol) return sum(((1 .- u) .^ 2) ./ 2) @@ -127,8 +133,8 @@ f_du0 = ForwardDiff.gradient(G_u, u0) z_dp = Zygote.gradient(G_p, p) z_du0 = Zygote.gradient(G_u, u0) -@test z_du0[1] ≈ f_du0 rtol=1e-12 -@test z_dp[1] ≈ f_dp rtol=1e-12 +@test z_du0[1]≈f_du0 rtol=1e-12 +@test z_dp[1]≈f_dp rtol=1e-12 ## Continuous Case @@ -139,7 +145,8 @@ function dg(u, p, t) end du0, dp = adjoint_sensitivities(sol, Tsit5(); dgdu_continuous = dg, g = g, - sensealg = QuadratureAdjoint(abstol=1e-14, reltol=1e-14, autojacvec = ZygoteVJP())) + sensealg = QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, + autojacvec = ZygoteVJP())) @test !iszero(du0) @test !iszero(dp) @@ -154,8 +161,8 @@ integrand = AdjointSensitivityIntegrand(sol, adj_sol, autojacvec = SciMLSensitivity.ZygoteVJP())) res, err = quadgk(integrand, 0.0, 5.0, atol = 1e-14, rtol = 1e-14) -@test adj_sol[end] ≈ du0 rtol=1e-12 -@test res ≈ dp rtol=1e-12 +@test adj_sol[end]≈du0 rtol=1e-12 +@test res≈dp rtol=1e-12 ##ForwardDiff @@ -178,18 +185,18 @@ end f_du0 = ForwardDiff.gradient(G_u, u0) f_dp = ForwardDiff.gradient(G_p, p) - @test !iszero(f_du0) @test !iszero(f_dp) - ## concrete solve du0, dp = Zygote.gradient((u0, p) -> sum(concrete_solve(prob, Tsit5(), u0, p, abstol = 1e-10, reltol = 1e-10, saveat = tsteps, - sensealg = QuadratureAdjoint(abstol=1e-14, reltol=1e-14,autojacvec = ZygoteVJP()))), + sensealg = QuadratureAdjoint(abstol = 1e-14, + reltol = 1e-14, + autojacvec = ZygoteVJP()))), u0, p) @test !iszero(du0) -@test !iszero(dp) \ No newline at end of file +@test !iszero(dp) From e7c89050cfe98ce8e327454f69a9b8eb4b195307 Mon Sep 17 00:00:00 2001 From: Abhishek Bhatt Date: Fri, 12 Aug 2022 23:08:48 +0530 Subject: [PATCH 23/25] ODEAdjointProblem fix --- test/adjoint_oop.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/adjoint_oop.jl b/test/adjoint_oop.jl index 99f2f9c3c..c8e5883e6 100644 --- a/test/adjoint_oop.jl +++ b/test/adjoint_oop.jl @@ -68,7 +68,7 @@ du0, dp = adjoint_sensitivities(sol, Tsit5(); t = tsteps, dgdu_discrete = dg_dis adj_prob = ODEAdjointProblem(sol, QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, autojacvec = SciMLSensitivity.ZygoteVJP()), - tsteps, dg_disc) + Tsit5(), tsteps, dg_disc) adj_sol = solve(adj_prob, Tsit5(), abstol = 1e-14, reltol = 1e-14) integrand = AdjointSensitivityIntegrand(sol, adj_sol, QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, @@ -154,7 +154,7 @@ du0, dp = adjoint_sensitivities(sol, Tsit5(); dgdu_continuous = dg, g = g, adj_prob = ODEAdjointProblem(sol, QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, autojacvec = SciMLSensitivity.ZygoteVJP()), - nothing, nothing, nothing, dg, nothing, g) + Tsit5(), nothing, nothing, nothing, dg, nothing, g) adj_sol = solve(adj_prob, Tsit5(), abstol = 1e-14, reltol = 1e-14) integrand = AdjointSensitivityIntegrand(sol, adj_sol, QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14, From 02f1ae29b3c3a4ecd8db24ea0ef888028e6033c5 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Fri, 12 Aug 2022 23:17:19 -0400 Subject: [PATCH 24/25] format --- test/stiff_adjoints.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/stiff_adjoints.jl b/test/stiff_adjoints.jl index 174bd903f..5bc05665d 100644 --- a/test/stiff_adjoints.jl +++ b/test/stiff_adjoints.jl @@ -175,7 +175,7 @@ if VERSION >= v"1.7-" ROCK4(), RKC(), # SERK2v2(), not defined? - ESERK5()]; + ESERK5()] p = rand(3) From 7340ca4f76e40fd0712aa83fc3914a78f0c05d59 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 13 Aug 2022 00:30:10 -0400 Subject: [PATCH 25/25] no Enzyme v1.6 --- test/stiff_adjoints.jl | 77 ++++++++++++++++++++++-------------------- 1 file changed, 40 insertions(+), 37 deletions(-) diff --git a/test/stiff_adjoints.jl b/test/stiff_adjoints.jl index 5bc05665d..a866cdc53 100644 --- a/test/stiff_adjoints.jl +++ b/test/stiff_adjoints.jl @@ -214,45 +214,48 @@ if VERSION >= v"1.7-" dp1 = @test_broken Zygote.gradient(p -> loss(p, ReverseDiffAdjoint()), p)[1] @test_broken dp≈dp1 rtol=1e-2 end -end -using SciMLSensitivity, OrdinaryDiffEq, ForwardDiff, Zygote, Test + # using SciMLSensitivity, OrdinaryDiffEq, ForwardDiff, Zygote, Test -function rober(du, u, p, t) - y₁, y₂, y₃ = u - k₁, k₂, k₃ = p[1], p[2], p[3] - du[1] = -k₁ * y₁ + k₃ * y₂ * y₃ - du[2] = k₁ * y₁ - k₂ * y₂^2 - k₃ * y₂ * y₃ - du[3] = k₂ * y₂^2 + sum(p) - nothing -end + function rober(du, u, p, t) + y₁, y₂, y₃ = u + k₁, k₂, k₃ = p[1], p[2], p[3] + du[1] = -k₁ * y₁ + k₃ * y₂ * y₃ + du[2] = k₁ * y₁ - k₂ * y₂^2 - k₃ * y₂ * y₃ + du[3] = k₂ * y₂^2 + sum(p) + nothing + end -function sum_of_solution_fwd(x) - _prob = ODEProblem(rober, x[1:3], (0.0, 1e4), x[4:end]) - sum(solve(_prob, Rodas5(), saveat = 1, reltol = 1e-12, abstol = 1e-12)) -end + function sum_of_solution_fwd(x) + _prob = ODEProblem(rober, x[1:3], (0.0, 1e4), x[4:end]) + sum(solve(_prob, Rodas5(), saveat = 1, reltol = 1e-12, abstol = 1e-12)) + end -function sum_of_solution_CASA(x; vjp = EnzymeVJP()) - sensealg = QuadratureAdjoint(autodiff = false, autojacvec = vjp) - _prob = ODEProblem(rober, x[1:3], (0.0, 1e4), x[4:end]) - sum(solve(_prob, Rodas5(), reltol = 1e-8, abstol = 1e-8, saveat = 1, - sensealg = sensealg)) -end + function sum_of_solution_CASA(x; vjp = EnzymeVJP()) + sensealg = QuadratureAdjoint(autodiff = false, autojacvec = vjp) + _prob = ODEProblem(rober, x[1:3], (0.0, 1e4), x[4:end]) + sum(solve(_prob, Rodas5(), reltol = 1e-8, abstol = 1e-8, saveat = 1, + sensealg = sensealg)) + end -u0 = [1.0, 0.0, 0.0] -p = ones(8) # change me, the number of parameters - -grad1 = ForwardDiff.gradient(sum_of_solution_fwd, [u0; p]) -grad2 = Zygote.gradient(sum_of_solution_CASA, [u0; p])[1] -grad3 = Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = ReverseDiffVJP()), [u0; p])[1] -grad4 = Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = ReverseDiffVJP(true)), [u0; p])[1] -@test_throws Any Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = true), [u0; p])[1] -grad6 = Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = false), [u0; p])[1] -@test_throws Any Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = ZygoteVJP()), [u0; p])[1] -@test_throws Any Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = TrackerVJP()), [u0; p])[1] - -@test grad1 ≈ grad2 -@test grad1 ≈ grad3 -@test grad1 ≈ grad4 -#@test grad1 ≈ grad5 -@test grad1 ≈ grad6 + u0 = [1.0, 0.0, 0.0] + p = ones(8) # change me, the number of parameters + + grad1 = ForwardDiff.gradient(sum_of_solution_fwd, [u0; p]) + grad2 = Zygote.gradient(sum_of_solution_CASA, [u0; p])[1] + grad3 = Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = ReverseDiffVJP()), [u0; p])[1] + grad4 = Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = ReverseDiffVJP(true)), + [u0; p])[1] + @test_throws Any Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = true), [u0; p])[1] + grad6 = Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = false), [u0; p])[1] + @test_throws Any Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = ZygoteVJP()), + [u0; p])[1] + @test_throws Any Zygote.gradient(x -> sum_of_solution_CASA(x, vjp = TrackerVJP()), + [u0; p])[1] + + @test grad1 ≈ grad2 + @test grad1 ≈ grad3 + @test grad1 ≈ grad4 + #@test grad1 ≈ grad5 + @test grad1 ≈ grad6 +end