Skip to content

Commit

Permalink
Give Apollo its own eta for adjust, and use sqrt(#params) for GradNor…
Browse files Browse the repository at this point in the history
…mGrowthLimiter
  • Loading branch information
murrellb committed Dec 14, 2024
1 parent 6aa32c1 commit d9637c6
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -603,10 +603,10 @@ end
"""
GradNormGrowthLimiter(γ = 1.1; m = 1e-3, ϵ = 1e-8, throw = true, paramscale_min = true)
Gradient norm growth limiter from Chen et al. (https://arxiv.org/pdf/2410.01623) and used with Apollo in Zhu et al. (https://arxiv.org/pdf/2412.05270).
With Optimisers.jl this will apply per-tensor, which may not be the same as the implementations in these papers. It still seems to help, but the ideal settings may vary.
This also introduces `m` a hard minimum on the gradient norm, and never rescales grads below this, preventing a tensor from getting "trapped" near zero.
This can be a fixed min, or scaled by the number of parameters in the tensor (with `paramscale_min = true`).
Gradient norm growth limiter. Inspired by [Chen et al.](https://arxiv.org/abs/2410.01623) and used with Apollo in [Zhu et al.](https://arxiv.org/abs/2412.05270), but
with Optimisers.jl this will apply per-tensor instead of per-model, and as a result the defaults are different. `γ` controls the maximum that the gradient norm can grow
from one step to the next. This implementation also introduces `m` a hard minimum on the gradient norm threshold, and never rescales grads below this, preventing a tensor
from getting "trapped" near zero. This can be a fixed min, or scaled by the square root of the number of parameters in the tensor (with `paramscale_min = true`).
"""
struct GradNormGrowthLimiter <: AbstractRule
γ::Float64
Expand All @@ -630,7 +630,7 @@ function apply!(o::GradNormGrowthLimiter, state, x::AbstractArray{T}, dx) where
else
#If you're below the hard min, then don't scale
if o.paramscale_min
minthresh = o.m * length(dx)
minthresh = o.m * sqrt(length(dx))
else
minthresh = o.m

Check warning on line 635 in src/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/rules.jl#L635

Added line #L635 was not covered by tests
end
Expand Down Expand Up @@ -659,19 +659,20 @@ Apollo optimizer from Zhu et al. (https://arxiv.org/pdf/2412.05270). Tracks mome
First argument can be an AdamW optimizer, or a learning rate (which will use the default AdamW optimizer with that learning rate). Second argument can be a rank, or a function
to compute the rank from the second dimension (or the product of all dims > 1) of the weight matrix (or tensor).
"""
struct Apollo{T1} <: AbstractRule
struct Apollo{T1, T2, T3, T4, T5} <: AbstractRule
opt::T1
r::Function #Maps non-first dims to rank
u::Int #Subspace update frequency (T in paper)
sort_dims::Bool #Whether to swap the dims of x and dx when the second dim is smaller than the first
eta::T2
r::T3 #Maps non-first dims to rank
u::T4 #Subspace update frequency (T in paper)
sort_dims::T5 #Whether to swap the dims of x and dx when the second dim is smaller than the first
end


Apollo() = Apollo(AdamW(0.001), dim -> ceil(Int, sqrt(dim)), 100, true)
Apollo::Real, rank::Int; u = 100, sort_dims = true) = Apollo(AdamW(η), dim -> max(dim, rank), u, sort_dims)
Apollo::Real; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(AdamW(η), rank_function, u, sort_dims)
Apollo(opt::AdamW, rank::Int; u = 100, sort_dims = true) = Apollo(AdamW(η), dim -> max(dim, rank), u, sort_dims)
Apollo(opt::AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(opt, rank_function, u, sort_dims)
Apollo() = Apollo(AdamW(0.001), 0.001, dim -> ceil(Int, sqrt(dim)), 100, true)
Apollo::Real, rank::Int; u = 100, sort_dims = true) = Apollo(AdamW(η), η, dim -> max(dim, rank), u, sort_dims)
Apollo::Real; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(AdamW(η), η, rank_function, u, sort_dims)
Apollo(opt::AdamW, rank::Int; u = 100, sort_dims = true) = Apollo(opt, opt.eta, dim -> max(dim, rank), u, sort_dims)
Apollo(opt::AdamW; rank_function::Function = dim -> ceil(Int, sqrt(dim)), u = 100, sort_dims = true) = Apollo(opt, opt.eta, rank_function, u, sort_dims)

Check warning on line 675 in src/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/rules.jl#L672-L675

Added lines #L672 - L675 were not covered by tests

#Use the base init and apply for 1D arrays
init(o::Apollo, x::AbstractArray{T,1}) where T = init(o.opt, x)
Expand Down Expand Up @@ -706,7 +707,7 @@ function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T
swapped = true

Check warning on line 707 in src/rules.jl

View check run for this annotation

Codecov / codecov/patch

src/rules.jl#L704-L707

Added lines #L704 - L707 were not covered by tests
end
(mt, vt, βt), t, P = state
η = T(o.opt.eta)
η = T(o.eta) #This is what will get modified by adjust
λ = T(o.opt.lambda)
β = T.(o.opt.beta)
ϵ = T(o.opt.epsilon)
Expand All @@ -728,6 +729,9 @@ function apply!(o::Apollo, state, x::AbstractArray{T}, dx) where T
return ((mt, vt, βt .* β), t+1, P), reshape(dx′′, original_size)
end

#Notes: chuck the AdamW from the struct, so that adjust will just work.



"""
WeightDecay(λ = 5e-4)
Expand Down

0 comments on commit d9637c6

Please sign in to comment.