Skip to content

Commit

Permalink
Add implementation of Lion optimiser (#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
mashu authored Feb 24, 2023
1 parent e2254b4 commit 9007ad5
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/Optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export destructure
include("rules.jl")
export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp,
AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief,
WeightDecay, ClipGrad, ClipNorm, OptimiserChain
WeightDecay, ClipGrad, ClipNorm, OptimiserChain, Lion

###
### one-array functions
Expand Down
30 changes: 30 additions & 0 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,36 @@ function apply!(o::Adam, state, x, dx)
return (mt, vt, βt .* β), dx′
end

"""
Lion(η = 0.001, β::Tuple = (0.9, 0.999))
[Lion](https://arxiv.org/abs/2302.06675) optimiser.
# Parameters
- Learning rate (`η`): Magnitude by which gradients are updating the weights.
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
second (β2) momentum estimate.
"""
struct Lion{T} <: AbstractRule
eta::T
beta::Tuple{T,T}
end
Lion= 1f-3, β = (9f-1, 9.99f-1)) = Lion{typeof(η)}(η, β)

init(o::Lion, x::AbstractArray) = zero(x)

function apply!(o::Lion, state, x, dx)
η, β = o.eta, o.beta

@.. state = β[2] * dx + (1-β[2]) * state

# The paper writes the update in terms of the old momentum,
# but easy to solve in terms of the current momentum instead:
dx′ = @lazy η * sign((β[2]-β[1]) * dx + β[1] * state)

return state, dx′
end

"""
RAdam(η = 1f-3, β = (9f-1, 9.99f-1), ϵ = eps(typeof(η)))
Expand Down
2 changes: 1 addition & 1 deletion test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ RULES = [
# All the rules at default settings:
Descent(), Adam(), Momentum(), Nesterov(), Rprop(), RMSProp(),
AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(),
AdamW(), RAdam(), OAdam(), AdaBelief(),
AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(),
# A few chained combinations:
OptimiserChain(WeightDecay(), Adam(0.001)),
OptimiserChain(ClipNorm(), Adam(0.001)),
Expand Down

0 comments on commit 9007ad5

Please sign in to comment.