Skip to content

Commit

Permalink
Merge #1133
Browse files Browse the repository at this point in the history
1133: add ClipValue and ClipNorm r=CarloLucibello a=AStupidBear



Co-authored-by: Yao Lu <luyaocns@gmail.com>
  • Loading branch information
bors[bot] and AStupidBear authored May 15, 2020
2 parents fab53e0 + 0075868 commit b6a5dd7
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 3 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand Down
13 changes: 13 additions & 0 deletions docs/src/training/optimisers.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,16 @@ ExpDecay
InvDecay
WeightDecay
```

## Gradient Clipping

Gradient clipping is useful for training recurrent neural networks, which have a tendency to suffer from the exploding gradient problem. An example usage is

```julia
opt = Optimiser(ClipValue(1e-3), ADAM(1e-3))
```

```@docs
ClipValue
ClipNorm
```
6 changes: 4 additions & 2 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ module Flux
# Zero Flux Given

using Base: tail
using Zygote, MacroTools, Juno, Reexport, Statistics, Random
using Statistics, Random, LinearAlgebra
using Zygote, MacroTools, Juno, Reexport
using MacroTools: @forward
@reexport using NNlib
using Zygote: Params, @adjoint, gradient, pullback, @nograd
Expand All @@ -20,7 +21,8 @@ using .Optimise
using .Optimise: @epochs
export Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay,
ClipValue, ClipNorm


using CuArrays
Expand Down
5 changes: 4 additions & 1 deletion src/optimise/Optimise.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
module Optimise

using LinearAlgebra

export train!, update!,
Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM,
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
InvDecay, ExpDecay, WeightDecay, stop, Optimiser,
ClipValue, ClipNorm

include("optimisers.jl")
include("train.jl")
Expand Down
28 changes: 28 additions & 0 deletions src/optimise/optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -533,3 +533,31 @@ function apply!(o::WeightDecay, x, Δ)
wd = o.wd
@. Δ += wd * x
end

"""
ClipValue(thresh)
Clip gradients when their absolute value exceeds `thresh`.
"""
mutable struct ClipValue{T}
thresh::T
end

apply!(o::ClipValue, x, Δ) = clamp!(Δ, -o.thresh, o.thresh)

"""
ClipNorm(thresh)
Clip gradients when their L2 norm exceeds `thresh`.
"""
mutable struct ClipNorm{T}
thresh::T
end

function apply!(o::ClipNorm, x, Δ)
Δnrm = norm(Δ)
if Δnrm > o.thresh
rmul!(Δ, o.thresh / Δnrm)
end
return Δ
end
12 changes: 12 additions & 0 deletions test/optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,15 @@ end
@test decay_steps == ground_truth
@test o.eta == o.clip
end

@testset "Clipping" begin
w = randn(10, 10)
loss(x) = sum(w * x)
θ = Params([w])
x = 1000 * randn(10)
= gradient(() -> loss(x), θ)[w]
w̄_value = Optimise.apply!(ClipValue(1.0), w, copy(w̄))
@test all(w̄_value .<= 1)
w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy(w̄))
@test norm(w̄_norm) <= 1
end

0 comments on commit b6a5dd7

Please sign in to comment.