Skip to content

Commit

Permalink
Merge pull request BlinkDL#7 from PicoCreator/dev-infctx-lr-final
Browse files Browse the repository at this point in the history
Added support for lr_final
  • Loading branch information
PicoCreator authored Jun 29, 2023
2 parents 6792641 + 32ae35d commit 84e159e
Showing 1 changed file with 99 additions and 8 deletions.
107 changes: 99 additions & 8 deletions RWKV-v4neo/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import gc, math
import os, math, sys
from random import randint
from typing import List, Optional

Expand Down Expand Up @@ -254,6 +254,9 @@ def __init__(self,
vocab_size: int,
grad_cp: bool,
lr_init: float,
lr_final: float = -1.0,
lr_period: int = -1,
lr_period_type: str = 'epoch',
warmup_steps: int = -1,
beta1: float = 0.9,
beta2: float = 0.99,
Expand All @@ -274,6 +277,9 @@ def __init__(self,
self.layerwise_lr = layerwise_lr
self.grad_cp = grad_cp
self.lr_init = lr_init
self.lr_final = lr_final
self.lr_period = lr_period
self.lr_period_type = lr_period_type
self.warmup_steps = warmup_steps
self.beta1 = beta1
self.beta2 = beta2
Expand Down Expand Up @@ -354,9 +360,15 @@ def configure_optimizers(self):
},
]

# Set ending_lr to starting_lr, as default behavior
starting_lr = self.lr_init
ending_lr = self.lr_final
if ending_lr < 0:
ending_lr = self.lr_init

if self.deepspeed_offload:
optimizer = DeepSpeedCPUAdam(optim_groups,
lr=self.lr_init,
lr=starting_lr,
betas=(self.beta1, self.beta2),
eps=self.adam_eps,
bias_correction=True,
Expand All @@ -365,13 +377,18 @@ def configure_optimizers(self):
amsgrad=False)
else:
optimizer = FusedAdam(optim_groups,
lr=self.lr_init,
lr=starting_lr,
betas=(self.beta1, self.beta2),
eps=self.adam_eps,
bias_correction=True,
adam_w_mode=False,
weight_decay=self.weight_decay,
amsgrad=False)

# Throw if wramup_steps and lr_period are both set (not supported)
if self.warmup_steps > 0 and self.lr_period > 0:
raise ValueError(
"Use either warmup_steps or lr_period, not both.")

if self.warmup_steps > 0:
lr_scheduler = deepspeed.runtime.lr_schedules.WarmupLR(
Expand All @@ -381,10 +398,83 @@ def configure_optimizers(self):
warmup_num_steps=self.warmup_steps,
warmup_type='linear')

return optimizer, lr_scheduler
return {
'optimizer': optimizer,
'lr_scheduler': lr_scheduler,
}

else:
return optimizer
# Skip the lr_scheduler process if lr_init and lr_final are the same
if starting_lr == ending_lr:
return optimizer

# The total number of steps to perform training rate decay with
lr_total_step = 0

# Handle lr_period -1 default behaviour of using the max_step / max_epoch
if self.lr_period == -1:
# Get trainer max_step / max_epoch
trainer_max_step = self.trainer.max_steps
trainer_max_epoch = self.trainer.max_epochs
if trainer_max_step > 0:
lr_total_step = trainer_max_step
elif trainer_max_epoch > 0:
lr_total_step = trainer_max_epoch * self.num_step_per_epoch()
else :
print("Warning: max_step/max_epoch not set, we would be performing lr_init to lr_final shift assuming 10 epoch")
lr_total_step = 10 * self.num_step_per_epoch()
else:
# Calculate lr_total_step based on lr_period
if self.lr_period_type == "step":
lr_total_step = self.lr_period
elif self.lr_period_type == "epoch":
lr_total_step = self.lr_period * self.num_step_per_epoch()
else:
raise ValueError(f"lr_period_type {self.lr_period_type} not supported.")

# Lets initialize the lr_scheduler
lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=1.0,
end_factor= ending_lr / starting_lr,
total_iters=lr_total_step
)

return {
'optimizer': optimizer,
'lr_scheduler': {
"scheduler": lr_scheduler,
"interval": "step",
"frequency": 1,
},
}


# We have to compute the number of steps per epoch ourselves
# as this value is not provided directly by pytorch lightning
# https://github.com/Lightning-AI/lightning/issues/5449#issuecomment-1501597319
def num_step_per_epoch(self) -> int:
# Estimated number of steps in total, added as the following
# https://github.com/Lightning-AI/lightning/pull/11599
#
# This MUST be called before len(self.trainer.train_loader)
# otherwise there is a bug in which the train_dataloader is not
# fully initialized, which seems to be resolved by computing the
# self.trainer.estimated_stepping_batches
estimated_stepping_batches = self.trainer.estimated_stepping_batches

# Get the number of epochs,
# use estimated_stepping_batches if max_epochs is set
max_epochs = self.trainer.max_epochs
if max_epochs > 0:
return estimated_stepping_batches // max_epochs

# Max epoch is not set, use the train_dataloader
dataset_size = len(self.trainer.train_dataloader)
num_devices = max(1, self.trainer.num_devices)
num_steps = dataset_size // (self.trainer.accumulate_grad_batches * num_devices)
return num_steps

@property
def deepspeed_offload(self) -> bool:
strategy = self.trainer.strategy
Expand Down Expand Up @@ -513,10 +603,11 @@ def checkpointed_step(idx, targets, mask, prev_loss, last_shift_states,
# Wandb logging only, if an active run exists
if wandb.run is not None:
wandb.log({
'substep': batch_idx,
'real_ctx_len': T,
'substep': batch_idx,
'real_ctx_len': T,
'train/loss': total_loss,
'trainer/global_step': self.global_step
'trainer/global_step':self.global_step,
'trainer/learning_rate': self.trainer.optimizers[0].param_groups[0]['lr']
})

return total_loss
Expand Down

0 comments on commit 84e159e

Please sign in to comment.