-
-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrong gradient involving splatting of kwargs #1284
Comments
I wonder if this is to do with |
That's precisely the case i believe
…On Sun, Aug 7, 2022, 20:47 Frames Catherine White ***@***.***> wrote:
I wonder if this is to do with kwarg[:x] showing up both in kwarg... and
kwarg[:x] ...?
—
Reply to this email directly, view it on GitHub
<#1284 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AJOZVVJXBEIYZHVZUIGWV3LVX7HRLANCNFSM552IQVNA>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
Interestingly, the issue persists even after eliminating this redundancy:
|
This turned out to be a fun issue 😱 . In short, kwargs are represented as Given all that, I can't help but wonder if this has been causing other mysterious bugs in the wild. Working on a PR that should hopefully be up soon. |
This seems to cause downstream failures. See function multiple_shoot(
p::AbstractArray,
ode_data::AbstractArray,
tsteps::AbstractArray,
ensembleprob::EnsembleProblem,
ensemblealg::SciMLBase.BasicEnsembleAlgorithm,
loss_function,
continuity_loss,
solver::DiffEqBase.AbstractODEAlgorithm,
group_size::Integer;
continuity_term::Real=100,
kwargs...
)
datasize = size(ode_data, 2)
prob = ensembleprob.prob
if group_size < 2 || group_size > datasize
throw(DomainError(group_size, "group_size can't be < 2 or > number of data points"))
end
@assert ndims(ode_data) == 3 "ode_data must have three dimension: `size(ode_data) = (problem_dimension,length(tsteps),trajectories)"
@assert size(ode_data,2) == length(tsteps)
@show kwargs
@assert size(ode_data,3) == kwargs[:trajectories] This then is called like: function loss_multiple_shooting_ens(p)
return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg,
loss_function, Tsit5(),
group_size; continuity_term,
trajectories,
abstol=1e-8, reltol=1e-6) # test solver kwargs
end kwargs = Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:trajectories, :abstol, :reltol), Tuple{Int64, Float64, Float64}}}(:trajectories => 2, :abstol => 1.0e-8, :reltol => 1.0e-6)
ERROR: MethodError: no method matching getindex(::Nothing, ::Int64)
Stacktrace:
[1] (::Zygote.var"#kwargs_literal_getindex_pullback#326"{Zygote.var"#1925#back#218"{Zygote.var"#back#217"{:trajectories, Zygote.Context{false}, NamedTuple{(:trajectories, :abstol, :reltol), Tuple{Int64, Float64, Float64}}, Int64}}})(Δ::Nothing)
@ Zygote C:\Users\accou\.julia\packages\Zygote\qGFGD\src\lib\base.jl:165
[2] Pullback
@ c:\Users\accou\.julia\packages\DiffEqFlux\Em1Aj\src\multiple_shooting.jl:185 [inlined] The line that errors is: @assert size(ode_data,3) == kwargs[:trajectories] My symbol is transformed into an integer and kwargs to https://github.com/SciML/SciMLSensitivity.jl/runs/8028243222?check_suite_focus=true using DiffEqFlux, OrdinaryDiffEq, Test
datasize = 30
u0 = Float32[2.0, 0.0]
tspan = (0.0f0, 5.0f0)
tsteps = range(tspan[1], tspan[2], length=datasize)
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u .^ 3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat=tsteps))
nn = FastChain((x, p) -> x .^ 3,
FastDense(2, 16, tanh),
FastDense(16, 2))
p_init = initial_params(nn)
neuralode = NeuralODE(nn, tspan, Tsit5(), saveat=tsteps)
prob_node = ODEProblem((u, p, t) -> nn(u, p), u0, tspan, p_init)
function loss_function(data, pred)
return sum(abs2, data - pred)
end
u0s = [Float32[2.0, 0.0], Float32[3.0, 1.0]]
function prob_func(prob, i, repeat)
remake(prob, u0=u0s[i])
end
ensemble_prob = EnsembleProblem(prob_node, prob_func=prob_func)
ensemble_prob_trueODE = EnsembleProblem(prob_trueode, prob_func=prob_func)
ensemble_alg = EnsembleThreads()
trajectories = 2
ode_data_ensemble = Array(solve(ensemble_prob_trueODE, Tsit5(), ensemble_alg, trajectories=trajectories, saveat=tsteps))
group_size = 3
continuity_term = 200
function loss_multiple_shooting_ens(p)
return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg,
loss_function, Tsit5(),
group_size; continuity_term,
trajectories,
abstol=1e-8, reltol=1e-6) # test solver kwargs
end
res_ms_ensembles = DiffEqFlux.sciml_train(loss_multiple_shooting_ens, neuralode.p,
ADAM(0.05), maxiters=300) |
Try #1295 on for size. |
That fixes it. |
Zygote (v0.6.43) currently gives a wrong gradient involving kwargs splatting. It seems to double-count a gradient contribution through implicit and explicit kwargs. Here's a small example:
The text was updated successfully, but these errors were encountered: