Skip to content

Commit

Permalink
Merge pull request #294 from kozistr/update/codes
Browse files Browse the repository at this point in the history
[Feature] Cautious optimizer, improve the stability of ADOPT optimizer, a new projector type `random` for `GaLore` optimizer
  • Loading branch information
kozistr authored Nov 27, 2024
2 parents 131314b + db82a58 commit 1524b89
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 16 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| SOAP | *Improving and Stabilizing Shampoo using Adam* | [github](https://github.com/nikhilvyas/SOAP) | <https://arxiv.org/abs/2409.11321> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240911321V/exportcitation) |
| ADOPT | *Modified Adam Can Converge with Any β2 with the Optimal Rate* | [github](https://github.com/iShohei220/adopt) | <https://arxiv.org/abs/2411.02853> | [cite](https://github.com/iShohei220/adopt?tab=readme-ov-file#citation) |
| FTRL | *Follow The Regularized Leader* | | <https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf> | |
| Cautious | *Improving Training with One Line of Code* | [github](https://github.com/kyleliang919/C-Optim) | <https://arxiv.org/pdf/2411.16085v1> | [cite](https://github.com/kyleliang919/C-Optim?tab=readme-ov-file#citation) |

## Supported LR Scheduler

Expand Down
6 changes: 6 additions & 0 deletions docs/changelogs/v3.3.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
* [Modified Adam Can Converge with Any β2 with the Optimal Rate](https://arxiv.org/abs/2411.02853)
* Implement `FTRL` optimizer. (#291)
* [Follow The Regularized Leader](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf)
* Implement `Cautious optimizer` feature. (#294)
* [Improving Training with One Line of Code](https://arxiv.org/pdf/2411.16085v1)
* you can use it by setting `cautious=True` for `Lion`, `AdaFactor` and `AdEMAMix` optimizers.
* Improve the stability of `ADOPT` optimizer. (#294)
* [Note](https://github.com/iShohei220/adopt?tab=readme-ov-file#update-on-nov-22-2024)
* Support a new projection type `random` for `GaLoreProjector`. (#294)

### Refactor

Expand Down
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ get_supported_optimizers(['adam*', 'ranger*'])
| SOAP | *Improving and Stabilizing Shampoo using Adam* | [github](https://github.com/nikhilvyas/SOAP) | <https://arxiv.org/abs/2409.11321> | [cite](https://ui.adsabs.harvard.edu/abs/2024arXiv240911321V/exportcitation) |
| ADOPT | *Modified Adam Can Converge with Any β2 with the Optimal Rate* | [github](https://github.com/iShohei220/adopt) | <https://arxiv.org/abs/2411.02853> | [cite](https://github.com/iShohei220/adopt?tab=readme-ov-file#citation) |
| FTRL | *Follow The Regularized Leader* | | <https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf> | |
| Cautious | *Improving Training with One Line of Code* | [github](https://github.com/kyleliang919/C-Optim) | <https://arxiv.org/pdf/2411.16085v1> | [cite](https://github.com/kyleliang919/C-Optim?tab=readme-ov-file#citation) |

## Supported LR Scheduler

Expand Down
11 changes: 11 additions & 0 deletions pytorch_optimizer/base/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,17 @@ def approximate_sq_grad(
c_factor: torch.Tensor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
torch.mul(r_factor, c_factor, out=output)

@staticmethod
def apply_cautious(update: torch.Tensor, grad: torch.Tensor) -> None:
r"""Apply the Cautious Optimizer feature.
:param update: torch.Tensor. update. it'll be masked in in-place manner.
:param grad: torch.Tensor. gradient.
"""
mask = (update * grad > 0).to(grad.dtype)
mask.mul_(mask.numel() / (mask.sum() + 1))
update.mul_(mask)

@staticmethod
def validate_range(x: float, name: str, low: float, high: float, range_type: str = '[)') -> None:
if range_type == '[)' and not low <= x < high:
Expand Down
7 changes: 6 additions & 1 deletion pytorch_optimizer/optimizer/adafactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class AdaFactor(BaseOptimizer):
:param momentum_dtype: torch.dtype. type of momentum variable. In VIT paper observed that storing momentum in
half-precision (bfloat16 type) does not affect training dynamics and has no effect on the outcome while
reducing optimize overhead from 2-fold to 1.5-fold.
:param cautious: bool. whether to use the Cautious variant.
"""

def __init__(
Expand All @@ -49,6 +50,7 @@ def __init__(
eps1: float = 1e-30,
eps2: float = 1e-3,
momentum_dtype: torch.dtype = torch.bfloat16,
cautious: bool = False,
**kwargs,
):
self.validate_learning_rate(lr)
Expand All @@ -62,6 +64,7 @@ def __init__(
self.eps1 = eps1
self.eps2 = eps2
self.momentum_dtype = momentum_dtype
self.cautious = cautious

defaults: DEFAULTS = {
'lr': lr,
Expand Down Expand Up @@ -214,7 +217,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
exp_avg = state['exp_avg']
exp_avg.mul_(beta1).add_(update, alpha=1.0 - beta1)

update = exp_avg
update = exp_avg.clone()
if self.cautious:
self.apply_cautious(update, grad)

self.apply_weight_decay(
p=p,
Expand Down
18 changes: 12 additions & 6 deletions pytorch_optimizer/optimizer/ademamix.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class AdEMAMix(BaseOptimizer):
:param fixed_decay: bool. fix weight decay.
:param alpha: float. usually between 4 and 10 would work well.
:param t_alpha_beta3: Optional[float]. total number of iterations is preferred when needed.
:param cautious: bool. whether to use cautious feature.
:param eps: float. term added to the denominator to improve numerical stability.
"""

Expand All @@ -32,6 +33,7 @@ def __init__(
fixed_decay: bool = False,
alpha: float = 5.0,
t_alpha_beta3: Optional[float] = None,
cautious: bool = False,
eps: float = 1e-8,
**kwargs,
):
Expand All @@ -42,6 +44,8 @@ def __init__(
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_non_negative(eps, 'eps')

self.cautious = cautious

defaults: DEFAULTS = {
'lr': lr,
'betas': betas,
Expand Down Expand Up @@ -71,9 +75,7 @@ def reset(self):

@staticmethod
def schedule_alpha(t_alpha_beta3: Optional[float], step: int, alpha: float) -> float:
if t_alpha_beta3 is None:
return alpha
return min(step * alpha / t_alpha_beta3, alpha)
return alpha if t_alpha_beta3 is None else min(step * alpha / t_alpha_beta3, alpha)

@staticmethod
def schedule_beta3(t_alpha_beta3: Optional[float], step: int, beta1: float, beta3: float) -> float:
Expand Down Expand Up @@ -107,6 +109,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
bias_correction1: float = self.debias(beta1, group['step'])
bias_correction2_sq: float = math.sqrt(self.debias(beta2, group['step']))

step_size: float = group['lr'] / bias_correction1

alpha_t: float = self.schedule_alpha(group['t_alpha_beta3'], group['step'], group['alpha'])
beta3_t: float = self.schedule_beta3(group['t_alpha_beta3'], group['step'], beta1, beta3)

Expand Down Expand Up @@ -140,10 +144,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
exp_avg_slow.mul_(beta3_t).add_(grad, alpha=1.0 - beta3_t)

de_nom = (exp_avg_sq.sqrt() / bias_correction2_sq).add_(group['eps'])
de_nom = exp_avg_sq.sqrt().div_(bias_correction2_sq).add_(group['eps'])

step_size = group['lr'] / bias_correction1
update = (exp_avg + alpha_t * exp_avg_slow).div_(de_nom)
if self.cautious:
self.apply_cautious(update, grad)

p.addcdiv_(exp_avg + alpha_t * exp_avg_slow, de_nom, value=-step_size)
p.add_(update, alpha=-step_size)

return loss
17 changes: 13 additions & 4 deletions pytorch_optimizer/optimizer/adopt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import math
from typing import Callable, Optional

import torch

from pytorch_optimizer.base.exception import NoSparseGradientError
Expand All @@ -22,6 +25,7 @@ def __init__(
params: PARAMETERS,
lr: float = 1e-3,
betas: BETAS = (0.9, 0.9999),
clip_lambda: Optional[Callable[[float], float]] = lambda step: math.pow(step, 0.25),
weight_decay: float = 0.0,
weight_decouple: bool = False,
fixed_decay: bool = False,
Expand All @@ -33,6 +37,8 @@ def __init__(
self.validate_non_negative(weight_decay, 'weight_decay')
self.validate_non_negative(eps, 'eps')

self.clip_lambda = clip_lambda

defaults: DEFAULTS = {
'lr': lr,
'betas': betas,
Expand Down Expand Up @@ -104,10 +110,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1.0 - beta2)

de_nom = exp_avg_sq.sqrt().clamp_(min=group['eps'])
if group['step'] == 2:
exp_avg.addcdiv_(grad, de_nom)
else:
exp_avg.mul_(beta1).addcdiv_(grad, de_nom, value=1.0 - beta1)

normed_grad = grad.div(de_nom)
if self.clip_lambda is not None:
clip = self.clip_lambda(group['step'])
normed_grad.clamp_(-clip, clip)

exp_avg.lerp_(normed_grad, weight=1.0 - beta1)

p.add_(exp_avg, alpha=-group['lr'])

Expand Down
22 changes: 19 additions & 3 deletions pytorch_optimizer/optimizer/galore.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS

PROJECTION_TYPE = Literal['std', 'reverse_std', 'right', 'left', 'full']
PROJECTION_TYPE = Literal['std', 'reverse_std', 'right', 'left', 'full', 'random']


class GaLoreProjector:
Expand All @@ -16,8 +16,8 @@ class GaLoreProjector:
:param rank: int. low rank to project.
:param update_proj_gap: int. num steps to update the projection.
:param scale: float. scale factor.
:param projection_type: PROJECTION_TYPE. type of projection. 'std', 'reverse_std', 'right', 'left', 'full' are
supported.
:param projection_type: PROJECTION_TYPE. type of projection. 'std', 'reverse_std', 'right', 'left', 'full' and
'random' are supported.
"""

def __init__(
Expand Down Expand Up @@ -101,6 +101,14 @@ def get_low_rank_grad_full(self, grad: torch.Tensor, steps: int) -> torch.Tensor
self.ortho_matrix = self.get_orthogonal_matrix(grad, self.rank, projection_type='full')
return torch.matmul(self.ortho_matrix[0].t(), grad) @ self.ortho_matrix[1].t()

def get_low_rank_grad_random(self, grad: torch.Tensor, steps: int) -> torch.Tensor:
is_right: bool = grad.size(0) >= grad.size(1)
if self.ortho_matrix is None or steps % self.update_proj_gap == 0:
self.ortho_matrix = self.get_orthogonal_matrix(
grad, self.rank, projection_type='right' if is_right else 'left'
)
return torch.matmul(grad, self.ortho_matrix.t()) if is_right else torch.matmul(self.ortho_matrix.t(), grad)

def project(self, full_rank_grad: torch.Tensor, steps: int) -> torch.Tensor:
if self.projection_type == 'std':
return self.get_low_rank_grad_std(full_rank_grad, steps)
Expand All @@ -112,6 +120,8 @@ def project(self, full_rank_grad: torch.Tensor, steps: int) -> torch.Tensor:
return self.get_low_rank_grad_left(full_rank_grad, steps)
if self.projection_type == 'full':
return self.get_low_rank_grad_full(full_rank_grad, steps)
if self.projection_type == 'random':
return self.get_low_rank_grad_random(full_rank_grad, steps)
raise NotImplementedError

def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor:
Expand All @@ -133,6 +143,12 @@ def project_back(self, low_rank_grad: torch.Tensor) -> torch.Tensor:
return torch.matmul(self.ortho_matrix, low_rank_grad) * self.scale
if self.projection_type == 'full':
return torch.matmul(self.ortho_matrix[0], low_rank_grad) @ self.ortho_matrix[1].t() * self.scale
if self.projection_type == 'random':
return (
torch.matmul(low_rank_grad, self.ortho_matrix.t())
if low_rank_grad.shape[0] >= low_rank_grad.shape[1]
else torch.matmul(self.ortho_matrix, low_rank_grad)
) * self.scale

raise NotImplementedError

Expand Down
6 changes: 6 additions & 0 deletions pytorch_optimizer/optimizer/lion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Lion(BaseOptimizer):
:param use_gc: bool. use gradient centralization.
:param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
:param adanorm: bool. whether to use the AdaNorm variant.
:param cautious: bool. whether to use the Cautious variant.
"""

def __init__(
Expand All @@ -31,13 +32,15 @@ def __init__(
use_gc: bool = False,
r: float = 0.95,
adanorm: bool = False,
cautious: bool = False,
**kwargs,
):
self.validate_learning_rate(lr)
self.validate_betas(betas)
self.validate_non_negative(weight_decay, 'weight_decay')

self.use_gc = use_gc
self.cautious = cautious

defaults: DEFAULTS = {
'lr': lr,
Expand Down Expand Up @@ -114,6 +117,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
update.mul_(beta1).add_(grad, alpha=1.0 - beta1).sign_()
exp_avg.mul_(beta2).add_(s_grad, alpha=1.0 - beta2)

if self.cautious:
self.apply_cautious(update, grad)

p.add_(update, alpha=-group['lr'])

return loss
12 changes: 10 additions & 2 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@
(AdamS, {'lr': 1e0, 'weight_decay': 1e-3, 'ams_bound': True}, 20),
(AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'scale_parameter': False}, 100),
(AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'ams_bound': True}, 120),
(AdaFactor, {'lr': 1e1, 'weight_decay': 1e-3, 'cautious': True}, 70),
(AdaFactor, {'lr': 1e1, 'betas': (None, 0.999), 'weight_decay': 1e-3}, 40),
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3}, 10),
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'rebound': 'belief'}, 10),
Expand All @@ -383,6 +384,7 @@
(Lion, {'lr': 5e-1, 'weight_decay': 1e-3}, 5),
(Lion, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decouple': False}, 5),
(Lion, {'lr': 5e-1, 'weight_decay': 1e-3, 'use_gc': True}, 10),
(Lion, {'lr': 5e-1, 'weight_decay': 1e-3, 'cautious': True}, 5),
(AliG, {'max_lr': 5e-1, 'momentum': 0.9}, 5),
(AliG, {'max_lr': 5e-1, 'momentum': 0.9, 'adjusted_momentum': True}, 5),
(SM3, {'lr': 5e-1, 'momentum': 0.9, 'beta': 0.9}, 5),
Expand Down Expand Up @@ -469,6 +471,11 @@
{'lr': 5e-1, 'weight_decay': 1e-3, 'rank': 2, 'scale': 1.0, 'update_proj_gap': 2, 'projection_type': 'full'},
5,
),
(
GaLore,
{'lr': 1e0, 'weight_decay': 1e-3, 'rank': 2, 'scale': 1.0, 'update_proj_gap': 1, 'projection_type': 'random'},
5,
),
(Adalite, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(ScheduleFreeSGD, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(ScheduleFreeAdamW, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
Expand All @@ -478,8 +485,9 @@
(Kate, {'lr': 5e-2}, 10),
(StableAdamW, {'lr': 1e0}, 5),
(AdamG, {'lr': 1e0}, 20),
(AdEMAMix, {'lr': 1e0}, 5),
(AdEMAMix, {'lr': 1e0, 't_alpha_beta3': 5}, 5),
(AdEMAMix, {'lr': 1e0}, 3),
(AdEMAMix, {'lr': 1e0, 't_alpha_beta3': 5}, 3),
(AdEMAMix, {'lr': 1e0, 'cautious': True}, 2),
(
SOAP,
{'lr': 1e0, 'shampoo_beta': 0.95, 'precondition_frequency': 1, 'merge_dims': False, 'precondition_1d': True},
Expand Down

0 comments on commit 1524b89

Please sign in to comment.