-
-
Notifications
You must be signed in to change notification settings - Fork 211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient over implicit parameters returns nothing #692
Comments
If I remove the custom adjoint, it works fine. So it has to be something related to the interaction between params and a custom struct S
W::Array{Float64}
end
Flux.@functor S
s = S(randn(4,4))
ps = params(s)
foo(s::S) = sum(s.W)
gs = gradient(ps) do
foo(s)
end
gs[s.W] # correct gradient |
A workaround is to write an intermediary function that takes only array inputs: struct S
W::Array{Float64}
end
Flux.@functor S
s = S(randn(4,4))
ps = params(s)
fff(s::S) = _fff(s.W)
_fff(w) = sum(sin.(w))
@adjoint function _fff(w)
_fff(w), Δ -> (similar(w) .= Δ .* cos.(w),)
end
gs = gradient(ps) do
fff(s)
end
gs[s.W] # correct gradients |
@MikeInnes Any idea what is happening here? This issue is producing wrong gradients silently, and it took me a while just to figure out the bug originated in Zygote. |
I don't believe Zygote can track implicit params usage in adjoint functions (by design I assume, otherwise there'd be no way to avoid AD in custom adjoints). So if |
To clarify a bit more, your custom adjoint means that the computation graph that the AD system works with looks like:
In other words, it never "sees" the array
Here, Zygote returns your custom adjoint for the gradient w.r.t. Similarly, without the custom adjoint, the AD has
Here, |
@darsnack You consider this resolved? |
Resolved == can't fix? (anyone feel free to reopen in case I'm wrong) Yes, my understanding is that this by design for adjoints. Writing a custom rule forces the AD to look away, and I don't think we would merge a change that breaks that fundamental assumption. The only alternative fix I see on Zygote's end would be to do post-pullback accumulating into implicit params since Zygote does the structural gradient anyways. This would require recursively traversing all the values. Maybe @mcabbott can comment on the correctness/feasibility of this. The recommended fixes here are to:
|
Structural gradient for reference: gradient(fun, S(rand(2, 2))) |
Ok, thanks! I will also put here this example from @ToucheSir (posted on the Slack) for future reference. struct S
W::Array{Float64}
end
s = S(randn(4,4))
fun(s::S) = sum(s.W)
@adjoint function fun(s::S)
fun(s), Δ -> ((; W = similar(s.W) .= Δ),)
end
julia> gs = gradient(s) do s
fun(s)
end
((W = [1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0],),) |
@darsnack One more question. Is there an alternative way I could have written the above explicit adjoint so that this example works fine? |
I would go even further and say Zygote should remove the implicit parameter system, or at least Flux should. It seemed like a good idea but 3 years later I think we've all learned it only causes pain. The main gain was syntactic sugar, but the system underlying it never really was that solid. This is just one of many unsolvable issues that arise from it, others being performance or compile time related, along with other weird correctness edge cases. Instead, we should all explore different ways to make explicit parameters have similarly nice syntax, and that would be the best of all worlds. |
We have the explicit form, and that would be good to use. Elsewhere, implicit gradients are tracked over the same rules and explicit ones. There's no design constraint over why one should work and another not. |
@ChrisRackauckas I would love for nothing more, but there unfortunately hasn't been a big push behind figuring out how to bring explicit params to parity, let alone a migration plan. This includes non-syntactic issues such as how to do tied weights and how to exclude certain params from optimization. I myself have at least a couple pages of design notes on various aspects/challenges, and looking at what others are doing it's clear this is not a trivial task! Anyhow, I just created a tracking project at https://github.com/orgs/FluxML/projects/2. Please add new issues/tasks as you encounter them—it would be great to record all this disparate discussion about implicit vs explicit params in one place. |
I have encountered this issue several times. This is the smallest example I was able to find to reproduce it.
I noticed that
gs
is storing the correct gradients inW
in anotherkey
, which equalsMain.s
, but I'm not even sure what that is and I cannot access it.But the correct key
s.W
is populated withnothing
, which is wrong.What is going on here?
The text was updated successfully, but these errors were encountered: