From df7b0a3f29b7ce83f8b9d91b73f62416547ab3bd Mon Sep 17 00:00:00 2001 From: Mateusz Kaduk Date: Thu, 16 Feb 2023 18:54:12 +0100 Subject: [PATCH] Add implementation of Lion optimiser --- src/Optimisers.jl | 2 +- src/rules.jl | 30 ++++++++++++++++++++++++++++++ test/rules.jl | 2 +- 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 8e8cb19f..2e32e755 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -14,7 +14,7 @@ export destructure include("rules.jl") export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, - WeightDecay, ClipGrad, ClipNorm, OptimiserChain + WeightDecay, ClipGrad, ClipNorm, OptimiserChain, Lion ### ### one-array functions diff --git a/src/rules.jl b/src/rules.jl index 0b366faa..57f2d752 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -217,6 +217,36 @@ function apply!(o::Adam, state, x, dx) return (mt, vt, βt .* β), dx′ end +""" + Lion(η = 0.001, β::Tuple = (0.9, 0.999)) + +[Lion](https://arxiv.org/abs/2302.06675) optimiser. + +# Parameters +- Learning rate (`η`): Magnitude by which gradients are updating the weights. +- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the + second (β2) momentum estimate. +""" +struct Lion{T} <: AbstractRule + eta::T + beta::Tuple{T,T} +end +Lion(η = 1f-3, β = (9f-1, 9.99f-1)) = Lion{typeof(η)}(η, β) + +init(o::Lion, x::AbstractArray) = zero(x) + +function apply!(o::Lion, state, x, dx) + η, β = o.eta, o.beta + + @.. state = β[2] * dx + (1-β[2]) * state + + # The paper writes the update in terms of the old momentum, + # but easy to solve in terms of the current momentum instead: + dx′ = @lazy η * sign((β[2]-β[1]) * dx + β[1] * state) + + return state, dx′ +end + """ RAdam(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η))) diff --git a/test/rules.jl b/test/rules.jl index fd9660a1..0bac0145 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -8,7 +8,7 @@ RULES = [ # All the rules at default settings: Descent(), Adam(), Momentum(), Nesterov(), Rprop(), RMSProp(), AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(), - AdamW(), RAdam(), OAdam(), AdaBelief(), + AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(), # A few chained combinations: OptimiserChain(WeightDecay(), Adam(0.001)), OptimiserChain(ClipNorm(), Adam(0.001)),