Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproducible TypeError in sciml_train #278

Closed
metanoid opened this issue Jun 2, 2020 · 6 comments
Closed

Reproducible TypeError in sciml_train #278

metanoid opened this issue Jun 2, 2020 · 6 comments

Comments

@metanoid
Copy link
Contributor

metanoid commented Jun 2, 2020

For a small jump diffusion problem adapted from here, I get a very odd TypeError, where the "expected" type is identical to the "got" type.

To reproduce:

using DiffEqFlux
using Flux
using DiffEqJump
using DiffEqSensitivity
using StochasticDiffEq
using DiffEqCallbacks

function f(du,u,p,t)
  du[1] = u[1]*p[1]
end

rate(u,p,t) = u[1]
affect!(integrator) = (integrator.u[1] = integrator.u[1]/2)
jump = VariableRateJump(rate,affect!)
jump2 = deepcopy(jump)


function g(du,u,p,t)
  du[1] = u[1]
end

sde_prob = SDEProblem(f,g,[0.2],(0.0,10.0), [0.5])

jump_prob = JumpProblem(sde_prob,Direct(),jump,jump2)

saved_values = SavedValues(Float64, Array{Float64})
cb = SavingCallback((u,t,integrator)->(u), saved_values, saveat=0.0:0.1:10.0)

sol = solve(jump_prob,SRIW1(), callback=cb )

θ = [0.2, 0.5]
function predict(θ)
  u0 = θ[1]
  p = θ[2:end]
  prob =   SDEProblem(f,g,u0,(0.0,10.0),p)
  jump_prob = JumpProblem(sde_prob,Direct(),jump,jump2)
  saved_values = SavedValues(Float64, Array{Float64})
  cb = SavingCallback((u,t,integrator)->(u), saved_values, saveat=0.0:0.1:10.0)
  sol = solve(jump_prob,SRIW1(), callback=cb )
  vals = [i[1] for i in saved_values.saveval]
  return sol, vals
end

test, test_vals = predict(θ);

function loss_test(θ)
  sol, vals = predict(θ)
  loss = 0.0
  for t in 1:length(vals)
      loss += abs2(vals[t] - 1.0)
  end
  return loss, vals
end

loss_test(θ)

result = DiffEqFlux.sciml_train(loss_test, θ,
                                    ADAM(0.1),
                                    maxiters = 5)

Error message:

