Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

neural_sde example in Flux #738

Merged
merged 10 commits into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 13 additions & 19 deletions docs/src/examples/neural_sde.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ First let's build training data from the same example as the neural ODE:

```@example nsde
using Plots, Statistics
using Lux, Optimization, OptimizationFlux, DiffEqFlux, StochasticDiffEq, SciMLBase.EnsembleAnalysis, Random
using Flux, Optimization, OptimizationFlux, DiffEqFlux, StochasticDiffEq, SciMLBase.EnsembleAnalysis

rng = Random.default_rng()
u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0, 1.0f0)
Expand Down Expand Up @@ -73,18 +72,13 @@ Now we build a neural SDE. For simplicity we will use the `NeuralDSDE`
neural SDE with diagonal noise layer function:

```@example nsde
drift_dudt = Lux.Chain(ActivationFunction(x -> x.^3),
Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2))
p1, st1 = Lux.setup(rng, drift_dudt)
drift_dudt = Flux.Chain(x -> x.^3,
Flux.Dense(2, 50, tanh),
Flux.Dense(50, 2))
p1, re1 = Flux.destructure(drift_dudt)

diffusion_dudt = Lux.Chain(Lux.Dense(2, 2))
p2, st2 = Lux.setup(rng, diffusion_dudt)

p1 = Lux.ComponentArray(p1)
p2 = Lux.ComponentArray(p2)
#Component Arrays doesn't provide a name to the first ComponentVector, only subsequent ones get a name for dereferencing
p = [p1, p2]
diffusion_dudt = Flux.Chain(Flux.Dense(2, 2))
p2, re2 = Flux.destructure(diffusion_dudt)

neuralsde = NeuralDSDE(drift_dudt, diffusion_dudt, tspan, SOSRI(),
saveat = tsteps, reltol = 1e-1, abstol = 1e-1)
Expand All @@ -94,12 +88,12 @@ Let's see what that looks like:

```@example nsde
# Get the prediction using the correct initial condition
prediction0, st1, st2 = neuralsde(u0,p,st1,st2)
prediction0 = neuralsde(u0)

drift_(u, p, t) = drift_dudt(u, p[1], st1)[1]
diffusion_(u, p, t) = diffusion_dudt(u, p[2], st2)[1]
drift_(u, p, t) = re1(p[1:neuralsde.len])(u)
diffusion_(u, p, t) = re2(p[neuralsde.len+1:end])(u)

prob_neuralsde = SDEProblem(drift_, diffusion_, u0,(0.0f0, 1.2f0), p)
prob_neuralsde = SDEProblem(drift_, diffusion_, u0,(0.0f0, 1.2f0), neuralsde.p)

ensemble_nprob = EnsembleProblem(prob_neuralsde)
ensemble_nsol = solve(ensemble_nprob, SOSRI(), trajectories = 100,
Expand All @@ -119,7 +113,7 @@ the data values:

```@example nsde
function predict_neuralsde(p, u = u0)
return Array(neuralsde(u, p, st1, st2)[1])
return Array(neuralsde(u, p))
end

function loss_neuralsde(p; n = 100)
Expand Down Expand Up @@ -172,7 +166,7 @@ opt = ADAM(0.025)
# First round of training with n = 10
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss_neuralsde(x, n=10), adtype)
optprob = Optimization.OptimizationProblem(optf, p)
optprob = Optimization.OptimizationProblem(optf, neuralsde.p)
result1 = Optimization.solve(optprob, opt,
callback = callback, maxiters = 100)
```
Expand Down
71 changes: 63 additions & 8 deletions src/neural_de.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
abstract type NeuralDELayer <: Function end
abstract type NeuralDELayer <: Lux.AbstractExplicitLayer end
basic_tgrad(u,p,t) = zero(u)
Flux.trainable(m::NeuralDELayer) = (m.p,)

Expand Down Expand Up @@ -69,6 +69,9 @@ struct NeuralODE{M,P,RE,T,A,K} <: NeuralDELayer
end
end

Lux.initialparameters(rng::AbstractRNG, n::NeuralODE) = Lux.initialparameters(rng, n.model)
Lux.initialstates(rng::AbstractRNG, n::NeuralODE) = Lux.initialstates(rng, n.model)

function (n::NeuralODE)(x,p=n.p)
dudt_(u,p,t) = n.re(p)(u)
ff = ODEFunction{false}(dudt_,tgrad=basic_tgrad)
Expand All @@ -86,10 +89,11 @@ function (n::NeuralODE{M})(x,p=n.p) where {M<:FastChain}
end

function (n::NeuralODE{M})(x,p,st) where {M<:Lux.AbstractExplicitLayer}
function dudt(u,p,t)
function dudt(u,p,t;st=st)
u_, st = n.model(u,p,st)
return u_
end

ff = ODEFunction{false}(dudt,tgrad=basic_tgrad)
prob = ODEProblem{false}(ff,x,n.tspan,p)
sense = InterpolatingAdjoint(autojacvec=ZygoteVJP())
Expand Down Expand Up @@ -179,19 +183,33 @@ function (n::NeuralDSDE{M})(x,p=n.p) where {M<:FastChain}
solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...)
end

function (n::NeuralDSDE{M})(x,p,st1,st2) where {M<:Lux.AbstractExplicitLayer}
function dudt_(u,p,t)
u_, st1 = n.model1(u,p[1],st1)
function Lux.initialparameters(rng::AbstractRNG, n::NeuralDSDE)
p1 = Lux.initialparameters(rng, n.model1)
p2 = Lux.initialparameters(rng, n.model2)
return Lux.ComponentArray((p1 = p1, p2 = p2))
end

function Lux.initialstates(rng::AbstractRNG, n::NeuralDSDE)
st1 = Lux.initialstates(rng, n.model1)
st2 = Lux.initialstates(rng, n.model2)
return (state1 = st1, state2 = st2)
end

function (n::NeuralDSDE{M})(x,p,st) where {M<:Lux.AbstractExplicitLayer}
st1 = st.state1
st2 = st.state2
function dudt_(u,p,t;st=st1)
u_, st = n.model1(u,p.p1,st)
return u_
end
function g(u,p,t)
u_, st2 = n.model2(u,p[2],st2)
function g(u,p,t;st=st2)
u_, st = n.model2(u,p.p2,st)
return u_
end

ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad)
prob = SDEProblem{false}(ff,g,x,n.tspan,p)
return solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...), st1, st2
return solve(prob,n.args...;sensealg=InterpolatingAdjoint(),n.kwargs...), (state1 = st1, state2 = st2)
end

"""
Expand Down Expand Up @@ -251,6 +269,15 @@ struct NeuralSDE{P,M,RE,M2,RE2,T,A,K} <: NeuralDELayer
typeof(tspan),typeof(args),typeof(kwargs)}(
p,length(p1),model1,re1,model2,re2,tspan,nbrown,args,kwargs)
end

function NeuralSDE(model1::Lux.AbstractExplicitLayer, model2::Lux.AbstractExplicitLayer,tspan,nbrown,args...;
p1 = nothing, p = nothing, kwargs...)
re1 = nothing
re2 = nothing
new{typeof(p),typeof(model1),typeof(re1),typeof(model2),typeof(re2),
typeof(tspan),typeof(args),typeof(kwargs)}(
p,Int(1),model1,re1,model2,re2,tspan,nbrown,args,kwargs)
end
end

function (n::NeuralSDE)(x,p=n.p)
Expand All @@ -269,6 +296,34 @@ function (n::NeuralSDE{P,M})(x,p=n.p) where {P,M<:FastChain}
solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...)
end

function Lux.initialparameters(rng::AbstractRNG, n::NeuralSDE)
p1 = Lux.initialparameters(rng, n.model1)
p2 = Lux.initialparameters(rng, n.model2)
return Lux.ComponentArray((p1 = p1, p2 = p2))
end
function Lux.initialstates(rng::AbstractRNG, n::NeuralSDE)
st1 = Lux.initialstates(rng, n.model1)
st2 = Lux.initialstates(rng, n.model2)
return (state1 = st1, state2 = st2)
end

function (n::NeuralSDE{P,M})(x,p,st) where {P,M<:Lux.AbstractExplicitLayer}
st1 = st.state1
st2 = st.state2
function dudt_(u,p,t;st=st1)
u_, st = n.model1(u,p.p1,st)
return u_
end
function g(u,p,t;st=st2)
u_, st = n.model2(u,p.p2,st)
return u_
end

ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad)
prob = SDEProblem{false}(ff,g,x,n.tspan,p,noise_rate_prototype=zeros(Float32,length(x),n.nbrown))
solve(prob,n.args...;sensealg=InterpolatingAdjoint(),n.kwargs...), (state1 = st1, state2 = st2)
end

"""
Constructs a neural delay differential equation (neural DDE) with constant
delays.
Expand Down