Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PresetTimeCallback errors with Zygote differentiation #502

Closed
mfalt opened this issue Mar 8, 2021 · 10 comments
Closed

PresetTimeCallback errors with Zygote differentiation #502

mfalt opened this issue Mar 8, 2021 · 10 comments

Comments

@mfalt
Copy link

mfalt commented Mar 8, 2021

Trying to solve a problem similar to
https://diffeqflux.sciml.ai/dev/examples/hybrid_diffeq/
seems to create some problems. The example above runs as it is, but creating the callback inside predict_n_ode results in error, e.g.:

function predict_n_ode()
    _prob = remake(prob,p=p)
    cb = PresetTimeCallback(dosetimes,affect!,save_positions=(false,false))
    Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=cb,saveat=t,sensealg=ReverseDiffAdjoint()))[1:2,:]
end

I need my callback to depend on the training data, so it would be good to create the affect! function on the fly. It's probably possible to get around using global variables, but I would prefer not.

ERROR: LoadError: Mutating arrays is not supported
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#372#373")(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/lib/array.jl:58
  [3] (::Zygote.var"#2249#back#374"{Zygote.var"#372#373"})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ ./bitarray.jl:344 [inlined]
  [5] (::typeof((copy_to_bitarray_chunks!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
  [6] Pullback
    @ ./bitarray.jl:534 [inlined]
  [7] (::typeof((BitVector)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
  [8] Pullback
    @ ./bitarray.jl:503 [inlined]
  [9] (::typeof((BitArray)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/.julia/packages/DiffEqBase/krS36/src/callbacks.jl:282 [inlined]
 [11] (::typeof((DiscreteCallback)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/.julia/packages/DiffEqBase/krS36/src/callbacks.jl:288 [inlined]
 [13] (::typeof((#DiscreteCallback#21)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/.julia/packages/DiffEqBase/krS36/src/callbacks.jl:288 [inlined]
 [15] (::typeof((Type##kw)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/.julia/packages/DiffEqCallbacks/3xzvj/src/preset_time.jl:29 [inlined]
 [17] (::typeof((#PresetTimeCallback#60)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [18] Pullback
    @ ~/.julia/packages/DiffEqCallbacks/3xzvj/src/preset_time.jl:5 [inlined]
 [19] (::typeof((PresetTimeCallback##kw)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [20] Pullback
    @ /work/mattiasf/nn-quad/test.jl:34 [inlined]
 [21] (::typeof((predict_n_ode)))(Δ::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [22] Pullback
    @ /work/mattiasf/nn-quad/test.jl:40 [inlined]
 [23] (::typeof((loss_n_ode)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface2.jl:0
 [24] #151
    @ ~/.julia/packages/Zygote/KpME9/src/lib/lib.jl:191 [inlined]
 [25] #1682#back
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [26] Pullback
    @ ~/.julia/packages/Flux/goUGu/src/optimise/train.jl:103 [inlined]
 [27] (::Zygote.var"#54#55"{Zygote.Params, typeof((#15)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface.jl:172
 [28] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~/.julia/packages/Zygote/KpME9/src/compiler/interface.jl:49
 [29] macro expansion
    @ ~/.julia/packages/Flux/goUGu/src/optimise/train.jl:102 [inlined]
 [30] macro expansion
    @ ~/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
 [31] train!(loss::Function, ps::Zygote.Params, data::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}}, opt::ADAM; cb::var"#310#312")
    @ Flux.Optimise ~/.julia/packages/Flux/goUGu/src/optimise/train.jl:100
 [32] top-level scope
@ChrisRackauckas ChrisRackauckas changed the title Creating discrete callback inside train! PresetTimeCallback errors with Zygote differentiation Mar 13, 2021
@ChrisRackauckas
Copy link
Member

The issue is the mutation in PresetTimeCallback here:

https://github.com/SciML/DiffEqCallbacks.jl/blob/master/src/preset_time.jl#L21-L23

@mfalt
Copy link
Author

mfalt commented Apr 19, 2021

Will this be fixed? Should I open a separate issue on DiffEqCallbacks.jl? I am not sure how I would tackle this myself.

@ChrisRackauckas
Copy link
Member

It'll take time.

@frankschae
Copy link
Member

Adding a Zygote.ignore

initialize_preset = function (c, u, t, integrator)
            initialize(c, u, t, integrator)
            Zygote.ignore() do 
                if filter_tstops
                    tdir = integrator.tdir
                    _tstops = tstops[@.((tdir*tstops > tdir*integrator.sol.prob.tspan[1]) * (tdir*tstops < tdir*integrator.sol.prob.tspan[2]))]
                    add_tstop!.((integrator,), _tstops)
                else
                    add_tstop!.((integrator,), tstops)
                end
            end
            if t  tstops
                user_affect!(integrator)
            end
    end

unfortunately only shifts the problem. It looks like its not specific to PresetTimeCallback but also fails for DiscreteCallback, see loss_n_ode5(p) below. Stacktrace is pointing to the save_positions BitArray.

using DiffEqSensitivity, OrdinaryDiffEq, DiffEqCallbacks, DiffEqFlux, Flux
using Random, Test
u0 = Float32[2.; 0.]
datasize = 100
tspan = (0.0f0,10.5f0)
dosetimes = [1.0,2.0,4.0,8.0]

function affect!(integrator)
    integrator.u = integrator.u.+1
end
cb_ = PresetTimeCallback(dosetimes,affect!,save_positions=(false,false))
function trueODEfunc(du,u,p,t)
    du .= -u
end
t = range(tspan[1],tspan[2],length=datasize)

prob = ODEProblem(trueODEfunc,u0,tspan)
ode_data = Array(solve(prob,Tsit5(),callback=cb_,saveat=t))
dudt2 = Chain(Dense(2,50,tanh),
             Dense(50,2))
p,re = Flux.destructure(dudt2) # use this p as the initial condition!

function dudt(du,u,p,t)
    du[1:2] .= -u[1:2]
    du[3:end] .= re(p)(u[1:2]) #re(p)(u[3:end])
end
z0 = Float32[u0;u0]
prob = ODEProblem(dudt,z0,tspan)

affect!(integrator) = integrator.u[1:2] .= integrator.u[3:end]

cb = PresetTimeCallback(dosetimes,affect!,save_positions=(false,false))

function loss_n_ode2(p)
    _prob = remake(prob,p=p)
    #cb = PresetTimeCallback(dosetimes,affect!,save_positions=(false,false))
    pred = Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=cb,saveat=t,sensealg=ReverseDiffAdjoint()))[1:2,:]
    sum(abs2,ode_data .- pred)
end

function loss_n_ode3(p)
    _prob = remake(prob,p=p)
    cb = PresetTimeCallback(dosetimes,affect!,save_positions=(false,false))
    pred = Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=cb,saveat=t,sensealg=ReverseDiffAdjoint()))[1:2,:]
    sum(abs2,ode_data .- pred)
end

cbdisc = DiscreteCallback((u, t, integrator)->(t  dosetimes), affect!,save_positions=(false,false))

function loss_n_ode4(p)
    _prob = remake(prob,p=p)
    pred = Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=cbdisc,saveat=t,tstops=dosetimes,sensealg=ReverseDiffAdjoint()))[1:2,:]
    sum(abs2,ode_data .- pred)
end

function loss_n_ode5(p)
    _prob = remake(prob,p=p)
    cbdisc = DiscreteCallback((u, t, integrator)->(t  dosetimes), affect!,save_positions=(false,false))
    pred = Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=cbdisc,saveat=t,tstops=dosetimes,sensealg=ReverseDiffAdjoint()))[1:2,:]
    sum(abs2,ode_data .- pred)
end

# all cbs lead to the same loss value
@test loss_n_ode2(p) == loss_n_ode3(p)
@test loss_n_ode2(p) == loss_n_ode4(p)
@test loss_n_ode2(p) == loss_n_ode5(p)
Zygote.gradient(loss_n_ode2,p) # works
Zygote.gradient(loss_n_ode3,p) # mutating error 
Zygote.gradient(loss_n_ode4,p) # works
Zygote.gradient(loss_n_ode5,p) # mutating error

stacktrace:

ERROR: LoadError: Mutating arrays is not supported -- called setindex!(::Vector{UInt64}, _...)
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#438#439"{Vector{UInt64}})(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/lib/array.jl:76
  [3] (::Zygote.var"#2373#back#440"{Zygote.var"#438#439"{Vector{UInt64}}})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ ./bitarray.jl:344 [inlined]
  [5] (::typeof(∂(copy_to_bitarray_chunks!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [6] Pullback
    @ ./bitarray.jl:534 [inlined]
  [7] (::typeof(∂(BitVector)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
  [8] Pullback
    @ ./bitarray.jl:503 [inlined]
  [9] (::typeof(∂(BitArray)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/.julia/packages/DiffEqBase/gCXSd/src/callbacks.jl:289 [inlined]
 [11] (::typeof(∂(DiscreteCallback)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [12] Pullback
    @ ~/.julia/packages/DiffEqBase/gCXSd/src/callbacks.jl:295 [inlined]
 [13] (::typeof(∂(#DiscreteCallback#23)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [14] Pullback
    @ ~/.julia/packages/DiffEqBase/gCXSd/src/callbacks.jl:295 [inlined]
 [15] (::typeof(∂(Type##kw)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [16] Pullback
    @ ~/.julia/dev/DiffEqSensitivity/test/downstream/HybridNODE.jl:113 [inlined]
 [17] (::typeof(∂(loss_n_ode5)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface2.jl:0
 [18] (::Zygote.var"#46#47"{typeof(∂(loss_n_ode5))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:41
 [19] gradient(f::Function, args::Vector{Float32})
    @ Zygote ~/.julia/packages/Zygote/TaBlo/src/compiler/interface.jl:76
 [20] top-level scope
    @ ~/.julia/dev/DiffEqSensitivity/test/downstream/HybridNODE.jl:126
in expression starting at ..

@frankschae
Copy link
Member

How about defining a remake function and adding a nograd for it to avoid differentiation of the constructor? At least it looks like we can use it to work around the Zygote error.

# remake callback, to update callbacks inside Zygote.gradient(..)
function myremake(cb::DiscreteCallback; condition=cb.condition, (affect!)=cb.affect!, initialize=cb.initialize, finalize=cb.finalize, save_positions=cb.save_positions)
    DiscreteCallback(condition,affect!,initialize,finalize,save_positions)
end

Zygote.@nograd myremake #avoid taking grads of this function

myremake(cb)
myremake(cbdisc)


function loss_n_ode6(p)
    _prob = remake(prob,p=p)
    _cb = myremake(cb)
    pred = Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=_cb,saveat=t,tstops=dosetimes,sensealg=ReverseDiffAdjoint()))[1:2,:]
    sum(abs2,ode_data .- pred)
end

function loss_n_ode7(p)
    _prob = remake(prob,p=p)
    _cb = myremake(cbdisc)
    pred = Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=_cb,saveat=t,tstops=dosetimes,sensealg=ReverseDiffAdjoint()))[1:2,:]
    sum(abs2,ode_data .- pred)
end

@test loss_n_ode2(p) == loss_n_ode6(p)
@test loss_n_ode2(p) == loss_n_ode7(p)

@test Zygote.gradient(loss_n_ode6,p) == Zygote.gradient(loss_n_ode2,p) # works
@test Zygote.gradient(loss_n_ode7,p) == Zygote.gradient(loss_n_ode2,p) # works 

@frankschae
Copy link
Member

Ok.. unfortunately that doesn't work if one actually changes something:

dosetimes2 = [1.0,3.0,6.0,8.0]
cb2 = PresetTimeCallback(dosetimes2,affect!,save_positions=(false,false))

function loss_n_ode8(p)
    _prob = remake(prob,p=p)
    pred = Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=cb2,saveat=t,tstops=dosetimes2,sensealg=ReverseDiffAdjoint()))[1:2,:]
    sum(abs2,ode_data .- pred)
end

function loss_n_ode9(p)
    _prob = remake(prob,p=p)
    _condition = (u, t, integrator)->(t  dosetimes2)
    _cb = myremake(cbdisc, condition=_condition)
    pred = Array(solve(_prob,Tsit5(),u0=z0,p=p,callback=_cb,saveat=t,tstops=dosetimes2,sensealg=ReverseDiffAdjoint()))[1:2,:]
    sum(abs2,ode_data .- pred)
end

@test loss_n_ode2(p) != loss_n_ode8(p)
@test loss_n_ode8(p) == loss_n_ode9(p)

@test Zygote.gradient(loss_n_ode8,p) == Zygote.gradient(loss_n_ode9,p) #  Zygote.gradient(loss_n_ode9,p)  fails with mutating error
@test Zygote.gradient(loss_n_ode8,p) != Zygote.gradient(loss_n_ode2,p)

@frankschae
Copy link
Member

@frankschae
Copy link
Member

In the following MWE, only the gradient for myf4b fails (with the same stacktrace):

using Zygote, Test

struct MyStruct{T1}
  sp::T1
end 

c = MyStruct(true)

function myf(p)
  if c.sp
    (p[1]^2 + p[2]^2)/2
  else
    p[1] + p[2]
  end 
end

p = randn(2)
@test myf(p) == sum(abs2,p)/2
@test Zygote.gradient(myf,p) == (p,)

function myf2(p)
  _c = MyStruct(true)
  if _c.sp
    (p[1]^2 + p[2]^2)/2
  else
    p[1] + p[2]
  end 
end

@test myf2(p) == sum(abs2,p)/2
@test Zygote.gradient(myf2,p) == (p,)

function myf3(p)
  _c = MyStruct([true])
  if _c.sp[1]
    (p[1]^2 + p[2]^2)/2
  else
    p[1] + p[2]
  end 
end

@test myf3(p) == sum(abs2,p)/2
@test Zygote.gradient(myf3,p) == (p,)


c2 = MyStruct(BitArray([1]))

function myf4a(p)
  if c2.sp[1]
    (p[1]^2 + p[2]^2)/2
  else
    p[1] + p[2]
  end 
end

@test myf4a(p) == sum(abs2,p)/2
@test Zygote.gradient(myf4a,p) == (p,)


function myf4b(p)
  _c = MyStruct(BitArray([1]))
  if _c.sp[1]
    (p[1]^2 + p[2]^2)/2
  else
    p[1] + p[2]
  end 
end

@test myf4b(p) == sum(abs2,p)/2
@test Zygote.gradient(myf4b,p) == (p,)  # fails

Maybe @oxinabox has an idea?

@oxinabox
Copy link

oxinabox commented Aug 19, 2021

I think we just need to add to ChainRules.jl:
@non_differentiable BitArray(::Any)
Which will tell it not to AD through the construction of a BitArray, as it has no tangent space anyway.

If I do that, all your code in the last comment succeeds.

PR's welcome

@ChrisRackauckas
Copy link
Member

Fixed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants