Skip to content

Commit

Permalink
Merge pull request #857 from bghira/feature/optim-schedulefree
Browse files Browse the repository at this point in the history
add schedulefree optim w/ kahan summation
  • Loading branch information
bghira authored Aug 23, 2024
2 parents 177190d + 9c4736b commit fd78777
Show file tree
Hide file tree
Showing 7 changed files with 254 additions and 31 deletions.
35 changes: 31 additions & 4 deletions helpers/training/optimizer_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
logger.setLevel(target_level)

is_optimi_available = False
from helpers.training.adam_bfloat16 import AdamWBF16
from helpers.training.optimizers.adamw_bfloat16 import AdamWBF16
from helpers.training.optimizers.adamw_schedulefree import AdamWScheduleFreeKahan

try:
from optimum.quanto import QTensor
Expand All @@ -35,6 +36,17 @@
},
"class": AdamWBF16,
},
"adamw_schedulefree": {
"precision": "any",
"override_lr_scheduler": True,
"can_warmup": True,
"default_settings": {
"betas": (0.9, 0.999),
"weight_decay": 1e-2,
"eps": 1e-8,
},
"class": AdamWScheduleFreeKahan,
},
"optimi-stableadamw": {
"precision": "any",
"default_settings": {
Expand Down Expand Up @@ -154,8 +166,8 @@
}

deprecated_optimizers = {
"prodigy": "Prodigy optimiser has been removed due to issues with precision levels and convergence. Please use optimi-stableadamw or optimi-lion instead - for decoupled learning rate, use --optimizer_config=decoupled_lr=True.",
"dadaptation": "D-adaptation optimiser has been removed due to issues with precision levels and convergence. Please use optimi-stableadamw instead.",
"prodigy": "Prodigy optimiser has been removed due to issues with precision levels and convergence. Please use adamw_schedulefree instead.",
"dadaptation": "D-adaptation optimiser has been removed due to issues with precision levels and convergence. Please use adamw_schedulefree instead.",
"adafactor": "Adafactor optimiser has been removed in favour of optimi-stableadamw, which offers improved memory efficiency and convergence.",
"adamw8bit": "AdamW8Bit has been removed in favour of optimi-adamw optimiser, which offers better low-precision support. Please use this or adamw_bf16 instead.",
}
Expand Down Expand Up @@ -206,6 +218,16 @@ def optimizer_parameters(optimizer, args):
raise ValueError(f"Optimizer {optimizer} not found.")


def is_lr_scheduler_disabled(optimizer: str):
"""Check if the optimizer has a built-in LR scheduler"""
is_disabled = False
if optimizer in optimizer_choices:
is_disabled = optimizer_choices.get(optimizer).get(
"override_lr_scheduler", False
)
return is_disabled


def show_optimizer_defaults(optimizer: str = None):
"""we'll print the defaults on a single line, eg. foo=bar, buz=baz"""
if optimizer is None:
Expand Down Expand Up @@ -260,7 +282,12 @@ def determine_optimizer_class_with_config(
else:
optimizer_class, optimizer_details = optimizer_parameters(args.optimizer, args)
default_settings = optimizer_details.get("default_settings")
logger.info(f"cls: {optimizer_class}, settings: {default_settings}")
if optimizer_details.get("can_warmup", False):
logger.info(
f"Optimizer contains LR scheduler, warmup steps will be set to {args.lr_warmup_steps}."
)
default_settings["warmup_steps"] = args.lr_warmup_steps
logger.info(f"cls: {optimizer_class}, settings: {default_settings}")
return default_settings, optimizer_class


Expand Down
File renamed without changes.
149 changes: 149 additions & 0 deletions helpers/training/optimizers/adamw_schedulefree/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import torch
from torch.optim.optimizer import Optimizer
import math
from typing import Iterable
from helpers.training.state_tracker import StateTracker


class AdamWScheduleFreeKahan(Optimizer):
"""AdamW optimizer with schedule-free adjustments and Kahan summation.
Args:
params: Iterable of parameters to optimize or dicts defining parameter groups.
lr: Learning rate.
betas: Coefficients for gradient and squared gradient moving averages (default: (0.9, 0.999)).
eps: Added to denominator to improve numerical stability (default: 1e-8).
weight_decay: Weight decay coefficient (default: 1e-2).
warmup_steps: Number of steps to warm up the learning rate (default: 0).
kahan_sum: Enables Kahan summation for more accurate parameter updates when training in low precision.
"""

def __init__(
self,
params: Iterable,
lr: float = 1e-3,
betas: tuple = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2,
warmup_steps: int = 0,
kahan_sum: bool = True,
):
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
warmup_steps=warmup_steps,
kahan_sum=kahan_sum,
)
super(AdamWScheduleFreeKahan, self).__init__(params, defaults)
self.k = 0
self.lr_max = -1.0
self.last_lr = -1.0
self.weight_sum = 0.0

def _initialize_state(self, state, p):
if "step" not in state:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if self.defaults["kahan_sum"]:
state["kahan_comp"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)

def eval(self):
for group in self.param_groups:
train_mode = group.get("train_mode", True)
beta1, _ = group["betas"]
if train_mode:
for p in group["params"]:
state = self.state[p]
if "z" in state:
# Set p.data to x
p.data.lerp_(
end=state["z"].to(p.data.device), weight=1 - 1 / beta1
)
group["train_mode"] = False

def train(self):
for group in self.param_groups:
train_mode = group.get("train_mode", False)
beta1, _ = group["betas"]
if not train_mode:
for p in group["params"]:
state = self.state[p]
if "z" in state:
# Set p.data to y
p.data.lerp_(end=state["z"].to(p.data.device), weight=1 - beta1)
group["train_mode"] = True

def step(self, closure=None):
"""Performs a single optimization step."""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
beta1, beta2 = group["betas"]
lr = group["lr"]
eps = group["eps"]
weight_decay = group["weight_decay"]
warmup_steps = group["warmup_steps"]
kahan_sum = group["kahan_sum"]

k = self.k

# Adjust learning rate with warmup
if k < warmup_steps:
sched = (k + 1) / warmup_steps
else:
sched = 1.0

bias_correction2 = 1 - beta2 ** (k + 1)
adjusted_lr = lr * sched * (bias_correction2**0.5)
self.lr_max = max(adjusted_lr, self.lr_max)

for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data

state = self.state[p]
self._initialize_state(state, p)

exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]

