-
-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PresetTimeCallback errors with Zygote differentiation #502
Comments
The issue is the mutation in PresetTimeCallback here: https://github.com/SciML/DiffEqCallbacks.jl/blob/master/src/preset_time.jl#L21-L23 |
Will this be fixed? Should I open a separate issue on DiffEqCallbacks.jl? I am not sure how I would tackle this myself. |
It'll take time. |
Adding a initialize_preset = function (c, u, t, integrator)
initialize(c, u, t, integrator)
Zygote.ignore() do
if filter_tstops
tdir = integrator.tdir
_tstops = tstops[@.((tdir*tstops > tdir*integrator.sol.prob.tspan[1]) * (tdir*tstops < tdir*integrator.sol.prob.tspan[2]))]
add_tstop!.((integrator,), _tstops)
else
add_tstop!.((integrator,), tstops)
end
end
if t ∈ tstops
user_affect!(integrator)
end
end unfortunately only shifts the problem. It looks like its not specific to using DiffEqSensitivity, OrdinaryDiffEq, DiffEqCallbacks, DiffEqFlux, Flux
using Random, Test
u0 = Float32[2.; 0.]
datasize = 100
tspan = (0.0f0,10.5f0)
dosetimes = [1.0,2.0,4.0,8.0]
function affect!(integrator)
integrator.u = integrator.u.+1
end
cb_ = PresetTimeCallback(dosetimes,affect!,save_positions=(false,false))
function trueODEfunc(du,u,p,t)
du .= -u
end
t = range(tspan[1],tspan[2],length=datasize)
prob = ODEProblem(trueODEfunc,u0,tspan)
ode_data = Array(solve(prob,Tsit5(),callback=cb_,saveat=t))
dudt2 = Chain(Dense(2,50,tanh),
Dense(50,2))
p,re = Flux.destructure(dudt2) # use this p as the initial condition!
function dudt(du,u,p,t)
du[1:2] .= -u[1:2]
du[3:end] .= re(p)(u[1:2]) #re(p)(u[3:end])
end
z0 = Float32[u0;u0]
prob = ODEProblem(dudt,z0,tspan)
affect!(integrator) = integrator.u[1:2] .= integrator.u[3:end]
cb = PresetTimeCallback(dosetimes,affect!,save_positions=(false,false))
function loss_n_ode2(p)
_prob = remake(prob,p=p)
#cb = PresetTimeCallback(dosetimes,affect!,save_positions=(false,false))
pred = Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=cb,saveat=t,sensealg=ReverseDiffAdjoint()))[1:2,:]
sum(abs2,ode_data .- pred)
end
function loss_n_ode3(p)
_prob = remake(prob,p=p)
cb = PresetTimeCallback(dosetimes,affect!,save_positions=(false,false))
pred = Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=cb,saveat=t,sensealg=ReverseDiffAdjoint()))[1:2,:]
sum(abs2,ode_data .- pred)
end
cbdisc = DiscreteCallback((u, t, integrator)->(t ∈ dosetimes), affect!,save_positions=(false,false))
function loss_n_ode4(p)
_prob = remake(prob,p=p)
pred = Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=cbdisc,saveat=t,tstops=dosetimes,sensealg=ReverseDiffAdjoint()))[1:2,:]
sum(abs2,ode_data .- pred)
end
function loss_n_ode5(p)
_prob = remake(prob,p=p)
cbdisc = DiscreteCallback((u, t, integrator)->(t ∈ dosetimes), affect!,save_positions=(false,false))
pred = Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=cbdisc,saveat=t,tstops=dosetimes,sensealg=ReverseDiffAdjoint()))[1:2,:]
sum(abs2,ode_data .- pred)
end
# all cbs lead to the same loss value
@test loss_n_ode2(p) == loss_n_ode3(p)
@test loss_n_ode2(p) == loss_n_ode4(p)
@test loss_n_ode2(p) == loss_n_ode5(p)
Zygote.gradient(loss_n_ode2,p) # works
Zygote.gradient(loss_n_ode3,p) # mutating error
Zygote.gradient(loss_n_ode4,p) # works
Zygote.gradient(loss_n_ode5,p) # mutating error stacktrace:
|
How about defining a # remake callback, to update callbacks inside Zygote.gradient(..)
function myremake(cb::DiscreteCallback; condition=cb.condition, (affect!)=cb.affect!, initialize=cb.initialize, finalize=cb.finalize, save_positions=cb.save_positions)
DiscreteCallback(condition,affect!,initialize,finalize,save_positions)
end
Zygote.@nograd myremake #avoid taking grads of this function
myremake(cb)
myremake(cbdisc)
function loss_n_ode6(p)
_prob = remake(prob,p=p)
_cb = myremake(cb)
pred = Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=_cb,saveat=t,tstops=dosetimes,sensealg=ReverseDiffAdjoint()))[1:2,:]
sum(abs2,ode_data .- pred)
end
function loss_n_ode7(p)
_prob = remake(prob,p=p)
_cb = myremake(cbdisc)
pred = Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=_cb,saveat=t,tstops=dosetimes,sensealg=ReverseDiffAdjoint()))[1:2,:]
sum(abs2,ode_data .- pred)
end
@test loss_n_ode2(p) == loss_n_ode6(p)
@test loss_n_ode2(p) == loss_n_ode7(p)
@test Zygote.gradient(loss_n_ode6,p) == Zygote.gradient(loss_n_ode2,p) # works
@test Zygote.gradient(loss_n_ode7,p) == Zygote.gradient(loss_n_ode2,p) # works |
Ok.. unfortunately that doesn't work if one actually changes something: dosetimes2 = [1.0,3.0,6.0,8.0]
cb2 = PresetTimeCallback(dosetimes2,affect!,save_positions=(false,false))
function loss_n_ode8(p)
_prob = remake(prob,p=p)
pred = Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=cb2,saveat=t,tstops=dosetimes2,sensealg=ReverseDiffAdjoint()))[1:2,:]
sum(abs2,ode_data .- pred)
end
function loss_n_ode9(p)
_prob = remake(prob,p=p)
_condition = (u, t, integrator)->(t ∈ dosetimes2)
_cb = myremake(cbdisc, condition=_condition)
pred = Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=_cb,saveat=t,tstops=dosetimes2,sensealg=ReverseDiffAdjoint()))[1:2,:]
sum(abs2,ode_data .- pred)
end
@test loss_n_ode2(p) != loss_n_ode8(p)
@test loss_n_ode8(p) == loss_n_ode9(p)
@test Zygote.gradient(loss_n_ode8,p) == Zygote.gradient(loss_n_ode9,p) # Zygote.gradient(loss_n_ode9,p) fails with mutating error
@test Zygote.gradient(loss_n_ode8,p) != Zygote.gradient(loss_n_ode2,p) |
Relevant ChainRules/ChainRulesCore issues and PRs: |
In the following MWE, only the gradient for using Zygote, Test
struct MyStruct{T1}
sp::T1
end
c = MyStruct(true)
function myf(p)
if c.sp
(p[1]^2 + p[2]^2)/2
else
p[1] + p[2]
end
end
p = randn(2)
@test myf(p) == sum(abs2,p)/2
@test Zygote.gradient(myf,p) == (p,)
function myf2(p)
_c = MyStruct(true)
if _c.sp
(p[1]^2 + p[2]^2)/2
else
p[1] + p[2]
end
end
@test myf2(p) == sum(abs2,p)/2
@test Zygote.gradient(myf2,p) == (p,)
function myf3(p)
_c = MyStruct([true])
if _c.sp[1]
(p[1]^2 + p[2]^2)/2
else
p[1] + p[2]
end
end
@test myf3(p) == sum(abs2,p)/2
@test Zygote.gradient(myf3,p) == (p,)
c2 = MyStruct(BitArray([1]))
function myf4a(p)
if c2.sp[1]
(p[1]^2 + p[2]^2)/2
else
p[1] + p[2]
end
end
@test myf4a(p) == sum(abs2,p)/2
@test Zygote.gradient(myf4a,p) == (p,)
function myf4b(p)
_c = MyStruct(BitArray([1]))
if _c.sp[1]
(p[1]^2 + p[2]^2)/2
else
p[1] + p[2]
end
end
@test myf4b(p) == sum(abs2,p)/2
@test Zygote.gradient(myf4b,p) == (p,) # fails Maybe @oxinabox has an idea? |
I think we just need to add to ChainRules.jl: If I do that, all your code in the last comment succeeds. PR's welcome |
Fixed. |
Trying to solve a problem similar to
https://diffeqflux.sciml.ai/dev/examples/hybrid_diffeq/
seems to create some problems. The example above runs as it is, but creating the callback inside
predict_n_ode
results in error, e.g.:I need my callback to depend on the training data, so it would be good to create the
affect!
function on the fly. It's probably possible to get around using global variables, but I would prefer not.The text was updated successfully, but these errors were encountered: