Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial changes for neuraldsde and neuralsde #736

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions docs/src/examples/neural_sde.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,28 +76,24 @@ neural SDE with diagonal noise layer function:
drift_dudt = Lux.Chain(ActivationFunction(x -> x.^3),
Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2))
p1, st1 = Lux.setup(rng, drift_dudt)

diffusion_dudt = Lux.Chain(Lux.Dense(2, 2))
p2, st2 = Lux.setup(rng, diffusion_dudt)

p1 = Lux.ComponentArray(p1)
p2 = Lux.ComponentArray(p2)
#Component Arrays doesn't provide a name to the first ComponentVector, only subsequent ones get a name for dereferencing
p = [p1, p2]

neuralsde = NeuralDSDE(drift_dudt, diffusion_dudt, tspan, SOSRI(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The choice of the solver here is leading to https://github.com/SciML/DiffEqFlux.jl/runs/7005695604?check_suite_focus=true#step:5:22
We'll either need a different sensealg for NeuralDSDE(currently InterpolatingAdjoint) which is compatible with Lux or change the solver

saveat = tsteps, reltol = 1e-1, abstol = 1e-1)

p, st = Lux.setup(rng, neuralsde)
```

Let's see what that looks like:

```@example nsde
# Get the prediction using the correct initial condition
prediction0, st1, st2 = neuralsde(u0,p,st1,st2)
prediction0, st = neuralsde(u0,p,st)

drift_(u, p, t) = drift_dudt(u, p[1], st1)[1]
diffusion_(u, p, t) = diffusion_dudt(u, p[2], st2)[1]
drift_(u, p, t) = drift_dudt(u, p.p1, st.state1)[1]
diffusion_(u, p, t) = diffusion_dudt(u, p.p2, st.state2)[1]

prob_neuralsde = SDEProblem(drift_, diffusion_, u0,(0.0f0, 1.2f0), p)

Expand All @@ -119,7 +115,7 @@ the data values:

```@example nsde
function predict_neuralsde(p, u = u0)
return Array(neuralsde(u, p, st1, st2)[1])
return Array(neuralsde(u, p, st)[1])
end

function loss_neuralsde(p; n = 100)
Expand Down
71 changes: 63 additions & 8 deletions src/neural_de.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
abstract type NeuralDELayer <: Function end
abstract type NeuralDELayer <: Lux.AbstractExplicitLayer end
basic_tgrad(u,p,t) = zero(u)
Flux.trainable(m::NeuralDELayer) = (m.p,)

Expand Down Expand Up @@ -69,6 +69,9 @@ struct NeuralODE{M,P,RE,T,A,K} <: NeuralDELayer
end
end

Lux.initialparameters(rng::AbstractRNG, n::NeuralODE) = Lux.initialparameters(rng, n.model)
Lux.initialstates(rng::AbstractRNG, n::NeuralODE) = Lux.initialstates(rng, n.model)

function (n::NeuralODE)(x,p=n.p)
dudt_(u,p,t) = n.re(p)(u)
ff = ODEFunction{false}(dudt_,tgrad=basic_tgrad)
Expand All @@ -86,10 +89,11 @@ function (n::NeuralODE{M})(x,p=n.p) where {M<:FastChain}
end

function (n::NeuralODE{M})(x,p,st) where {M<:Lux.AbstractExplicitLayer}
function dudt(u,p,t)
function dudt(u,p,t;st=st)
u_, st = n.model(u,p,st)
return u_
end

ff = ODEFunction{false}(dudt,tgrad=basic_tgrad)
prob = ODEProblem{false}(ff,x,n.tspan,p)
sense = InterpolatingAdjoint(autojacvec=ZygoteVJP())
Expand Down Expand Up @@ -179,19 +183,33 @@ function (n::NeuralDSDE{M})(x,p=n.p) where {M<:FastChain}
solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...)
end

function (n::NeuralDSDE{M})(x,p,st1,st2) where {M<:Lux.AbstractExplicitLayer}
function dudt_(u,p,t)
u_, st1 = n.model1(u,p[1],st1)
function Lux.initialparameters(rng::AbstractRNG, n::NeuralDSDE)
p1 = Lux.initialparameters(rng, n.model1)
p2 = Lux.initialparameters(rng, n.model2)
return Lux.ComponentArray((p1 = p1, p2 = p2))
end

function Lux.initialstates(rng::AbstractRNG, n::NeuralDSDE)
st1 = Lux.initialstates(rng, n.model1)
st2 = Lux.initialstates(rng, n.model2)
return (state1 = st1, state2 = st2)
end

function (n::NeuralDSDE{M})(x,p,st) where {M<:Lux.AbstractExplicitLayer}
st1 = st.state1
st2 = st.state2
function dudt_(u,p,t;st=st1)
u_, st = n.model1(u,p.p1,st)
return u_
end
function g(u,p,t)
u_, st2 = n.model2(u,p[2],st2)
function g(u,p,t;st=st2)
u_, st = n.model2(u,p.p2,st)
return u_
end

ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad)
prob = SDEProblem{false}(ff,g,x,n.tspan,p)
return solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...), st1, st2
return solve(prob,n.args...;sensealg=ReverseDiffAdjoint(),n.kwargs...), (state1 = st1, state2 = st2)
end

"""
Expand Down Expand Up @@ -251,6 +269,15 @@ struct NeuralSDE{P,M,RE,M2,RE2,T,A,K} <: NeuralDELayer
typeof(tspan),typeof(args),typeof(kwargs)}(
p,length(p1),model1,re1,model2,re2,tspan,nbrown,args,kwargs)
end

function NeuralSDE(model1::Lux.AbstractExplicitLayer, model2::Lux.AbstractExplicitLayer,tspan,nbrown,args...;
p1 = nothing, p = nothing, kwargs...)
re1 = nothing
re2 = nothing
new{typeof(p),typeof(model1),typeof(re1),typeof(model2),typeof(re2),
typeof(tspan),typeof(args),typeof(kwargs)}(
p,Int(1),model1,re1,model2,re2,tspan,nbrown,args,kwargs)
end
end

function (n::NeuralSDE)(x,p=n.p)
Expand All @@ -269,6 +296,34 @@ function (n::NeuralSDE{P,M})(x,p=n.p) where {P,M<:FastChain}
solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...)
end

function Lux.initialparameters(rng::AbstractRNG, n::NeuralSDE)
p1 = Lux.initialparameters(rng, n.model1)
p2 = Lux.initialparameters(rng, n.model2)
return Lux.ComponentArray((p1 = p1, p2 = p2))
end
function Lux.initialstates(rng::AbstractRNG, n::NeuralSDE)
st1 = Lux.initialstates(rng, n.model1)
st2 = Lux.initialstates(rng, n.model2)
return (state1 = st1, state2 = st2)
end

function (n::NeuralSDE{P,M})(x,p,st) where {P,M<:Lux.AbstractExplicitLayer}
st1 = st.state1
st2 = st.state2
function dudt_(u,p,t;st=st1)
u_, st = n.model1(u,p.p1,st)
return u_
end
function g(u,p,t;st=st2)
u_, st = n.model2(u,p.p2,st)
return u_
end

ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad)
prob = SDEProblem{false}(ff,g,x,n.tspan,p,noise_rate_prototype=zeros(Float32,length(x),n.nbrown))
solve(prob,n.args...;sensealg=ReverseDiffAdjoint(),n.kwargs...), (state1 = st1, state2 = st2)
end

"""
Constructs a neural delay differential equation (neural DDE) with constant
delays.
Expand Down