Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lux compatible layers #750

Merged
merged 27 commits into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
869ac50
Lux compatible layers
ba2tro Jul 11, 2022
ce5b42c
Merge branch 'SciML:master' into lux
ba2tro Jul 31, 2022
246fa84
Merge branch 'SciML:master' into lux
ba2tro Aug 13, 2022
6b63b73
Merge branch 'SciML:master' into lux
ba2tro Sep 28, 2022
040f5d3
Merge branch 'SciML:master' into lux
ba2tro Oct 11, 2022
b86953b
further changes
ba2tro Oct 11, 2022
1470167
tests
ba2tro Oct 12, 2022
95ddf8c
add lux as a test dep
ba2tro Oct 12, 2022
d540045
fix
ba2tro Oct 12, 2022
7c354de
Merge branch 'SciML:master' into lux
ba2tro Oct 23, 2022
61276bf
solver change
ba2tro Oct 23, 2022
d3b48a9
EulerHeun
ba2tro Oct 23, 2022
30abf27
Update neural_de_lux.jl
ba2tro Oct 23, 2022
c87be78
Merge branch 'SciML:master' into lux
ba2tro Jan 4, 2023
23f05fc
test updates
ba2tro Jan 4, 2023
ba9a948
fix
ba2tro Jan 4, 2023
f28906a
fix AugmentedNDELayer and test with Lux
ba2tro Jan 5, 2023
30deebc
Remove using Revise from there : )
ba2tro Jan 5, 2023
7e5955f
Make AugmentedNDELayer a Lux layer
ba2tro Jan 5, 2023
efb435f
Add functor definitions to restrict parameter list
ChrisRackauckas Jan 6, 2023
de0cfbb
Merge branch 'SciML:master' into lux
ba2tro Jan 9, 2023
7512d32
Consistent naming scheme
ba2tro Jan 9, 2023
d0daeb3
Use outer constructors for NeuralODE
ba2tro Jan 10, 2023
ffd4a8f
Update neural_de.jl
ba2tro Jan 10, 2023
ee9178c
Make all nde layers use outer constructors
ba2tro Jan 10, 2023
b9289d6
Update test/runtests.jl
ChrisRackauckas Jan 10, 2023
34ea536
Call NeuralSDE functions drift and diffusion
ba2tro Jan 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
Expand Down
1 change: 1 addition & 0 deletions src/DiffEqFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ using Requires
using Cassette
@reexport using Flux
@reexport using OptimizationOptimJL
using Functors

import ChainRulesCore

Expand Down
4 changes: 2 additions & 2 deletions src/hnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ Arguments:
documentation for more details.
"""
struct NeuralHamiltonianDE{M,P,RE,T,A,K} <: NeuralDELayer
hnn::HamiltonianNN{M,RE,P}
model::HamiltonianNN{M,RE,P}
p::P
tspan::T
args::A
Expand All @@ -112,7 +112,7 @@ end

function (nhde::NeuralHamiltonianDE)(x, p = nhde.p)
function neural_hamiltonian!(du, u, p, t)
du .= reshape(nhde.hnn(u, p), size(du))
du .= reshape(nhde.model(u, p), size(du))
end
prob = ODEProblem(neural_hamiltonian!, x, nhde.tspan, p)
# NOTE: Nesting Zygote is an issue. So we can't use ZygoteVJP
Expand Down
Loading