Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Wrapper optimizers #332

Merged
merged 10 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/changelogs/v3.3.5.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
### Change Log

### Feature

* Implement `FOCUS` optimizer. (#330, #331)
* [First Order Concentrated Updating Scheme](https://arxiv.org/abs/2501.12243)

### Update

* Support `OrthoGrad` variant to `Ranger25`. (#332)

### Fix

* Add the missing `state` property in `OrthoGrad` optimizer. (#326, #327)
* Add the missing `state_dict`, and `load_state_dict` methods to `TRAC` and `OrthoGrad` optimizers. (#332)
* Skip when the gradient is sparse in `OrthoGrad` optimizer. (#332)

### Contributions

Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@ select = [
"TID", "ARG", "ERA", "RUF", "YTT", "PL", "Q"
]
ignore = [
"A005", "B905", "D100", "D102", "D104", "D105", "D107", "D203", "D213", "D413", "PIE790", "PLR0912", "PLR0913",
"PLR0915", "PLR2004", "RUF013", "Q003", "ARG002",
"A005", "B905",
"D100", "D102", "D104", "D105", "D107", "D203", "D213", "D413",
"PLR0912", "PLR0913", "PLR0915", "PLR2004",
"Q003", "ARG002",
]
fixable = ["ALL"]
unfixable = ["F401"]
Expand Down
23 changes: 22 additions & 1 deletion pytorch_optimizer/base/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,16 @@
from torch.optim import Optimizer

from pytorch_optimizer.base.exception import NegativeLRError, NegativeStepError
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, HUTCHINSON_G, LOSS, PARAMETERS, STATE
from pytorch_optimizer.base.types import (
BETAS,
CLOSURE,
DEFAULTS,
HUTCHINSON_G,
LOSS,
OPTIMIZER_INSTANCE_OR_CLASS,
PARAMETERS,
STATE,
)


class BaseOptimizer(ABC, Optimizer):
Expand All @@ -15,6 +24,18 @@ class BaseOptimizer(ABC, Optimizer):
def __init__(self, params: PARAMETERS, defaults: DEFAULTS) -> None:
super().__init__(params, defaults)

@staticmethod
def load_optimizer(optimizer: OPTIMIZER_INSTANCE_OR_CLASS, **kwargs) -> Optimizer:
r"""Build torch.optim.Optimizer class."""
if isinstance(optimizer, Optimizer):
return optimizer

if 'params' in kwargs:
params = kwargs.pop('params')
return optimizer(params, **kwargs)

raise ValueError('need to pass `params` when you pass the `torch.optim.Optimizer` instance.')

@staticmethod
@torch.no_grad()
def set_hessian(param_groups: PARAMETERS, state: STATE, hessian: List[torch.Tensor]) -> None:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_optimizer/loss/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class JaccardLoss(_Loss):
def __init__(
self,
mode: CLASS_MODE,
classes: List[int] = None,
classes: Optional[List[int]] = None,
log_loss: bool = False,
from_logits: bool = True,
label_smooth: float = 0.0,
Expand All @@ -59,7 +59,7 @@ def __init__(

if classes is not None:
if mode == 'binary':
raise ValueError('[-] Masking classes is not supported with mode=binary')
raise ValueError('masking classes is not supported with mode=binary')

classes = torch.LongTensor(classes)

Expand Down
24 changes: 23 additions & 1 deletion pytorch_optimizer/optimizer/experimental/ranger25.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class Ranger25(BaseOptimizer):
r"""Mixin' every fancy optimizer hacks.

ADOPT + AdEMAMix + Cautious + StableAdamW + Adam-Atan2
ADOPT + AdEMAMix + Cautious + StableAdamW + Adam-Atan2 + OrthoGrad

:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
:param lr: float. learning rate.
Expand All @@ -23,6 +23,7 @@ class Ranger25(BaseOptimizer):
:param t_alpha_beta3: Optional[float]. total number of iterations is preferred when needed.
:param cautious: bool. whether to use the Cautious variant.
:param stable_adamw: bool. whether to use stable AdamW variant.
:param orthograd: bool. whether to use orthograd variant.
:param eps: Optional[float]. term added to the denominator to improve numerical stability. when eps is None and
stable_adamw is False, adam-atan2 feature will be used.
"""
Expand All @@ -39,6 +40,7 @@ def __init__(
t_alpha_beta3: Optional[float] = None,
cautious: bool = True,
stable_adamw: bool = True,
orthograd: bool = True,
eps: Optional[float] = 1e-8,
**kwargs,
):
Expand All @@ -51,6 +53,7 @@ def __init__(

self.cautious = cautious
self.stable_adamw: bool = stable_adamw if isinstance(eps, float) else False
self.orthograd = orthograd

defaults: DEFAULTS = {
'lr': lr,
Expand Down Expand Up @@ -97,13 +100,32 @@ def schedule_beta3(t_alpha_beta3: Optional[float], step: int, beta1: float, beta
beta3,
)

@torch.no_grad()
def orthogonalize_gradients(self, params, eps: float = 1e-16) -> None:
for p in params:
if p.grad is None or p.grad.is_sparse:
continue

w = p.view(-1)
g = p.grad.view(-1)

proj = torch.dot(w, g).div_(torch.dot(w, w).add_(eps))
g_ortho = g.to(dtype=torch.float32, copy=True).sub_(w, alpha=proj)
g_ortho_scaled = g_ortho.mul_(g.norm(2).div_(g_ortho.norm(2).add_(eps)))

p.grad.copy_(g_ortho_scaled.view_as(p.grad))

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
with torch.enable_grad():
loss = closure()

if self.orthograd:
for group in self.param_groups:
self.orthogonalize_gradients(group['params'])

for group in self.param_groups:
if 'step' in group:
group['step'] += 1
Expand Down
8 changes: 1 addition & 7 deletions pytorch_optimizer/optimizer/lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,7 @@ def __init__(
self.validate_range(alpha, 'alpha', 0.0, 1.0)
self.validate_options(pullback_momentum, 'pullback_momentum', ['none', 'reset', 'pullback'])

if isinstance(optimizer, Optimizer):
self.optimizer = optimizer
elif 'params' in kwargs:
params = kwargs.pop('params')
self.optimizer = optimizer(params, **kwargs)
else:
raise ValueError('Need to pass `params` when you pass the torch.optim.Optimizer instance.')
self.optimizer: Optimizer = self.load_optimizer(optimizer, **kwargs)

self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
self._optimizer_step_post_hooks: Dict[int, Callable] = {}
Expand Down
20 changes: 10 additions & 10 deletions pytorch_optimizer/optimizer/orthograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.optim import Optimizer

from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS, STATE


class OrthoGrad(BaseOptimizer):
Expand All @@ -20,13 +20,7 @@ def __init__(self, optimizer: OPTIMIZER_INSTANCE_OR_CLASS, **kwargs) -> None:
self._optimizer_step_post_hooks: Dict[int, Callable] = {}
self.eps: float = 1e-30

if isinstance(optimizer, Optimizer):
self.optimizer = optimizer
elif 'params' in kwargs:
params = kwargs.pop('params')
self.optimizer = optimizer(params, **kwargs)
else:
raise ValueError('Need to pass `params` when you pass the torch.optim.Optimizer instance.')
self.optimizer: Optimizer = self.load_optimizer(optimizer, **kwargs)

self.defaults: DEFAULTS = self.optimizer.defaults

Expand All @@ -38,9 +32,15 @@ def param_groups(self):
return self.optimizer.param_groups

@property
def state(self):
def state(self) -> STATE:
return self.optimizer.state

def state_dict(self) -> STATE:
return self.optimizer.state_dict()

def load_state_dict(self, state_dict: STATE) -> None:
self.optimizer.load_state_dict(state_dict)

@torch.no_grad()
def zero_grad(self) -> None:
self.optimizer.zero_grad(set_to_none=True)
Expand All @@ -52,7 +52,7 @@ def reset(self):
@torch.no_grad()
def orthogonalize_gradients(self, params) -> None:
for p in params:
if p.grad is None:
if p.grad is None or p.grad.is_sparse:
continue

w = p.view(-1)
Expand Down
1 change: 0 additions & 1 deletion pytorch_optimizer/optimizer/shampoo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def __init__(self, *args):

def add_statistics(self, grad: torch.Tensor, unused_beta2: float) -> None:
r"""Add the statistics."""
pass

def precondition_gradient(self, grad: torch.Tensor) -> torch.Tensor:
r"""Get preconditioned gradient."""
Expand Down
24 changes: 12 additions & 12 deletions pytorch_optimizer/optimizer/trac.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch.optim import Optimizer

from pytorch_optimizer.base.optimizer import BaseOptimizer
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, OPTIMIZER_INSTANCE_OR_CLASS, STATE


def polyval(x: torch.Tensor, coef: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -112,23 +112,17 @@ def __init__(
self.validate_non_negative(s_prev, 's_prev')
self.validate_non_negative(eps, 'eps')

if isinstance(optimizer, Optimizer):
self.optimizer = optimizer
elif 'params' in kwargs:
params = kwargs.pop('params')
self.optimizer = optimizer(params, **kwargs)
else:
raise ValueError('Need to pass `params` when you pass the torch.optim.Optimizer instance.')

self._optimizer_step_pre_hooks: Dict[int, Callable] = {}
self._optimizer_step_post_hooks: Dict[int, Callable] = {}

self.erf = ERF1994(num_coefs=num_coefs)
self.optimizer: Optimizer = self.load_optimizer(optimizer, **kwargs)

self.betas = betas
self.s_prev = s_prev
self.eps = eps

self.f_term = self.s_prev / self.erf_imag(1.0 / torch.sqrt(torch.tensor(2.0)))
self.erf: nn.Module = ERF1994(num_coefs=num_coefs)
self.f_term: torch.Tensor = self.s_prev / self.erf_imag(1.0 / torch.sqrt(torch.tensor(2.0)))

self.defaults: DEFAULTS = self.optimizer.defaults

Expand All @@ -140,9 +134,15 @@ def param_groups(self):
return self.optimizer.param_groups

@property
def state(self):
def state(self) -> STATE:
return self.optimizer.state

def state_dict(self) -> STATE:
return self.optimizer.state_dict()

def load_state_dict(self, state_dict: STATE) -> None:
self.optimizer.load_state_dict(state_dict)

@torch.no_grad()
def reset(self):
device = self.param_groups[0]['params'][0].device
Expand Down
3 changes: 2 additions & 1 deletion pytorch_optimizer/optimizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def is_deepspeed_zero3_enabled() -> bool:
return is_deepspeed_zero3_enabled() # pragma: no cover

warnings.warn(
'you need to install `transformers` to use `is_deepspeed_zero3_enabled` function. it\'ll return False.',
'you need to install `transformers` to use `is_deepspeed_zero3_enabled` function. '
'it will return False.',
category=ImportWarning,
stacklevel=2,
)
Expand Down
6 changes: 3 additions & 3 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,9 @@
(TAM, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
(AdaTAM, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
(FOCUS, {'lr': 1e-1, 'weight_decay': 1e-3}, 5),
(Ranger25, {'lr': 1e0}, 5),
(Ranger25, {'lr': 1e0, 't_alpha_beta3': 5}, 5),
(Ranger25, {'lr': 1e-1, 'stable_adamw': False, 'eps': None}, 5),
(Ranger25, {'lr': 5e0}, 2),
(Ranger25, {'lr': 5e0, 't_alpha_beta3': 5}, 2),
(Ranger25, {'lr': 2e-1, 'stable_adamw': False, 'orthograd': False, 'eps': None}, 3),
]
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),
Expand Down
11 changes: 9 additions & 2 deletions tests/test_optimizer_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,18 @@ def test_galore_projection_type():


@pytest.mark.parametrize('optimizer_instance', [Lookahead, OrthoGrad, TRAC])
def test_load_optimizer(optimizer_instance):
def test_load_wrapper_optimizer(optimizer_instance):
params = [simple_parameter()]

_ = optimizer_instance(torch.optim.AdamW(params))
_ = optimizer_instance(torch.optim.AdamW, params=params)
optimizer = optimizer_instance(torch.optim.AdamW, params=params)
optimizer.zero_grad()

with pytest.raises(ValueError):
optimizer_instance(torch.optim.AdamW)

_ = optimizer.param_groups
_ = optimizer.state

state = optimizer.state_dict()
optimizer.load_state_dict(state)
Loading