Skip to content

Commit

Permalink
Added infinite lr schedules
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijkg committed Mar 25, 2024
1 parent 7267a74 commit 5f2a3f2
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 1 deletion.
36 changes: 36 additions & 0 deletions megatron/learning_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def __init__(
decay_style,
last_iter,
min_lr=0.0,
constant_lr=0.0,
constant_iters=None,
cooldown_iters=None,
timescale=None,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False,
use_mup=False,
Expand All @@ -43,9 +47,13 @@ def __init__(
self.optimizer = optimizer
self.start_lr = start_lr
self.min_lr = min_lr
self.constant_lr = constant_lr
self.warmup_iter = warmup_iter
self.num_iters = last_iter
self.end_iter = total_iters
self.constant_iters = constant_iters
self.cooldown_iters = cooldown_iters
self.timescale = timescale
assert self.end_iter > 0
self.decay_style = decay_style
self.override_lr_scheduler = override_lr_scheduler
Expand Down Expand Up @@ -84,6 +92,34 @@ def get_lr(self):
# exp(-0.693) = 1/2
end_iter = self.end_iter - self.warmup_iter
lr = self.start_lr * math.exp(-0.693 * num_iters_ / end_iter)
elif self.decay_style == "infinite_cosine" or self.decay_style == "infinite_inv_sqrt":
if num_iters_ <= self.cooldown_iter:
if self.decay_style == "infinite_cosine":
lr = self.constant_lr + (
(self.start_lr-self.constant_lr)
/ 2.0
* (math.cos(math.pi * num_iters_ / self.cooldown_iter) + 1)
)
else:
def inv_f(t):
return (1/math.sqrt(1+(self.timescale*t))) - 1
lr = self.start_lr + (
(self.constant_lr - self.start_lr)
/ inv_f(1)
* (inv_f(num_iters_ / self.cooldown_iter))
)
return lr
else:
num_iters_ = num_iters_ - self.cooldown_iter
if num_iters_ <= self.constant_iter:
# Stay constant for constant_iters
lr = self.constant_lr
else:
# Go from constant iters to min LR using exponential decay in remaining iters
end_iter_ = self.end_iter - self.warmup_iter - self.constant_iter
num_iters_ = num_iters_ - self.constant_iter
exp_factor = -math.log(self.min_lr/self.constant_lr) / end_iter_
lr = self.constant_lr * math.exp(-1* exp_factor * num_iters_)
else:
lr = self.start_lr
return max(lr, self.min_lr)
Expand Down
22 changes: 21 additions & 1 deletion megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ class NeoXArgsLRScheduler(NeoXArgsTemplate):
LR Scheduler Arguments
"""

lr_decay_style: Literal["constant", "linear", "cosine", "exponential"] = "linear"
lr_decay_style: Literal["constant", "linear", "cosine", "exponential", "infinite_cosine", "infinite_inv_sqrt"] = "linear"
"""
Learning rate decay function. Choose from 'constant', 'linear', 'cosine', 'exponential'.
"""
Expand All @@ -546,11 +546,31 @@ class NeoXArgsLRScheduler(NeoXArgsTemplate):
Minimum value for learning rate. The scheduler clips values below this threshold.
"""

constant_lr: float = 0.0
"""
Constant learning rate when using infinite cosine or infinite inv sqrt decay styles.
"""

warmup: float = 0.01
"""
Percentage of total iterations to warmup on (.01 = 1 percent of all training iters).
"""

cooldown_iters_perc: float = 0.0
"""
Percentage of total iterations to cooldown for.
"""

constant_iters_perc: float = 0.0
"""
Percentage of total iterations to keep the learning rate constant for.
"""

timescale: float = 1.0
"""
Timescale for the steepness of the inverse square root cooldown.
"""

override_lr_scheduler: bool = False
"""
Reset the values of the scheduler (learning rate,warmup iterations, minimum learning rate, maximum number of iterations, and decay style from input arguments and ignore values from checkpoints. Note that all the above values will be reset.
Expand Down
6 changes: 6 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,8 @@ def get_learning_rate_scheduler(optimizer, neox_args):
num_iters = max(1, num_iters)
init_step = 0
warmup_iter = neox_args.warmup * num_iters
constant_iters = neox_args.constant_iters_perc * num_iters
cooldown_iters = neox_args.cooldown_iters_perc * num_iters
lr_scheduler = AnnealingLR(
optimizer,
start_lr=neox_args.lr,
Expand All @@ -664,6 +666,10 @@ def get_learning_rate_scheduler(optimizer, neox_args):
decay_style=neox_args.lr_decay_style,
last_iter=init_step,
min_lr=neox_args.min_lr,
constant_lr=neox_args.constant_lr,
constant_iters=constant_iters,
cooldown_iters=cooldown_iters,
timescale=neox_args.timescale,
use_checkpoint_lr_scheduler=neox_args.use_checkpoint_lr_scheduler,
override_lr_scheduler=neox_args.override_lr_scheduler,
use_mup=neox_args.use_mup,
Expand Down

0 comments on commit 5f2a3f2

Please sign in to comment.