Skip to content

Commit

Permalink
Remove Unwanted Dependencies and rrules
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 20, 2023
1 parent 1e414cd commit 66ecf1b
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 75 deletions.
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ jobs:
- ADJOINT
version:
- '1'
- '1.6'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
8 changes: 2 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -22,15 +21,13 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
ChainRulesCore = "1"
ConcreteStructs = "0.2"
DiffEqBase = "6.119"
LinearSolve = "1, 2"
Lux = "0.4, 0.5"
MLUtils = "0.2, 0.3, 0.4"
Lux = "0.5.7"
NonlinearSolve = "2"
OrdinaryDiffEq = "6"
Reexport = "1"
Expand All @@ -40,5 +37,4 @@ Setfield = "1"
SteadyStateDiffEq = "1.16"
TruncatedStacktraces = "1.1"
Zygote = "0.6.34"
ZygoteRules = "0.2"
julia = "1.6"
julia = "1.9"
12 changes: 2 additions & 10 deletions src/DeepEquilibriumNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import Reexport: @reexport

@reexport using Lux, NonlinearSolve, OrdinaryDiffEq, SciMLSensitivity

using DiffEqBase, LinearAlgebra, LinearSolve, MLUtils, Random, SciMLBase,
Setfield, Statistics, SteadyStateDiffEq, Zygote
using DiffEqBase, LinearAlgebra, LinearSolve, Random, SciMLBase, Statistics,
SteadyStateDiffEq, Zygote

import DiffEqBase: AbstractSteadyStateProblem
import SciMLBase: AbstractNonlinearSolution, AbstractSteadyStateAlgorithm
Expand Down Expand Up @@ -34,14 +34,6 @@ include("layers/evaluate.jl")

include("chainrules.jl")

# Start of Weird Patches
# Honestly no clue why this is needed! -- probably a whacky fix which shouldn't be ever
# needed.
using ZygoteRules
ZygoteRules.gradtuple1(::NamedTuple{()}) = (nothing, nothing, nothing, nothing, nothing)
ZygoteRules.gradtuple1(x::NamedTuple) = collect(values(x))
# End of Weird Patches

# Useful Shorthand
export DEQs

Expand Down
45 changes: 2 additions & 43 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -1,49 +1,8 @@
__backing::CRC.Tangent) = __backing(CRC.backing(Δ))
__backing::Tuple) = __backing.(Δ)
__backing::NamedTuple{F}) where {F} = NamedTuple{F}(__backing(values(Δ)))
__backing(Δ) = Δ

function CRC.rrule(::Type{<:DeepEquilibriumSolution}, z_star::T, u0::T, residual::T,
jacobian_loss::R, nfe::Int) where {T, R <: AbstractFloat}
function deep_equilibrium_solution_pullback(dsol)
function ∇deep_equilibrium_solution(dsol)
return (∂∅, dsol.z_star, dsol.u0, dsol.residual, dsol.jacobian_loss, dsol.nfe)
end
return (DeepEquilibriumSolution(z_star, u0, residual, jacobian_loss, nfe),
deep_equilibrium_solution_pullback)
end

function _safe_getfield(x::NamedTuple{fields}, field) where {fields}
return field fields ? getfield(x, field) : ∂∅
end

function CRC.rrule(::Type{T}, args...) where {T <: NamedTuple}
y = T(args...)
function nt_pullback(dy)
fields = fieldnames(T)
dy isa CRC.Tangent && (dy = CRC.backing(dy))
return (∂∅, _safe_getfield.((dy,), fields)...)
end
return y, nt_pullback
end

function CRC.rrule(::typeof(Setfield.set), obj, l::Setfield.PropertyLens{field},
val) where {field}
res = Setfield.set(obj, l, val)
function setfield_pullback(Δres)
Δres = __backing(Δres)
Δobj = Setfield.set(obj, l, ∂∅)
return (∂∅, Δobj, ∂∅, getfield(Δres, field))
end
return res, setfield_pullback
end

function CRC.rrule(::typeof(_construct_problem), deq::AbstractDEQs, dudt, z,
ps::NamedTuple{F}, x) where {F}
prob = _construct_problem(deq, dudt, z, ps, x)
function ∇_construct_problem(Δ)
Δ = __backing(Δ)
nograds = NamedTuple{F}(ntuple(i -> ∂∅, length(F)))
return (∂∅, ∂∅, ∂∅, Δ.u0, merge(nograds, (; model=Δ.p.ps)), Δ.p.x)
end
return prob, ∇_construct_problem
∇deep_equilibrium_solution)
end
6 changes: 2 additions & 4 deletions src/layers/deq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,10 @@ _jacobian_regularization(::SkipDeepEquilibriumNetwork{J}) where {J} = J
function _get_initial_condition(deq::SkipDeepEquilibriumNetwork{J, M, Nothing}, x, ps,
st) where {J, M}
z, st_ = deq.model((zero(x), x), ps.model, st.model)
@set! st.model = st_
return z, st
return z, merge(st, (; model=st_))
end

function _get_initial_condition(deq::SkipDeepEquilibriumNetwork, x, ps, st)
z, st_ = deq.shortcut(x, ps.shortcut, st.shortcut)
@set! st.shortcut = st_
return z, st
return z, merge(st, (; shortcut=st_))
end
14 changes: 7 additions & 7 deletions src/layers/evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ function (deq::AbstractDEQs)(x::AbstractArray{T}, ps, st::NamedTuple, ::Val{true
z_star, st_ = _evaluate_unrolled_model(deq, deq.model, z, x, ps.model, st.model,
st.fixed_depth)

@set! st.model = st_
@set! st.solution = build_solution(deq, z_star, z, x, ps, st, depth, T(0))
st__ = merge(st,
(; model=st_, solution=build_solution(deq, z_star, z, x, ps, st, depth, T(0))))

return _postprocess_output(deq, z_star), st
return _postprocess_output(deq, z_star), st__
end

function (deq::AbstractDEQs)(x::AbstractArray, ps, st::NamedTuple, ::Val{false})
Expand Down Expand Up @@ -60,9 +60,9 @@ function (deq::AbstractDEQs)(x::AbstractArray, ps, st::NamedTuple, ::Val{false})
jac_loss = T(0)
end

@set! st.model = model.st
@set! st.solution = build_solution(deq, z_star, z, x, ps, st, nfe, jac_loss)
@set! st.rng = rng
st_ = merge(st,
(; model=model.st, rng,
solution=build_solution(deq, z_star, z, x, ps, st, nfe, jac_loss)))

return _postprocess_output(deq, z_star), st
return _postprocess_output(deq, z_star), st_
end
6 changes: 2 additions & 4 deletions src/layers/mdeq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,13 @@ function _get_initial_condition(deq::MultiScaleSkipDeepEquilibriumNetwork{N, M,
x, ps, st) where {N, M}
u0, st = _get_zeros_initial_condition_mdeq(deq.scales, x, st)
z, st_ = deq.model((u0, x), ps.model, st.model)
@set! st.model = st_
return z, st
return z, merge(st, (; model=st_))
end

function _get_initial_condition(deq::MultiScaleSkipDeepEquilibriumNetwork, x, ps, st)
z0, st_ = deq.shortcut(x, ps.shortcut, st.shortcut)
z = mapreduce(flatten, vcat, z0)
@set! st.shortcut = st_
return z, st
return z, merge(st, (; shortcut=st_))
end

@concrete struct MultiScaleNeuralODE{N} <: AbstractDeepEquilibriumNetwork
Expand Down
4 changes: 4 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ DEQs.split_and_reshape(x, split_idxs, shapes)
push!(calls, :(return tuple($(varnames...))))
return Expr(:block, calls...)
end

@inline flatten(x::AbstractVector) = reshape(x, length(x), 1)
@inline flatten(x::AbstractMatrix) = x
@inline flatten(x::AbstractArray) = reshape(x, :, size(x, ndims(x)))

0 comments on commit 66ecf1b

Please sign in to comment.