if kahan_sum:
kahan_comp = state["kahan_comp"]
grad.add_(kahan_comp)

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

denom = exp_avg_sq.sqrt().add_(eps)

step_size = adjusted_lr / (bias_correction2**0.5)

if weight_decay != 0:
p.data.add_(p.data, alpha=-weight_decay)

# Kahan summation to improve precision
step = exp_avg / denom
p.data.add_(-step_size * step)

if kahan_sum:
buffer = p.data.add(-step_size * step)
kahan_comp.copy_(p.data.sub(buffer).add(buffer.sub_(p.data)))

self.k += 1
self.last_lr = adjusted_lr
StateTracker.set_last_lr(adjusted_lr)

return loss

def get_last_lr(self):
return self.last_lr
11 changes: 11 additions & 0 deletions helpers/training/state_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class StateTracker:
# Aspect to resolution map, we'll store once generated for consistency.
aspect_resolution_map = {}

# for schedulefree
last_lr = 0.0

# hugging face hub user details
hf_user = None

Expand Down Expand Up @@ -530,3 +533,11 @@ def load_aspect_resolution_map(cls, dataloader_resolution: float):
logger.debug(
f"Aspect resolution map: {cls.aspect_resolution_map[dataloader_resolution]}"
)

@classmethod
def get_last_lr(cls):
return cls.last_lr

@classmethod
def set_last_lr(cls, last_lr: float):
cls.last_lr = float(last_lr)
72 changes: 54 additions & 18 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from helpers.training.validation import Validation, prepare_validation_prompt_list
from helpers.training.state_tracker import StateTracker
from helpers.training.schedulers import load_scheduler_from_args
from helpers.training.custom_schedule import get_lr_scheduler
from helpers.training.optimizer_param import is_lr_scheduler_disabled
from helpers.training.adapter import determine_adapter_target_modules, load_lora_weights
from helpers.training.diffusion_model import load_diffusion_model
from helpers.training.text_encoding import (
Expand Down Expand Up @@ -839,9 +841,15 @@ def main():
f" {args.num_train_epochs} epochs and {num_update_steps_per_epoch} steps per epoch."
)
overrode_max_train_steps = True
logger.info(
f"Loading {args.lr_scheduler} learning rate scheduler with {args.lr_warmup_steps} warmup steps"
)
is_schedulefree = is_lr_scheduler_disabled(args.optimizer)
if is_schedulefree:
logger.info(
"Using experimental AdamW ScheduleFree optimiser from Facebook. Experimental due to newly added Kahan summation."
)
else:
logger.info(
f"Loading {args.lr_scheduler} learning rate scheduler with {args.lr_warmup_steps} warmup steps"
)
if args.layer_freeze_strategy == "bitfit":
from helpers.training.model_freeze import apply_bitfit_freezing

