From 1903cfa8b9eca06c604bd1e5d18c6a3cc9268c26 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 18 Sep 2023 17:09:36 -0400 Subject: [PATCH] Allow debug mode --- src/layers/mdeq.jl | 15 ++++++++++++--- src/solve.jl | 3 ++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/layers/mdeq.jl b/src/layers/mdeq.jl index f1a6cc2e..312a6123 100644 --- a/src/layers/mdeq.jl +++ b/src/layers/mdeq.jl @@ -17,7 +17,7 @@ end u, x = z u_ = split_and_reshape(u, m.split_idxs, m.scales) u_res, st = m.model(($(inputs...),), ps, st) - return vcat(flatten.(u_res)...), st + return mapreduce(flatten, vcat, u_res), st end end @@ -80,6 +80,10 @@ See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref) kwargs end +function MultiScaleDeepEquilibriumNetwork(model::MultiScaleInputLayer{N}, args...) where {N} + return MultiScaleDeepEquilibriumNetwork{N}(model, args...) +end + @truncate_stacktrace MultiScaleDeepEquilibriumNetwork 1 3 function Lux.initialstates(rng::AbstractRNG, deq::MultiScaleDeepEquilibriumNetwork) @@ -104,7 +108,7 @@ function MultiScaleDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers::Ma split_idxs, scales) end - return MultiScaleDeepEquilibriumNetwork{N}(model, solver, sensealg, scales, split_idxs, + return MultiScaleDeepEquilibriumNetwork(model, solver, sensealg, scales, split_idxs, kwargs) end @@ -205,6 +209,11 @@ See also: [`DeepEquilibriumNetwork`](@ref), [`SkipDeepEquilibriumNetwork`](@ref) kwargs end +function MultiScaleSkipDeepEquilibriumNetwork(model::MultiScaleInputLayer{N}, + args...) where {N} + return MultiScaleSkipDeepEquilibriumNetwork{N}(model, args...) +end + @truncate_stacktrace MultiScaleSkipDeepEquilibriumNetwork 1 3 4 function Lux.initialstates(rng::AbstractRNG, deq::MultiScaleSkipDeepEquilibriumNetwork) @@ -231,7 +240,7 @@ function MultiScaleSkipDeepEquilibriumNetwork(main_layers::Tuple, mapping_layers split_idxs, scales) end - return MultiScaleSkipDeepEquilibriumNetwork{N}(model, shortcut, solver, sensealg, + return MultiScaleSkipDeepEquilibriumNetwork(model, shortcut, solver, sensealg, scales, split_idxs, kwargs) end diff --git a/src/solve.jl b/src/solve.jl index aea1c043..2a2825c3 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -80,8 +80,9 @@ end @truncate_stacktrace EquilibriumSolution 1 2 -function DiffEqBase.__solve(prob::AbstractSteadyStateProblem, alg::AbstractDEQSolver, +function SciMLBase.__solve(prob::AbstractSteadyStateProblem, alg::AbstractDEQSolver, args...; kwargs...) + # FIXME: Remove this handle sol = solve(prob, alg.alg, args...; kwargs...) u, du, retcode = sol.u, sol.resid, sol.retcode