Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lux does not handle sensitivity algorithms with ReverseDiff #609

Closed
IlyaOrson opened this issue Jun 10, 2022 · 23 comments
Closed

Lux does not handle sensitivity algorithms with ReverseDiff #609

IlyaOrson opened this issue Jun 10, 2022 · 23 comments
Labels

Comments

@IlyaOrson
Copy link

IlyaOrson commented Jun 10, 2022

I tried to update some sensitivity code from FastChains (#610) to Lux.
The inplace version with ReverseDiff fails while the out of place version with Zygote seems to work.

using Zygote, Lux, OrdinaryDiffEq, DiffEqSensitivity, ComponentArrays, Random

function system!(du, u, p, t, controller)

    α, β, γ, δ = 0.5f0, 1.0f0, 1.0f0, 1.0f0

    y1, y2 = u
    c1, c2 = controller(u, p)

    y1_prime = -(c1 + α * c1^2) * y1 + δ * c2
    y2_prime =* c1 - γ * c2) * y1

    # return [y1_prime, y2_prime]  # works
    @inbounds begin  # fails
        du[1] = y1_prime
        du[2] = y2_prime
    end
end

function loss(params, prob, tsteps)
    sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP())  # fails
    # sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP())  # works
    sol = solve(prob, Tsit5(); p=params, saveat=tsteps, sensealg)  # integrate ODE system
    return -Array(sol)[2, end]  # second variable, last value, maximize
end

u0 = [1.0f0, 0.0f0]
tspan = (0.0f0, 1.0f0)
tsteps = 0.0f0:0.01f0:1.0f0

controller = Chain(
    Dense(2, 12, tanh),
    Dense(12, 12, tanh),
    Dense(12, 2),
    x -> 5 * Lux.σ.(x),  # controllers ∈ (0, 5)
)

init_params, init_states = Lux.setup(Random.default_rng(), controller)
θ = ComponentArray(init_params)
lux_controller(x, params) = controller(x, params, init_states)[1]

dudt!(du, u, p, t) = system!(du, u, p, t, lux_controller)
prob = ODEProblem(dudt!, u0, tspan, θ)

loss(params) = loss(params, prob, tsteps)

Zygote.gradient(loss, θ)

Stacktrace

ERROR: MethodError: no method matching (::Chain{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), Tuple{Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, WrappedFunction{var"#7#8"}}}})(::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, ::ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}}, ::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), NTuple{4, NamedTuple{(), Tuple{}}}})
Closest candidates are:
  (::Chain)(::Any, ::Union{ComponentArray, NamedTuple}, ::NamedTuple) at C:\Users\ilyao\.julia\packages\Lux\r2DyF\src\layers\basic.jl:519