Expand Down Expand Up @@ -978,26 +986,31 @@ def main():
optimizer,
)

from helpers.training.custom_schedule import get_lr_scheduler

if not use_deepspeed_scheduler:
if is_lr_scheduler_disabled(args.optimizer):
# we don't use LR schedulers with schedulefree schedulers (lol)
lr_scheduler = None
if not use_deepspeed_scheduler and not is_schedulefree:
logger.info(
f"Loading {args.lr_scheduler} learning rate scheduler with {args.lr_warmup_steps} warmup steps"
)
lr_scheduler = get_lr_scheduler(
args, optimizer, accelerator, logger, use_deepspeed_scheduler=False
)
else:
logger.info(f"Using DeepSpeed learning rate scheduler")
lr_scheduler = accelerate.utils.DummyScheduler(
optimizer,
total_num_steps=args.max_train_steps,
warmup_num_steps=args.lr_warmup_steps,
)
if hasattr(lr_scheduler, "num_update_steps_per_epoch"):
lr_scheduler.num_update_steps_per_epoch = num_update_steps_per_epoch
if hasattr(lr_scheduler, "last_step"):
lr_scheduler.last_step = global_resume_step
logger.info(f"Using dummy learning rate scheduler")
if torch.backends.mps.is_available():
lr_scheduler = None
else:
lr_scheduler = accelerate.utils.DummyScheduler(
optimizer,
total_num_steps=args.max_train_steps,
warmup_num_steps=args.lr_warmup_steps,
)
if lr_scheduler is not None:
if hasattr(lr_scheduler, "num_update_steps_per_epoch"):
lr_scheduler.num_update_steps_per_epoch = num_update_steps_per_epoch
if hasattr(lr_scheduler, "last_step"):
lr_scheduler.last_step = global_resume_step

accelerator.wait_for_everyone()

Expand Down Expand Up @@ -1285,6 +1298,10 @@ def main():
if "sampler" in backend:
backend["sampler"].log_state()

if is_lr_scheduler_disabled(args.optimizer) and hasattr(optimizer, "train"):
# we typically have to call train() on the optim for schedulefree.
optimizer.train()

total_steps_remaining_at_start = args.max_train_steps
# We store the number of dataset resets that have occurred inside the checkpoint.
first_epoch = StateTracker.get_epoch()
Expand Down Expand Up @@ -1399,7 +1416,11 @@ def main():
if webhook_handler is not None:
webhook_handler.send(message=initial_msg)
if args.validation_on_startup and global_step <= 1:
if is_lr_scheduler_disabled(args.optimizer):
optimizer.eval()
validation.run_validations(validation_type="base_model", step=0)
if is_lr_scheduler_disabled(args.optimizer):
optimizer.train()

# Only show the progress bar once on each machine.
show_progress_bar = True
Expand Down Expand Up @@ -2051,8 +2072,12 @@ def main():
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
try:
lr_scheduler.step(**scheduler_kwargs)
lr = lr_scheduler.get_last_lr()[0]
if is_schedulefree:
# hackjob method of retrieving LR from accelerated optims
lr = StateTracker.get_last_lr()
else:
lr_scheduler.step(**scheduler_kwargs)
lr = lr_scheduler.get_last_lr()[0]
except Exception as e:
logger.error(
f"Failed to get the last learning rate from the scheduler. Error: {e}"
Expand Down Expand Up @@ -2164,7 +2189,12 @@ def main():
args.output_dir, f"checkpoint-{global_step}"
)
print("\n")
# schedulefree optim needs the optimizer to be in eval mode to save the state (and then back to train after)
if is_schedulefree:
optimizer.eval()
accelerator.save_state(save_path)
if is_schedulefree:
optimizer.train()
for _, backend in StateTracker.get_data_backends().items():
if "sampler" in backend:
logger.debug(f"Backend: {backend}")
Expand All @@ -2185,7 +2215,11 @@ def main():
"lr": lr,
}
progress_bar.set_postfix(**logs)
if is_schedulefree:
optimizer.eval()
validation.run_validations(validation_type="intermediary", step=step)
if is_schedulefree:
optimizer.train()
if (
args.push_to_hub
and args.push_checkpoints_to_hub
Expand Down Expand Up @@ -2220,6 +2254,8 @@ def main():
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
if is_schedulefree:
optimizer.eval()
validation_images = validation.run_validations(
validation_type="final",
step=global_step,
Expand Down
Loading

0 comments on commit fd78777

Please sign in to comment.