-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve some activation function gradients #392
Conversation
test/activations.jl
Outdated
if Symbol(f) in NNlib.BINARY_ACTS | ||
@test rrule(f, rand(), rand()) !== nothing | ||
@test rrule(broadcasted, f, rand(2), rand(2)) !== nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these tests were never run. Because NNlib.BINARY_ACTS contains tuples, not symbols.
test/activations.jl
Outdated
@testset "binary rule" begin | ||
## Check that rules, including broadcast rules, are defined: | ||
@test rrule(f, rand(), rand()) !== nothing | ||
@test rrule(broadcasted, f, rand(2), rand(2)) !== nothing | ||
|
||
## Correctness tests above don't check 2-arg version. | ||
gradtest(x -> f(x, 0.2), 1 + rand(rng)) | ||
gradtest(x -> f(x, 0.7), 1 + rand(rng)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Old binary tests didn't attempt to check correctness, only that t rule existed. These ones do.
There is also a test of the gradient rule below.
has_rule(a) = rrule(a, 1f0) === nothing ? "(no rule)" : "" | ||
|
||
@testset "Gradient inference" begin | ||
@testset "$(a): $(has_rule(a))" for a in ACTIVATION_FUNCTIONS | ||
@testset "$T" for T in [Float16, Float32, Float64] | ||
for val in [-10, -1, 0, 1, 10] | ||
grad = @inferred gradient(a, T(val)) | ||
@test typeof(grad[1]) == T |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This replaces test_gradient_float_precision_preserving
. I thought something was broken, so I made it print more stuff. But in the end it's not broken. Just the tests are easier to read.
@testset "relu: " begin | ||
# relu doesn't have to force floating point outputs | ||
# The following ones can pass integers through. But it's not very important. | ||
@testset "relu: Int -> Int" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I simplified the integer tests. There's no real need to test forever that Int32 is type-stable; they should work with Int, but I can't imagine them being useful in anything performance-sensitive.
Also, long long testset names cause things to wrap badly and become hard to read. Which was what motivated taking the axe to these.
src/activations.jl
Outdated
NO_ACT_GRAD = ChainRulesCore.@not_implemented "for simplicitly NNlib assumes the 2nd argument of this activation function is a constant" | ||
|
||
BINARY_ACTS = [ # f, df1, df2 | ||
(:elu, :(deriv_elu(Ω, x2)), :(NoTangent())), # TODO use real deriv instead of DNE | ||
] | ||
(:elu, :(deriv_elu(Ω, x2)), NO_ACT_GRAD), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@not_implemented
is designed for exactly this. But seems not to be GPU-friendly, sadly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With NaN instead, the example from #386 runs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, in fact it runs forwards, but not backwards:
julia> l, b = Flux.pullback(() -> gradient_penalty(cm₁, cx), params(cm₁)) # Fails to compile
(0.00047222938f0, Zygote.var"#94#95"{Params, typeof(∂(#45)), Zygote.Context}(Params([Float32[-0.015067626;;], Float32[0.0]]), ∂(#45), Zygote.Context(IdDict{Any, Any}(Zygote.Stack{Any}(0, Any[]) => Base.RefValue{Any}((idx = nothing, data = nothing)), Zygote.Stack{Any}(0, Any[]) => Base.RefValue{Any}((idx = nothing, data = nothing)), Zygote.Stack{Any}(0, Any[]) => Base.RefValue{Any}((idx = nothing, data = nothing)), Zygote.Stack{Any}(0, Any[]) => Base.RefValue{Any}((idx = nothing, data = nothing)), Zygote.Stack{Any}(0, Any[]) => Base.RefValue{Any}((idx = nothing, data = nothing)), Zygote.Stack{Any}(0, Any[]) => Base.RefValue{Any}((idx = nothing, data = nothing)), Zygote.Stack{Any}(1, Any[Zygote.ZBack{ChainRules.var"#===_pullback#81"}(ChainRules.var"#===_pullback#81"())]) => Base.RefValue{Any}((idx = nothing, data = nothing)), Zygote.Stack{Any}(1, Any[Zygote.ZBack{ChainRules.var"#===_pullback#81"}(ChainRules.var"#===_pullback#81"())]) => Base.RefValue{Any}((idx = nothing, data = nothing)), Zygote.Stack{Any}(0, Any[]) => Base.RefValue{Any}((idx = nothing, data = nothing)), Zygote.Stack{Any}(0, Any[]) => Base.RefValue{Any}((idx = nothing, data = nothing))…))))
julia> b(1f0)
ERROR: Mutating arrays is not supported -- called copyto!(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, _...)
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.var"#446#447"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})(#unused#::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/FPUm3/src/lib/array.jl:74
[3] (::Zygote.var"#2351#back#448"{Zygote.var"#446#447"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})(Δ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[4] Pullback
@ ./broadcast.jl:871 [inlined]
[5] Pullback
@ ./broadcast.jl:868 [inlined]
[6] Pullback
@ ./broadcast.jl:864 [inlined]
[7] Pullback
@ ~/.julia/packages/Zygote/FPUm3/src/lib/broadcast.jl:275 [inlined]
[8] (::typeof(∂(λ)))(Δ::Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})
@ Zygote ./compiler/interface2.jl:0
[9] Pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[10] (::typeof(∂(λ)))(Δ::Tuple{Nothing, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})
@ Zygote ./compiler/interface2.jl:0
[11] Pullback
@ ./REPL[27]:2 [inlined]
[12] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ./compiler/interface2.jl:0
[13] Pullback
@ ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:357 [inlined]
[14] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote ./compiler/interface2.jl:0
[15] Pullback
@ ./REPL[27]:3 [inlined]
[16] (::typeof(∂(gradient_penalty)))(Δ::Float32)
@ Zygote ./compiler/interface2.jl:0
[17] Pullback
@ ./REPL[36]:1 [inlined]
[18] (::typeof(∂(#45)))(Δ::Float32)
@ Zygote ./compiler/interface2.jl:0
[19] (::Zygote.var"#94#95"{Params, typeof(∂(#45)), Zygote.Context})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:357
[20] top-level scope
@ REPL[37]:1
[21] top-level scope
@ ~/.julia/packages/CUDA/KnJGx/src/initialization.jl:52
Am I right in assuming this is good to go? :) |
Yes, this is ready I think. |
BINARY_ACTS = [ # f, dfdx1, dfdx2 | ||
## In the same order as above! | ||
(:leakyrelu, :(ifelse(Ω > 0, oftf(Ω, 1), oftf(Ω, x2))), NO_ACT_GRAD), | ||
(:elu, :(deriv_elu(Ω, x2)), NO_ACT_GRAD), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What rules like this don't solve is that you need to change Dense(10 => 20, lekyrelu)
into
Chain(Dense(10 => 20), x -> leakyrelu.(x, 0.1))
to provide the 2nd argument.
Not this PR, but it would be nice if you could write Dense(10 => 20, lekyrelu(0.1))
instead. Perhaps it has to be Dense(10 => 20, Lekyrelu(0.1))
with some new struct, but perhaps you can just be clever with overloading broadcasted
so that lekyrelu(0.1)
makes a function but Dense(10 => 20, lekyrelu)
doesn't create a vector of functions?
This adds gradient rules for a few more activation functions.
It also upgrades the tests -- I think the two-argument forms were not being tested at all.
Now that I push it to a GPU machine, I see that it does not fix issue #386, because it marks the gradient of the 2nd arg of
leakyrelu
with ChainRulesCore.NotImplemented, and it seems that isn't GPU friendly:Could revert it to silently wrong answers. Or make it NaN to avoid this (which is what DiffRules does).
Edit -- now made NaN. This runs, but the gradient does not.