Skip to content

Commit

Permalink
fix ConsineAnnealing
Browse files Browse the repository at this point in the history
  • Loading branch information
poodarchu committed Apr 20, 2023
1 parent 7c673ba commit b1e4bc5
Showing 1 changed file with 97 additions and 81 deletions.
178 changes: 97 additions & 81 deletions efg/solver/lr_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,124 +83,140 @@ def _compute_values(self) -> List[float]:
return self.get_lr()


class WarmupCosineAnnealingLR(_LRScheduler):
r"""Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial lr and
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
.. math::
\begin{aligned}
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
& T_{cur} \neq (2k+1)T_{max}; \\
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
& T_{cur} = (2k+1)T_{max}.
\end{aligned}
When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
is defined recursively, the learning rate can be simultaneously modified
outside this scheduler by other operators. If the learning rate is set
solely by this scheduler, the learning rate at each step becomes:
.. math::
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
\cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
It has been proposed in
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
implements the cosine annealing part of SGDR, and not the restarts.
Args:
optimizer (Optimizer): Wrapped optimizer.
T_max (int): Maximum number of iterations.
eta_min (float): Minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
class LinearWarmupCosineAnnealingLR(_LRScheduler):
"""Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr
and base_lr followed by a cosine annealing schedule between base_lr and eta_min.
.. warning::
It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR`
after each iteration as calling it after each epoch will keep the starting lr at
warmup_start_lr for the first epoch which is 0 in most cases.
.. warning::
passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING.
It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of
:func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing
epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling
train and validation methods.
Example:
>>> layer = nn.Linear(10, 1)
>>> optimizer = Adam(layer.parameters(), lr=0.02)
>>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40)
>>> #
>>> # the default case
>>> for epoch in range(40):
... # train(...)
... # validate(...)
... scheduler.step()
>>> #
>>> # passing epoch param case
>>> for epoch in range(40):
... scheduler.step(epoch)
... # train(...)
... # validate(...)
"""

def __init__(
self,
optimizer,
max_iters,
warmup_iters=0,
warmup_method="linear",
warmup_factor=0.001,
eta_min=0,
last_epoch=-1,
verbose=False,
):
self.T_max = max_iters
self.T_warmup = warmup_iters
self.warmup_method = warmup_method
self.warmup_factor = warmup_factor
optimizer: torch.optim.Optimizer,
warmup_epochs: int,
max_epochs: int,
warmup_start_lr: float = 0.0,
eta_min: float = 0.0,
last_epoch: int = -1,
) -> None:
"""
Args:
optimizer (Optimizer): Wrapped optimizer.
warmup_epochs (int): Maximum number of iterations for linear warmup
max_epochs (int): Maximum number of iterations
warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
eta_min (float): Minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
"""
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs
self.warmup_start_lr = warmup_start_lr
self.eta_min = eta_min
super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch, verbose)

def get_lr(self):
super().__init__(optimizer, last_epoch)

def get_lr(self) -> List[float]:
"""Compute learning rate using chainable form of the scheduler."""
if not self._get_lr_called_within_step:
warnings.warn(
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning
"To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
UserWarning,
)

if self.last_epoch == 0:
return [self.warmup_factor * base_lr for base_lr in self.base_lrs]
elif self.last_epoch < self.T_warmup:
warmup_factor = _get_warmup_factor_at_iter(
self.warmup_method, self.last_epoch, self.T_warmup, self.warmup_factor
)
return [self.warmup_start_lr] * len(self.base_lrs)
if self.last_epoch < self.warmup_epochs:
return [
group["lr"] + (base_lr - base_lr * warmup_factor) / (self.T_warmup - 1)
group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]
elif (self.last_epoch - 1 - self.T_max) % (2 * (self.T_max - self.T_warmup)) == 0:
if self.last_epoch == self.warmup_epochs:
return self.base_lrs
if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
return [
group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.T_max - self.T_warmup))) / 2
group["lr"]
+ (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
]

return [
(1 + math.cos(math.pi * (self.last_epoch - self.T_warmup) / (self.T_max - self.T_warmup)))
/ (1 + math.cos(math.pi * (self.last_epoch - self.T_warmup - 1) / (self.T_max - self.T_warmup)))
(1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
/ (
1
+ math.cos(
math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)
)
)
* (group["lr"] - self.eta_min)
+ self.eta_min
for group in self.optimizer.param_groups
]

def _get_closed_form_lr(self):
def _get_closed_form_lr(self) -> List[float]:
"""Called when epoch is passed as a param to the `step` function of the scheduler."""
if self.last_epoch < self.T_warmup:
warmup_factor = _get_warmup_factor_at_iter(
self.warmup_method, self.last_epoch, self.T_warmup, self.warmup_factor
)

if self.last_epoch < self.warmup_epochs:
return [
base_lr * warmup_factor + self.last_epoch * (base_lr - base_lr * warmup_factor) / (self.T_warmup - 1)
self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
for base_lr in self.base_lrs
]

return [
self.eta_min
+ (base_lr - self.eta_min)
* (1 + math.cos(math.pi * (self.last_epoch - self.T_warmup) / (self.T_max - self.T_warmup)))
/ 2
+ 0.5
* (base_lr - self.eta_min)
* (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
for base_lr in self.base_lrs
]


@LR_SCHEDULERS.register()
class WarmupCosineAnnealing:
class LinearWarmupCosineAnnealing:
"""
config.solver.lr_scheduer
Args:
max_iters / max_epochs: int
min_lr_ratio: float, default 0.001, multipier of config.solver.optimizer.lr
warmup_iters: int, default 1000
warmup_ratio: float, default 0.001
"""
@staticmethod
def build(config, optimizer):
sconfig = config.solver.lr_scheduler
max_epochs = sconfig.pop("max_epochs")
epoch_iters = config.solver.lr_scheduler.pop("epoch_iters")
lr_scheduler = WarmupCosineAnnealingLR(optimizer, **sconfig)
sconfig.max_epochs = max_epochs
config.solver.lr_scheduler.epoch_iters = epoch_iters
return lr_scheduler
scheduler = LinearWarmupCosineAnnealingLR(
optimizer,
max_epochs=sconfig.max_iters,
eta_min=config.solver.optimizer.lr * sconfig.min_lr_ratio,
warmup_epochs=sconfig.warmup_iters,
warmup_start_lr=config.solver.optimizer.lr * sconfig.warmup_ratio,
)
return scheduler


@LR_SCHEDULERS.register()
Expand Down

0 comments on commit b1e4bc5

Please sign in to comment.