Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

neural_ode_sciml example fails when Dense layer replaced by GRU #432

Closed
John-Boik opened this issue Oct 12, 2020 · 20 comments
Closed

neural_ode_sciml example fails when Dense layer replaced by GRU #432

John-Boik opened this issue Oct 12, 2020 · 20 comments
Labels
good first issue Good for newcomers

Comments

@John-Boik
Copy link

As a first step leading up to GRU-ODE or ODE-LSTM implementations, I'd like to switch out the Dense layer in the neural_ode_sciml example with a GRU layer. However, doing so raises the error LoadError: DimensionMismatch("array could not be broadcast to match destination"). I don't understand where the problem is occuring, exactly, or how to fix it. Any ideas?

Code is as follows, with the main differences from the original example being:

  • using statements have been changed to import statements (for clarity)
  • FastChain has been changed to Chain
  • (x, p) -> x.^3 changed to x -> x.^3,
  • the code has been placed in a module, called via a Linux command window with include("./TestDiffEq3b.jl")
  • the Dense layer has been changed to a GRU layer

This issue is loosely related to Training of UDEs with recurrent networks #391 and Flux.destructure doesn't preserve RNN state #1329. See also
ODE-LSTM layer #422
.

The code is as follows, with the Dense layer commented out and replaced by the GRU layer:

module TestDiffeq3b

using Revise
using Infiltrator
using Formatting

import DiffEqFlux
import OrdinaryDiffEq 
import Flux
import Optim
import Plots

u0 = Float32[2.0; 0.0]    
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)

function trueODEfunc(du, u, p, t)    
    true_A = [-0.1 2.0; -2.0 -0.1]    
    du .= ((u.^3)'true_A)'    
end

prob_trueode = OrdinaryDiffEq.ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(OrdinaryDiffEq.solve(prob_trueode, OrdinaryDiffEq.Tsit5(), saveat = tsteps))    

dudt2 = Flux.Chain(
    x -> x.^3,
    Flux.Dense(2, 50, tanh),
    #Flux.Dense(50, 2)
    Flux.GRU(50, 2)
    )

p, re = Flux.destructure(dudt2)  
neural_ode_f(u, p, t) = re(p)(u)

prob = OrdinaryDiffEq.ODEProblem(neural_ode_f, u0, tspan, p)


function predict_neuralode(p)
    tmp_prob = OrdinaryDiffEq.remake(prob,p=p)
    res = Array(OrdinaryDiffEq.solve(tmp_prob, OrdinaryDiffEq.Tsit5(), saveat=tsteps))
    return res
end


function loss_neuralode(p)
        pred = predict_neuralode(p)  # (2,30)
        loss = sum(abs2, ode_data .- pred)  # scalar
        return loss, pred
end


callback = function (p, l, pred; doplot = true)
    display(l)
    # plot current prediction against data
    plt = Plots.scatter(tsteps, ode_data[1,:], label = "data")
    Plots.scatter!(plt, tsteps, pred[1,:], label = "prediction")
    if doplot
        display(Plots.plot(plt))
    end
    return false
end

result_neuralode = DiffEqFlux.sciml_train(
    loss_neuralode, 
    p,
    Flux.ADAM(0.05), 
    cb = callback,
    maxiters = 300
    )


result_neuralode2 = DiffEqFlux.sciml_train(
    loss_neuralode,
    result_neuralode.minimizer,
    Optim.LBFGS(),
    cb = callback,
    allow_f_increases = false
    )


end    # ------------------------------- module -----------------------------------

The error message is:


ERROR: LoadError: DimensionMismatch("array could not be broadcast to match destination")
Stacktrace:
 [1] check_broadcast_shape at ./broadcast.jl:520 [inlined]
 [2] check_broadcast_axes at ./broadcast.jl:523 [inlined]
 [3] instantiate at ./broadcast.jl:269 [inlined]
 [4] materialize! at ./broadcast.jl:848 [inlined]
 [5] materialize!(::SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true}, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(identity),Tuple{Array{Float32,1}}}) at ./broadcast.jl:845
 [6] _vecjacobian!(::SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true}, ::Array{Float32,1}, ::SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true}, ::Array{Float32,1}, ::Float32, ::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool}},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Array{Float32,1},DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}}, ::DiffEqSensitivity.ZygoteVJP, ::SubArray{Float32,1,Array{Float32,1},Tuple{UnitRange{Int64}},true}, ::Nothing) at /home/jboik/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/derivative_wrappers.jl:296
 [7] _vecjacobian! at /home/jboik/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/derivative_wrappers.jl:193 [inlined]
 [8] #vecjacobian!#20 at /home/jboik/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/derivative_wrappers.jl:147 [inlined]
 [9] (::DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool}},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Array{Float32,1},DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}})(::Array{Float32,1}, ::Array{Float32,1}, ::Array{Float32,1}, ::Float32) at /home/jboik/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/interpolating_adjoint.jl:145
 [10] ODEFunction at /home/jboik/.julia/packages/DiffEqBase/gLFRA/src/diffeqfunction.jl:248 [inlined]
 [11] initialize!(::OrdinaryDiffEq.ODEIntegrator{OrdinaryDiffEq.Tsit5,true,Array{Float32,1},Nothing,Float32,Array{Float32,1},Float32,Float32,Float32,Array{Array{Float32,1},1},DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},DiffEqBase.ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool}},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Array{Float32,1},DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Symbol,DiffEqBase.CallbackSet{Tuple{},Tuple{DiffEqBase.DiscreteCallback{DiffEqCallbacks.var"#33#38"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#40"{Bool,DiffEqCallbacks.var"#37#42"{Bool},DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}}}}}},Tuple{Symbol},NamedTuple{(:callback,),Tuple{DiffEqBase.CallbackSet{Tuple{},Tuple{DiffEqBase.DiscreteCallback{DiffEqCallbacks.var"#33#38"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#40"{Bool,DiffEqCallbacks.var"#37#42"{Bool},DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}}}}}}}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool}},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Array{Float32,1},DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5Cache{Array{Float32,1},Array{Float32,1},Array{Float32,1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}},DiffEqBase.DEStats},DiffEqBase.ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool}},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Array{Float32,1},DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.Tsit5Cache{Array{Float32,1},Array{Float32,1},Array{Float32,1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},OrdinaryDiffEq.DEOptions{Float64,Float64,Float32,Float32,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),DiffEqBase.CallbackSet{Tuple{},Tuple{DiffEqBase.DiscreteCallback{DiffEqCallbacks.var"#33#38"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#40"{Bool,DiffEqCallbacks.var"#37#42"{Bool},DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}}}}}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float32,Base.Order.ForwardOrdering},DataStructures.BinaryHeap{Float32,Base.Order.ForwardOrdering},Nothing,Nothing,Int64,Array{Float32,1},Array{Float32,1},Tuple{}},Array{Float32,1},Float32,Nothing,OrdinaryDiffEq.DefaultInit}, ::OrdinaryDiffEq.Tsit5Cache{Array{Float32,1},Array{Float32,1},Array{Float32,1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}}) at /home/jboik/.julia/packages/OrdinaryDiffEq/HO8vN/src/perform_step/low_order_rk_perform_step.jl:623
 [12] __init(::DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},DiffEqBase.ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool}},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Array{Float32,1},DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Symbol,DiffEqBase.CallbackSet{Tuple{},Tuple{DiffEqBase.DiscreteCallback{DiffEqCallbacks.var"#33#38"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#40"{Bool,DiffEqCallbacks.var"#37#42"{Bool},DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}}}}}},Tuple{Symbol},NamedTuple{(:callback,),Tuple{DiffEqBase.CallbackSet{Tuple{},Tuple{DiffEqBase.DiscreteCallback{DiffEqCallbacks.var"#33#38"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#40"{Bool,DiffEqCallbacks.var"#37#42"{Bool},DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}}}}}}}}},DiffEqBase.StandardODEProblem}, ::OrdinaryDiffEq.Tsit5, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Type{Val{true}}; saveat::Array{Float32,1}, tstops::Array{Float32,1}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Bool, callback::DiffEqBase.CallbackSet{Tuple{},Tuple{DiffEqBase.DiscreteCallback{DiffEqCallbacks.var"#33#38"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#40"{Bool,DiffEqCallbacks.var"#37#42"{Bool},DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}}}}}}, dense::Bool, calck::Bool, dt::Float32, dtmin::Nothing, dtmax::Float32, force_dtmin::Bool, adaptive::Bool, gamma::Rational{Int64}, abstol::Float64, reltol::Float64, qmin::Rational{Int64}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, qoldinit::Rational{Int64}, fullnormalize::Bool, failfactor::Int64, beta1::Nothing, beta2::Nothing, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool, alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEq.DefaultInit, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/jboik/.julia/packages/OrdinaryDiffEq/HO8vN/src/solve.jl:428
 [13] #__solve#391 at /home/jboik/.julia/packages/OrdinaryDiffEq/HO8vN/src/solve.jl:4 [inlined]
 [14] solve_call(::DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},DiffEqBase.ODEFunction{true,DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction{DiffEqSensitivity.AdjointDiffCache{Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool}},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Array{Float32,1},DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats},Nothing,DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Symbol,DiffEqBase.CallbackSet{Tuple{},Tuple{DiffEqBase.DiscreteCallback{DiffEqCallbacks.var"#33#38"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#40"{Bool,DiffEqCallbacks.var"#37#42"{Bool},DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}}}}}},Tuple{Symbol},NamedTuple{(:callback,),Tuple{DiffEqBase.CallbackSet{Tuple{},Tuple{DiffEqBase.DiscreteCallback{DiffEqCallbacks.var"#33#38"{Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}},DiffEqCallbacks.var"#35#40"{Bool,DiffEqCallbacks.var"#37#42"{Bool},DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},Base.RefValue{Union{Nothing, Float32}},DiffEqCallbacks.var"#34#39"{DiffEqSensitivity.var"#94#96"{Base.RefValue{Int64},Array{Float32,1}},DiffEqSensitivity.var"#95#97"{DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon},DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Base.OneTo{Int64},UnitRange{Int64},LinearAlgebra.UniformScaling{Bool},Bool,Nothing,Nothing,Nothing,Nothing,Bool,Array{Float32,1},Array{Float32,1},Array{Float32,1},Base.RefValue{Int64},Int64,LinearAlgebra.UniformScaling{Bool}},Base.RefValue{Union{Nothing, Float32}}}}}}}}}},DiffEqBase.StandardODEProblem}, ::OrdinaryDiffEq.Tsit5; merge_callbacks::Bool, kwargs::Base.Iterators.Pairs{Symbol,Any,NTuple{6,Symbol},NamedTuple{(:save_everystep, :save_start, :saveat, :tstops, :abstol, :reltol),Tuple{Bool,Bool,Array{Float32,1},Array{Float32,1},Float64,Float64}}}) at /home/jboik/.julia/packages/DiffEqBase/gLFRA/src/solve.jl:65
 [15] #solve_up#458 at /home/jboik/.julia/packages/DiffEqBase/gLFRA/src/solve.jl:86 [inlined]
 [16] #solve#457 at /home/jboik/.julia/packages/DiffEqBase/gLFRA/src/solve.jl:74 [inlined]
 [17] _adjoint_sensitivities(::DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats}, ::DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool}, ::OrdinaryDiffEq.Tsit5, ::DiffEqSensitivity.var"#df#134"{Array{Float32,2},Array{Float32,1},Colon}, ::Array{Float32,1}, ::Nothing; abstol::Float64, reltol::Float64, checkpoints::Array{Float32,1}, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/jboik/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/sensitivity_interface.jl:22
 [18] _adjoint_sensitivities(::DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats}, ::DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool}, ::OrdinaryDiffEq.Tsit5, ::Function, ::Array{Float32,1}, ::Nothing) at /home/jboik/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/sensitivity_interface.jl:13 (repeats 2 times)
 [19] adjoint_sensitivities(::DiffEqBase.ODESolution{Float32,2,Array{Array{Float32,1},1},Nothing,Nothing,Array{Float32,1},Array{Array{Array{Float32,1},1},1},DiffEqBase.ODEProblem{Array{Float32,1},Tuple{Float32,Float32},false,Array{Float32,1},DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},DiffEqBase.StandardODEProblem},OrdinaryDiffEq.Tsit5,OrdinaryDiffEq.InterpolationData{DiffEqBase.ODEFunction{false,typeof(Main.TestDiffeq3b.neural_ode_f),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Array{Float32,1},1},Array{Float32,1},Array{Array{Array{Float32,1},1},1},OrdinaryDiffEq.Tsit5ConstantCache{Float32,Float32}},DiffEqBase.DEStats}, ::OrdinaryDiffEq.Tsit5, ::Vararg{Any,N} where N; sensealg::DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool}, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at /home/jboik/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/sensitivity_interface.jl:6
 [20] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#133"{OrdinaryDiffEq.Tsit5,DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Array{Float32,1},Array{Float32,1},Tuple{},Colon})(::Array{Float32,2}) at /home/jboik/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/concrete_solve.jl:144
 [21] #673#back at /home/jboik/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:55 [inlined]
 [22] #145 at /home/jboik/.julia/packages/Zygote/chgvX/src/lib/lib.jl:175 [inlined]
 [23] (::Zygote.var"#1681#back#147"{Zygote.var"#145#146"{DiffEqBase.var"#673#back#471"{DiffEqSensitivity.var"#adjoint_sensitivity_backpass#133"{OrdinaryDiffEq.Tsit5,DiffEqSensitivity.InterpolatingAdjoint{0,true,Val{:central},Bool,Bool},Array{Float32,1},Array{Float32,1},Tuple{},Colon}},Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::Array{Float32,2}) at /home/jboik/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [24] #solve#457 at /home/jboik/.julia/packages/DiffEqBase/gLFRA/src/solve.jl:74 [inlined]
 [25] (::typeof(∂(#solve#457)))(::Array{Float32,2}) at /home/jboik/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [26] (::Zygote.var"#145#146"{typeof(∂(#solve#457)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}})(::Array{Float32,2}) at /home/jboik/.julia/packages/Zygote/chgvX/src/lib/lib.jl:175
 [27] (::Zygote.var"#1681#back#147"{Zygote.var"#145#146"{typeof(∂(#solve#457)),Tuple{NTuple{6,Nothing},Tuple{Nothing}}}})(::Array{Float32,2}) at /home/jboik/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [28] (::typeof(∂(solve##kw)))(::Array{Float32,2}) at /home/jboik/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [29] predict_neuralode at /home/jboik/Devel/Ideai_Ju/Ideai/examples/Load_01/load_fake/TestDiffEq3b.jl:70 [inlined]
 [30] (::typeof(∂(predict_neuralode)))(::Array{Float32,2}) at /home/jboik/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [31] loss_neuralode at /home/jboik/Devel/Ideai_Ju/Ideai/examples/Load_01/load_fake/TestDiffEq3b.jl:76 [inlined]
 [32] #145 at /home/jboik/.julia/packages/Zygote/chgvX/src/lib/lib.jl:175 [inlined]
 [33] (::Zygote.var"#1681#back#147"{Zygote.var"#145#146"{typeof(∂(loss_neuralode)),Tuple{Tuple{Nothing},Tuple{}}}})(::Tuple{Float32,Nothing}) at /home/jboik/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [34] #74 at /home/jboik/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:120 [inlined]
 [35] (::typeof(∂(λ)))(::Float32) at /home/jboik/.julia/packages/Zygote/chgvX/src/compiler/interface2.jl:0
 [36] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float32) at /home/jboik/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:177
 [37] gradient(::Function, ::Zygote.Params) at /home/jboik/.julia/packages/Zygote/chgvX/src/compiler/interface.jl:54
 [38] macro expansion at /home/jboik/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:119 [inlined]
 [39] macro expansion at /home/jboik/.julia/packages/ProgressLogging/BBN0b/src/ProgressLogging.jl:328 [inlined]
 [40] (::DiffEqFlux.var"#73#78"{Main.TestDiffeq3b.var"#3#5",Int64,Bool,Bool,typeof(Main.TestDiffeq3b.loss_neuralode),Array{Float32,1},Zygote.Params})() at /home/jboik/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:64
 [41] with_logstate(::Function, ::Any) at ./logging.jl:408
 [42] with_logger at ./logging.jl:514 [inlined]
 [43] maybe_with_logger(::DiffEqFlux.var"#73#78"{Main.TestDiffeq3b.var"#3#5",Int64,Bool,Bool,typeof(Main.TestDiffeq3b.loss_neuralode),Array{Float32,1},Zygote.Params}, ::LoggingExtras.TeeLogger{Tuple{LoggingExtras.EarlyFilteredLogger{TerminalLoggers.TerminalLogger,DiffEqFlux.var"#68#70"},LoggingExtras.EarlyFilteredLogger{Logging.ConsoleLogger,DiffEqFlux.var"#69#71"}}}) at /home/jboik/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:39
 [44] sciml_train(::Function, ::Array{Float32,1}, ::Flux.Optimise.ADAM, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}; cb::Function, maxiters::Int64, progress::Bool, save_best::Bool) at /home/jboik/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:63
 [45] top-level scope at /home/jboik/Devel/Ideai_Ju/Ideai/examples/Load_01/load_fake/TestDiffEq3b.jl:93
 [46] include(::String) at ./client.jl:457
 [47] top-level scope at REPL[2]:1
in expression starting at /home/jboik/Devel/Ideai_Ju/Ideai/examples/Load_01/load_fake/TestDiffEq3b.jl:93

@ChrisRackauckas
Copy link
Member

Doesn't a GRU have state though, so the model wouldn't be well-defined?

@ChrisRackauckas
Copy link
Member

I think @avik-pal and @DhairyaLGandhi have mentioned something about destructure giving different arguments out when a layer has state, which is a bit weird. Could one of you give some input on that? I think that would be someone to fix up on the Flux side, even if it would be a breaking change, making the outputs out of that function uniform and documented would fix issues like this.

@John-Boik
Copy link
Author

If it's of help, learning-long-term-irregular-ts shows (starting on line 566) code for the GRUODE written in python.

@ChrisRackauckas
Copy link
Member

Note that method is only going to be compatible with adaptive=false because otherwise the state makes the ODE undefined. I think all you need is to turn off adaptivity and whatever that different destructure is.

@John-Boik
Copy link
Author

John-Boik commented Oct 13, 2020

@avik-pal and @DhairyaLGandhi, note that the solve function runs properly and produces anticipated output. The DimensionMismatch error occurs later when gradients are taken. Also, the same error occurs when using adaptive=false: OrdinaryDiffEq.solve(tmp_prob, OrdinaryDiffEq.Tsit5(), saveat=tsteps, adaptive=false, dt=.5), instead of the solve() in the code above.

@avik-pal
Copy link
Member

The destructure issue Chris mentioned above should not lead to dimension mismatch error. It just makes the GRU work without any recurrence, as the state is overwritten every time we do re(p). (@DhairyaLGandhi do you know how to fix this?)

The exact source of the error you encounter seems to be the sensitivity algorithm. A quick fix would be:

res = Array(OrdinaryDiffEq.solve(tmp_prob, OrdinaryDiffEq.Tsit5(), saveat=tsteps, sensealg=InterpolatingAdjoint(autojacvec=false)))

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Oct 13, 2020

We don't close over a number of arguments in destructure, which may be necessary for our restructure case as well. Adding those back to our cache, which can be passed around to the restrcture could do it.

function destructure(m; cache = IdDict())
  xs = Zygote.Buffer([])
  fmap(m) do x
    if x isa AbstractArray
      push!(xs, x)
    else
      cache[x] = x
    end
    return x
  end
  return vcat(vec.(copy(xs))...), p -> _restructure(m, p, cache = cache)
end

function _restructure(m, XS; cache = IdDict())
  i = 0
  fmap(m) do x
    x isa AbstractArray || return cache[x]
    x = reshape(xs[i.+(1:length(x))], size(x))
    i += length(x)
    return x
  end
end

This is untested currently, @avik-pal would something like this solve the specific issue you're talking about?

@John-Boik
Copy link
Author

Thanks. I can verify that the following works with Flux.GRU. I used autojacvec=true rather than false, and it seems to run a bit faster that way.

module TestDiffeq3bb

using Revise
using Infiltrator
using Formatting

import DiffEqFlux
import OrdinaryDiffEq 
import DiffEqSensitivity
import Flux
import Optim
import Plots
import Zygote
import Functors

u0 = Float32[2.0; 0.0]    
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
iter = 0

function trueODEfunc(du, u, p, t)    
    true_A = Float32[-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)' 
end


prob_trueode = OrdinaryDiffEq.ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(OrdinaryDiffEq.solve(prob_trueode, OrdinaryDiffEq.Tsit5(), saveat = tsteps)) 

dudt2 = Flux.Chain(
    x -> x.^3,
    Flux.Dense(2, 20, tanh),
    Flux.GRU(20, 20),
    Flux.Dense(20, 2),
    )


function destructure(m; cache = IdDict())
  xs = Zygote.Buffer([])
  Functors.fmap(m) do x
    if x isa AbstractArray
      push!(xs, x)
    else
      cache[x] = x
    end
    return x
  end
  return vcat(vec.(copy(xs))...), p -> _restructure(m, p, cache = cache)
end

function _restructure(m, xs; cache = IdDict())
  i = 0
  Functors.fmap(m) do x
    x isa AbstractArray || return cache[x]
    x = reshape(xs[i .+ (1:length(x))], size(x))
    i += length(x)
    return x
  end
end

p, re = destructure(dudt2)

function neural_ode_f(u, p, t)
    return re(p)(u)
end

prob = OrdinaryDiffEq.ODEProblem(neural_ode_f, u0, tspan, p)


function predict_neuralode(p)
    tmp_prob = OrdinaryDiffEq.remake(prob,p=p)
    res = Array(OrdinaryDiffEq.solve(tmp_prob, OrdinaryDiffEq.Tsit5(), saveat=tsteps, sensealg=DiffEqSensitivity.InterpolatingAdjoint(autojacvec=true)))
    return res
end


function loss_neuralode(p)
        pred = predict_neuralode(p)
        loss = sum(abs2, ode_data .- pred) 
        return loss, pred
end


callback = function (p, l, pred; doplot = true)
    global iter
    iter += 1
    @show iter, l
    # plot current prediction against data
    plt = Plots.scatter(tsteps, ode_data[1,:], label = "data", title=string(iter))
    Plots.scatter!(plt, tsteps, pred[1,:], label = "prediction")
    if doplot
        display(Plots.plot(plt))
    end
    return false
end

result_neuralode = DiffEqFlux.sciml_train(
    loss_neuralode, 
    p,
    Flux.ADAM(0.05), 
    cb = callback,
    maxiters = 60
    )

end    # ------------------------------- module -----------------------------------

@ChrisRackauckas
Copy link
Member

The fixed restructure/destructure works:

import DiffEqFlux
import OrdinaryDiffEq
import Flux
import Optim
import Plots
import Zygote

u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[1], tspan[2], length = datasize)

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end

prob_trueode = OrdinaryDiffEq.ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(OrdinaryDiffEq.solve(prob_trueode, OrdinaryDiffEq.Tsit5(), saveat = tsteps))

dudt2 = Flux.Chain(
    x -> x.^3,
    Flux.Dense(2, 50, tanh),
    #Flux.Dense(50, 2)
    Flux.GRU(50, 2)
    )

sf

function destructure(m; cache = IdDict())
  xs = Zygote.Buffer([])
  Flux.fmap(m) do x
    if x isa AbstractArray
      push!(xs, x)
    else
      cache[x] = x
    end
    return x
  end
  return vcat(vec.(copy(xs))...), p -> _restructure(m, p, cache = cache)
end

function _restructure(m, xs; cache = IdDict())
  i = 0
  Flux.fmap(m) do x
    x isa AbstractArray || return cache[x]
    x = reshape(xs[i.+(1:length(x))], size(x))
    i += length(x)
    return x
  end
end

p, re = destructure(dudt2)
neural_ode_f(u, p, t) = re(p)(u)

prob = OrdinaryDiffEq.ODEProblem(neural_ode_f, u0, tspan, p)


function predict_neuralode(p)
    tmp_prob = OrdinaryDiffEq.remake(prob,p=p)
    res = Array(OrdinaryDiffEq.solve(tmp_prob, OrdinaryDiffEq.Tsit5(), saveat=tsteps, dt=0.01, adaptive=false))
    return res
end


function loss_neuralode(p)
        pred = predict_neuralode(p)  # (2,30)
        loss = sum(abs2, ode_data .- pred)  # scalar
        return loss, pred
end


callback = function (p, l, pred; doplot = true)
    display(l)
    # plot current prediction against data
    plt = Plots.scatter(tsteps, ode_data[1,:], label = "data")
    Plots.scatter!(plt, tsteps, pred[1,:], label = "prediction")
    if doplot
        display(Plots.plot(plt))
    end
    return false
end

result_neuralode = DiffEqFlux.sciml_train(
    loss_neuralode,
    p,
    Flux.ADAM(0.05),
    cb = callback,
    maxiters = 3000
    )


result_neuralode2 = DiffEqFlux.sciml_train(
    loss_neuralode,
    result_neuralode.minimizer,
    Flux.ADAM(0.05),
    cb = callback,
    maxiters = 1000,
    )

The method isn't very good, but it does what you asked for.

@John-Boik
Copy link
Author

Excellent @ChrisRackauckas. I see that sensealg=DiffEqSensitivity.InterpolatingAdjoint() is not needed, and dt=0.01, adaptive=false can be used.

The method is a means to an end (eventually, GRU-ODE), but it does work reasonably well as is if Flux.GRU(50, 2) is replaced by Flux.GRU(50, 50), Flux.Dense(50, 2). At 200 iterations of just SGD, the error is about 0.18. If I switch in my custom GRU (below, as per the GRU-ODE paper), and reduce the hidden layer size to 20 from 50, the error is about 0.02 at 200 iterations. Both are less than the error of about 0.48 achieved with a chain of Flux.Dense(2, 50, tanh), Flux.Dense(50, 2).

The custom GRU is as follows. In this problem, I use Flux2.GRU2(20, true, x -> x ) when defining the chain.

module Flux2

import Flux
using Infiltrator

mutable struct GRUCell2{W,U,B,L,TF,F}
    update_W::W
    update_U::U
    update_b::B
    
    reset_W::W
    reset_U::U
    reset_b::B
    
    out_W::W
    out_U::U
    out_b::B
        
    H::L
    is_dhdt::TF
    fx::F
end


GRUCell2(L, is_output_dhdt, fx; init = Flux.glorot_uniform) =
    GRUCell2(
        init(L, L), 
        init(L, L),
        init(L,1), 
        
        init(L, L), 
        init(L, L),
        init(L,1), 
        
        init(L, L), 
        init(L, L),
        init(L,1), 
                
        zeros(Float32, (L,1)),
        is_output_dhdt,
        fx
        )


function (m::GRUCell2)(H, X)
    update_gate = Flux.sigmoid.( 
            (m.update_W * X) 
            .+ (m.update_U * m.H) 
            .+ m.update_b)
    
    reset_gate = Flux.sigmoid.( 
            (m.reset_W * X) 
            .+ (m.reset_U * m.H) 
            .+ m.reset_b)
    
    output_gate = m.fx.( 
            (m.out_W * X) 
            .+ (m.out_U * (reset_gate .* m.H))
            .+ m.out_b)
    
    if m.is_dhdt == true
        # output is dhdt
        output =  (Float32(1) .- update_gate) .* (output_gate .- H)
    else
        # standard GRU output
        output = ((Float32(1) .- update_gate) .* output_gate) .+ (update_gate .* H)
    end
            
    H = output
    return H, H
end

Flux.hidden(m::GRUCell2) = m.H
Flux.@functor GRUCell2

Base.show(io::IO, l::GRUCell2) =
    print(io, "GRUCell2(", size(l.update_W, 2), ", ", size(l.update_W, 1), ")")

GRU2(a...; ka...) = Flux.Recur(GRUCell2(a...; ka...))

end  # -----------------------------  module

@ChrisRackauckas
Copy link
Member

Cool yeah. The other thing to try is sensealg=ReverseDiffAdjoint(). Using direct reverse-mode AD might be better if it's fixed time step since that would not have the same possibility of having adjoint error like the continuous adjoints, which would be more of an issue if it's not adaptive on the reverse.

@ChrisRackauckas
Copy link
Member

It would be good to turn this into a tutorial when all is said and done. @DhairyaLGandhi could you add that restructure/destructure patch to Flux and then tag a release? @John-Boik would you be willing to contribute a tutorial?

@ChrisRackauckas
Copy link
Member

Or @mkg33 might be able to help out here.

@John-Boik
Copy link
Author

Sure, I would be happy to help if I can.

@mkg33
Copy link
Contributor

mkg33 commented Oct 14, 2020

Of course, I'll add it to my tasks.

@sungjuGit
Copy link

Has this "fix" been released to FluxML yet?

@sungjuGit
Copy link

ODE-LSTM implementations
@John-Boik, Are you or others still working on ODE-LSTM implementation in FluxML?

@John-Boik
Copy link
Author

I'm working on similar models, which also use re/de structure. The fix to re/de structure has been released, and both functions are working fine as far as I know.

@ChrisRackauckas
Copy link
Member

I think @DhairyaLGandhi didn't merge the fix yet FluxML/Flux.jl#1353

It's still a bad model though.

@ChrisRackauckas
Copy link
Member

This was fixed by FluxML/Flux.jl#1901, and one can now use Lux which makes the state explicit. Cheers!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

7 participants