Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 19, 2023
1 parent 1903cfa commit 4d6c1ed
Show file tree
Hide file tree
Showing 12 changed files with 108 additions and 114 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DeepEquilibriumNetworks"
uuid = "6748aba7-0e9b-415e-a410-ae3cc0ecb334"
authors = ["Avik Pal <avikpal@mit.edu>"]
version = "1.3.0"
version = "1.4.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down Expand Up @@ -36,7 +36,7 @@ NonlinearSolve = "2"
OrdinaryDiffEq = "6"
Reexport = "1"
SciMLBase = "2"
SciMLSensitivity = "7"
SciMLSensitivity = "7.43"
Setfield = "1"
Static = "0.6, 0.7, 0.8"
SteadyStateDiffEq = "1.16"
Expand Down
39 changes: 20 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,31 @@ Pkg.add("DeepEquilibriumNetworks")
## Quickstart

```julia
import DeepEquilibriumNetworks as DEQs
import Lux
import Random
import Zygote
using DeepEquilibriumNetworks, Lux, Random, Zygote
# using LuxCUDA, LuxAMDGPU ## Install and Load for GPU Support

seed = 0
rng = Random.default_rng()
Random.seed!(rng, seed)

model = Lux.Chain(Lux.Dense(2, 2),
DEQs.DeepEquilibriumNetwork(Lux.Parallel(+,
Lux.Dense(2, 2; use_bias=false),
Lux.Dense(2, 2; use_bias=false)),
DEQs.ContinuousDEQSolver(;
abstol=0.1f0,
reltol=0.1f0,
abstol_termination=0.1f0,
reltol_termination=0.1f0)))

ps, st = gpu.(Lux.setup(rng, model))
x = gpu(rand(rng, Float32, 2, 1))
y = gpu(rand(rng, Float32, 2, 1))

gs = Zygote.gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1]
model = Chain(Dense(2 => 2),
DeepEquilibriumNetwork(Parallel(+,
Dense(2 => 2; use_bias=false),
Dense(2 => 2; use_bias=false)),
ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0,
reltol_termination=0.1f0);
save_everystep=true))

gdev = gpu_device()
cdev = cpu_device()

ps, st = Lux.setup(rng, model) |> gdev
x = rand(rng, Float32, 2, 1) |> gdev
y = rand(rng, Float32, 2, 1) |> gdev

model(x, ps, st)

gs = only(Zygote.gradient(p -> sum(abs2, first(first(model(x, p, st))) .- y), ps))
```

## Citation
Expand Down
40 changes: 20 additions & 20 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

DeepEquilibriumNetworks.jl is a framework built on top of
[DifferentialEquations.jl](https://docs.sciml.ai/DiffEqDocs/stable/) and
[Lux.jl](https://docs.sciml.ai/Lux/stable/), enabling the efficient training and inference for
[Lux.jl](https://lux.csail.mit.edu/), enabling the efficient training and inference for
Deep Equilibrium Networks (Infinitely Deep Neural Networks).

## Installation
Expand All @@ -17,30 +17,30 @@ Pkg.add("DeepEquilibriumNetworks")
## Quick-start

```julia
import DeepEquilibriumNetworks as DEQs
import Lux
import Random
import Zygote
using DeepEquilibriumNetworks, Lux, Random, Zygote
# using LuxCUDA, LuxAMDGPU ## Install and Load for GPU Support

seed = 0
rng = Random.default_rng()
Random.seed!(rng, seed)
model = Chain(Dense(2 => 2),
DeepEquilibriumNetwork(Parallel(+,
Dense(2 => 2; use_bias=false),
Dense(2 => 2; use_bias=false)),
ContinuousDEQSolver(; abstol=0.1f0, reltol=0.1f0, abstol_termination=0.1f0,
reltol_termination=0.1f0);
save_everystep=true))

model = Lux.Chain(Lux.Dense(2, 2),
DEQs.DeepEquilibriumNetwork(Lux.Parallel(+,
Lux.Dense(2, 2; use_bias=false),
Lux.Dense(2, 2; use_bias=false)),
DEQs.ContinuousDEQSolver(;
abstol=0.1f0,
reltol=0.1f0,
abstol_termination=0.1f0,
reltol_termination=0.1f0)))

ps, st = gpu.(Lux.setup(rng, model))
x = gpu(rand(rng, Float32, 2, 1))
y = gpu(rand(rng, Float32, 2, 1))

gs = Zygote.gradient(p -> sum(abs2, model(x, p, st)[1][1] .- y), ps)[1]
gdev = gpu_device()
cdev = cpu_device()

ps, st = Lux.setup(rng, model) |> gdev
x = rand(rng, Float32, 2, 1) |> gdev
y = rand(rng, Float32, 2, 1) |> gdev

model(x, ps, st)

gs = only(Zygote.gradient(p -> sum(abs2, first(first(model(x, p, st))) .- y), ps))
```

## Citation
Expand Down
8 changes: 8 additions & 0 deletions src/DeepEquilibriumNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import ChainRulesCore as CRC
import ConcreteStructs: @concrete

const DEQs = DeepEquilibriumNetworks
const ∂∅ = CRC.NoTangent()

## FIXME: Uses of nothing was removed in Lux 0.5 with a deprecation. It was not updated
## here
Expand All @@ -33,6 +34,13 @@ include("layers/evaluate.jl")

include("chainrules.jl")

# Start of Weird Patches
# Honestly no clue why this is needed! -- probably a whacky fix which shouldn't be ever
# needed.
ZygoteRules.gradtuple1(::NamedTuple{()}) = (nothing, nothing, nothing, nothing, nothing)
ZygoteRules.gradtuple1(x::NamedTuple) = collect(values(x))

Check warning on line 41 in src/DeepEquilibriumNetworks.jl

View check run for this annotation

Codecov / codecov/patch

src/DeepEquilibriumNetworks.jl#L40-L41

Added lines #L40 - L41 were not covered by tests
# End of Weird Patches

# Useful Shorthand
export DEQs

Expand Down
32 changes: 17 additions & 15 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
__backing::CRC.Tangent) = __backing(CRC.backing(Δ))
__backing::Tuple) = __backing.(Δ)
__backing::NamedTuple{F}) where {F} = NamedTuple{F}(__backing(values(Δ)))
__backing(Δ) = Δ

Check warning on line 4 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L1-L4

Added lines #L1 - L4 were not covered by tests

function CRC.rrule(::Type{<:DeepEquilibriumSolution}, z_star::T, u0::T, residual::T,

Check warning on line 6 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L6

Added line #L6 was not covered by tests
jacobian_loss::R, nfe::Int) where {T, R <: AbstractFloat}
function deep_equilibrium_solution_pullback(dsol)
return (CRC.NoTangent(), dsol.z_star, dsol.u0, dsol.residual, dsol.jacobian_loss,
dsol.nfe)
return (∂∅, dsol.z_star, dsol.u0, dsol.residual, dsol.jacobian_loss, dsol.nfe)

Check warning on line 9 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L9

Added line #L9 was not covered by tests
end
return (DeepEquilibriumSolution(z_star, u0, residual, jacobian_loss, nfe),
deep_equilibrium_solution_pullback)
end

function _safe_getfield(x::NamedTuple{fields}, field) where {fields}
return field fields ? getfield(x, field) : CRC.NoTangent()
return field fields ? getfield(x, field) : ∂∅

Check warning on line 16 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L16

Added line #L16 was not covered by tests
end

function CRC.rrule(::Type{T}, args...) where {T <: NamedTuple}
y = T(args...)
function nt_pullback(dy)
fields = fieldnames(T)
if dy isa CRC.Tangent
dy = CRC.backing(dy)
end
return (CRC.NoTangent(), _safe_getfield.((dy,), fields)...)
dy isa CRC.Tangent && (dy = CRC.backing(dy))
return (∂∅, _safe_getfield.((dy,), fields)...)

Check warning on line 24 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L23-L24

Added lines #L23 - L24 were not covered by tests
end
return y, nt_pullback
end
Expand All @@ -28,20 +30,20 @@ function CRC.rrule(::typeof(Setfield.set), obj, l::Setfield.PropertyLens{field},
val) where {field}
res = Setfield.set(obj, l, val)
function setfield_pullback(Δres)
if Δres isa CRC.Tangent
Δres = CRC.backing(Δres)
end
Δobj = Setfield.set(obj, l, CRC.NoTangent())
return (CRC.NoTangent(), Δobj, CRC.NoTangent(), getfield(Δres, field))
Δres = __backing(Δres)
Δobj = Setfield.set(obj, l, ∂∅)
return (∂∅, Δobj, ∂∅, getfield(Δres, field))

Check warning on line 35 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L33-L35

Added lines #L33 - L35 were not covered by tests
end
return res, setfield_pullback
end

function CRC.rrule(::typeof(_construct_problem), deq::AbstractDEQs, dudt, z, ps, x)
function CRC.rrule(::typeof(_construct_problem), deq::AbstractDEQs, dudt, z,

Check warning on line 40 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L40

Added line #L40 was not covered by tests
ps::NamedTuple{F}, x) where {F}
prob = _construct_problem(deq, dudt, z, ps, x)
function ∇_construct_problem(Δ)
return (CRC.NoTangent(), CRC.NoTangent(), CRC.NoTangent(), Δ.u0,
(; model = Δ.p.ps), Δ.p.x)
Δ = __backing(Δ)
nograds = NamedTuple{F}(ntuple(i -> ∂∅, length(F)))
return (∂∅, ∂∅, ∂∅, Δ.u0, merge(nograds, (; model=Δ.p.ps)), Δ.p.x)

Check warning on line 46 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L42-L46

Added lines #L42 - L46 were not covered by tests
end
return prob, ∇_construct_problem

Check warning on line 48 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L48

Added line #L48 was not covered by tests
end
7 changes: 4 additions & 3 deletions src/layers/evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ end
@inline _postprocess_output(_, z_star) = z_star

@inline function _construct_problem(::AbstractDEQs, dudt, z, ps, x)
return SteadyStateProblem(ODEFunction{false}(dudt), z,
NamedTuple{(:ps, :x)}((ps.model, x)))
return SteadyStateProblem(ODEFunction{false}(dudt), z, (; ps=ps.model, x))

Check warning on line 17 in src/layers/evaluate.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/evaluate.jl#L16-L17

Added lines #L16 - L17 were not covered by tests
end

@inline _fix_solution_output(_, x) = x
Expand Down Expand Up @@ -48,7 +47,9 @@ function (deq::AbstractDEQs)(x::AbstractArray, ps, st::NamedTuple, ::Val{false})

prob = _construct_problem(deq, dudt, z, ps, x)

Check warning on line 48 in src/layers/evaluate.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/evaluate.jl#L48

Added line #L48 was not covered by tests
sol = solve(prob, deq.solver; deq.sensealg, deq.kwargs...)
z_star = sol.u
_z_star = sol.u

Check warning on line 50 in src/layers/evaluate.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/evaluate.jl#L50

Added line #L50 was not covered by tests
# Handle Neural ODEs
z_star = _z_star isa Vector{<:AbstractArray} ? last(_z_star) : _z_star

Check warning on line 52 in src/layers/evaluate.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/evaluate.jl#L52

Added line #L52 was not covered by tests

if _jacobian_regularization(deq)
rng = Lux.replicate(st.rng)
Expand Down
7 changes: 3 additions & 4 deletions src/layers/mdeq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,7 @@ end
"""
MultiScaleNeuralODE(main_layers::Tuple, mapping_layers::Matrix,
post_fuse_layer::Union{Nothing,Tuple}, solver, scales;
sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()),
kwargs...)
sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), kwargs...)
Multiscale Neural ODE with Input Injection.
Expand Down Expand Up @@ -334,7 +333,7 @@ See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref)
"""
function MultiScaleNeuralODE(main_layers::Tuple, mapping_layers::Matrix,

Check warning on line 334 in src/layers/mdeq.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/mdeq.jl#L334

Added line #L334 was not covered by tests
post_fuse_layer::Union{Nothing, Tuple}, solver, scales::NTuple{N, NTuple{L, Int64}};
sensealg=InterpolatingAdjoint(; autojacvec=ZygoteVJP()), kwargs...) where {N, L}
sensealg=GaussAdjoint(; autojacvec=ZygoteVJP()), kwargs...) where {N, L}
l1 = Parallel(nothing, main_layers...)
l2 = BranchLayer(Parallel.(+, map(x -> tuple(x...), eachrow(mapping_layers))...)...)

Expand All @@ -357,7 +356,7 @@ function _get_initial_condition(deq::MultiScaleNeuralODE, x, ps, st)
end

@inline function _construct_problem(::MultiScaleNeuralODE, dudt, z, ps, x)
return ODEProblem(ODEFunction{false}(dudt), z, (0.0f0, 1.0f0), ps.model)
return ODEProblem(ODEFunction{false}(dudt), z, (0.0f0, 1.0f0), (; ps=ps.model, x))

Check warning on line 359 in src/layers/mdeq.jl

View check run for this annotation

Codecov / codecov/patch

src/layers/mdeq.jl#L358-L359

Added lines #L358 - L359 were not covered by tests
end

@inline _fix_solution_output(::MultiScaleNeuralODE, x) = x[end]
Expand Down
Loading

0 comments on commit 4d6c1ed

Please sign in to comment.