Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve some activation function gradients #392

Merged
merged 12 commits into from
Feb 24, 2022
Merged

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Feb 21, 2022

This adds gradient rules for a few more activation functions.

It also upgrades the tests -- I think the two-argument forms were not being tested at all.

Now that I push it to a GPU machine, I see that it does not fix issue #386, because it marks the gradient of the 2nd arg of leakyrelu with ChainRulesCore.NotImplemented, and it seems that isn't GPU friendly:

julia> l, b = Flux.pullback(() -> gradient_penalty(cm₁, cx), params(cm₁)) # Fails to compile
ERROR: GPU compilation of kernel broadcast_kernel(CUDA.CuKernelContext, CuDeviceMatrix{ChainRulesCore.NotImplemented, 1}, Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#1115#1118"{typeof(*)}, Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, CUDA.CuRefValue{ChainRulesCore.NotImplemented}}}, Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#1115#1118"{typeof(*)}, Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, CUDA.CuRefValue{ChainRulesCore.NotImplemented}}}, which is not isbits:
  .args is of type Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{Float32, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}, CUDA.CuRefValue{ChainRulesCore.NotImplemented}} which is not isbits.
    .2 is of type CUDA.CuRefValue{ChainRulesCore.NotImplemented} which is not isbits.
      .x is of type ChainRulesCore.NotImplemented which is not isbits.
        .mod is of type Module which is not isbits.
        .source is of type LineNumberNode which is not isbits.
          .file is of type Union{Nothing, Symbol} which is not isbits.
        .info is of type String which is not isbits.

Could revert it to silently wrong answers. Or make it NaN to avoid this (which is what DiffRules does).

Edit -- now made NaN. This runs, but the gradient does not.

Comment on lines 339 to 341
if Symbol(f) in NNlib.BINARY_ACTS
@test rrule(f, rand(), rand()) !== nothing
@test rrule(broadcasted, f, rand(2), rand(2)) !== nothing
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these tests were never run. Because NNlib.BINARY_ACTS contains tuples, not symbols.

Comment on lines 328 to 335
@testset "binary rule" begin
## Check that rules, including broadcast rules, are defined:
@test rrule(f, rand(), rand()) !== nothing
@test rrule(broadcasted, f, rand(2), rand(2)) !== nothing

## Correctness tests above don't check 2-arg version.
gradtest(x -> f(x, 0.2), 1 + rand(rng))
gradtest(x -> f(x, 0.7), 1 + rand(rng))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Old binary tests didn't attempt to check correctness, only that t rule existed. These ones do.

There is also a test of the gradient rule below.

Comment on lines +292 to +299
has_rule(a) = rrule(a, 1f0) === nothing ? "(no rule)" : ""

@testset "Gradient inference" begin
@testset "$(a): $(has_rule(a))" for a in ACTIVATION_FUNCTIONS
@testset "$T" for T in [Float16, Float32, Float64]
for val in [-10, -1, 0, 1, 10]
grad = @inferred gradient(a, T(val))
@test typeof(grad[1]) == T
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This replaces test_gradient_float_precision_preserving. I thought something was broken, so I made it print more stuff. But in the end it's not broken. Just the tests are easier to read.

@testset "relu: " begin
# relu doesn't have to force floating point outputs
# The following ones can pass integers through. But it's not very important.
@testset "relu: Int -> Int" begin
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simplified the integer tests. There's no real need to test forever that Int32 is type-stable; they should work with Int, but I can't imagine them being useful in anything performance-sensitive.

Also, long long testset names cause things to wrap badly and become hard to read. Which was what motivated taking the axe to these.

Comment on lines 843 to 846
NO_ACT_GRAD = ChainRulesCore.@not_implemented "for simplicitly NNlib assumes the 2nd argument of this activation function is a constant"

BINARY_ACTS = [ # f, df1, df2
(:elu, :(deriv_elu(Ω, x2)), :(NoTangent())), # TODO use real deriv instead of DNE
]
(:elu, :(deriv_elu(Ω, x2)), NO_ACT_GRAD),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@not_implemented is designed for exactly this. But seems not to be GPU-friendly, sadly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With NaN instead, the example from #386 runs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, in fact it runs forwards, but not backwards:

julia> l, b = Flux.pullback(() -> gradient_penalty(cm₁, cx), params(cm₁)) # Fails to compile
(0.00047222938f0, Zygote.var"#94#95"{Params, typeof(∂(#45)), Zygote.Context}(Params([Float32[-0.015067626;;], Float32[0.0]]), ∂(#45), Zygote.Context(IdDict{Any, Any}(Zygote.Stack{Any}(0, Any[]) => Base.RefValue{Any}((idx = nothing, data = nothing)), Zygote.Stack{Any}(0, Any[]) => Base.RefValue{Any}((idx = nothing, data = nothing)), Zygote.Stack{Any}(0, Any[]) => Base.RefValue{Any}((idx = nothing, data = nothing)), Zygote.Stack{Any}(0, Any[]) => Base.RefValue{Any}((idx = nothing, data = nothing)), Zygote.Stack{Any}(0, Any[]) => Base.RefValue{Any}((idx = nothing, data = nothing)), Zygote.Stack{Any}(0, Any[]) => Base.RefValue{Any}((idx = nothing, data = nothing)), Zygote.Stack{Any}(1, Any[Zygote.ZBack{ChainRules.var"#===_pullback#81"}(ChainRules.var"#===_pullback#81"())]) => Base.RefValue{Any}((idx = nothing, data = nothing)), Zygote.Stack{Any}(1, Any[Zygote.ZBack{ChainRules.var"#===_pullback#81"}(ChainRules.var"#===_pullback#81"())]) => Base.RefValue{Any}((idx = nothing, data = nothing)), Zygote.Stack{Any}(0, Any[]) => Base.RefValue{Any}((idx = nothing, data = nothing)), Zygote.Stack{Any}(0, Any[]) => Base.RefValue{Any}((idx = nothing, data = nothing))…))))

julia> b(1f0)
ERROR: Mutating arrays is not supported -- called copyto!(::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, _...)
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#446#447"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})(#unused#::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/lib/array.jl:74
  [3] (::Zygote.var"#2351#back#448"{Zygote.var"#446#447"{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})(Δ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [4] Pullback
    @ ./broadcast.jl:871 [inlined]
  [5] Pullback
    @ ./broadcast.jl:868 [inlined]
  [6] Pullback
    @ ./broadcast.jl:864 [inlined]
  [7] Pullback
    @ ~/.julia/packages/Zygote/FPUm3/src/lib/broadcast.jl:275 [inlined]
  [8] (::typeof(∂(λ)))(Δ::Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})
    @ Zygote ./compiler/interface2.jl:0
  [9] Pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [10] (::typeof(∂(λ)))(Δ::Tuple{Nothing, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})
    @ Zygote ./compiler/interface2.jl:0
 [11] Pullback
    @ ./REPL[27]:2 [inlined]
 [12] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ./compiler/interface2.jl:0
 [13] Pullback
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:357 [inlined]
 [14] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote ./compiler/interface2.jl:0
 [15] Pullback
    @ ./REPL[27]:3 [inlined]
 [16] (::typeof(∂(gradient_penalty)))(Δ::Float32)
    @ Zygote ./compiler/interface2.jl:0
 [17] Pullback
    @ ./REPL[36]:1 [inlined]
 [18] (::typeof(∂(#45)))(Δ::Float32)
    @ Zygote ./compiler/interface2.jl:0
 [19] (::Zygote.var"#94#95"{Params, typeof(∂(#45)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:357
 [20] top-level scope
    @ REPL[37]:1
 [21] top-level scope
    @ ~/.julia/packages/CUDA/KnJGx/src/initialization.jl:52

@ToucheSir
Copy link
Member

Am I right in assuming this is good to go? :)

@mcabbott
Copy link
Member Author

Yes, this is ready I think.

BINARY_ACTS = [ # f, dfdx1, dfdx2
## In the same order as above!
(:leakyrelu, :(ifelse(Ω > 0, oftf(Ω, 1), oftf(Ω, x2))), NO_ACT_GRAD),
(:elu, :(deriv_elu(Ω, x2)), NO_ACT_GRAD),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What rules like this don't solve is that you need to change Dense(10 => 20, lekyrelu) into

Chain(Dense(10 => 20), x -> leakyrelu.(x, 0.1))

to provide the 2nd argument.

Not this PR, but it would be nice if you could write Dense(10 => 20, lekyrelu(0.1)) instead. Perhaps it has to be Dense(10 => 20, Lekyrelu(0.1)) with some new struct, but perhaps you can just be clever with overloading broadcasted so that lekyrelu(0.1) makes a function but Dense(10 => 20, lekyrelu) doesn't create a vector of functions?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants