Deep Equilibrium Models
(Bai et al., 2019) introduced Discrete Deep Equilibrium Models which drives a Discrete Dynamical System to its steady-state. (Pal et al., 2022) extends this framework to Continuous Dynamical Systems which converge to the steady-stable in a more stable fashion. For a detailed discussion refer to (Pal et al., 2022).
To construct a continuous DEQ, any ODE solver compatible with DifferentialEquations.jl
API can be passed as the solver. To construct a discrete DEQ, any root finding algorithm compatible with NonlinearSolve.jl
API can be passed as the solver.
Choosing a Solver
Root Finding Algorithms
Using Root Finding Algorithms give fast convergence when possible, but these methods also tend to be unstable. If you must use a root finding algorithm, we recommend using:
NewtonRaphson
orTrustRegion
for small modelsLimitedMemoryBroyden
for large Deep Learning applications (with well-conditioned Jacobians)NewtonRaphson(; linsolve = KrylovJL_GMRES())
for cases when Broyden methods fail
Note that Krylov Methods rely on efficient VJPs which are not available for all Lux models. If you think this is causing a performance regression, please open an issue in Lux.jl.
ODE Solvers
Using ODE Solvers give slower convergence, but are more stable. We generally recommend these methods over root finding algorithms. If you use implicit ODE solvers, remember to use Krylov linear solvers, see OrdinaryDiffEq.jl documentation for these. For most cases, we recommend:
VCAB3()
for high tolerance problemsTsit5()
for high tolerance problems whereVCAB3()
fails- In all other cases, follow the recommendation given in OrdinaryDiffEq.jl documentation
Sensitivity Analysis
- For
MultiScaleNeuralODE
, we default toGaussAdjoint(; autojacvec = ZygoteVJP())
. A faster alternative would beBacksolveAdjoint(; autojacvec = ZygoteVJP())
but there are stability concerns for using that. Follow the recommendation given in SciMLSensitivity.jl documentation. - For Steady State Problems, we default to
SteadyStateAdjoint(; linsolve = SimpleGMRES(; blocksize, linsolve_kwargs = (; maxiters=10, abstol=1e-3, reltol=1e-3)))
. This default will perform poorly on small models. It is recommended to passsensealg = SteadyStateAdjoint()
orsensealg = SteadyStateAdjoint(; linsolve = LUFactorization())
for small models.
Standard Models
DeepEquilibriumNetworks.DeepEquilibriumNetwork
— TypeDeepEquilibriumNetwork(model, solver; init = missing, jacobian_regularization=nothing,
+ problem_type::Type{pType}=SteadyStateProblem{false}, kwargs...)
Deep Equilibrium Network as proposed in (Bai et al., 2019) and (Pal et al., 2022).
Arguments
model
: Neural Network.solver
: Solver for the rootfinding problem. ODE Solvers and Nonlinear Solvers are both supported.
Keyword Arguments
init
: Initial Condition for the rootfinding problem. Ifnothing
, the initial condition is set tozero(x)
. Ifmissing
, the initial condition is set toWrappedFunction{:direct_call}(zero)
. In other cases the initial condition is set toinit(x, ps, st)
.jacobian_regularization
: Must be one ofnothing
,AutoForwardDiff
,AutoFiniteDiff
orAutoZygote
.problem_type
: Provides a way to simulate a Vanilla Neural ODE by setting theproblem_type
toODEProblem
. By default, the problem type is set toSteadyStateProblem
.kwargs
: Additional Parameters that are directly passed toSciMLBase.solve
.
Example
julia> model = DeepEquilibriumNetwork(
+ Parallel(+, Dense(2, 2; use_bias=false), Dense(2, 2; use_bias=false)),
+ VCABM3(); verbose=false)
+DeepEquilibriumNetwork(
+ model = Parallel(
+ +
+ layer_1 = Dense(2 => 2, bias=false), # 4 parameters
+ layer_2 = Dense(2 => 2, bias=false), # 4 parameters
+ ),
+ init = WrappedFunction(Base.Fix1{typeof(DeepEquilibriumNetworks.__zeros_init), Nothing}(DeepEquilibriumNetworks.__zeros_init, nothing)),
+) # Total: 8 parameters,
+ # plus 0 states.
+
+julia> rng = Xoshiro(0);
+
+julia> ps, st = Lux.setup(rng, model);
+
+julia> size(first(model(ones(Float32, 2, 1), ps, st)))
+(2, 1)
See also: SkipDeepEquilibriumNetwork
, MultiScaleDeepEquilibriumNetwork
, MultiScaleSkipDeepEquilibriumNetwork
.
DeepEquilibriumNetworks.SkipDeepEquilibriumNetwork
— FunctionSkipDeepEquilibriumNetwork(model, [init=nothing,] solver; kwargs...)
Skip Deep Equilibrium Network as proposed in (Pal et al., 2022). Alias which creates a DeepEquilibriumNetwork
with init
kwarg set to passed value.
MultiScale Models
DeepEquilibriumNetworks.MultiScaleDeepEquilibriumNetwork
— FunctionMultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
+ post_fuse_layer::Union{Nothing, Tuple}, solver,
+ scales::NTuple{N, NTuple{L, Int64}}; kwargs...)
Multi Scale Deep Equilibrium Network as proposed in (Bai et al., 2020).
Arguments
main_layers
: Tuple of Neural Networks. Each Neural Network is applied to the corresponding scale.mapping_layers
: Matrix of Neural Networks. Each Neural Network is applied to the corresponding scale and the corresponding layer.post_fuse_layer
: Neural Network applied to the fused output of the main layers.solver
: Solver for the rootfinding problem. ODE Solvers and Nonlinear Solvers are both supported.scales
: Scales of the Multi Scale DEQ. Each scale is a tuple of integers. The length of the tuple is the number of layers in the corresponding main layer.
For keyword arguments, see DeepEquilibriumNetwork
.
Example
julia> main_layers = (
+ Parallel(+, Dense(4 => 4, tanh; use_bias=false), Dense(4 => 4, tanh; use_bias=false)),
+ Dense(3 => 3, tanh), Dense(2 => 2, tanh), Dense(1 => 1, tanh));
+
+julia> mapping_layers = [NoOpLayer() Dense(4 => 3, tanh) Dense(4 => 2, tanh) Dense(4 => 1, tanh);
+ Dense(3 => 4, tanh) NoOpLayer() Dense(3 => 2, tanh) Dense(3 => 1, tanh);
+ Dense(2 => 4, tanh) Dense(2 => 3, tanh) NoOpLayer() Dense(2 => 1, tanh);
+ Dense(1 => 4, tanh) Dense(1 => 3, tanh) Dense(1 => 2, tanh) NoOpLayer()];
+
+julia> model = MultiScaleDeepEquilibriumNetwork(
+ main_layers, mapping_layers, nothing, NewtonRaphson(), ((4,), (3,), (2,), (1,)));
+
+julia> rng = Xoshiro(0);
+
+julia> ps, st = Lux.setup(rng, model);
+
+julia> x = rand(rng, Float32, 4, 12);
+
+julia> size.(first(model(x, ps, st)))
+((4, 12), (3, 12), (2, 12), (1, 12))
DeepEquilibriumNetworks.MultiScaleSkipDeepEquilibriumNetwork
— FunctionMultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Matrix,
+ post_fuse_layer::Union{Nothing, Tuple}, [init = nothing,] solver,
+ scales::NTuple{N, NTuple{L, Int64}}; kwargs...)
Skip Multi Scale Deep Equilibrium Network as proposed in (Pal et al., 2022). Alias which creates a MultiScaleDeepEquilibriumNetwork
with init
kwarg set to passed value.
If init
is not passed, it creates a MultiScale Regularized Deep Equilibrium Network.
DeepEquilibriumNetworks.MultiScaleNeuralODE
— FunctionMultiScaleNeuralODE(args...; kwargs...)
Same arguments as MultiScaleDeepEquilibriumNetwork
but sets problem_type
to ODEProblem{false}
.
Solution
DeepEquilibriumNetworks.DeepEquilibriumSolution
— TypeDeepEquilibriumSolution(z_star, u₀, residual, jacobian_loss, nfe, solution)
Stores the solution of a DeepEquilibriumNetwork and its variants.
Fields
z_star
: Steady-State or the value reached due to maxitersu0
: Initial Conditionresidual
: Difference of the $z^*$ and $f(z^*, x)$jacobian_loss
: Jacobian Stabilization Loss (see individual networks to see how it can be computed)nfe
: Number of Function Evaluationsoriginal
: Original Internal Solution