From 0ece9f193bd8d78cfac1991cd7a0bb9b65b8b75f Mon Sep 17 00:00:00 2001 From: Samuel Albert Date: Thu, 16 May 2019 11:45:19 -0400 Subject: [PATCH] modifications to support proximal optimization --- src/Flux.jl | 5 +++-- src/optimise/Optimise.jl | 4 +++- src/optimise/regularization.jl | 34 ++++++++++++++++++++++++++++++++++ src/optimise/train.jl | 5 ++++- 4 files changed, 44 insertions(+), 4 deletions(-) create mode 100644 src/optimise/regularization.jl diff --git a/src/Flux.jl b/src/Flux.jl index a041a69a8d..7a4c9f8fff 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,7 +7,7 @@ using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, - DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, + DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, params, mapleaves, cpu, gpu, f32, f64 @reexport using NNlib @@ -21,7 +21,8 @@ using .Optimise using .Optimise: @epochs export SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, - ADAMW, InvDecay, ExpDecay, WeightDecay + ADAMW, InvDecay, ExpDecay, WeightDecay, + L1_regularization, L2_regularization include("utils.jl") include("onehot.jl") diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 5bb38d1ecb..58955fa992 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -3,9 +3,11 @@ module Optimise export train!, SGD, Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, - InvDecay, ExpDecay, WeightDecay, stop, Optimiser + InvDecay, ExpDecay, WeightDecay, stop, Optimiser, + L1_regularization, L2_regularization include("optimisers.jl") +include("regularization.jl") include("train.jl") include("deprecations.jl") diff --git a/src/optimise/regularization.jl b/src/optimise/regularization.jl new file mode 100644 index 0000000000..6900eeaa4a --- /dev/null +++ b/src/optimise/regularization.jl @@ -0,0 +1,34 @@ +# Proximal updates for convex regularization +using LinearAlgebra + +struct L1_regularization + α::Float64 + f::Function +end + +shrink(α) = f(z) = z > α ? α : z < -α ? -α : z + +function L1_regularization(α) + return L1_regularization(α, shrink(α)) +end + +function apply!(r::L1_regularization, x, Δ) + z = data(x) + Δ .= r.f.(z) + return Δ +end + +struct L2_regularization + α::Float64 +end + +function apply!(r::L2_regularization, x, Δ) + z = data(x) + norm_z = norm(z) + if norm_z > r.α + Δ .= (r.α/norm_z) .* z + else + Δ .= z + end + return Δ +end diff --git a/src/optimise/train.jl b/src/optimise/train.jl index ab8be57898..63cd6ddc29 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -63,7 +63,7 @@ The callback can call `Flux.stop()` to interrupt the training loop. Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. """ -function train!(loss, ps, data, opt; cb = () -> ()) +function train!(loss, ps, data, opt; regularization = nothing, cb = () -> ()) ps = Params(ps) cb = runall(cb) @progress for d in data @@ -72,6 +72,9 @@ function train!(loss, ps, data, opt; cb = () -> ()) loss(d...) end update!(opt, ps, gs) + if regularization != nothing + update!(regularization, ps, gs) + end if cb() == :stop depwarn("Use of `:stop` is deprecated; use `Flux.stop()` instead", :stop) break