Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inverse sqrt learning rate scheduler #21495

Merged
merged 10 commits into from
Feb 7, 2023
2 changes: 2 additions & 0 deletions docs/source/en/main_classes/optimizer_schedules.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ The `.optimization` module provides:

[[autodoc]] get_polynomial_decay_schedule_with_warmup

[[autodoc]] get_inverse_sqrt_schedule

### Warmup (TensorFlow)

[[autodoc]] WarmUp
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2588,6 +2588,7 @@
"get_constant_schedule_with_warmup",
"get_cosine_schedule_with_warmup",
"get_cosine_with_hard_restarts_schedule_with_warmup",
"get_inverse_sqrt_schedule",
"get_linear_schedule_with_warmup",
"get_polynomial_decay_schedule_with_warmup",
"get_scheduler",
Expand Down Expand Up @@ -5659,6 +5660,7 @@
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_inverse_sqrt_schedule,
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
get_scheduler,
Expand Down
40 changes: 40 additions & 0 deletions src/transformers/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,50 @@ def lr_lambda(current_step: int):
return LambdaLR(optimizer, lr_lambda, last_epoch)


def get_inverse_sqrt_schedule(
optimizer: Optimizer, num_warmup_steps: int, timescale: int = None, last_epoch: int = -1
):
"""
Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a
warmup period which increases lr linearly from 0 to the initial lr set in the optimizer.

Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
timescale (`int`, *optional*, defaults to `num_warmup_steps`):
Time scale.
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.

Note: this implementation is adapted from
https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a comment in the doc, the user reading the documentation won't really care about this.


Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
if timescale is None:
timescale = num_warmup_steps

def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
shift = timescale - num_warmup_steps
decay = 1.0 / math.sqrt((current_step + shift) / timescale)
return decay

return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)


TYPE_TO_SCHEDULER_FUNCTION = {
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
SchedulerType.CONSTANT: get_constant_schedule,
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
}


Expand Down Expand Up @@ -263,6 +300,9 @@ def get_scheduler(
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)

if name == SchedulerType.INVERSE_SQRT:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)

# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
Expand Down
1 change: 1 addition & 0 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ class SchedulerType(ExplicitEnum):
POLYNOMIAL = "polynomial"
CONSTANT = "constant"
CONSTANT_WITH_WARMUP = "constant_with_warmup"
INVERSE_SQRT = "inverse_sqrt"


class TrainerMemoryTracker:
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -7019,6 +7019,10 @@ def get_cosine_with_hard_restarts_schedule_with_warmup(*args, **kwargs):
requires_backends(get_cosine_with_hard_restarts_schedule_with_warmup, ["torch"])


def get_inverse_sqrt_schedule(*args, **kwargs):
requires_backends(get_inverse_sqrt_schedule, ["torch"])


def get_linear_schedule_with_warmup(*args, **kwargs):
requires_backends(get_linear_schedule_with_warmup, ["torch"])

Expand Down
5 changes: 5 additions & 0 deletions tests/optimization/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_cosine_with_hard_restarts_schedule_with_warmup,
get_inverse_sqrt_schedule,
get_linear_schedule_with_warmup,
get_polynomial_decay_schedule_with_warmup,
)
Expand Down Expand Up @@ -145,6 +146,10 @@ def test_schedulers(self):
{**common_kwargs, "power": 2.0, "lr_end": 1e-7},
[0.0, 5.0, 10.0, 7.656, 5.625, 3.906, 2.5, 1.406, 0.625, 0.156],
),
get_inverse_sqrt_schedule: (
{"num_warmup_steps": 2},
[0.0, 5.0, 10.0, 8.165, 7.071, 6.325, 5.774, 5.345, 5.0, 4.714],
),
}

for scheduler_func, data in scheds.items():
Expand Down