Stacktrace:
  [1] lux_controller(x::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, params::ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}})
    @ Main .\REPL[51]:1
  [2] system!(du::Vector{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}}}}, u::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, p::ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}}, t::ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}, controller::typeof(lux_controller))
    @ Main .\REPL[35]:6
  [3] dudt!
    @ .\REPL[52]:1 [inlined]
  [4] ODEFunction
    @ C:\Users\ilyao\.julia\packages\SciMLBase\dYFnI\src\scimlfunctions.jl:1595 [inlined]
  [5] (::DiffEqSensitivity.var"#106#121"{ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}})(u::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, p::ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}}, t::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}})
    @ DiffEqSensitivity C:\Users\ilyao\.julia\packages\DiffEqSensitivity\GjhZ8\src\adjoint_common.jl:151
  [6] ReverseDiff.GradientTape(f::Function, input::Tuple{Vector{Float32}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, Vector{Float32}}, cfg::ReverseDiff.GradientConfig{Tuple{ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}}, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}})
    @ ReverseDiff C:\Users\ilyao\.julia\packages\ReverseDiff\5MMPp\src\api\tape.jl:207
  [7] ReverseDiff.GradientTape(f::Function, input::Tuple{Vector{Float32}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, Vector{Float32}})
    @ ReverseDiff C:\Users\ilyao\.julia\packages\ReverseDiff\5MMPp\src\api\tape.jl:204
  [8] adjointdiffcache(g::DiffEqSensitivity.var"#df#240"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, sensealg::InterpolatingAdjoint{0, true, Val{:central}, ReverseDiffVJP{false}, Bool}, discrete::Bool, sol::ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float32}, Vector{Float32}, Vector{Float32}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, dg::Nothing, f::ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}; quad::Bool, noiseterm::Bool, needs_jac::Bool)
    @ DiffEqSensitivity C:\Users\ilyao\.julia\packages\DiffEqSensitivity\GjhZ8\src\adjoint_common.jl:149
  [9] DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction(g::Function, sensealg::InterpolatingAdjoint{0, true, Val{:central}, ReverseDiffVJP{false}, Bool}, discrete::Bool, sol::ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float32}, Vector{Float32}, Vector{Float32}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, dg::Nothing, f::Function, checkpoints::Vector{Float32}, tols::NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}, tstops::Nothing; noiseterm::Bool)
    @ DiffEqSensitivity C:\Users\ilyao\.julia\packages\DiffEqSensitivity\GjhZ8\src\interpolating_adjoint.jl:72
 [10] DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction(g::Function, sensealg::InterpolatingAdjoint{0, true, Val{:central}, ReverseDiffVJP{false}, Bool}, discrete::Bool, sol::ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float32}, Vector{Float32}, Vector{Float32}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, dg::Nothing, f::Function, checkpoints::Vector{Float32}, tols::NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}}, tstops::Nothing)
    @ DiffEqSensitivity C:\Users\ilyao\.julia\packages\DiffEqSensitivity\GjhZ8\src\interpolating_adjoint.jl:22
 [11] ODEAdjointProblem(sol::ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float32}, Vector{Float32}, Vector{Float32}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, sensealg::InterpolatingAdjoint{0, true, Val{:central}, ReverseDiffVJP{false}, Bool}, g::DiffEqSensitivity.var"#df#240"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, t::Vector{Float32}, dg::Nothing; checkpoints::Vector{Float32}, callback::Nothing, reltol::Float64, abstol::Float64, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqSensitivity C:\Users\ilyao\.julia\packages\DiffEqSensitivity\GjhZ8\src\interpolating_adjoint.jl:260
 [12] _adjoint_sensitivities(sol::ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float32}, Vector{Float32}, Vector{Float32}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, sensealg::InterpolatingAdjoint{0, true, Val{:central}, ReverseDiffVJP{false}, Bool}, alg::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, g::DiffEqSensitivity.var"#df#240"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, t::Vector{Float32}, dg::Nothing; abstol::Float64, reltol::Float64, checkpoints::Vector{Float32}, corfunc_analytical::Nothing, callback::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqSensitivity C:\Users\ilyao\.julia\packages\DiffEqSensitivity\GjhZ8\src\sensitivity_interface.jl:286
 [13] adjoint_sensitivities(::ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float32}, Vector{Float32}, Vector{Float32}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::Vararg{Any}; sensealg::InterpolatingAdjoint{0, true, Val{:central}, ReverseDiffVJP{false}, Bool}, kwargs::Base.Pairs{Symbol, Nothing, Tuple{Symbol}, NamedTuple{(:callback,), Tuple{Nothing}}})
    @ DiffEqSensitivity C:\Users\ilyao\.julia\packages\DiffEqSensitivity\GjhZ8\src\sensitivity_interface.jl:271
 [14] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#239"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, ReverseDiffVJP{false}, Bool}, Vector{Float32}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, Tuple{}, Colon, NamedTuple{(), Tuple{}}})(Δ::Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ DiffEqSensitivity C:\Users\ilyao\.julia\packages\DiffEqSensitivity\GjhZ8\src\concrete_solve.jl:253
 [15] ZBack
    @ C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\compiler\chainrules.jl:205 [inlined]
 [16] (::Zygote.var"#kw_zpullback#37"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#239"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, ReverseDiffVJP{false}, Bool}, Vector{Float32}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, Tuple{}, Colon, NamedTuple{(), Tuple{}}}})(dy::Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\compiler\chainrules.jl:231
 [17] SciML/DiffEqFlux.jl#208
    @ C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\lib\lib.jl:207 [inlined]
 [18] (::Zygote.var"#1750#back#210"{Zygote.var"#208#209"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, Zygote.var"#kw_zpullback#37"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#239"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, ReverseDiffVJP{false}, Bool}, Vector{Float32}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, Tuple{}, Colon, NamedTuple{(), Tuple{}}}}}})(Δ::Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote C:\Users\ilyao\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
 [19] Pullback
    @ C:\Users\ilyao\.julia\packages\DiffEqBase\R12pQ\src\solve.jl:710 [inlined]
 [20] (::typeof((#solve#29)))(Δ::Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [21] (::Zygote.var"#208#209"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, typeof((#solve#29))})(Δ::Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\lib\lib.jl:207
 [22] (::Zygote.var"#1750#back#210"{Zygote.var"#208#209"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, typeof((#solve#29))}})(Δ::Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote C:\Users\ilyao\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
 [23] Pullback
    @ C:\Users\ilyao\.julia\packages\DiffEqBase\R12pQ\src\solve.jl:703 [inlined]
 [24] (::typeof((solve##kw)))(Δ::Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [25] Pullback
    @ .\REPL[36]:4 [inlined]
 [26] (::typeof((loss)))(Δ::Float32)
    @ Zygote C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [27] Pullback
    @ .\REPL[54]:1 [inlined]
 [28] (::typeof((loss)))(Δ::Float32)
    @ Zygote C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [29] (::Zygote.var"#52#53"{typeof((loss))})(Δ::Float32)
    @ Zygote C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\compiler\interface.jl:41
 [30] gradient(f::Function, args::ComponentVector{Float32})
    @ Zygote C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\compiler\interface.jl:76
 [31] top-level scope
    @ REPL[56]:1
@avik-pal
Copy link
Member

The TrackedArray ReverseDiff constructs, wraps the component array. It should be the other way around.

@ChrisRackauckas
Copy link
Member

DiffEqFlux is not used in this example, transferring to the appropriate repo.

@avik-pal do you have an example of how that's done?

@ChrisRackauckas ChrisRackauckas transferred this issue from SciML/DiffEqFlux.jl Jun 11, 2022
@avik-pal
Copy link
Member

Actually, this is a Lux issue. I can remove type constraints on the parameters, and this would just work. (Though I am slightly anxious to allow users to passing absolutely everything)

@avik-pal
Copy link
Member

Test the new tag and see if the problem persists

@IlyaOrson
Copy link
Author

It is working now 👍

@IlyaOrson
Copy link
Author

If TrackerVJP is used another error is thrown.

julia> Zygote.gradient(loss, θ)
ERROR: type TrackedArray has no field layer_1
Stacktrace:
  [1] getproperty
    @ .\Base.jl:38 [inlined]
  [2] macro expansion
    @ C:\Users\ilyao\.julia\packages\Lux\x5I6q\src\layers\basic.jl:0 [inlined]
  [3] applychain(layers::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), Tuple{Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, WrappedFunction{var"#11#12"}}}, x::TrackedArray{…,Vector{Float32}}, ps::TrackedArray{…,ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}}, st::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), NTuple{4, NamedTuple{(), Tuple{}}}})
    @ Lux C:\Users\ilyao\.julia\packages\Lux\x5I6q\src\layers\basic.jl:509
  [4] (::Chain{NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), Tuple{Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Dense{true, typeof(NNlib.tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, WrappedFunction{var"#11#12"}}}})(x::TrackedArray{…,Vector{Float32}}, ps::TrackedArray{…,ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}}, st::NamedTuple{(:layer_1, :layer_2, :layer_3, :layer_4), NTuple{4, NamedTuple{(), Tuple{}}}})
    @ Lux C:\Users\ilyao\.julia\packages\Lux\x5I6q\src\layers\basic.jl:506
  [5] lux_controller(x::TrackedArray{…,Vector{Float32}}, params::TrackedArray{,ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}})
    @ Main .\REPL[37]:1
  [6] system!(du::Vector{Tracker.TrackedReal{Float32}}, u::TrackedArray{…,Vector{Float32}}, p::TrackedArray{,ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}}, t::Float32, controller::typeof(lux_controller))
    @ Main .\REPL[15]:6
  [7] dudt!
    @ .\REPL[24]:1 [inlined]
  [8] ODEFunction
    @ C:\Users\ilyao\.julia\packages\SciMLBase\dYFnI\src\scimlfunctions.jl:1595 [inlined]
  [9] #26
    @ C:\Users\ilyao\.julia\packages\DiffEqSensitivity\SjURy\src\derivative_wrappers.jl:319 [inlined]
 [10] #20
    @ C:\Users\ilyao\.julia\packages\Tracker\9xWLl\src\back.jl:148 [inlined]
 [11] forward(f::Tracker.var"#20#22"{DiffEqSensitivity.var"#26#30"{Float32, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, Tuple{TrackedArray{,Vector{Float32}}, TrackedArray{,ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}}}}, ps::Tracker.Params)
    @ Tracker C:\Users\ilyao\.julia\packages\Tracker\9xWLl\src\back.jl:135
 [12] forward(::Function, ::Vector{Float32}, ::ComponentVector{Float32})
    @ Tracker C:\Users\ilyao\.julia\packages\Tracker\9xWLl\src\back.jl:148
 [13] _vecjacobian!(dλ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, y::Vector{Float32}, λ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, p::ComponentVector{Float32}, t::Float32, S::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, Vector{Float32}, ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float32}, Vector{Float32}, Vector{Float32}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, Nothing, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, isautojacvec::TrackerVJP, dgrad::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, dy::Nothing, W::Nothing)
    @ DiffEqSensitivity C:\Users\ilyao\.julia\packages\DiffEqSensitivity\SjURy\src\derivative_wrappers.jl:317
 [14] #vecjacobian!#25
    @ C:\Users\ilyao\.julia\packages\DiffEqSensitivity\SjURy\src\derivative_wrappers.jl:224 [inlined]
 [15] (::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, Vector{Float32}, ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float32}, Vector{Float32}, Vector{Float32}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, Nothing, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}})(du::Vector{Float32}, u::Vector{Float32}, p::ComponentVector{Float32}, t::Float32)
    @ DiffEqSensitivity C:\Users\ilyao\.julia\packages\DiffEqSensitivity\SjURy\src\interpolating_adjoint.jl:116
 [16] ODEFunction
    @ C:\Users\ilyao\.julia\packages\SciMLBase\dYFnI\src\scimlfunctions.jl:1595 [inlined]
 [17] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, true, Vector{Float32}, Nothing, Float32, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, Float32, Float32, Float32, Float32, Vector{Vector{Float32}}, ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, Vector{Float32}, ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float32}, Vector{Float32}, Vector{Float32}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, Nothing, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, CallbackSet{Tuple{}, Tuple{DiscreteCallback{DiffEqCallbacks.var"#60#63"{Vector{Float32}}, DiffEqCallbacks.var"#61#64"{DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, DiffEqCallbacks.var"#62#65"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, Vector{Float32}, DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, typeof(SciMLBase.FINALIZE_DEFAULT)}}}, Tuple{Symbol}, NamedTuple{(:callback,), Tuple{CallbackSet{Tuple{}, Tuple{DiscreteCallback{DiffEqCallbacks.var"#60#63"{Vector{Float32}}, DiffEqCallbacks.var"#61#64"{DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, DiffEqCallbacks.var"#62#65"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, Vector{Float32}, DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, typeof(SciMLBase.FINALIZE_DEFAULT)}}}}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, Vector{Float32}, ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float32}, Vector{Float32}, Vector{Float32}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, Nothing, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float32}, Vector{Float32}, Vector{Float32}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, ODEFunction{true, DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, Vector{Float32}, ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float32}, Vector{Float32}, Vector{Float32}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, Nothing, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, OrdinaryDiffEq.Tsit5Cache{Vector{Float32}, Vector{Float32}, Vector{Float32}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.DEOptions{Float64, Float64, Float32, Float32, PIController{Rational{Int64}}, typeof(DiffEqBase.ODE_DEFAULT_NORM), typeof(LinearAlgebra.opnorm), Nothing, CallbackSet{Tuple{}, Tuple{DiscreteCallback{DiffEqCallbacks.var"#60#63"{Vector{Float32}}, DiffEqCallbacks.var"#61#64"{DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, DiffEqCallbacks.var"#62#65"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, Vector{Float32}, DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, typeof(SciMLBase.FINALIZE_DEFAULT)}}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, Nothing, Nothing, Int64, Vector{Float32}, Vector{Float32}, Tuple{}}, Vector{Float32}, Float32, Nothing, OrdinaryDiffEq.DefaultInit}, cache::OrdinaryDiffEq.Tsit5Cache{Vector{Float32}, Vector{Float32}, Vector{Float32}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False})
    @ OrdinaryDiffEq C:\Users\ilyao\.julia\packages\OrdinaryDiffEq\ZgJ9s\src\perform_step\low_order_rk_perform_step.jl:627
 [18] __init(prob::ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, Vector{Float32}, ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float32}, Vector{Float32}, Vector{Float32}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, Nothing, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, CallbackSet{Tuple{}, Tuple{DiscreteCallback{DiffEqCallbacks.var"#60#63"{Vector{Float32}}, DiffEqCallbacks.var"#61#64"{DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, DiffEqCallbacks.var"#62#65"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, Vector{Float32}, DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, typeof(SciMLBase.FINALIZE_DEFAULT)}}}, Tuple{Symbol}, NamedTuple{(:callback,), Tuple{CallbackSet{Tuple{}, Tuple{DiscreteCallback{DiffEqCallbacks.var"#60#63"{Vector{Float32}}, DiffEqCallbacks.var"#61#64"{DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, DiffEqCallbacks.var"#62#65"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, Vector{Float32}, DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, typeof(SciMLBase.FINALIZE_DEFAULT)}}}}}}, SciMLBase.StandardODEProblem}, alg::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{Val{true}}; saveat::Vector{Float32}, tstops::Vector{Float32}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::CallbackSet{Tuple{}, Tuple{DiscreteCallback{DiffEqCallbacks.var"#60#63"{Vector{Float32}}, DiffEqCallbacks.var"#61#64"{DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, DiffEqCallbacks.var"#62#65"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, Vector{Float32}, DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, typeof(SciMLBase.FINALIZE_DEFAULT)}}}, dense::Bool, calck::Bool, dt::Float32, dtmin::Nothing, dtmax::Float32, force_dtmin::Bool, adaptive::Bool, gamma::Rational{Int64}, abstol::Float64, reltol::Float64, qmin::Rational{Int64}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{Int64}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ OrdinaryDiffEq C:\Users\ilyao\.julia\packages\OrdinaryDiffEq\ZgJ9s\src\solve.jl:456
 [19] __solve(::ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, Vector{Float32}, ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float32}, Vector{Float32}, Vector{Float32}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, Nothing, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, CallbackSet{Tuple{}, Tuple{DiscreteCallback{DiffEqCallbacks.var"#60#63"{Vector{Float32}}, DiffEqCallbacks.var"#61#64"{DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, DiffEqCallbacks.var"#62#65"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, Vector{Float32}, DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, typeof(SciMLBase.FINALIZE_DEFAULT)}}}, Tuple{Symbol}, NamedTuple{(:callback,), Tuple{CallbackSet{Tuple{}, Tuple{DiscreteCallback{DiffEqCallbacks.var"#60#63"{Vector{Float32}}, DiffEqCallbacks.var"#61#64"{DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, DiffEqCallbacks.var"#62#65"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, Vector{Float32}, DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, typeof(SciMLBase.FINALIZE_DEFAULT)}}}}}}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}; kwargs::Base.Pairs{Symbol, Any, NTuple{7, Symbol}, NamedTuple{(:callback, :save_everystep, :save_start, :saveat, :tstops, :abstol, :reltol), Tuple{CallbackSet{Tuple{}, Tuple{DiscreteCallback{DiffEqCallbacks.var"#60#63"{Vector{Float32}}, DiffEqCallbacks.var"#61#64"{DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, DiffEqCallbacks.var"#62#65"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, Vector{Float32}, DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, typeof(SciMLBase.FINALIZE_DEFAULT)}}}, Bool, Bool, Vector{Float32}, Vector{Float32}, Float64, Float64}}})
    @ OrdinaryDiffEq C:\Users\ilyao\.julia\packages\OrdinaryDiffEq\ZgJ9s\src\solve.jl:4
 [20] #solve_call#28
    @ C:\Users\ilyao\.julia\packages\DiffEqBase\hZncn\src\solve.jl:428 [inlined]
 [21] solve_up(prob::ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, Vector{Float32}, ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float32}, Vector{Float32}, Vector{Float32}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, Nothing, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, CallbackSet{Tuple{}, Tuple{DiscreteCallback{DiffEqCallbacks.var"#60#63"{Vector{Float32}}, DiffEqCallbacks.var"#61#64"{DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, DiffEqCallbacks.var"#62#65"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, Vector{Float32}, DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, typeof(SciMLBase.FINALIZE_DEFAULT)}}}, Tuple{Symbol}, NamedTuple{(:callback,), Tuple{CallbackSet{Tuple{}, Tuple{DiscreteCallback{DiffEqCallbacks.var"#60#63"{Vector{Float32}}, DiffEqCallbacks.var"#61#64"{DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, DiffEqCallbacks.var"#62#65"{typeof(SciMLBase.INITIALIZE_DEFAULT), Bool, Vector{Float32}, DiffEqSensitivity.ReverseLossCallback{Vector{Float32}, Vector{Float32}, Vector{Float32}, Base.RefValue{Int64}, LinearAlgebra.UniformScaling{Bool}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, DiffEqSensitivity.AdjointDiffCache{Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Float32}, Nothing, Base.OneTo{Int64}, UnitRange{Int64}, LinearAlgebra.UniformScaling{Bool}}}}, typeof(SciMLBase.FINALIZE_DEFAULT)}}}}}}, SciMLBase.StandardODEProblem}, sensealg::Nothing, u0::Vector{Float32}, p::ComponentVector{Float32}, args::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}; kwargs::Base.Pairs{Symbol, Any, NTuple{6, Symbol}, NamedTuple{(:save_everystep, :save_start, :saveat, :tstops, :abstol, :reltol), Tuple{Bool, Bool, Vector{Float32}, Vector{Float32}, Float64, Float64}}})
    @ DiffEqBase C:\Users\ilyao\.julia\packages\DiffEqBase\hZncn\src\solve.jl:726
 [22] #solve#29
    @ C:\Users\ilyao\.julia\packages\DiffEqBase\hZncn\src\solve.jl:710 [inlined]
 [23] _adjoint_sensitivities(sol::ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float32}, Vector{Float32}, Vector{Float32}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, sensealg::InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, alg::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, g::DiffEqSensitivity.var"#df#236"{Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, Colon}, t::Vector{Float32}, dg::Nothing; abstol::Float64, reltol::Float64, checkpoints::Vector{Float32}, corfunc_analytical::Nothing, callback::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ DiffEqSensitivity C:\Users\ilyao\.julia\packages\DiffEqSensitivity\SjURy\src\sensitivity_interface.jl:305
 [24] adjoint_sensitivities(::ODESolution{Float32, 2, Vector{Vector{Float32}}, Nothing, Nothing, Vector{Float32}, Vector{Vector{Vector{Float32}}}, ODEProblem{Vector{Float32}, Tuple{Float32, Float32}, true, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, typeof(dudt!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{Vector{Float32}}, Vector{Float32}, Vector{Vector{Vector{Float32}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float32}, Vector{Float32}, Vector{Float32}, OrdinaryDiffEq.Tsit5ConstantCache{Float32, Float32}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.DEStats}, ::Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ::Vararg{Any}; sensealg::InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, kwargs::Base.Pairs{Symbol, Nothing, Tuple{Symbol}, NamedTuple{(:callback,), Tuple{Nothing}}})
    @ DiffEqSensitivity C:\Users\ilyao\.julia\packages\DiffEqSensitivity\SjURy\src\sensitivity_interface.jl:271
 [25] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#235"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, Vector{Float32}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, Tuple{}, Colon, NamedTuple{(), Tuple{}}})(Δ::Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ DiffEqSensitivity C:\Users\ilyao\.julia\packages\DiffEqSensitivity\SjURy\src\concrete_solve.jl:253
 [26] ZBack
    @ C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\compiler\chainrules.jl:205 [inlined]
 [27] (::Zygote.var"#kw_zpullback#37"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#235"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, Vector{Float32}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, Tuple{}, Colon, NamedTuple{(), Tuple{}}}})(dy::Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\compiler\chainrules.jl:231
 [28] #208
    @ C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\lib\lib.jl:207 [inlined]
 [29] (::Zygote.var"#1750#back#210"{Zygote.var"#208#209"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, Zygote.var"#kw_zpullback#37"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#235"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, TrackerVJP, Bool}, Vector{Float32}, ComponentVector{Float32, Vector{Float32}, Tuple{Axis{(layer_1 = ViewAxis(1:36, Axis(weight = ViewAxis(1:24, ShapedAxis((12, 2), NamedTuple())), bias = ViewAxis(25:36, ShapedAxis((12, 1), NamedTuple())))), layer_2 = ViewAxis(37:192, Axis(weight = ViewAxis(1:144, ShapedAxis((12, 12), NamedTuple())), bias = ViewAxis(145:156, ShapedAxis((12, 1), NamedTuple())))), layer_3 = ViewAxis(193:218, Axis(weight = ViewAxis(1:24, ShapedAxis((2, 12), NamedTuple())), bias = ViewAxis(25:26, ShapedAxis((2, 1), NamedTuple())))), layer_4 = 219:218)}}}, Tuple{}, Colon, NamedTuple{(), Tuple{}}}}}})(Δ::Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote C:\Users\ilyao\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
 [30] Pullback
    @ C:\Users\ilyao\.julia\packages\DiffEqBase\hZncn\src\solve.jl:710 [inlined]
 [31] (::typeof((#solve#29)))(Δ::Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [32] (::Zygote.var"#208#209"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, typeof((#solve#29))})(Δ::Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\lib\lib.jl:207
 [33] (::Zygote.var"#1750#back#210"{Zygote.var"#208#209"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, typeof((#solve#29))}})(Δ::Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote C:\Users\ilyao\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67
 [34] Pullback
    @ C:\Users\ilyao\.julia\packages\DiffEqBase\hZncn\src\solve.jl:703 [inlined]
 [35] (::typeof((solve##kw)))(Δ::Zygote.OneElement{Float32, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}})
    @ Zygote C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [36] Pullback
    @ .\REPL[49]:4 [inlined]
 [37] (::typeof((loss)))(Δ::Float32)
    @ Zygote C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [38] Pullback
    @ .\REPL[41]:1 [inlined]
 [39] (::typeof((loss)))(Δ::Float32)
    @ Zygote C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\compiler\interface2.jl:0
 [40] (::Zygote.var"#52#53"{typeof((loss))})(Δ::Float32)
    @ Zygote C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\compiler\interface.jl:41
 [41] gradient(f::Function, args::ComponentVector{Float32})
    @ Zygote C:\Users\ilyao\.julia\packages\Zygote\DkIUK\src\compiler\interface.jl:76
 [42] top-level scope
    @ REPL[50]:1

@IlyaOrson
Copy link
Author

Probably an unrelated issue but I noticed that the gradient returned by QuadratureAdjoint is a ComponentVector (as the parameters) while InterpolatingAdjoint returns a Tuple.

@ba2tro
Copy link
Contributor

ba2tro commented Jun 12, 2022

So, I had a similar problem running this code with Lux v0.4.4, its resolved with the new release:

using Lux, OrdinaryDiffEq, DiffEqSensitivity, Zygote, Statistics, Random, Test

rng = Random.default_rng()

function lotka!(du, u, p, t)
    α, β, γ, δ = p
    du[1] = α*u[1] - β*u[2]*u[1]
    du[2] = γ*u[1]*u[2]  - δ*u[2]
end

tspan = (0.0f0,3.0f0)
u0 = Float32[0.44249296,4.6280594]
p_ = Float32[1.3, 0.9, 0.8, 1.8]
prob = ODEProblem(lotka!, u0,tspan, p_)
solution = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 0.1)

X = Array(solution)
t = solution.t

x̄ = mean(X, dims = 2)
noise_magnitude = Float32(5e-2)
Xₙ = X .+ (noise_magnitude*x̄) .* randn(eltype(X), size(X))

rbf(x) = exp.(-(x.^2))

U = Lux.Chain(
    Lux.Dense(2,5,rbf),
    Lux.Dense(5,5,rbf),
    Lux.Dense(5,5,rbf),
    Lux.Dense(5,2)
)

p, st = Lux.setup(rng, U)
p = Lux.ComponentArray(p)

function ude_dynamics!(du,u, p, t, p_true)
    û = U(u, p, st)[1]

    du[1] = p_true[1]*u[1] + û[1]

    du[2] = -p_true[4]*u[2] + û[2]

end

nn_dynamics!(du,u,p,t) = ude_dynamics!(du,u,p,t,p_)

prob_nn = ODEProblem(nn_dynamics!,Xₙ[:, 1], tspan, p)

u0 = Xₙ[:,1]
_sol = solve(prob_nn, Vern7(), u0 = u0, p=p,
                saveat = t,
                abstol=1e-8, reltol=1e-8,
                sensealg = ForwardDiffSensitivity()
                )

ū0,adj = adjoint_sensitivities(_sol,Vern7(),((out,u,p,t,i) -> out .= -1),tspan,abstol=1e-8,reltol=1e-8)

du01,dp1 = Zygote.gradient((u0,p)->sum(solve(prob_nn,Vern7(),u0=u0,p=p,abstol=1e-8,reltol=1e-8,saveat=t,sensealg=QuadratureAdjoint())),u0,p)
du02,dp2 = Zygote.gradient((u0,p)->sum(solve(prob_nn,Vern7(),u0=u0,p=p,abstol=1e-8,reltol=1e-8,saveat=t,sensealg=InterpolatingAdjoint())),u0,p)

@avik-pal
Copy link
Member

Tracker needs some work (not on Lux side but more on ComponentArrays compatibility). MWE:

using Lux, ComponentArrays, Random, Tracker

c = Chain(Dense(3, 2), Dense(2, 1))

x = randn(Float32, 3, 1)
ps, st = Lux.setup(Random.default_rng(), c)

ps_c = ps |> Lux.ComponentArray

Tracker.param(xs::ComponentArray) = ComponentArray(TrackedArray(float.(getdata(xs))), getaxes(xs))

Tracker.gradient(ps -> sum(first(Lux.apply(c, x, ps, st))), ps_c)
ERROR: MethodError: Cannot `convert` an object of type Vector{Float32} to an object of type SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}
Closest candidates are:
  convert(::Type{T}, ::LinearAlgebra.Factorization) where T<:AbstractArray at /mnt/softwares/julia-nightly/share/julia/stdlib/v1.8/LinearAlgebra/src/factorization.jl:58
  convert(::Type{T}, ::T) where T<:AbstractArray at abstractarray.jl:16
  convert(::Type{T}, ::T) where T at Base.jl:61
  ...
Stacktrace:
  [1] setproperty!(x::Tracker.Tracked{SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}}, f::Symbol, v::Vector{Float32})
    @ Base ./Base.jl:39
  [2] back(x::Tracker.Tracked{SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}}, Δ::Vector{Float32}, once::Bool)
    @ Tracker /mnt/julia/packages/Tracker/9xWLl/src/back.jl:53
  [3] #13
    @ /mnt/julia/packages/Tracker/9xWLl/src/back.jl:38 [inlined]
  [4] #58
    @ ./tuple.jl:556 [inlined]
  [5] BottomRF
    @ ./reduce.jl:81 [inlined]
  [6] _foldl_impl(op::Base.BottomRF{Base.var"#58#59"{Tracker.var"#13#14"{Bool}}}, init::Nothing, itr::Base.Iterators.Zip{Tuple{Tuple{Tracker.Tracked{SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}}, Nothing}, Tuple{Vector{Float32}, Nothing}}})
    @ Base ./reduce.jl:58
  [7] foldl_impl
    @ ./reduce.jl:48 [inlined]
  [8] mapfoldl_impl
    @ ./reduce.jl:44 [inlined]
  [9] #mapfoldl#258
    @ ./reduce.jl:162 [inlined]
 [10] #foldl#259
    @ ./reduce.jl:185 [inlined]
 [11] foreach
    @ ./tuple.jl:556 [inlined]
 [12] back_(c::Tracker.Call{Tracker.var"#442#444"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, TrackedArray{,SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}}, Tuple{UnitRange{Int64}}}, Tuple{Tracker.Tracked{SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}}, Nothing}}, Δ::Vector{Float32}, once::Bool)
    @ Tracker /mnt/julia/packages/Tracker/9xWLl/src/back.jl:38
 [13] back(x::Tracker.Tracked{SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}}, Δ::Vector{Float32}, once::Bool)
    @ Tracker /mnt/julia/packages/Tracker/9xWLl/src/back.jl:58
--- the last 11 lines are repeated 2 more times ---
 [36] #13
    @ /mnt/julia/packages/Tracker/9xWLl/src/back.jl:38 [inlined]
 [37] #58
    @ ./tuple.jl:556 [inlined]
 [38] BottomRF
    @ ./reduce.jl:81 [inlined]
 [39] _foldl_impl(op::Base.BottomRF{Base.var"#58#59"{Tracker.var"#13#14"{Bool}}}, init::Nothing, itr::Base.Iterators.Zip{Tuple{Tuple{Tracker.Tracked{Matrix{Float32}}, Tracker.Tracked{Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}})
    @ Base ./reduce.jl:58
 [40] foldl_impl
    @ ./reduce.jl:48 [inlined]
 [41] mapfoldl_impl(f::typeof(identity), op::Base.var"#58#59"{Tracker.var"#13#14"{Bool}}, nt::Nothing, itr::Base.Iterators.Zip{Tuple{Tuple{Tracker.Tracked{Matrix{Float32}}, Tracker.Tracked{Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}})
    @ Base ./reduce.jl:44
 [42] mapfoldl(f::Function, op::Function, itr::Base.Iterators.Zip{Tuple{Tuple{Tracker.Tracked{Matrix{Float32}}, Tracker.Tracked{Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}}; init::Nothing)
    @ Base ./reduce.jl:162
 [43] #foldl#259
    @ ./reduce.jl:185 [inlined]
 [44] foreach(::Function, ::Tuple{Tracker.Tracked{Matrix{Float32}}, Tracker.Tracked{Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}}, ::Tuple{Matrix{Float32}, Matrix{Float32}})
    @ Base ./tuple.jl:556
 [45] back_(c::Tracker.Call{Tracker.var"#back#622"{2, typeof(+), Tuple{TrackedArray{,Matrix{Float32}}, TrackedArray{,Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}}}, Tuple{Tracker.Tracked{Matrix{Float32}}, Tracker.Tracked{Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}}}}}, Δ::Matrix{Float32}, once::Bool)
    @ Tracker /mnt/julia/packages/Tracker/9xWLl/src/back.jl:38
 [46] back(x::Tracker.Tracked{Matrix{Float32}}, Δ::Matrix{Float32}, once::Bool)
    @ Tracker /mnt/julia/packages/Tracker/9xWLl/src/back.jl:58
 [47] #13
    @ /mnt/julia/packages/Tracker/9xWLl/src/back.jl:38 [inlined]
 [48] #58
    @ ./tuple.jl:556 [inlined]
 [49] BottomRF
    @ ./reduce.jl:81 [inlined]
 [50] _foldl_impl
    @ ./reduce.jl:58 [inlined]
 [51] foldl_impl
    @ ./reduce.jl:48 [inlined]
 [52] mapfoldl_impl
    @ ./reduce.jl:44 [inlined]
 [53] #mapfoldl#258
    @ ./reduce.jl:162 [inlined]
 [54] #foldl#259
    @ ./reduce.jl:185 [inlined]
 [55] foreach
    @ ./tuple.jl:556 [inlined]
 [56] back_(c::Tracker.Call{Tracker.var"#552#553"{TrackedArray{,Matrix{Float32}}}, Tuple{Tracker.Tracked{Matrix{Float32}}}}, Δ::Float32, once::Bool)
    @ Tracker /mnt/julia/packages/Tracker/9xWLl/src/back.jl:38
 [57] back(x::Tracker.Tracked{Float32}, Δ::Int64, once::Bool)
    @ Tracker /mnt/julia/packages/Tracker/9xWLl/src/back.jl:58
 [58] #back!#15
    @ /mnt/julia/packages/Tracker/9xWLl/src/back.jl:77 [inlined]
 [59] #back!#32
    @ /mnt/julia/packages/Tracker/9xWLl/src/lib/real.jl:16 [inlined]
 [60] back!(x::Tracker.TrackedReal{Float32})
    @ Tracker /mnt/julia/packages/Tracker/9xWLl/src/lib/real.jl:14
 [61] gradient_(f::Function, xs::ComponentVector{Float32})
    @ Tracker /mnt/julia/packages/Tracker/9xWLl/src/back.jl:4
 [62] gradient(f::Function, xs::ComponentVector{Float32}; nest::Bool)
    @ Tracker /mnt/julia/packages/Tracker/9xWLl/src/back.jl:164
 [63] gradient(f::Function, xs::ComponentVector{Float32})
    @ Tracker /mnt/julia/packages/Tracker/9xWLl/src/back.jl:164
 [64] top-level scope
    @ REPL[33]:1

@ba2tro
Copy link
Contributor

ba2tro commented Jun 13, 2022

Yeah, NeuralSDE and NeuralDSDE use TrackerAdjoint by default and give a similar error with Lux compatible constructors. Following is an error I got with AugmentedNDE Layer constructed with NeuralSDE

julia> grads = Zygote.gradient((x,p,st1,st2) -> sum(andsde(x,p,st1,st2)[1]),x,p,st1,st2)

ERROR: MethodError: no method matching zero(::Type{ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:250, Axis(weight = ViewAxis(1:200, ShapedAxis((50, 4), NamedTuple())), bias = ViewAxis(201:250, ShapedAxis((50, 1), NamedTuple())))), layer_2 = ViewAxis(251:454, Axis(weight = ViewAxis(1:200, ShapedAxis((4, 50), NamedTuple())), bias = ViewAxis(201:204, ShapedAxis((4, 1), NamedTuple())))))}}}})
Closest candidates are:
  zero(::Union{Type{P}, P}) where P<:Dates.Period at C:\Users\user\AppData\Local\Programs\Julia-1.7.2\share\julia\stdlib\v1.7\Dates\src\periods.jl:53
  zero(::Union{AbstractAlgebra.Generic.LaurentSeriesFieldElem{T}, AbstractAlgebra.Generic.LaurentSeriesRingElem{T}} where T<:AbstractAlgebra.RingElement) at C:\Users\user\.julia\packages\AbstractAlgebra\nmiq9\src\generic\LaurentSeries.jl:466
  zero(::Union{AbstractAlgebra.Generic.LaurentSeriesFieldElem{T}, AbstractAlgebra.Generic.LaurentSeriesRingElem{T}} where T<:AbstractAlgebra.RingElement, ::String; cached) at C:\Users\user\.julia\packages\AbstractAlgebra\nmiq9\src\generic\LaurentSeries.jl:480

@ba2tro
Copy link
Contributor

ba2tro commented Jun 22, 2022

Same issue with ReverseDiffAdjoint, it wraps the parameters in the reverse pass which doesn't fit the type constraint for the input of a Lux.Chain

@avik-pal
Copy link
Member

What is the exact error stacktrace? I do test for Chain with reversediff https://github.com/avik-pal/Lux.jl/blob/ecc5dc5d86c603a429a4372bc2f13b360fb8a60c/test/autodiff.jl#L10-L24

@ba2tro
Copy link
Contributor

ba2tro commented Jun 22, 2022

The issue here, as I said above is ReverseDiffAdjoint wrapping the second argument

ERROR: MethodError: no method matching (::Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}})(::ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ::ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentArrays.ComponentVector{Float32, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{ComponentArrays.Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))}}}, ComponentArrays.ComponentVector{Float32, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true},
Tuple{ComponentArrays.Axis{(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))}}}}, ::NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}})
Closest candidates are:
  (::Lux.Chain)(::Any, !Matched::Union{ComponentArrays.ComponentArray, NamedTuple}, ::NamedTuple) at C:\Users\user\.julia\packages\Lux\qQlb5\src\layers\basic.jl:519

Stacktrace:

(::DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity),
typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}})(u::ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, p::ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, t::Float32; st::NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}) at [C:\Users\user.julia\dev\DiffEqFlux\src\neural_de.jl](vscode-file://vscode-app/c:/Program%20Files/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html)

(::DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}})(u::ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, p::ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, t::Float32) at [C:\Users\user.julia\dev\DiffEqFlux\src\neural_de.jl](vscode-file://vscode-app/c:/Program%20Files/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html)

(::SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing})(::ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ::Vararg{Any}) at [C:\Users\user.julia\packages\SciMLBase\dYFnI\src\scimlfunctions.jl](vscode-file://vscode-app/c:/Program%20Files/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html)

(::DiffEqSensitivity.var"#_f#281"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}})(::ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ::Vararg{Any}) at [C:\Users\user.julia\packages\DiffEqSensitivity\kMyur\src\concrete_solve.jl](vscode-file://vscode-app/c:/Program%20Files/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html)

(::SDEFunction{false, DiffEqSensitivity.var"#_f#281"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, DiffEqSensitivity.var"#_g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing})(::ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ::Vararg{Any}) at [C:\Users\user.julia\packages\SciMLBase\dYFnI\src\scimlfunctions.jl](vscode-file://vscode-app/c:/Program%20Files/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html)

sde_determine_initdt(u0::ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, t::Float32, tdir::Float32, dtmax::Float32, abstol::Float64, reltol::Float64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), prob::SDEProblem{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, Tuple{Float32, Float32}, false, ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, Nothing, SDEFunction{false, DiffEqSensitivity.var"#_f#281"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, DiffEqSensitivity.var"#_g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqSensitivity.var"#_g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}, order::Rational{Int64}, integrator::StochasticDiffEq.SDEIntegrator{SOSRI, false, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, Float32, Float32, ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, DiffEqNoiseProcess.NoiseProcess{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, 3, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, Vector{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, DiffEqNoiseProcess.RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, Nothing, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, RODESolution{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, 3, Vector{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Nothing, Nothing, Vector{Float32}, DiffEqNoiseProcess.NoiseProcess{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, 3, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, Vector{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, DiffEqNoiseProcess.RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, SDEProblem{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, Tuple{Float32, Float32}, false, ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, Nothing, SDEFunction{false, DiffEqSensitivity.var"#_f#281"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, DiffEqSensitivity.var"#_g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqSensitivity.var"#_g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}, SOSRI, StochasticDiffEq.LinearInterpolationData{Vector{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Vector{Float32}}, DiffEqBase.DEStats}, StochasticDiffEq.FourStageSRIConstantCache{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Float32}, SDEFunction{false, DiffEqSensitivity.var"#_f#281"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, DiffEqSensitivity.var"#_g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqSensitivity.var"#_g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, Nothing, StochasticDiffEq.SDEOptions{Float32, Float32, OrdinaryDiffEq.PIController{Float32}, typeof(DiffEqBase.ODE_DEFAULT_NORM), Nothing, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, Nothing, Nothing, Int64, Float64, Float64, ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Tuple{}, StepRangeLen{Float32, Float64, Float64, Int64}, Tuple{}}, Nothing, ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Nothing, Nothing}) at [C:\Users\user.julia\packages\StochasticDiffEq\LYyNp\src\initdt.jl](vscode-file://vscode-app/c:/Program%20Files/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html)

`auto_dt_reset!(integrator::StochasticDiffEq.SDEIntegrator{SOSRI, false, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, Float32, Float32, ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, DiffEqNoiseProcess.NoiseProcess{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, 3, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, Vector{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, DiffEqNoiseProcess.RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, Nothing, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, RODESolution{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, 3, Vector{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Nothing, Nothing, Vector{Float32}, DiffEqNoiseProcess.NoiseProcess{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, 3, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, Vector{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, DiffEqNoiseProcess.RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, SDEProblem{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, Tuple{Float32, Float32}, false, ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, Nothing, SDEFunction{false, DiffEqSensitivity.var"#f#281"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, DiffEqSensitivity.var"#g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqSensitivity.var"#g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}, SOSRI, StochasticDiffEq.LinearInterpolationData{Vector{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Vector{Float32}}, DiffEqBase.DEStats}, StochasticDiffEq.FourStageSRIConstantCache{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Float32}, SDEFunction{false, DiffEqSensitivity.var"#f#281"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, DiffEqSensitivity.var"#g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol,…

@ba2tro
Copy link
Contributor

ba2tro commented Jun 22, 2022

This is related to SciML/DiffEqFlux.jl#736 where I am trying to make the doc example work, with the Lux compatible constructor for NeuralDSDE defined in the same pr. The error occurs when we have sensealg = ReverseDiffAdjoint here https://github.com/Abhishek-1Bhatt/DiffEqFlux.jl/blob/327611f928a166d8e466675ab3365dbe0f422cb5/src/neural_de.jl#L212 which is currently set to InterpolatingAdjoint(). And occurs when running this line in the tutorial https://github.com/Abhishek-1Bhatt/DiffEqFlux.jl/blob/neural_sde/docs/src/examples/neural_sde.md#:~:text=result1%20%3D%20Optimization.solve(optprob%2C%20opt%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20callback%20%3D%20callback%2C%20maxiters%20%3D%20100)

@avik-pal
Copy link
Member

That is an old version of Lux

@ba2tro
Copy link
Contributor

ba2tro commented Jun 22, 2022

It is really strange that it uses 0.4.3 while I have 0.4.6, do you have any ideas on what could cause that?

(@v1.7) pkg> st
      Status `C:\Users\user\.julia\environments\v1.7\Project.toml`
  [4fba245c] ArrayInterface v6.0.17 
  [6e4b80f9] BenchmarkTools v1.3.1  
  [479239e8] Catalyst v12.0.0       
  [b0b7db55] ComponentArrays v0.12.0
  [2445eb08] DataDrivenDiffEq v0.8.4
  [bcd4f6db] DelayDiffEq v5.37.0    
  [f3b72e0c] DiffEqDevTools v2.30.0 
  [aae7a2af] DiffEqFlux v1.50.0
  [9fdde737] DiffEqOperators v4.43.1
  [41bf760c] DiffEqSensitivity v6.80.0 `C:\Users\user\.julia\dev\DiffEqSensitivity`
  [0c46a032] DifferentialEquations v7.1.0
  [b4f34e82] Distances v0.10.7
  [31c24e10] Distributions v0.25.62
  [ced4e74d] DistributionsAD v0.6.41
  [7da242da] Enzyme v0.10.1
  [5789e2e9] FileIO v1.14.0
  [587475ba] Flux v0.13.3
  [f67ccb44] HDF5 v0.16.10
  [033835bb] JLD2 v0.4.22
  [98e50ef6] JuliaFormatter v1.0.3
  [e5e0dc1b] Juno v0.8.4
  [7f56f5a3] LSODA v0.7.0
  [bdcacae8] LoopVectorization v0.12.118
  [b2108857] Lux v0.4.6
  [23992714] MAT v0.10.3
  [eb30cadb] MLDatasets v0.7.2
  [ee78f7c6] Makie v0.17.7
  [961ee093] ModelingToolkit v8.14.1
  [54ca160b] ODEInterface v0.5.0
  [09606e27] ODEInterfaceDiffEq v3.10.1
  [5913d0e6] OperatorLearning v0.2.2 `C:\Users\user\.julia\dev\OperatorLearning`
  [429524aa] Optim v1.7.0
  [7f7a1694] Optimization v3.7.0
  [253f991c] OptimizationFlux v0.1.0
  [4e6fcdb7] OptimizationNLopt v0.1.0
  [36348300] OptimizationOptimJL v0.1.1
  [42dfb2eb] OptimizationOptimisers v0.1.0
  [500b13db] OptimizationPolyalgorithms v0.1.0
  [1dea7af3] OrdinaryDiffEq v6.16.1
  [91a5bcdd] Plots v1.30.1
  [b4db0fb7] ReactionNetworkImporters v0.13.4
  [0bca4576] SciMLBase v1.41.3
  [de6bee2f] SimpleChains v0.2.12 `C:\Users\user\.julia\dev\SimpleChains`
  [789caeaf] StochasticDiffEq v6.49.1
  [c3572dad] Sundials v4.9.4
  [a759f4b9] TimerOutputs v0.5.20
  [3d5dd08c] VectorizationBase v0.21.36
  [e88e6eb3] Zygote v0.6.40

Anyways due to getting this error locally, I didn't try this on the doctests, I'll give it a try hopefully it works on CI with ReverseDiffAdjoint

@ChrisRackauckas
Copy link
Member

Remove other packages like OperatorLearning

@ba2tro
Copy link
Contributor

ba2tro commented Jun 23, 2022

With the Lux v0.4.6, this issue shows up

ERROR: DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 20 and 2")
Stacktrace:

_bcs1 at [.\broadcast.jl](vscode-file://vscode-app/c:/Program%20Files/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html)

_bcs at [.\broadcast.jl](vscode-file://vscode-app/c:/Program%20Files/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html)

broadcast_shape at [.\broadcast.jl](vscode-file://vscode-app/c:/Program%20Files/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html)

combine_axes at [.\broadcast.jl](vscode-file://vscode-app/c:/Program%20Files/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html)

instantiate at [.\broadcast.jl](vscode-file://vscode-app/c:/Program%20Files/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html)

materialize(bc::Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle, Nothing, typeof(/), Tuple{Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle, Nothing, typeof(max), Tuple{Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle, Nothing, typeof(DiffEqBase.ODE_DEFAULT_NORM), Tuple{Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle, Nothing, typeof(+), Tuple{ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}}, Float32}}, Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle, Nothing, typeof(DiffEqBase.ODE_DEFAULT_NORM), Tuple{Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle, Nothing, typeof(-), Tuple{ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}}, Float32}}}}, Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle, Nothing, typeof(+), Tuple{Float64, Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle, Nothing, typeof(*), Tuple{Base.Broadcast.Broadcasted{ReverseDiff.TrackedStyle, Nothing, typeof(DiffEqBase.ODE_DEFAULT_NORM), Tuple{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, Float32}}, Float64}}}}}}) at [.\broadcast.jl](vscode-file://vscode-app/c:/Program%20Files/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html)

sde_determine_initdt(u0::ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, t::Float32, tdir::Float32, dtmax::Float32, abstol::Float64, reltol::Float64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), prob::SDEProblem{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, Tuple{Float32, Float32}, false, ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, Nothing, SDEFunction{false, DiffEqSensitivity.var"#_f#281"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, DiffEqSensitivity.var"#_g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqSensitivity.var"#_g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}, order::Rational{Int64}, integrator::StochasticDiffEq.SDEIntegrator{SOSRI, false, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, Float32, Float32, ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, DiffEqNoiseProcess.NoiseProcess{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, 3, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, Vector{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, DiffEqNoiseProcess.RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, Nothing, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, RODESolution{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, 3, Vector{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Nothing, Nothing, Vector{Float32}, DiffEqNoiseProcess.NoiseProcess{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, 3, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, Vector{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, DiffEqNoiseProcess.RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, SDEProblem{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, Tuple{Float32, Float32}, false, ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, Nothing, SDEFunction{false, DiffEqSensitivity.var"#_f#281"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, DiffEqSensitivity.var"#_g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqSensitivity.var"#_g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}, SOSRI, StochasticDiffEq.LinearInterpolationData{Vector{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Vector{Float32}}, DiffEqBase.DEStats}, StochasticDiffEq.FourStageSRIConstantCache{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Float32}, SDEFunction{false, DiffEqSensitivity.var"#_f#281"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, DiffEqSensitivity.var"#_g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqSensitivity.var"#_g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, Nothing, StochasticDiffEq.SDEOptions{Float32, Float32, OrdinaryDiffEq.PIController{Float32}, typeof(DiffEqBase.ODE_DEFAULT_NORM), Nothing, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, Nothing, Nothing, Int64, Float64, Float64, ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Tuple{}, StepRangeLen{Float32, Float64, Float64, Int64}, Tuple{}}, Nothing, ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Nothing, Nothing}) at [C:\Users\user.julia\packages\StochasticDiffEq\LYyNp\src\initdt.jl](vscode-file://vscode-app/c:/Program%20Files/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html)

auto_dt_reset!(integrator::StochasticDiffEq.SDEIntegrator{SOSRI, false, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, Float32, Float32, ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, DiffEqNoiseProcess.NoiseProcess{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, 3, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, Vector{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, DiffEqNoiseProcess.RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, Nothing, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, RODESolution{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, 3, Vector{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Nothing, Nothing, Vector{Float32}, DiffEqNoiseProcess.NoiseProcess{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, 3, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, Vector{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, DiffEqNoiseProcess.RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, SDEProblem{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, Tuple{Float32, Float32}, false, ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, Nothing, SDEFunction{false, DiffEqSensitivity.var"#_f#281"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, DiffEqSensitivity.var"#_g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqSensitivity.var"#_g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}, SOSRI, StochasticDiffEq.LinearInterpolationData{Vector{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Vector{Float32}}, DiffEqBase.DEStats}, StochasticDiffEq.FourStageSRIConstantCache{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Float32}, SDEFunction{false, DiffEqSensitivity.var"#_f#281"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, DiffEqSensitivity.var"#_g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqSensitivity.var"#_g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt_#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, Nothing, StochasticDiffEq.SDEOptions{Float32, Float32, OrdinaryDiffEq.PIController{Float32}, typeof(DiffEqBase.ODE_DEFAULT_NORM), Nothing, CallbackSet{Tuple{}, Tuple{}}, typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, DataStructures.BinaryHeap{Float32, DataStructures.FasterForward}, Nothing, Nothing, Int64, Float64, Float64, ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Tuple{}, StepRangeLen{Float32, Float64, Float64, Int64}, Tuple{}}, Nothing, ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Nothing, Nothing}) at [C:\Users\user.julia\packages\StochasticDiffEq\LYyNp\src\integrators\integrator_interface.jl](vscode-file://vscode-app/c:/Program%20Files/Microsoft%20VS%20Code/resources/app/out/vs/code/electron-browser/workbench/workbench.html)

`handle_dt!(integrator::StochasticDiffEq.SDEIntegrator{SOSRI, false, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, Float32, Float32, ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, DiffEqNoiseProcess.NoiseProcess{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, 3, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, Vector{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, DiffEqNoiseProcess.RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, Nothing, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, RODESolution{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, 3, Vector{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, Nothing, Nothing, Vector{Float32}, DiffEqNoiseProcess.NoiseProcess{ReverseDiff.TrackedReal{Float32, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, 3, Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, Vector{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, typeof(DiffEqNoiseProcess.WHITE_NOISE_DIST), typeof(DiffEqNoiseProcess.WHITE_NOISE_BRIDGE), false, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, ResettableStacks.ResettableStack{Tuple{Float32, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}}, false}, DiffEqNoiseProcess.RSWM{Float64}, Nothing, RandomNumbers.Xorshifts.Xoroshiro128Plus}, SDEProblem{ReverseDiff.TrackedArray{Float32, Float32, 2, Matrix{Float32}, Matrix{Float32}}, Tuple{Float32, Float32}, false, ReverseDiff.TrackedArray{Float32, Float32, 1, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}}, Nothing, SDEFunction{false, DiffEqSensitivity.var"#f#281"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, DiffEqSensitivity.var"#g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, LinearAlgebra.UniformScaling{Bool}, Nothing, typeof(DiffEqFlux.basic_tgrad), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqFlux.var"#g#152"{DiffEqFlux.var"#g#149#153"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Tuple{}}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Optimisers.Restructure{Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Tuple{}}, Tuple{Float32, Float32}, Tuple{SOSRI}, Base.Pairs{Symbol, Any, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:saveat, :reltol, :abstol), Tuple{StepRangeLen{Float32, Float64, Float64, Int64}, Float64, Float64}}}}}, NamedTuple{(), Tuple{}}}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, DiffEqSensitivity.var"#g#282"{SDEProblem{Matrix{Float32}, Tuple{Float32, Float32}, false, ComponentArrays.ComponentVector{Float32, Vector{Float32}, Tuple{ComponentArrays.Axis{(p1 = ViewAxis(1:252, Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:150, Axis(weight = ViewAxis(1:100, ShapedAxis((50, 2), NamedTuple())), bias = ViewAxis(101:150, ShapedAxis((50, 1), NamedTuple())))), layer_3 = ViewAxis(151:252, Axis(weight = ViewAxis(1:100, ShapedAxis((2, 50), NamedTuple())), bias = ViewAxis(101:102, ShapedAxis((2, 1), NamedTuple())))))), p2 = ViewAxis(253:258, Axis(weight = ViewAxis(1:4, ShapedAxis((2, 2), NamedTuple())), bias = ViewAxis(5:6, ShapedAxis((2, 1), NamedTuple())))))}}}, Nothing, SDEFunction{false, DiffEqFlux.var"#dudt#150"{DiffEqFlux.var"#dudt_#148#151"{NeuralDSDE{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}, Lux.Dense{true, typeof(tanh_fast), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}, Lux.Dense{true, typeof(identity), typeof(Lux.glorot_uniform), typeof(Lux.zeros32)}}}}, Vector{Bool}, Optimisers.Restructure{Lux.Chain{NamedTuple{(:layer_1, :layer_2, :layer_3), Tuple{ActivationFunction{var"#5#6"}…

@ba2tro
Copy link
Contributor

ba2tro commented Jun 23, 2022

@ChrisRackauckas
Copy link
Member

Yes it is a wrapper package issue: mixing ComponentArrays + TrackedArray in both senses. I'm going to punt on this for now but it's something to keep in mind.

@avik-pal
Copy link
Member

avik-pal commented Jun 24, 2022

So fixing Tracker.jl is not hard

Tracker.param(c::ComponentArray) = ComponentArray(Tracker.param(getdata(c)), getaxes(c))

Tracker.grad(c::ComponentArray) = Tracker.grad(getdata(c))

Tracker.tracker(c::ComponentArray) = Tracker.tracker(getdata(c))

though I am not sure where to create the PR. Tracker or ComponentArrays

EDIT: I am wrong, doesn't work in nesting beyond level 2 (level_1 = (weight = ..., bias = ...),...)

@ChrisRackauckas
Copy link
Member

ComponentArrays

@avik-pal
Copy link
Member

The original code works with the latest Lux release (v0.4.53).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants