diff --git a/Project.toml b/Project.toml index da82c0c874..5dd809bd8e 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" diff --git a/src/DiffEqFlux.jl b/src/DiffEqFlux.jl index fbb7583c4b..ecd8e0f706 100644 --- a/src/DiffEqFlux.jl +++ b/src/DiffEqFlux.jl @@ -18,6 +18,7 @@ using Requires using Cassette @reexport using Flux @reexport using OptimizationOptimJL +using Functors import ChainRulesCore diff --git a/src/hnn.jl b/src/hnn.jl index faa8f04c89..4858ecb5ad 100644 --- a/src/hnn.jl +++ b/src/hnn.jl @@ -90,7 +90,7 @@ Arguments: documentation for more details. """ struct NeuralHamiltonianDE{M,P,RE,T,A,K} <: NeuralDELayer - hnn::HamiltonianNN{M,RE,P} + model::HamiltonianNN{M,RE,P} p::P tspan::T args::A @@ -112,7 +112,7 @@ end function (nhde::NeuralHamiltonianDE)(x, p = nhde.p) function neural_hamiltonian!(du, u, p, t) - du .= reshape(nhde.hnn(u, p), size(du)) + du .= reshape(nhde.model(u, p), size(du)) end prob = ODEProblem(neural_hamiltonian!, x, nhde.tspan, p) # NOTE: Nesting Zygote is an issue. So we can't use ZygoteVJP diff --git a/src/neural_de.jl b/src/neural_de.jl index a847d5c25e..150ed1da26 100644 --- a/src/neural_de.jl +++ b/src/neural_de.jl @@ -1,6 +1,6 @@ -abstract type NeuralDELayer <: Lux.AbstractExplicitLayer end +abstract type NeuralDELayer <: Lux.AbstractExplicitContainerLayer{(:model,)} end +abstract type NeuralSDELayer <: Lux.AbstractExplicitContainerLayer{(:drift,:diffusion,)} end basic_tgrad(u,p,t) = zero(u) -Flux.trainable(m::NeuralDELayer) = (m.p,) """ Constructs a continuous-time recurrant neural network, also known as a neural @@ -43,34 +43,33 @@ struct NeuralODE{M,P,RE,T,A,K} <: NeuralDELayer tspan::T args::A kwargs::K +end - function NeuralODE(model,tspan,args...;p = nothing,kwargs...) - _p,re = Flux.destructure(model) - if p === nothing - p = _p - end - new{typeof(model),typeof(p),typeof(re), +function NeuralODE(model,tspan,args...;p = nothing,kwargs...) + _p,re = Flux.destructure(model) + if p === nothing + p = _p + end + NeuralODE{typeof(model),typeof(p),typeof(re), typeof(tspan),typeof(args),typeof(kwargs)}( - model,p,re,tspan,args,kwargs) - end + model,p,re,tspan,args,kwargs) +end - function NeuralODE(model::FastChain,tspan,args...;p=initial_params(model),kwargs...) - re = nothing - new{typeof(model),typeof(p),typeof(re), - typeof(tspan),typeof(args),typeof(kwargs)}( - model,p,re,tspan,args,kwargs) - end +function NeuralODE(model::FastChain,tspan,args...;p=initial_params(model),kwargs...) + re = nothing + NeuralODE{typeof(model),typeof(p),typeof(re), + typeof(tspan),typeof(args),typeof(kwargs)}( + model,p,re,tspan,args,kwargs) +end - function NeuralODE(model::Lux.AbstractExplicitLayer,tspan,args...;p=nothing,kwargs...) - re = nothing - new{typeof(model),typeof(p),typeof(re), - typeof(tspan),typeof(args),typeof(kwargs)}( - model,p,re,tspan,args,kwargs) - end +function NeuralODE(model::Lux.AbstractExplicitLayer,tspan,args...;p=nothing,kwargs...) + re = nothing + NeuralODE{typeof(model),typeof(p),typeof(re), + typeof(tspan),typeof(args),typeof(kwargs)}( + model,p,re,tspan,args,kwargs) end -Lux.initialparameters(rng::AbstractRNG, n::NeuralODE) = Lux.initialparameters(rng, n.model) -Lux.initialstates(rng::AbstractRNG, n::NeuralODE) = Lux.initialstates(rng, n.model) +@functor NeuralODE (p,) function (n::NeuralODE)(x,p=n.p) dudt_(u,p,t) = n.re(p)(u) @@ -104,16 +103,16 @@ end Constructs a neural stochastic differential equation (neural SDE) with diagonal noise. ```julia -NeuralDSDE(model1,model2,tspan,alg=nothing,args...; +NeuralDSDE(drift,diffusion,tspan,alg=nothing,args...; sensealg=TrackerAdjoint(),kwargs...) -NeuralDSDE(model1::FastChain,model2::FastChain,tspan,alg=nothing,args...; +NeuralDSDE(drift::FastChain,diffusion::FastChain,tspan,alg=nothing,args...; sensealg=TrackerAdjoint(),kwargs...) ``` Arguments: -- `model1`: A Chain or FastChain neural network that defines the drift function. -- `model2`: A Chain or FastChain neural network that defines the diffusion function. +- `drift`: A Chain or FastChain neural network that defines the drift function. +- `diffusion`: A Chain or FastChain neural network that defines the diffusion function. Should output a vector of the same size as the input. - `tspan`: The timespan to be solved on. - `alg`: The algorithm used to solve the ODE. Defaults to `nothing`, i.e. the @@ -125,48 +124,51 @@ Arguments: documentation for more details. """ -struct NeuralDSDE{M,P,RE,M2,RE2,T,A,K} <: NeuralDELayer +struct NeuralDSDE{M,P,RE,M2,RE2,T,A,K} <: NeuralSDELayer p::P len::Int - model1::M + drift::M re1::RE - model2::M2 + diffusion::M2 re2::RE2 tspan::T args::A kwargs::K - function NeuralDSDE(model1,model2,tspan,args...;p = nothing, kwargs...) - p1,re1 = Flux.destructure(model1) - p2,re2 = Flux.destructure(model2) - if p === nothing - p = [p1;p2] - end - new{typeof(model1),typeof(p),typeof(re1),typeof(model2),typeof(re2), - typeof(tspan),typeof(args),typeof(kwargs)}(p, - length(p1),model1,re1,model2,re2,tspan,args,kwargs) - end +end - function NeuralDSDE(model1::FastChain,model2::FastChain,tspan,args...; - p1 = initial_params(model1), - p = [p1;initial_params(model2)], kwargs...) - re1 = nothing - re2 = nothing - new{typeof(model1),typeof(p),typeof(re1),typeof(model2),typeof(re2), - typeof(tspan),typeof(args),typeof(kwargs)}(p, - length(p1),model1,re1,model2,re2,tspan,args,kwargs) - end +function NeuralDSDE(drift,diffusion,tspan,args...;p = nothing, kwargs...) + p1,re1 = Flux.destructure(drift) + p2,re2 = Flux.destructure(diffusion) + if p === nothing + p = [p1;p2] + end + NeuralDSDE{typeof(drift),typeof(p),typeof(re1),typeof(diffusion),typeof(re2), + typeof(tspan),typeof(args),typeof(kwargs)}(p, + length(p1),drift,re1,diffusion,re2,tspan,args,kwargs) +end - function NeuralDSDE(model1::Lux.Chain,model2::Lux.Chain,tspan,args...; - p1 =nothing, - p = nothing, kwargs...) - re1 = nothing - re2 = nothing - new{typeof(model1),typeof(p),typeof(re1),typeof(model2),typeof(re2), - typeof(tspan),typeof(args),typeof(kwargs)}(p, - Int(1),model1,re1,model2,re2,tspan,args,kwargs) - end +function NeuralDSDE(drift::FastChain,diffusion::FastChain,tspan,args...; + p1 = initial_params(drift), + p = [p1;initial_params(diffusion)], kwargs...) + re1 = nothing + re2 = nothing + NeuralDSDE{typeof(drift),typeof(p),typeof(re1),typeof(diffusion),typeof(re2), + typeof(tspan),typeof(args),typeof(kwargs)}(p, + length(p1),drift,re1,diffusion,re2,tspan,args,kwargs) +end + +function NeuralDSDE(drift::Lux.Chain,diffusion::Lux.Chain,tspan,args...; + p1 =nothing, + p = nothing, kwargs...) + re1 = nothing + re2 = nothing + NeuralDSDE{typeof(drift),typeof(p),typeof(re1),typeof(diffusion),typeof(re2), + typeof(tspan),typeof(args),typeof(kwargs)}(p, + Int(1),drift,re1,diffusion,re2,tspan,args,kwargs) end +@functor NeuralDSDE (p,) + function (n::NeuralDSDE)(x,p=n.p) dudt_(u,p,t) = n.re1(p[1:n.len])(u) g(u,p,t) = n.re2(p[(n.len+1):end])(u) @@ -176,56 +178,44 @@ function (n::NeuralDSDE)(x,p=n.p) end function (n::NeuralDSDE{M})(x,p=n.p) where {M<:FastChain} - dudt_(u,p,t) = n.model1(u,p[1:n.len]) - g(u,p,t) = n.model2(u,p[(n.len+1):end]) + dudt_(u,p,t) = n.drift(u,p[1:n.len]) + g(u,p,t) = n.diffusion(u,p[(n.len+1):end]) ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad) prob = SDEProblem{false}(ff,g,x,n.tspan,p) solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...) end -function Lux.initialparameters(rng::AbstractRNG, n::NeuralDSDE) - p1 = Lux.initialparameters(rng, n.model1) - p2 = Lux.initialparameters(rng, n.model2) - return Lux.ComponentArray((p1 = p1, p2 = p2)) -end - -function Lux.initialstates(rng::AbstractRNG, n::NeuralDSDE) - st1 = Lux.initialstates(rng, n.model1) - st2 = Lux.initialstates(rng, n.model2) - return (state1 = st1, state2 = st2) -end - function (n::NeuralDSDE{M})(x,p,st) where {M<:Lux.AbstractExplicitLayer} - st1 = st.state1 - st2 = st.state2 + st1 = st.drift + st2 = st.diffusion function dudt_(u,p,t;st=st1) - u_, st = n.model1(u,p.p1,st) + u_, st = n.drift(u,p.drift,st) return u_ end function g(u,p,t;st=st2) - u_, st = n.model2(u,p.p2,st) + u_, st = n.diffusion(u,p.diffusion,st) return u_ end ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad) prob = SDEProblem{false}(ff,g,x,n.tspan,p) - return solve(prob,n.args...;sensealg=InterpolatingAdjoint(),n.kwargs...), (state1 = st1, state2 = st2) + return solve(prob,n.args...;sensealg=BacksolveAdjoint(),n.kwargs...), (drift = st1, diffusion = st2) end """ Constructs a neural stochastic differential equation (neural SDE). ```julia -NeuralSDE(model1,model2,tspan,nbrown,alg=nothing,args...; +NeuralSDE(drift,diffusion,tspan,nbrown,alg=nothing,args...; sensealg=TrackerAdjoint(),kwargs...) -NeuralSDE(model1::FastChain,model2::FastChain,tspan,nbrown,alg=nothing,args...; +NeuralSDE(drift::FastChain,diffusion::FastChain,tspan,nbrown,alg=nothing,args...; sensealg=TrackerAdjoint(),kwargs...) ``` Arguments: -- `model1`: A Chain or FastChain neural network that defines the drift function. -- `model2`: A Chain or FastChain neural network that defines the diffusion function. +- `drift`: A Chain or FastChain neural network that defines the drift function. +- `diffusion`: A Chain or FastChain neural network that defines the diffusion function. Should output a matrix that is nbrown x size(x,1). - `tspan`: The timespan to be solved on. - `nbrown`: The number of Brownian processes @@ -238,48 +228,51 @@ Arguments: documentation for more details. """ -struct NeuralSDE{P,M,RE,M2,RE2,T,A,K} <: NeuralDELayer +struct NeuralSDE{P,M,RE,M2,RE2,T,A,K} <: NeuralSDELayer p::P len::Int - model1::M + drift::M re1::RE - model2::M2 + diffusion::M2 re2::RE2 tspan::T nbrown::Int args::A kwargs::K - function NeuralSDE(model1,model2,tspan,nbrown,args...;p=nothing,kwargs...) - p1,re1 = Flux.destructure(model1) - p2,re2 = Flux.destructure(model2) - if p === nothing - p = [p1;p2] - end - new{typeof(p),typeof(model1),typeof(re1),typeof(model2),typeof(re2), - typeof(tspan),typeof(args),typeof(kwargs)}( - p,length(p1),model1,re1,model2,re2,tspan,nbrown,args,kwargs) - end +end - function NeuralSDE(model1::FastChain,model2::FastChain,tspan,nbrown,args...; - p1 = initial_params(model1), - p = [p1;initial_params(model2)], kwargs...) - re1 = nothing - re2 = nothing - new{typeof(p),typeof(model1),typeof(re1),typeof(model2),typeof(re2), - typeof(tspan),typeof(args),typeof(kwargs)}( - p,length(p1),model1,re1,model2,re2,tspan,nbrown,args,kwargs) - end +function NeuralSDE(drift,diffusion,tspan,nbrown,args...;p=nothing,kwargs...) + p1,re1 = Flux.destructure(drift) + p2,re2 = Flux.destructure(diffusion) + if p === nothing + p = [p1;p2] + end + NeuralSDE{typeof(p),typeof(drift),typeof(re1),typeof(diffusion),typeof(re2), + typeof(tspan),typeof(args),typeof(kwargs)}( + p,length(p1),drift,re1,diffusion,re2,tspan,nbrown,args,kwargs) +end - function NeuralSDE(model1::Lux.AbstractExplicitLayer, model2::Lux.AbstractExplicitLayer,tspan,nbrown,args...; - p1 = nothing, p = nothing, kwargs...) - re1 = nothing - re2 = nothing - new{typeof(p),typeof(model1),typeof(re1),typeof(model2),typeof(re2), - typeof(tspan),typeof(args),typeof(kwargs)}( - p,Int(1),model1,re1,model2,re2,tspan,nbrown,args,kwargs) - end +function NeuralSDE(drift::FastChain,diffusion::FastChain,tspan,nbrown,args...; + p1 = initial_params(drift), + p = [p1;initial_params(diffusion)], kwargs...) + re1 = nothing + re2 = nothing + NeuralSDE{typeof(p),typeof(drift),typeof(re1),typeof(diffusion),typeof(re2), + typeof(tspan),typeof(args),typeof(kwargs)}( + p,length(p1),drift,re1,diffusion,re2,tspan,nbrown,args,kwargs) end +function NeuralSDE(drift::Lux.AbstractExplicitLayer, diffusion::Lux.AbstractExplicitLayer,tspan,nbrown,args...; + p1 = nothing, p = nothing, kwargs...) + re1 = nothing + re2 = nothing + NeuralSDE{typeof(p),typeof(drift),typeof(re1),typeof(diffusion),typeof(re2), + typeof(tspan),typeof(args),typeof(kwargs)}( + p,Int(1),drift,re1,diffusion,re2,tspan,nbrown,args,kwargs) +end + +@functor NeuralSDE (p,) + function (n::NeuralSDE)(x,p=n.p) dudt_(u,p,t) = n.re1(p[1:n.len])(u) g(u,p,t) = n.re2(p[(n.len+1):end])(u) @@ -289,39 +282,28 @@ function (n::NeuralSDE)(x,p=n.p) end function (n::NeuralSDE{P,M})(x,p=n.p) where {P,M<:FastChain} - dudt_(u,p,t) = n.model1(u,p[1:n.len]) - g(u,p,t) = n.model2(u,p[(n.len+1):end]) + dudt_(u,p,t) = n.drift(u,p[1:n.len]) + g(u,p,t) = n.diffusion(u,p[(n.len+1):end]) ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad) prob = SDEProblem{false}(ff,g,x,n.tspan,p,noise_rate_prototype=zeros(Float32,length(x),n.nbrown)) solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...) end -function Lux.initialparameters(rng::AbstractRNG, n::NeuralSDE) - p1 = Lux.initialparameters(rng, n.model1) - p2 = Lux.initialparameters(rng, n.model2) - return Lux.ComponentArray((p1 = p1, p2 = p2)) -end -function Lux.initialstates(rng::AbstractRNG, n::NeuralSDE) - st1 = Lux.initialstates(rng, n.model1) - st2 = Lux.initialstates(rng, n.model2) - return (state1 = st1, state2 = st2) -end - function (n::NeuralSDE{P,M})(x,p,st) where {P,M<:Lux.AbstractExplicitLayer} - st1 = st.state1 - st2 = st.state2 + st1 = st.drift + st2 = st.diffusion function dudt_(u,p,t;st=st1) - u_, st = n.model1(u,p.p1,st) + u_, st = n.drift(u,p.drift,st) return u_ end function g(u,p,t;st=st2) - u_, st = n.model2(u,p.p2,st) + u_, st = n.diffusion(u,p.diffusion,st) return u_ end ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad) prob = SDEProblem{false}(ff,g,x,n.tspan,p,noise_rate_prototype=zeros(Float32,length(x),n.nbrown)) - solve(prob,n.args...;sensealg=InterpolatingAdjoint(),n.kwargs...), (state1 = st1, state2 = st2) + return solve(prob,n.args...;sensealg=BacksolveAdjoint(),n.kwargs...), (drift = st1, diffusion = st2) end """ @@ -353,6 +335,18 @@ Arguments: documentation for more details. """ +Unsupported_NeuralCDDE_pairing_message = """ + NeuralCDDE can only be instantiated with a Flux chain + """ + +struct Unsupported_pairing <:Exception + msg::Any +end + +function Base.showerror(io::IO, e::Unsupported_pairing) + println(io, e.msg) +end + struct NeuralCDDE{P,M,RE,H,L,T,A,K} <: NeuralDELayer p::P model::M @@ -362,25 +356,35 @@ struct NeuralCDDE{P,M,RE,H,L,T,A,K} <: NeuralDELayer tspan::T args::A kwargs::K +end - function NeuralCDDE(model,tspan,hist,lags,args...;p=nothing,kwargs...) - _p,re = Flux.destructure(model) - if p === nothing - p = _p - end - new{typeof(p),typeof(model),typeof(re),typeof(hist),typeof(lags), - typeof(tspan),typeof(args),typeof(kwargs)}(p,model, - re,hist,lags,tspan,args,kwargs) - end +function NeuralCDDE(model,tspan,hist,lags,args...;p=nothing,kwargs...) + _p,re = Flux.destructure(model) + if p === nothing + p = _p + end + NeuralCDDE{typeof(p),typeof(model),typeof(re),typeof(hist),typeof(lags), + typeof(tspan),typeof(args),typeof(kwargs)}(p,model, + re,hist,lags,tspan,args,kwargs) +end - function NeuralCDDE(model::FastChain,tspan,hist,lags,args...;p = initial_params(model),kwargs...) - re = nothing - new{typeof(p),typeof(model),typeof(re),typeof(hist),typeof(lags), - typeof(tspan),typeof(args),typeof(kwargs)}(p,model, - re,hist,lags,tspan,args,kwargs) - end +function NeuralCDDE(model::FastChain,tspan,hist,lags,args...;p = initial_params(model),kwargs...) + re = nothing + NeuralCDDE{typeof(p),typeof(model),typeof(re),typeof(hist),typeof(lags), + typeof(tspan),typeof(args),typeof(kwargs)}(p,model, + re,hist,lags,tspan,args,kwargs) end +function NeuralCDDE(model::Lux.AbstractExplicitLayer,tspan,hist,lags,args...;p = nothing,kwargs...) + throw(Unsupported_pairing(Unsupported_NeuralCDDE_pairing_message)) +# re = nothing +# new{typeof(p),typeof(model),typeof(re),typeof(hist),typeof(lags), +# typeof(tspan),typeof(args),typeof(kwargs)}(p,model, +# re,hist,lags,tspan,args,kwargs) +end + +@functor NeuralCDDE (p,) + function (n::NeuralCDDE)(x,p=n.p) function dudt_(u,h,p,t) _u = vcat(u,(h(p,t-lag) for lag in n.lags)...) @@ -438,22 +442,34 @@ struct NeuralDAE{P,M,M2,D,RE,T,DV,A,K} <: NeuralDELayer differential_vars::DV args::A kwargs::K +end - function NeuralDAE(model,constraints_model,tspan,du0=nothing,args...;p=nothing,differential_vars=nothing,kwargs...) - _p,re = Flux.destructure(model) +function NeuralDAE(model,constraints_model,tspan,du0=nothing,args...;p=nothing,differential_vars=nothing,kwargs...) + _p,re = Flux.destructure(model) - if p === nothing - p = _p - end + if p === nothing + p = _p + end - new{typeof(p),typeof(model),typeof(constraints_model), - typeof(du0),typeof(re),typeof(tspan), - typeof(differential_vars),typeof(args),typeof(kwargs)}( - model,constraints_model,p,du0,re,tspan,differential_vars, - args,kwargs) - end + NeuralDAE{typeof(p),typeof(model),typeof(constraints_model), + typeof(du0),typeof(re),typeof(tspan), + typeof(differential_vars),typeof(args),typeof(kwargs)}( + model,constraints_model,p,du0,re,tspan,differential_vars, + args,kwargs) end +function NeuralDAE(model::Lux.AbstractExplicitLayer,constraints_model,tspan,du0=nothing,args...;p=nothing,differential_vars=nothing,kwargs...) + re = nothing + + NeuralDAE{typeof(p),typeof(model),typeof(constraints_model), + typeof(du0),typeof(re),typeof(tspan), + typeof(differential_vars),typeof(args),typeof(kwargs)}( + model,constraints_model,p,du0,re,tspan,differential_vars, + args,kwargs) +end + +@functor NeuralDAE (p,) + function (n::NeuralDAE)(x,du0=n.du0,p=n.p) function f(du,u,p,t) nn_out = n.re(p)(vcat(u,du)) @@ -474,6 +490,27 @@ function (n::NeuralDAE)(x,du0=n.du0,p=n.p) solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...) end +function (n::NeuralDAE{P,M})(x,p,st) where {P,M<:Lux.AbstractExplicitLayer} + du0 = n.du0 + function f(du,u,p,t;st=st) + nn_out, st = n.model(vcat(u,du),p,st) + alg_out = n.constraints_model(u,p,t) + iter_nn = 0 + iter_consts = 0 + map(n.differential_vars) do isdiff + if isdiff + iter_nn += 1 + nn_out[iter_nn] + else + iter_consts += 1 + alg_out[iter_consts] + end + end + end + prob = DAEProblem{false}(f,du0,x,n.tspan,p,differential_vars=n.differential_vars) + return solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...), st +end + """ Constructs a physically-constrained continuous-time recurrant neural network, also known as a neural differential-algebraic equation (neural DAE), with a @@ -524,37 +561,38 @@ struct NeuralODEMM{M,M2,P,RE,T,MM,A,K} <: NeuralDELayer mass_matrix::MM args::A kwargs::K +end - function NeuralODEMM(model,constraints_model,tspan,mass_matrix,args...; - p = nothing, kwargs...) - _p,re = Flux.destructure(model) - - if p === nothing - p = _p - end - new{typeof(model),typeof(constraints_model),typeof(p),typeof(re), - typeof(tspan),typeof(mass_matrix),typeof(args),typeof(kwargs)}( - model,constraints_model,p,re,tspan,mass_matrix,args,kwargs) - end +function NeuralODEMM(model,constraints_model,tspan,mass_matrix,args...; + p = nothing, kwargs...) + _p,re = Flux.destructure(model) - function NeuralODEMM(model::FastChain,constraints_model,tspan,mass_matrix,args...; - p = initial_params(model), kwargs...) - re = nothing - new{typeof(model),typeof(constraints_model),typeof(p),typeof(re), - typeof(tspan),typeof(mass_matrix),typeof(args),typeof(kwargs)}( - model,constraints_model,p,re,tspan,mass_matrix,args,kwargs) - end + if p === nothing + p = _p + end + NeuralODEMM{typeof(model),typeof(constraints_model),typeof(p),typeof(re), + typeof(tspan),typeof(mass_matrix),typeof(args),typeof(kwargs)}( + model,constraints_model,p,re,tspan,mass_matrix,args,kwargs) +end - function NeuralODEMM(model::Lux.Chain,constraints_model,tspan,mass_matrix,args...; - p=nothing,kwargs...) - re = nothing - new{typeof(model),typeof(constraints_model),typeof(p),typeof(re), - typeof(tspan),typeof(mass_matrix),typeof(args),typeof(kwargs)}( - model,constraints_model,p,re,tspan,mass_matrix,args,kwargs) - end +function NeuralODEMM(model::FastChain,constraints_model,tspan,mass_matrix,args...; + p = initial_params(model), kwargs...) + re = nothing + NeuralODEMM{typeof(model),typeof(constraints_model),typeof(p),typeof(re), + typeof(tspan),typeof(mass_matrix),typeof(args),typeof(kwargs)}( + model,constraints_model,p,re,tspan,mass_matrix,args,kwargs) +end +function NeuralODEMM(model::Lux.AbstractExplicitLayer,constraints_model,tspan,mass_matrix,args...; + p=nothing,kwargs...) + re = nothing + NeuralODEMM{typeof(model),typeof(constraints_model),typeof(p),typeof(re), + typeof(tspan),typeof(mass_matrix),typeof(args),typeof(kwargs)}( + model,constraints_model,p,re,tspan,mass_matrix,args,kwargs) end +@functor NeuralODEMM (p,) + function (n::NeuralODEMM)(x,p=n.p) function f(u,p,t) nn_out = n.re(p)(u) @@ -582,7 +620,7 @@ function (n::NeuralODEMM{M})(x,p=n.p) where {M<:FastChain} end function (n::NeuralODEMM{M})(x,p,st) where {M<:Lux.AbstractExplicitLayer} - function f(u,p,t) + function f(u,p,t;st=st) nn_out,st = n.model(u,p,st) alg_out = n.constraints_model(u,p,t) return vcat(nn_out,alg_out) @@ -611,7 +649,8 @@ References: [1] Dupont, Emilien, Arnaud Doucet, and Yee Whye Teh. "Augmented neural ODEs." In Proceedings of the 33rd International Conference on Neural Information Processing Systems, pp. 3140-3150. 2019. """ -struct AugmentedNDELayer{DE<:NeuralDELayer} <: NeuralDELayer +abstract type AugmentedNDEType <: Lux.AbstractExplicitContainerLayer{(:nde,)} end +struct AugmentedNDELayer{DE<:Union{NeuralDELayer,NeuralSDELayer}} <: AugmentedNDEType nde::DE adim::Int end diff --git a/test/Project.toml b/test/Project.toml index 2ece6eac59..76d81da842 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,6 +14,7 @@ OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" GraphSignals = "3ebe565e-a4b5-49c6-aed2-300248c3a9c1" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd" diff --git a/test/augmented_nde.jl b/test/augmented_nde.jl index f0b0db973c..269d10e213 100644 --- a/test/augmented_nde.jl +++ b/test/augmented_nde.jl @@ -1,4 +1,4 @@ -using DiffEqFlux, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Test +using DiffEqFlux, Lux, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Test, Random x = Float32[2.; 0.] xs = Float32.(hcat([0.; 0.], [1.; 0.], [2.; 0.])) @@ -14,7 +14,7 @@ anode = AugmentedNDELayer( ) anode(x) -grads = Zygote.gradient(() -> sum(anode(x)), Flux.params(x, anode)) +grads = Zygote.gradient(() -> sum(anode(x)), Flux.params(x, anode.nde)) @test ! iszero(grads[x]) @test ! iszero(grads[anode.p]) @@ -24,7 +24,7 @@ andsde = AugmentedNDELayer( ) andsde(x) -grads = Zygote.gradient(() -> sum(andsde(x)), Flux.params(x, andsde)) +grads = Zygote.gradient(() -> sum(andsde(x)), Flux.params(x, andsde.nde)) @test ! iszero(grads[x]) @test ! iszero(grads[andsde.p]) @@ -34,7 +34,7 @@ asode = AugmentedNDELayer( ) asode(x) -grads = Zygote.gradient(() -> sum(asode(x)), Flux.params(x, asode)) +grads = Zygote.gradient(() -> sum(asode(x)), Flux.params(x, asode.nde)) @test ! iszero(grads[x]) @test ! iszero(grads[asode.p]) @@ -45,6 +45,50 @@ adode = AugmentedNDELayer( ) adode(x) -grads = Zygote.gradient(() -> sum(adode(x)), Flux.params(x, adode)) +grads = Zygote.gradient(() -> sum(adode(x)), Flux.params(x, adode.nde)) @test ! iszero(grads[x]) @test ! iszero(grads[adode.p]) + +## AugmentedNDELayer with Lux + +rng = Random.default_rng() + +dudt = Lux.Chain(Lux.Dense(4, 50, tanh), Lux.Dense(50, 4)) +dudt2 = Lux.Chain(Lux.Dense(4, 50, tanh), Lux.Dense(50, 4)) +dudt22 = Lux.Chain(Lux.Dense(4, 50, tanh), Lux.Dense(50, 16), (x) -> reshape(x, 4, 4)) + +# Augmented Neural ODE +anode = AugmentedNDELayer( + NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false), 2 +) +pd, st = Lux.setup(rng, anode) +pd = Lux.ComponentArray(pd) +anode(x,pd,st) + +grads = Zygote.gradient((x,p,st) -> sum(anode(x,p,st)[1]), x, pd, st) +@test ! iszero(grads[1]) +@test ! iszero(grads[2]) + +# Augmented Neural DSDE +andsde = AugmentedNDELayer( + NeuralDSDE(dudt, dudt2, (0.0f0, 0.1f0), EulerHeun(), saveat=0.0:0.01:0.1, dt=0.01), 2 +) +pd, st = Lux.setup(rng, andsde) +pd = Lux.ComponentArray(pd) +andsde(x,pd,st) + +grads = Zygote.gradient((x,p,st) -> sum(andsde(x,p,st)[1]), x, pd, st) +@test ! iszero(grads[1]) +@test ! iszero(grads[2]) + +# Augmented Neural SDE +asode = AugmentedNDELayer( + NeuralSDE(dudt, dudt22,(0.0f0, 0.1f0), 4, EulerHeun(), saveat=0.0:0.01:0.1, dt=0.01), 2 +) +pd, st = Lux.setup(rng, asode) +pd = Lux.ComponentArray(pd) +asode(x,pd,st) + +grads = Zygote.gradient((x,p,st) -> sum(asode(x,p,st)[1]), x, pd, st) +@test ! iszero(grads[1]) +@test ! iszero(grads[2]) diff --git a/test/neural_de_lux.jl b/test/neural_de_lux.jl new file mode 100644 index 0000000000..001fb4e69c --- /dev/null +++ b/test/neural_de_lux.jl @@ -0,0 +1,283 @@ +using DiffEqFlux, Lux, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Test, Random + +rng = Random.default_rng() + +mp = Float32[0.1,0.1] +x = Float32[2.; 0.] +xs = Float32.(hcat([0.; 0.], [1.; 0.], [2.; 0.])) +tspan = (0.0f0,1.0f0) +dudt = Flux.Chain(Flux.Dense(2,50,tanh),Flux.Dense(50,2)) +luxdudt = Lux.Chain(Lux.Dense(2,50,tanh),Lux.Dense(50,2)) + +NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false)(x) +NeuralODE(dudt,tspan,Tsit5(),saveat=0.1)(x) +NeuralODE(dudt,tspan,Tsit5(),saveat=0.1,sensealg=TrackerAdjoint())(x) + +NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false)(xs) +NeuralODE(dudt,tspan,Tsit5(),saveat=0.1)(xs) +NeuralODE(dudt,tspan,Tsit5(),saveat=0.1,sensealg=TrackerAdjoint())(xs) + +@info "Test some gradients" + +node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false) +grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) +@test ! iszero(grads[x]) +@test ! iszero(grads[node.p]) + +grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) +@test ! iszero(grads[xs]) +@test ! iszero(grads[node.p]) + +node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=TrackerAdjoint()) +grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) +@test ! iszero(grads[x]) +@test ! iszero(grads[node.p]) + +grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) +@test ! iszero(grads[xs]) +@test ! iszero(grads[node.p]) + +node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=BacksolveAdjoint()) +grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) +@test ! iszero(grads[x]) +@test ! iszero(grads[node.p]) + +grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) +@test ! iszero(grads[xs]) +@test ! iszero(grads[node.p]) + +node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=BacksolveAdjoint(autojacvec=ZygoteVJP())) +grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) +@test ! iszero(grads[x]) +@test ! iszero(grads[node.p]) + +grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) +@test ! iszero(grads[xs]) +@test ! iszero(grads[node.p]) + +## Lux + +@info "Test some Lux layers" + +node = NeuralODE(luxdudt,tspan,Tsit5(),save_everystep=false,save_start=false) +pd, st = Lux.setup(rng, node) +pd = Lux.ComponentArray(pd) +grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) +@test ! iszero(grads[1]) +@test ! iszero(grads[2]) + +grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),xs,pd,st) +@test ! iszero(grads[1]) +@test ! iszero(grads[2]) + +#test with low tolerance ode solver +node = NeuralODE(luxdudt, tspan, Tsit5(), abstol=1e-12, reltol=1e-12, save_everystep=false, save_start=false) +grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) +@test ! iszero(grads[1]) +@test ! iszero(grads[2]) + +# node = NeuralODE(luxdudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=TrackerAdjoint()) +# @test_broken grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) +# @test ! iszero(grads[1]) +# @test ! iszero(grads[2]) + +# @test_broken grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),xs,pd,st) +# @test ! iszero(grads[1]) +# @test ! iszero(grads[2]) + +node = NeuralODE(luxdudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=BacksolveAdjoint()) +grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) +@test ! iszero(grads[1]) +@test ! iszero(grads[2]) + +grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),xs,pd,st) +@test ! iszero(grads[1]) +@test ! iszero(grads[2]) + +@info "Test some adjoints" + +# Adjoint +@testset "adjoint mode" begin + node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false) + grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) + @test ! iszero(grads[x]) + @test ! iszero(grads[node.p]) + + grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) + @test ! iszero(grads[xs]) + @test ! iszero(grads[node.p]) + + node = NeuralODE(dudt,tspan,Tsit5(),saveat=0.0:0.1:1.0) + grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) + @test ! iszero(grads[x]) + @test ! iszero(grads[node.p]) + + grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) + @test ! iszero(grads[xs]) + @test ! iszero(grads[node.p]) + + node = NeuralODE(dudt,tspan,Tsit5(),saveat=0.1) + grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) + @test ! iszero(grads[x]) + @test ! iszero(grads[node.p]) + + grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) + @test ! iszero(grads[xs]) + @test ! iszero(grads[node.p]) + + node = NeuralODE(luxdudt,tspan,Tsit5(),save_everystep=false,save_start=false) + grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) + @test ! iszero(grads[1]) + @test ! iszero(grads[2]) + + grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),xs,pd,st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + + node = NeuralODE(luxdudt,tspan,Tsit5(),saveat=0.0:0.1:1.0) + grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) + @test ! iszero(grads[1]) + @test ! iszero(grads[2]) + + grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),xs,pd,st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) + + node = NeuralODE(luxdudt,tspan,Tsit5(),saveat=0.1) + grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) + @test ! iszero(grads[1]) + @test ! iszero(grads[2]) + + grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),xs,pd,st) + @test !iszero(grads[1]) + @test !iszero(grads[2]) +end + +@info "Test Tracker" + +# RD +@testset "Tracker mode" begin + node = NeuralODE(dudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=TrackerAdjoint()) + grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) + @test ! iszero(grads[x]) + @test ! iszero(grads[node.p]) + + grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) + @test ! iszero(grads[xs]) + @test ! iszero(grads[node.p]) + + node = NeuralODE(dudt,tspan,Tsit5(),saveat=0.0:0.1:1.0,sensealg=TrackerAdjoint()) + grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) + @test ! iszero(grads[x]) + @test ! iszero(grads[node.p]) + + grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) + @test ! iszero(grads[xs]) + @test ! iszero(grads[node.p]) + + node = NeuralODE(dudt,tspan,Tsit5(),saveat=0.1,sensealg=TrackerAdjoint()) + grads = Zygote.gradient(()->sum(node(x)),Flux.params(x,node)) + @test ! iszero(grads[x]) + @test ! iszero(grads[node.p]) + + grads = Zygote.gradient(()->sum(node(xs)),Flux.params(xs,node)) + @test ! iszero(grads[xs]) + @test ! iszero(grads[node.p]) + + # node = NeuralODE(luxdudt,tspan,Tsit5(),save_everystep=false,save_start=false,sensealg=TrackerAdjoint()) + # @test_broken grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) + # @test ! iszero(grads[1]) + # @test ! iszero(grads[2]) + + # @test_broken grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),xs,pd,st) + # @test_broken ! iszero(grads[1]) + # @test_broken ! iszero(grads[2]) + + # node = NeuralODE(luxdudt,tspan,Tsit5(),saveat=0.0:0.1:1.0,sensealg=TrackerAdjoint()) + # @test_broken grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) + # @test ! iszero(grads[1]) + # @test ! iszero(grads[2]) + + # @test_broken grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),xs,pd,st) + # @test_broken ! iszero(grads[1]) + # @test_broken ! iszero(grads[2]) + + # node = NeuralODE(luxdudt,tspan,Tsit5(),saveat=0.1,sensealg=TrackerAdjoint()) + # @test_broken grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),x,pd,st) + # @test ! iszero(grads[1]) + # @test ! iszero(grads[2]) + + # @test_broken grads = Zygote.gradient((x,p,st)->sum(node(x,p,st)[1]),xs,pd,st) + # @test_broken ! iszero(grads[1]) + # @test_broken ! iszero(grads[2]) +end + +@info "Test non-ODEs" + +dudt2 = Flux.Chain(Flux.Dense(2,50,tanh),Flux.Dense(50,2)) +luxdudt2 = Lux.Chain(Lux.Dense(2,50,tanh),Lux.Dense(50,2)) +NeuralDSDE(dudt,dudt2,(0.0f0,.1f0),SOSRI(),saveat=0.1)(x) +sode = NeuralDSDE(dudt,dudt2,(0.0f0,.1f0),SOSRI(),saveat=0.0:0.01:0.1) + +grads = Zygote.gradient(()->sum(sode(x)),Flux.params(x,sode)) +@test ! iszero(grads[x]) +@test ! iszero(grads[sode.p]) +@test ! iszero(grads[sode.p][end]) + +grads = Zygote.gradient(()->sum(sode(xs)),Flux.params(xs,sode)) +@test ! iszero(grads[xs]) +@test ! iszero(grads[sode.p]) +@test ! iszero(grads[sode.p][end]) + +sode = NeuralDSDE(luxdudt,luxdudt2,(0.0f0,.1f0),EulerHeun(),saveat=0.0:0.01:0.1,dt=0.1) +pd, st = Lux.setup(rng, sode) +pd = Lux.ComponentArray(pd) + +grads = Zygote.gradient((x,p,st)->sum(sode(x,p,st)[1]),x,pd,st) +@test ! iszero(grads[1]) +@test ! iszero(grads[2]) +@test ! iszero(grads[2][end]) + +grads = Zygote.gradient((x,p,st)->sum(sode(x,p,st)[1]),xs,pd,st) +@test ! iszero(grads[1]) +@test ! iszero(grads[2]) +@test ! iszero(grads[2][end]) + +dudt22 = Flux.Chain(Flux.Dense(2,50,tanh),Flux.Dense(50,4),x->reshape(x,2,2)) +luxdudt22 = Lux.Chain(Lux.Dense(2,50,tanh),Lux.Dense(50,4),x->reshape(x,2,2)) +NeuralSDE(dudt,dudt22,(0.0f0,.1f0),2,LambaEM(),saveat=0.01)(x) + +sode = NeuralSDE(dudt,dudt22,(0.0f0,0.1f0),2,LambaEM(),saveat=0.0:0.01:0.1) + +grads = Zygote.gradient(()->sum(sode(x)),Flux.params(x,sode)) +@test ! iszero(grads[x]) +@test ! iszero(grads[sode.p]) +@test ! iszero(grads[sode.p][end]) + +@test_broken grads = Zygote.gradient(()->sum(sode(xs)),Flux.params(xs,sode)) +@test_broken ! iszero(grads[xs]) +@test ! iszero(grads[sode.p]) +@test ! iszero(grads[sode.p][end]) + +sode = NeuralSDE(luxdudt,luxdudt22,(0.0f0,0.1f0),2,EulerHeun(),saveat=0.0:0.01:0.1,dt=0.01) +pd,st = Lux.setup(rng, sode) +pd = Lux.ComponentArray(pd) + +grads = Zygote.gradient((x,p,st)->sum(sode(x,p,st)[1]),x,pd,st) +@test ! iszero(grads[1]) +@test ! iszero(grads[2]) +@test ! iszero(grads[2][end]) + +@test_broken grads = Zygote.gradient((x,p,st)->sum(sode(x,p,st)),xs,pd,st) + +ddudt = Flux.Chain(Flux.Dense(6,50,tanh),Flux.Dense(50,2)) +NeuralCDDE(ddudt,(0.0f0,2.0f0),(p,t)->zero(x),(1f-1,2f-1),MethodOfSteps(Tsit5()),saveat=0.1)(x) +dode = NeuralCDDE(ddudt,(0.0f0,2.0f0),(p,t)->zero(x),(1f-1,2f-1),MethodOfSteps(Tsit5()),saveat=0.0:0.1:2.0) + +grads = Zygote.gradient(()->sum(dode(x)),Flux.params(x,dode)) +@test ! iszero(grads[x]) +@test ! iszero(grads[dode.p]) + +@test_broken grads = Zygote.gradient(()->sum(dode(xs)),Flux.params(xs,dode)) isa Tuple +@test_broken ! iszero(grads[xs]) +@test ! iszero(grads[dode.p]) diff --git a/test/runtests.jl b/test/runtests.jl index d8a9e51abe..659f5307a2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,6 +11,7 @@ if GROUP == "All" || GROUP == "DiffEqFlux" || GROUP == "Layers" end if GROUP == "All" || GROUP == "DiffEqFlux" || GROUP == "BasicNeuralDE" + @safetestset "Neural DE Tests with Lux" begin include("neural_de_lux.jl") end @safetestset "Neural DE Tests" begin include("neural_de.jl") end @safetestset "Augmented Neural DE Tests" begin include("augmented_nde.jl") end #@safetestset "Neural Graph DE" begin include("neural_gde.jl") end