Skip to content

Commit

Permalink
Merge pull request #722 from Abhishek-1Bhatt/remove_deprecations
Browse files Browse the repository at this point in the history
Removal of deprecations
  • Loading branch information
ChrisRackauckas authored Jun 14, 2022
2 parents 30e09f3 + ad35f1e commit fb75ffa
Show file tree
Hide file tree
Showing 15 changed files with 488 additions and 461 deletions.
18 changes: 1 addition & 17 deletions src/ffjord.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ After these steps one may use the NN model and the learned θ to predict the den
FFJORD(model, basedist=nothing, monte_carlo=false, tspan, args...; kwargs...)
```
Arguments:
- `model`: A Chain neural network that defines the dynamics of the model.
- `model`: A Flux.Chain or Lux.AbstractExplicitLayer that defines the dynamics of the model.
- `basedist`: Distribution of the base variable. Set to the unit normal by default.
- `tspan`: The timespan to be solved on.
- `kwargs`: Additional arguments splatted to the ODE solver. See the
Expand Down Expand Up @@ -61,7 +61,6 @@ struct FFJORD{M, P, ST, RE, D, T, A, K} <: CNFLayer where {M, P <: AbstractVecto
end

function FFJORD(model, tspan, args...;p=nothing, st=nothing, basedist=nothing, kwargs...)
#
_p, re = Flux.destructure(model)
if isnothing(p)
p = _p
Expand All @@ -74,24 +73,9 @@ struct FFJORD{M, P, ST, RE, D, T, A, K} <: CNFLayer where {M, P <: AbstractVecto
new{typeof(model),typeof(p),typeof(st),typeof(re),
typeof(basedist),typeof(tspan),typeof(args),typeof(kwargs)}(
model,p,st,re,basedist,tspan,args,kwargs)
# FFJORD(model, p, st, re, basedist, tspan, args, kwargs)
end
end

# function FFJORD(model, tspan, args...;
# p::P=nothing, st=nothing, basedist::D=nothing, kwargs...) where {P <: Union{AbstractVector{<: AbstractFloat}, Nothing}, RE <: Function, D <: Union{Distribution, Nothing}}
# _p, re = Flux.destructure(model)
# if isnothing(p)
# p = _p
# end
# if isnothing(basedist)
# size_input = size(model[1].weight, 2)
# type_input = eltype(model[1].weight)
# basedist = MvNormal(zeros(type_input, size_input), Diagonal(ones(type_input, size_input)))
# end
# FFJORD(model, p, st, re, basedist, tspan, args, kwargs)
# end

_norm_batched(x::AbstractMatrix) = sqrt.(sum(x.^2, dims=1))

function jacobian_fn(f, x::AbstractVector, args...)
Expand Down
28 changes: 25 additions & 3 deletions src/hnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ to define a training loop to circumvent this issue.
```julia
HamiltonianNN(model; p = nothing)
HamiltonianNN(model::Lux.AbstractExplicitLayer; p = initial_params(model))
HamiltonianNN(model::Lux.AbstractExplicitLayer; p = nothing)
```
Arguments:
1. `model`: A Chain or Lux.AbstractExplicitLayer neural network that returns the Hamiltonian of the
1. `model`: A Flux.Chain or Lux.AbstractExplicitLayer neural network that returns the Hamiltonian of the
system.
2. `p`: The initial parameters of the neural network.
Expand All @@ -45,17 +45,29 @@ struct HamiltonianNN{M, R, P}
end
return new{typeof(model), typeof(re), typeof(p)}(model, re, p)
end

function HamiltonianNN(model::Lux.Chain; p = nothing)
return new{typeof(model), typeof(re), typeof(p)}(model, re, p)
end
end

Flux.trainable(hnn::HamiltonianNN) = (hnn.p,)

function _hamiltonian_forward(re, p, x)
function _hamiltonian_forward(re, p, x, args...)
H = Flux.gradient(x -> sum(re(p)(x)), x)[1]
n = size(x, 1) ÷ 2
return cat(H[(n + 1):2n, :], -H[1:n, :], dims=1)
end

function _hamiltonian_forward(re::Lux.Chain, p, x, args...)
st = args[1]
H = Lux.gradient(x -> sum(Lux.apply(re,x,p,st)[1]), x)[1]
n = size(x, 1) ÷ 2
return cat(H[(n + 1):2n, :], -H[1:n, :], dims=1), st
end

(hnn::HamiltonianNN)(x, p = hnn.p) = _hamiltonian_forward(hnn.re, p, x)
(hnn::HamiltonianNN{M})(x, p, st) where {M<:Lux.AbstractExplicitLayer} = _hamiltonian_forward(hnn.model, p, x, st)


"""
Expand Down Expand Up @@ -105,3 +117,13 @@ function (nhde::NeuralHamiltonianDE)(x, p = nhde.p)
sense = InterpolatingAdjoint(autojacvec = false)
solve(prob, nhde.args...; sensealg = sense, nhde.kwargs...)
end

function (nhde::NeuralHamiltonianDE{M})(x, p, st) where {M<:Lux.AbstractExplicitLayer}
function neural_hamiltonian!(du, u, p, t)
du .= reshape(nhde.hnn(u, p, st)[1], size(du))
end
prob = ODEProblem(neural_hamiltonian!, x, nhde.tspan, p)
# NOTE: Nesting Zygote is an issue. So we can't use ZygoteVJP
sense = InterpolatingAdjoint(autojacvec = false)
solve(prob, nhde.args...; sensealg = sense, nhde.kwargs...)
end
95 changes: 82 additions & 13 deletions src/neural_de.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,12 @@ struct NeuralDSDE{M,P,RE,M2,RE2,T,A,K} <: NeuralDELayer
end

function NeuralDSDE(model1::Lux.Chain,model2::Lux.Chain,tspan,args...;
p1 =nothing,
p = nothing, kwargs...)
re1 = nothing
re2 = nothing
new{typeof(model1),typeof(p),typeof(re1),typeof(model2),typeof(re2),
typeof(tspan),typeof(args),typeof(kwargs)}(p,
length(p1),model1,re1,model2,re2,tspan,args,kwargs)
Int(1),model1,re1,model2,re2,tspan,args,kwargs)
end
end

Expand All @@ -143,13 +142,13 @@ function (n::NeuralDSDE)(x,p=n.p)
solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...)
end

function (n::NeuralDSDE{M})(x,p1,p2,st1,st2) where {M<:Lux.AbstractExplicitLayer}
function (n::NeuralDSDE{M})(x,p,st1,st2) where {M<:Lux.AbstractExplicitLayer}
function dudt_(u,p,t)
u_, st1 = n.model1(u,p1,st1)
u_, st1 = n.model1(u,p[1],st1)
return u_
end
function g(u,p,t)
u_, st2 = n.model2(u,p2,st2)
u_, st2 = n.model2(u,p[2],st2)
return u_
end

Expand Down Expand Up @@ -205,6 +204,14 @@ struct NeuralSDE{P,M,RE,M2,RE2,T,A,K} <: NeuralDELayer
typeof(tspan),typeof(args),typeof(kwargs)}(
p,length(p1),model1,re1,model2,re2,tspan,nbrown,args,kwargs)
end

function NeuralSDE(model1::Lux.Chain,model2::Lux.Chain,tspan,nbrown,args...;p=nothing,kwargs...)
re1 = nothing
re2 = nothing
new{typeof(p),typeof(model1),typeof(re1),typeof(model2),typeof(re2),
typeof(tspan),typeof(args),typeof(kwargs)}(
p,Int(1),model1,re1,model2,re2,tspan,nbrown,args,kwargs)
end
end

function (n::NeuralSDE)(x,p=n.p)
Expand All @@ -215,6 +222,20 @@ function (n::NeuralSDE)(x,p=n.p)
solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...)
end

function (n::NeuralSDE{M})(x,p,st1,st2) where {M<:Lux.AbstractExplicitLayer}
function dudt_(u,p,t)
u_, st1 = n.model1(u,p[1],st1)
return u_
end
function g(u,p,t)
u_, st2 = n.model2(u,p[2],st2)
return u_
end
ff = SDEFunction{false}(dudt_,g,tgrad=basic_tgrad)
prob = SDEProblem{false}(ff,g,x,n.tspan,p,noise_rate_prototype=zeros(Float32,length(x),n.nbrown))
return solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...), st1, st2
end

"""
Constructs a neural delay differential equation (neural DDE) with constant
delays.
Expand Down Expand Up @@ -244,9 +265,9 @@ Arguments:
documentation for more details.
"""
struct NeuralCDDE{P,M,RE,H,L,T,A,K} <: NeuralDELayer
p::P
struct NeuralCDDE{M,P,RE,H,L,T,A,K} <: NeuralDELayer
model::M
p::P
re::RE
hist::H
lags::L
Expand All @@ -259,10 +280,17 @@ struct NeuralCDDE{P,M,RE,H,L,T,A,K} <: NeuralDELayer
if p === nothing
p = _p
end
new{typeof(p),typeof(model),typeof(re),typeof(hist),typeof(lags),
typeof(tspan),typeof(args),typeof(kwargs)}(p,model,
new{typeof(model),typeof(p),typeof(re),typeof(hist),typeof(lags),
typeof(tspan),typeof(args),typeof(kwargs)}(model,p,
re,hist,lags,tspan,args,kwargs)
end

function NeuralCDDE(model::Lux.Chain,tspan,hist,lags,args...;p=nothing,kwargs...)
re = nothing
new{typeof(model),typeof(p),typeof(re),typeof(hist),typeof(lags),
typeof(tspan),typeof(args),typeof(kwargs)}(model,p,
re,hist,lags,tspan,args,kwargs)
end
end

function (n::NeuralCDDE)(x,p=n.p)
Expand All @@ -275,6 +303,17 @@ function (n::NeuralCDDE)(x,p=n.p)
solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...)
end

function (n::NeuralCDDE{M})(x,p,st) where {M<:Lux.AbstractExplicitLayer}
function dudt_(u,h,p,t)
_u = vcat(u,(h(p,t-lag) for lag in n.lags)...)
u_, st = n.model(_u,p,st)
return u_
end
ff = DDEFunction{false}(dudt_,tgrad=basic_tgrad)
prob = DDEProblem{false}(ff,x,n.hist,n.tspan,p,constant_lags = n.lags)
return solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...), st
end

"""
Constructs a neural differential-algebraic equation (neural DAE).
Expand Down Expand Up @@ -302,10 +341,10 @@ Arguments:
documentation for more details.
"""
struct NeuralDAE{P,M,M2,D,RE,T,DV,A,K} <: NeuralDELayer
struct NeuralDAE{M,P,M2,D,RE,T,DV,A,K} <: NeuralDELayer
model::M
constraints_model::M2
p::P
constraints_model::M2
du0::D
re::RE
tspan::T
Expand All @@ -320,12 +359,21 @@ struct NeuralDAE{P,M,M2,D,RE,T,DV,A,K} <: NeuralDELayer
p = _p
end

new{typeof(p),typeof(model),typeof(constraints_model),
new{typeof(model),typeof(p),typeof(constraints_model),
typeof(du0),typeof(re),typeof(tspan),
typeof(differential_vars),typeof(args),typeof(kwargs)}(
model,constraints_model,p,du0,re,tspan,differential_vars,
model,p,constraints_model,du0,re,tspan,differential_vars,
args,kwargs)
end

function NeuralDAE(model::Lux.Chain,constraints_model,tspan,du0=nothing,args...;p=nothing,differential_vars=nothing,kwargs...)

new{typeof(model),typeof(p),typeof(constraints_model),
typeof(du0),typeof(re),typeof(tspan),
typeof(differential_vars),typeof(args),typeof(kwargs)}(
model,p,constraints_model,du0,re,tspan,differential_vars,
args,kwargs)
end
end

function (n::NeuralDAE)(x,du0=n.du0,p=n.p)
Expand All @@ -348,6 +396,27 @@ function (n::NeuralDAE)(x,du0=n.du0,p=n.p)
solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...)
end

function (n::NeuralDAE{M})(x,p,st,du0=n.du0) where {M<:Lux.AbstractExplicitLayer}
function f(du,u,p,t)
_u = vcat(u,du)
nn_out, st = n.model(_u, p, st)
alg_out = n.constraints_model(u,p,t)
iter_nn = 0
iter_consts = 0
map(n.differential_vars) do isdiff
if isdiff
iter_nn += 1
nn_out[iter_nn]
else
iter_consts += 1
alg_out[iter_consts]
end
end
end
prob = DAEProblem{false}(f,du0,x,n.tspan,p,differential_vars=n.differential_vars)
return solve(prob,n.args...;sensealg=TrackerAdjoint(),n.kwargs...), st
end

"""
Constructs a physically-constrained continuous-time recurrant neural network,
also known as a neural differential-algebraic equation (neural DAE), with a
Expand Down
7 changes: 4 additions & 3 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@ Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
GraphSignals = "3ebe565e-a4b5-49c6-aed2-300248c3a9c1"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
62 changes: 36 additions & 26 deletions test/augmented_nde.jl
Original file line number Diff line number Diff line change
@@ -1,50 +1,60 @@
using DiffEqFlux, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Test
using Lux, DiffEqFlux, DelayDiffEq, OrdinaryDiffEq, StochasticDiffEq, Test, Random

rng = Random.default_rng()
x = Float32[2.; 0.]
xs = Float32.(hcat([0.; 0.], [1.; 0.], [2.; 0.]))
tspan = (0.0f0, 1.0f0)
fastdudt = FastChain(FastDense(4, 50, tanh), FastDense(50, 4))
fastdudt2 = FastChain(FastDense(4, 50, tanh), FastDense(50, 4))
fastdudt22 = FastChain(FastDense(4, 50, tanh), FastDense(50, 16), (x, p) -> reshape(x, 4, 4))
fastddudt = FastChain(FastDense(12, 50, tanh), FastDense(50, 4))
dudt = Lux.Chain(Lux.Dense(4, 50, tanh), Lux.Dense(50, 4))
dudt2 = Lux.Chain(Lux.Dense(4, 50, tanh), Lux.Dense(50, 4))
dudt22 = Lux.Chain(Lux.Dense(4, 50, tanh), Lux.Dense(50, 16), ActivationFunction(x -> reshape(x, 4, 4)))
ddudt = Lux.Chain(Lux.Dense(12, 50, tanh), Lux.Dense(50, 4))

# Augmented Neural ODE
anode = AugmentedNDELayer(
NeuralODE(fastdudt, tspan, Tsit5(), save_everystep=false, save_start=false), 2
NeuralODE(dudt, tspan, Tsit5(), save_everystep=false, save_start=false), 2
)
anode(x)

grads = Zygote.gradient(() -> sum(anode(x)), Flux.params(x, anode))
@test ! iszero(grads[x])
@test ! iszero(grads[anode.p])
p1, st1 = Lux.setup(rng, dudt)
p1 = Lux.ComponentArray(p1)
anode(x,p1,st1)
grads = Zygote.gradient((x, p, st) -> sum(anode(x, p, st)[1]), x, p1, st1)
@test ! iszero(grads[1])
@test ! iszero(grads[2])

# Augmented Neural DSDE
andsde = AugmentedNDELayer(
NeuralDSDE(fastdudt, fastdudt2, (0.0f0, 0.1f0), SOSRI(), saveat=0.0:0.01:0.1), 2
NeuralDSDE(dudt, dudt2, (0.0f0, 0.1f0), SOSRI(), saveat=0.0:0.01:0.1), 2
)
andsde(x)
p2, st2 = Lux.setup(rng, dudt2)
p2 = Lux.ComponentArray(p2)
p = [p1,p2]
andsde(x,p,st1,st2)

grads = Zygote.gradient(() -> sum(andsde(x)), Flux.params(x, andsde))
@test ! iszero(grads[x])
@test ! iszero(grads[andsde.p])
grads = Zygote.gradient((x,p,st1,st2) -> sum(andsde(x,p,st1,st2)[1]),x,p,st1,st2)
@test ! iszero(grads[1])
@test ! iszero(grads[2])

# Augmented Neural SDE
asode = AugmentedNDELayer(
NeuralSDE(fastdudt, fastdudt22,(0.0f0, 0.1f0), 4, LambaEM(), saveat=0.0:0.01:0.1), 2
NeuralSDE(dudt, dudt22,(0.0f0, 0.1f0), 4, LambaEM(), saveat=0.0:0.01:0.1), 2
)
asode(x)
p22, st22 = Lux.setup(rng,dudt22)
p22 = Lux.ComponentArray(p22)
p = [p1,p22]
asode(x,p,st1,st22)

grads = Zygote.gradient(() -> sum(asode(x)), Flux.params(x, asode))
@test ! iszero(grads[x])
@test ! iszero(grads[asode.p])
ograds = Zygote.gradient((x,p,st1,st22) -> sum(asode(x,p,st1,st22)[1]),x,p,st1,st22)
@test ! iszero(grads[1])
@test ! iszero(grads[1])

# Augmented Neural CDDE
adode = AugmentedNDELayer(
NeuralCDDE(fastddudt, (0.0f0, 2.0f0), (p, t) -> zeros(Float32, 4), (1f-1, 2f-1),
NeuralCDDE(ddudt, (0.0f0, 2.0f0), (p, t) -> zeros(Float32, 4), (1f-1, 2f-1),
MethodOfSteps(Tsit5()), saveat=0.0:0.1:2.0), 2
)
adode(x)
p, st = Lux.setup(rng, ddudt)
p = Lux.ComponentArray(p)
adode(x,p,st)

grads = Zygote.gradient(() -> sum(adode(x)), Flux.params(x, adode))
@test ! iszero(grads[x])
@test ! iszero(grads[adode.p])
grads = Zygote.gradient((x,p,st) -> sum(adode(x,p,st)[1]), x, p, st)
@test ! iszero(grads[1])
@test ! iszero(grads[2])
Loading

0 comments on commit fb75ffa

Please sign in to comment.