ERROR: TypeError: in new, expected Zygote.var"#174#175"{typeof(∂(__init##kw)),Tuple{Tuple{Nothing,Nothing,Nothing},NTuple{5,Nothing}}}, got Zygote.var"#174#175"{typeof(∂(__init##kw)),Tuple{Tuple{Nothing,Nothing,Nothing},NTuple{5,Nothing}}}
Stacktrace:
 [1] _pullback at C:\Users\username\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:49 [inlined]
 [2] #init_call#438 at C:\Users\username\.julia\packages\DiffEqBase\KnYSY\src\solve.jl:16 [inlined]
 [3] _pullback(::Zygote.Context, ::DiffEqBase.var"##init_call#438", ::Base.Iterators.Pairs{Symbol,DiscreteCallback{DiffEqCallbacks.var"#30#31",DiffEqCallbacks.SavingAffect{var"#11#13",Float64,Array{Float64,N} where N,DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Array{Float64,1}},typeof(DiffEqCallbacks.saving_initialize)},Tuple{Symbol},NamedTuple{(:callback,),Tuple{DiscreteCallback{DiffEqCallbacks.var"#30#31",DiffEqCallbacks.SavingAffect{var"#11#13",Float64,Array{Float64,N} where N,DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Array{Float64,1}},typeof(DiffEqCallbacks.saving_initialize)}}}}, ::typeof(DiffEqBase.init_call), ::JumpProblem{true,SDEProblem{ExtendedJumpArray{Array{Float64,1},Array{Float64,1}},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,DiffEqJump.var"#jump_f#63"{SDEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,typeof(f),typeof(g),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(g),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},JumpSet{Tuple{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64},VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},Tuple{},Nothing,Nothing}},DiffEqJump.var"#60#64"{SDEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,typeof(f),typeof(g),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(g),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},DiffEqJump.var"#60#64"{SDEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,typeof(f),typeof(g),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(g),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},Direct,CallbackSet{Tuple{ContinuousCallback{DiffEqJump.var"#75#77",DiffEqJump.var"#76#78"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},DiffEqJump.var"#76#78"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64},ContinuousCallback{DiffEqJump.var"#79#81",DiffEqJump.var"#80#82"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},DiffEqJump.var"#80#82"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64}},Tuple{}},Nothing,Tuple{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64},VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},Nothing,Nothing}, ::SRIW1, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
 [4] adjoint at C:\Users\username\.julia\packages\Zygote\YeCEW\src\lib\lib.jl:179 [inlined]
 [5] _pullback at C:\Users\username\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
 [6] _pullback(::Zygote.Context, ::DiffEqBase.var"#init_call##kw", ::NamedTuple{(:callback,),Tuple{DiscreteCallback{DiffEqCallbacks.var"#30#31",DiffEqCallbacks.SavingAffect{var"#11#13",Float64,Array{Float64,N} where N,DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Array{Float64,1}},typeof(DiffEqCallbacks.saving_initialize)}}}, ::typeof(DiffEqBase.init_call), ::JumpProblem{true,SDEProblem{ExtendedJumpArray{Array{Float64,1},Array{Float64,1}},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,DiffEqJump.var"#jump_f#63"{SDEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,typeof(f),typeof(g),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(g),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},JumpSet{Tuple{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64},VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},Tuple{},Nothing,Nothing}},DiffEqJump.var"#60#64"{SDEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,typeof(f),typeof(g),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(g),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},DiffEqJump.var"#60#64"{SDEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,typeof(f),typeof(g),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(g),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},Direct,CallbackSet{Tuple{ContinuousCallback{DiffEqJump.var"#75#77",DiffEqJump.var"#76#78"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},DiffEqJump.var"#76#78"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64},ContinuousCallback{DiffEqJump.var"#79#81",DiffEqJump.var"#80#82"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},DiffEqJump.var"#80#82"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64}},Tuple{}},Nothing,Tuple{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64},VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},Nothing,Nothing}, ::SRIW1, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
 [7] adjoint at C:\Users\username\.julia\packages\Zygote\YeCEW\src\lib\lib.jl:179 [inlined]
 [8] _pullback at C:\Users\username\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
 [9] #init#439 at C:\Users\username\.julia\packages\DiffEqBase\KnYSY\src\solve.jl:33 [inlined]
 [10] _pullback(::Zygote.Context, ::DiffEqBase.var"##init#439", ::Base.Iterators.Pairs{Symbol,DiscreteCallback{DiffEqCallbacks.var"#30#31",DiffEqCallbacks.SavingAffect{var"#11#13",Float64,Array{Float64,N} where 
N,DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Array{Float64,1}},typeof(DiffEqCallbacks.saving_initialize)},Tuple{Symbol},NamedTuple{(:callback,),Tuple{DiscreteCallback{DiffEqCallbacks.var"#30#31",DiffEqCallbacks.SavingAffect{var"#11#13",Float64,Array{Float64,N} where N,DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Array{Float64,1}},typeof(DiffEqCallbacks.saving_initialize)}}}}, ::typeof(init), ::JumpProblem{true,SDEProblem{ExtendedJumpArray{Array{Float64,1},Array{Float64,1}},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,DiffEqJump.var"#jump_f#63"{SDEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,typeof(f),typeof(g),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(g),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},JumpSet{Tuple{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64},VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},Tuple{},Nothing,Nothing}},DiffEqJump.var"#60#64"{SDEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,typeof(f),typeof(g),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(g),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},DiffEqJump.var"#60#64"{SDEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,typeof(f),typeof(g),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(g),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},Direct,CallbackSet{Tuple{ContinuousCallback{DiffEqJump.var"#75#77",DiffEqJump.var"#76#78"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},DiffEqJump.var"#76#78"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64},ContinuousCallback{DiffEqJump.var"#79#81",DiffEqJump.var"#80#82"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},DiffEqJump.var"#80#82"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64}},Tuple{}},Nothing,Tuple{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64},VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},Nothing,Nothing}, ::SRIW1, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
 [11] adjoint at C:\Users\username\.julia\packages\Zygote\YeCEW\src\lib\lib.jl:179 [inlined]
 [12] _pullback at C:\Users\username\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
 [13] _pullback(::Zygote.Context, ::DiffEqBase.var"#init##kw", ::NamedTuple{(:callback,),Tuple{DiscreteCallback{DiffEqCallbacks.var"#30#31",DiffEqCallbacks.SavingAffect{var"#11#13",Float64,Array{Float64,N} where N,DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Array{Float64,1}},typeof(DiffEqCallbacks.saving_initialize)}}}, ::typeof(init), ::JumpProblem{true,SDEProblem{ExtendedJumpArray{Array{Float64,1},Array{Float64,1}},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,DiffEqJump.var"#jump_f#63"{SDEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,typeof(f),typeof(g),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(g),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},JumpSet{Tuple{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64},VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},Tuple{},Nothing,Nothing}},DiffEqJump.var"#60#64"{SDEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,typeof(f),typeof(g),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(g),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},DiffEqJump.var"#60#64"{SDEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,typeof(f),typeof(g),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(g),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},Direct,CallbackSet{Tuple{ContinuousCallback{DiffEqJump.var"#75#77",DiffEqJump.var"#76#78"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},DiffEqJump.var"#76#78"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64},ContinuousCallback{DiffEqJump.var"#79#81",DiffEqJump.var"#80#82"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},DiffEqJump.var"#80#82"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64}},Tuple{}},Nothing,Tuple{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64},VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},Nothing,Nothing}, ::SRIW1, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
 [14] #__solve#83 at C:\Users\username\.julia\packages\DiffEqJump\hNr94\src\solve.jl:6 [inlined]
 [15] _pullback(::Zygote.Context, ::DiffEqJump.var"##__solve#83", ::Base.Iterators.Pairs{Symbol,DiscreteCallback{DiffEqCallbacks.var"#30#31",DiffEqCallbacks.SavingAffect{var"#11#13",Float64,Array{Float64,N} where N,DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Array{Float64,1}},typeof(DiffEqCallbacks.saving_initialize)},Tuple{Symbol},NamedTuple{(:callback,),Tuple{DiscreteCallback{DiffEqCallbacks.var"#30#31",DiffEqCallbacks.SavingAffect{var"#11#13",Float64,Array{Float64,N} where N,DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Array{Float64,1}},typeof(DiffEqCallbacks.saving_initialize)}}}}, ::typeof(DiffEqBase.__solve), ::JumpProblem{true,SDEProblem{ExtendedJumpArray{Array{Float64,1},Array{Float64,1}},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,DiffEqJump.var"#jump_f#63"{SDEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,typeof(f),typeof(g),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(g),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},JumpSet{Tuple{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64},VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},Tuple{},Nothing,Nothing}},DiffEqJump.var"#60#64"{SDEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,typeof(f),typeof(g),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(g),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},DiffEqJump.var"#60#64"{SDEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,typeof(f),typeof(g),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(g),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},Direct,CallbackSet{Tuple{ContinuousCallback{DiffEqJump.var"#75#77",DiffEqJump.var"#76#78"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},DiffEqJump.var"#76#78"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64},ContinuousCallback{DiffEqJump.var"#79#81",DiffEqJump.var"#80#82"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},DiffEqJump.var"#80#82"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64}},Tuple{}},Nothing,Tuple{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64},VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},Nothing,Nothing}, ::SRIW1, ::Array{Any,1}, ::Array{Any,1}, ::Array{Any,1}, ::Type{Val{true}}) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
 [16] adjoint at C:\Users\username\.julia\packages\Zygote\YeCEW\src\lib\lib.jl:179 [inlined]
 [17] _pullback at C:\Users\username\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
 [18] #solve#447 at C:\Users\username\.julia\packages\DiffEqBase\KnYSY\src\solve.jl:108 [inlined]
 [19] adjoint at C:\Users\username\.julia\packages\Zygote\YeCEW\src\lib\lib.jl:179 [inlined]
 [20] _pullback at C:\Users\username\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
 [21] _pullback(::Zygote.Context, ::DiffEqBase.var"#solve##kw", ::NamedTuple{(:callback,),Tuple{DiscreteCallback{DiffEqCallbacks.var"#30#31",DiffEqCallbacks.SavingAffect{var"#11#13",Float64,Array{Float64,N} where N,DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Array{Float64,1}},typeof(DiffEqCallbacks.saving_initialize)}}}, ::typeof(solve), ::JumpProblem{true,SDEProblem{ExtendedJumpArray{Array{Float64,1},Array{Float64,1}},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,DiffEqJump.var"#jump_f#63"{SDEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,typeof(f),typeof(g),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(g),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},JumpSet{Tuple{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64},VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},Tuple{},Nothing,Nothing}},DiffEqJump.var"#60#64"{SDEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,typeof(f),typeof(g),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(g),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},DiffEqJump.var"#60#64"{SDEProblem{Array{Float64,1},Tuple{Float64,Float64},true,Array{Float64,1},Nothing,SDEFunction{true,typeof(f),typeof(g),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(g),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},Direct,CallbackSet{Tuple{ContinuousCallback{DiffEqJump.var"#75#77",DiffEqJump.var"#76#78"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},DiffEqJump.var"#76#78"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64},ContinuousCallback{DiffEqJump.var"#79#81",DiffEqJump.var"#80#82"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},DiffEqJump.var"#80#82"{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},typeof(DiffEqBase.INITIALIZE_DEFAULT),Float64,Int64,Nothing,Int64}},Tuple{}},Nothing,Tuple{VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64},VariableRateJump{typeof(rate),typeof(affect!),Nothing,Float64,Int64}},Nothing,Nothing}, ::SRIW1) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
 [22] predict at C:\Users\username\Documents\Projects\CausalDiseasePrediction\Models\Experiments\zygote_error.jl:64 [inlined]
 [23] _pullback(::Zygote.Context, ::typeof(predict), ::Array{Float64,1}) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
 [24] loss_test at C:\Users\username\Documents\Projects\CausalDiseasePrediction\Models\Experiments\zygote_error.jl:72 [inlined]
 [25] _pullback(::Zygote.Context, ::typeof(loss_test), ::Array{Float64,1}) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
 [26] adjoint at C:\Users\username\.julia\packages\Zygote\YeCEW\src\lib\lib.jl:179 [inlined]
 [27] _pullback at C:\Users\username\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
 [28] #24 at C:\Users\username\.julia\packages\DiffEqFlux\aNQGp\src\train.jl:99 [inlined]
 [29] _pullback(::Zygote.Context, ::DiffEqFlux.var"#24#29"{Tuple{},typeof(loss_test),Array{Float64,1}}) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
 [30] pullback(::Function, ::Zygote.Params) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface.jl:174
 [31] gradient(::Function, ::Zygote.Params) at C:\Users\username\.julia\packages\Zygote\YeCEW\src\compiler\interface.jl:54
 [32] macro expansion at C:\Users\username\.julia\packages\DiffEqFlux\aNQGp\src\train.jl:98 [inlined]
 [33] macro expansion at C:\Users\username\.julia\packages\ProgressLogging\g8xnW\src\ProgressLogging.jl:328 [inlined]
 [34] (::DiffEqFlux.var"#23#28"{DiffEqFlux.var"#27#32",Int64,Bool,Bool,typeof(loss_test),Array{Float64,1},Zygote.Params})() at C:\Users\username\.julia\packages\DiffEqFlux\aNQGp\src\train.jl:43
 [35] maybe_with_logger(::DiffEqFlux.var"#23#28"{DiffEqFlux.var"#27#32",Int64,Bool,Bool,typeof(loss_test),Array{Float64,1},Zygote.Params}, ::Nothing) at C:\Users\username\.julia\packages\DiffEqBase\KnYSY\src\utils.jl:259
 [36] sciml_train(::Function, ::Array{Float64,1}, ::ADAM, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}; cb::Function, maxiters::Int64, progress::Bool, save_best::Bool) at C:\Users\username\.julia\packages\DiffEqFlux\aNQGp\src\train.jl:42
 [37] top-level scope at none:0

Versions:

  [459566f4] DiffEqCallbacks v2.13.2
  [aae7a2af] DiffEqFlux v1.12.0 #master (https://github.com/SciML/DiffEqFlux.jl.git)
  [c894b116] DiffEqJump v6.7.6 #master (https://github.com/SciML/DiffEqJump.jl.git)
  [41bf760c] DiffEqSensitivity v6.19.0 #master (https://github.com/SciML/DiffEqSensitivity.jl.git)
  [587475ba] Flux v0.10.5 #master (https://github.com/FluxML/Flux.jl.git)
@ChrisRackauckas
Copy link
Member

I think Zygote doesn't know how to handle ExtendedRateJump constructors.

@metanoid
Copy link
Contributor Author

metanoid commented Jun 7, 2020

Are there any workarounds currently that would allow a Jump diffusion to be trained via sciml_train?

@ChrisRackauckas
Copy link
Member

Quickest is to use a derivative-free optimizer, like using NLopt; Opt(:LN_BOBYQA, 4)

@metanoid
Copy link
Contributor Author

metanoid commented Jun 7, 2020

I've updated the code to test this solution, but there is something wrong - NLopt only ever evaluated the loss function once.

I just updated the last few lines from above:

line_counter = 0
function loss_test(θ)
  global line_counter = line_counter + 1
  println("Loss function evaluations: $(line_counter)")
  sol, vals = predict(θ)
  loss = 0.0
  for t in 1:length(vals)
      loss += abs2(vals[t] - 1.0)
  end
  return loss, vals
end

loss_test(θ)

using NLopt
opt = Opt(:LN_BOBYQA, length(θ))
result = DiffEqFlux.sciml_train(loss_test, θ,
                                    opt,
                                    maxeval = 500)

This demonstrates that the loss function is only called once, not 500 times as claimed.
Is this a bug or am I making a silly mistake?

@ChrisRackauckas
Copy link
Member

That could be a bug in the NLopt wrapper. I wouldn't worry about getting it fixed in here though, I'd just worry about reproducing and fixing in GalacticOptim.jl

@ChrisRackauckas
Copy link
Member

sciml_train has been removed and deprecated, and this is fixed in Optimization.jl

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants