-
-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Latest release (0.7.9) broke softmax gradients with Tracker #251
Comments
I have found the root cause: function softmax!(out::O, x::T; dims = 1) where {O<:AbstractArray,T<:AbstractArray}
out .= exp.(x .- maximum(x; dims = dims))
out ./= sum(out; dims = dims)
end this works. gradient(x -> dot([1.0, 0], softmax(x)), [0.0, 1]) this causes an error gradient(x -> dot(x, softmax(x)), [0.0, 1]) @CarloLucibello , |
Zygote works fine, you can use mutating operations as long as you provide an explicit adjoint, julia> using Zygote, NNlib, LinearAlgebra # Zygote v5.17, NNlib v0.7.9
julia> gradient(x -> dot([1.0, 0], softmax(x)), [0.0, 1])
([0.19661193324148185, -0.19661193324148185],)
julia> gradient(x -> dot(x, softmax(x)), [0.0, 1])
([0.07232948812851325, 0.9276705118714867],) |
The problem should be only with Tracker, which must be performing tracing without exploiting the definition of ∇softmax. I'm not very familiar with Tracker, but I think the problem should be fixed there, not here |
maybe this will already be fixed by FluxML/Tracker.jl#90 |
gradient(x -> dot(x, softmax(x)), [0.0, 1]) Zygote@0.5.17 works. But Zygote@0.6.0 raises that error. |
that is fine, we moved NNlib's AD rules out of Zygote in 0.6.0, they are being ported here and will be available in NNlib v0.8 |
So current NNlib is broken and shouldn't be used since it's incompatible with zygote 0.5? |
@DhairyaLGandhi ??? current NNlib is perfectly fine and compatible with zygote 0.5 |
Sorry it was a question, I mistyped. I edited the question mark in |
The following example works flawlessly with NNlib 0.7.8 but yields an error with NNlib 0.7.9:
Due to this problem, the CI tests in DistributionsAD (e.g. https://github.com/TuringLang/DistributionsAD.jl/pull/143/checks?check_run_id=1592162436) and Bijectors fail currently.
The text was updated successfully, but these errors were encountered: