Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add schedulefree optim w/ kahan summation #857

Merged
merged 6 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions helpers/training/optimizer_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
logger.setLevel(target_level)

is_optimi_available = False
from helpers.training.adam_bfloat16 import AdamWBF16
from helpers.training.optimizers.adamw_bfloat16 import AdamWBF16
from helpers.training.optimizers.adamw_schedulefree import AdamWScheduleFreeKahan

try:
from optimum.quanto import QTensor
Expand All @@ -35,6 +36,17 @@
},
"class": AdamWBF16,
},
"adamw_schedulefree": {
"precision": "any",
"override_lr_scheduler": True,
"can_warmup": True,
"default_settings": {
"betas": (0.9, 0.999),
"weight_decay": 1e-2,
"eps": 1e-8,
},
"class": AdamWScheduleFreeKahan,
},
"optimi-stableadamw": {
"precision": "any",
"default_settings": {
Expand Down Expand Up @@ -154,8 +166,8 @@
}

deprecated_optimizers = {
"prodigy": "Prodigy optimiser has been removed due to issues with precision levels and convergence. Please use optimi-stableadamw or optimi-lion instead - for decoupled learning rate, use --optimizer_config=decoupled_lr=True.",
"dadaptation": "D-adaptation optimiser has been removed due to issues with precision levels and convergence. Please use optimi-stableadamw instead.",
"prodigy": "Prodigy optimiser has been removed due to issues with precision levels and convergence. Please use adamw_schedulefree instead.",
"dadaptation": "D-adaptation optimiser has been removed due to issues with precision levels and convergence. Please use adamw_schedulefree instead.",
"adafactor": "Adafactor optimiser has been removed in favour of optimi-stableadamw, which offers improved memory efficiency and convergence.",
"adamw8bit": "AdamW8Bit has been removed in favour of optimi-adamw optimiser, which offers better low-precision support. Please use this or adamw_bf16 instead.",
}
Expand Down Expand Up @@ -206,6 +218,16 @@ def optimizer_parameters(optimizer, args):
raise ValueError(f"Optimizer {optimizer} not found.")


def is_lr_scheduler_disabled(optimizer: str):
"""Check if the optimizer has a built-in LR scheduler"""
is_disabled = False
if optimizer in optimizer_choices:
is_disabled = optimizer_choices.get(optimizer).get(
"override_lr_scheduler", False
)
return is_disabled


def show_optimizer_defaults(optimizer: str = None):
"""we'll print the defaults on a single line, eg. foo=bar, buz=baz"""
if optimizer is None:
Expand Down Expand Up @@ -260,7 +282,12 @@ def determine_optimizer_class_with_config(
else:
optimizer_class, optimizer_details = optimizer_parameters(args.optimizer, args)
default_settings = optimizer_details.get("default_settings")
logger.info(f"cls: {optimizer_class}, settings: {default_settings}")
if optimizer_details.get("can_warmup", False):
logger.info(
f"Optimizer contains LR scheduler, warmup steps will be set to {args.lr_warmup_steps}."
)
default_settings["warmup_steps"] = args.lr_warmup_steps
logger.info(f"cls: {optimizer_class}, settings: {default_settings}")
return default_settings, optimizer_class


Expand Down
149 changes: 149 additions & 0 deletions helpers/training/optimizers/adamw_schedulefree/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import torch
from torch.optim.optimizer import Optimizer
import math
from typing import Iterable
from helpers.training.state_tracker import StateTracker


class AdamWScheduleFreeKahan(Optimizer):
"""AdamW optimizer with schedule-free adjustments and Kahan summation.

Args:
params: Iterable of parameters to optimize or dicts defining parameter groups.
lr: Learning rate.
betas: Coefficients for gradient and squared gradient moving averages (default: (0.9, 0.999)).
eps: Added to denominator to improve numerical stability (default: 1e-8).
weight_decay: Weight decay coefficient (default: 1e-2).
warmup_steps: Number of steps to warm up the learning rate (default: 0).
kahan_sum: Enables Kahan summation for more accurate parameter updates when training in low precision.
"""

def __init__(
self,
params: Iterable,
lr: float = 1e-3,
betas: tuple = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2,
warmup_steps: int = 0,
kahan_sum: bool = True,
):
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
warmup_steps=warmup_steps,
kahan_sum=kahan_sum,
)
super(AdamWScheduleFreeKahan, self).__init__(params, defaults)
self.k = 0
self.lr_max = -1.0
self.last_lr = -1.0
self.weight_sum = 0.0

def _initialize_state(self, state, p):
if "step" not in state:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if self.defaults["kahan_sum"]:
state["kahan_comp"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)

def eval(self):
for group in self.param_groups:
train_mode = group.get("train_mode", True)
beta1, _ = group["betas"]
if train_mode:
for p in group["params"]:
state = self.state[p]
if "z" in state:
# Set p.data to x
p.data.lerp_(
end=state["z"].to(p.data.device), weight=1 - 1 / beta1
)
group["train_mode"] = False

def train(self):
for group in self.param_groups:
train_mode = group.get("train_mode", False)
beta1, _ = group["betas"]
if not train_mode:
for p in group["params"]:
state = self.state[p]
if "z" in state:
# Set p.data to y
p.data.lerp_(end=state["z"].to(p.data.device), weight=1 - beta1)
group["train_mode"] = True

def step(self, closure=None):
"""Performs a single optimization step."""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
beta1, beta2 = group["betas"]
lr = group["lr"]
eps = group["eps"]
weight_decay = group["weight_decay"]
warmup_steps = group["warmup_steps"]
kahan_sum = group["kahan_sum"]

k = self.k

# Adjust learning rate with warmup
if k < warmup_steps:
sched = (k + 1) / warmup_steps
else:
sched = 1.0

bias_correction2 = 1 - beta2 ** (k + 1)
adjusted_lr = lr * sched * (bias_correction2**0.5)
self.lr_max = max(adjusted_lr, self.lr_max)

for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data

state = self.state[p]
self._initialize_state(state, p)

exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]

if kahan_sum:
kahan_comp = state["kahan_comp"]
grad.add_(kahan_comp)

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

denom = exp_avg_sq.sqrt().add_(eps)

step_size = adjusted_lr / (bias_correction2**0.5)

if weight_decay != 0:
p.data.add_(p.data, alpha=-weight_decay)

# Kahan summation to improve precision
step = exp_avg / denom
p.data.add_(-step_size * step)

if kahan_sum:
buffer = p.data.add(-step_size * step)
kahan_comp.copy_(p.data.sub(buffer).add(buffer.sub_(p.data)))

self.k += 1
self.last_lr = adjusted_lr
StateTracker.set_last_lr(adjusted_lr)

return loss

def get_last_lr(self):
return self.last_lr
11 changes: 11 additions & 0 deletions helpers/training/state_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class StateTracker:
# Aspect to resolution map, we'll store once generated for consistency.
aspect_resolution_map = {}

# for schedulefree
last_lr = 0.0

# hugging face hub user details
hf_user = None

Expand Down Expand Up @@ -530,3 +533,11 @@ def load_aspect_resolution_map(cls, dataloader_resolution: float):
logger.debug(
f"Aspect resolution map: {cls.aspect_resolution_map[dataloader_resolution]}"
)

@classmethod
def get_last_lr(cls):
return cls.last_lr

@classmethod
def set_last_lr(cls, last_lr: float):
cls.last_lr = float(last_lr)
72 changes: 54 additions & 18 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from helpers.training.validation import Validation, prepare_validation_prompt_list
from helpers.training.state_tracker import StateTracker
from helpers.training.schedulers import load_scheduler_from_args
from helpers.training.custom_schedule import get_lr_scheduler
from helpers.training.optimizer_param import is_lr_scheduler_disabled
from helpers.training.adapter import determine_adapter_target_modules, load_lora_weights
from helpers.training.diffusion_model import load_diffusion_model
from helpers.training.text_encoding import (
Expand Down Expand Up @@ -839,9 +841,15 @@ def main():
f" {args.num_train_epochs} epochs and {num_update_steps_per_epoch} steps per epoch."
)
overrode_max_train_steps = True
logger.info(
f"Loading {args.lr_scheduler} learning rate scheduler with {args.lr_warmup_steps} warmup steps"
)
is_schedulefree = is_lr_scheduler_disabled(args.optimizer)
if is_schedulefree:
logger.info(
"Using experimental AdamW ScheduleFree optimiser from Facebook. Experimental due to newly added Kahan summation."
)
else:
logger.info(
f"Loading {args.lr_scheduler} learning rate scheduler with {args.lr_warmup_steps} warmup steps"
)
if args.layer_freeze_strategy == "bitfit":
from helpers.training.model_freeze import apply_bitfit_freezing

Expand Down Expand Up @@ -978,26 +986,31 @@ def main():
optimizer,
)

from helpers.training.custom_schedule import get_lr_scheduler

if not use_deepspeed_scheduler:
if is_lr_scheduler_disabled(args.optimizer):
# we don't use LR schedulers with schedulefree schedulers (lol)
lr_scheduler = None
if not use_deepspeed_scheduler and not is_schedulefree:
logger.info(
f"Loading {args.lr_scheduler} learning rate scheduler with {args.lr_warmup_steps} warmup steps"
)
lr_scheduler = get_lr_scheduler(
args, optimizer, accelerator, logger, use_deepspeed_scheduler=False
)
else:
logger.info(f"Using DeepSpeed learning rate scheduler")
lr_scheduler = accelerate.utils.DummyScheduler(
optimizer,
total_num_steps=args.max_train_steps,
warmup_num_steps=args.lr_warmup_steps,
)
if hasattr(lr_scheduler, "num_update_steps_per_epoch"):
lr_scheduler.num_update_steps_per_epoch = num_update_steps_per_epoch
if hasattr(lr_scheduler, "last_step"):
lr_scheduler.last_step = global_resume_step
logger.info(f"Using dummy learning rate scheduler")
if torch.backends.mps.is_available():
lr_scheduler = None
else:
lr_scheduler = accelerate.utils.DummyScheduler(
optimizer,
total_num_steps=args.max_train_steps,
warmup_num_steps=args.lr_warmup_steps,
)
if lr_scheduler is not None:
if hasattr(lr_scheduler, "num_update_steps_per_epoch"):
lr_scheduler.num_update_steps_per_epoch = num_update_steps_per_epoch
if hasattr(lr_scheduler, "last_step"):
lr_scheduler.last_step = global_resume_step

accelerator.wait_for_everyone()

Expand Down Expand Up @@ -1285,6 +1298,10 @@ def main():
if "sampler" in backend:
backend["sampler"].log_state()

if is_lr_scheduler_disabled(args.optimizer) and hasattr(optimizer, "train"):
# we typically have to call train() on the optim for schedulefree.
optimizer.train()

total_steps_remaining_at_start = args.max_train_steps
# We store the number of dataset resets that have occurred inside the checkpoint.
first_epoch = StateTracker.get_epoch()
Expand Down Expand Up @@ -1399,7 +1416,11 @@ def main():
if webhook_handler is not None:
webhook_handler.send(message=initial_msg)
if args.validation_on_startup and global_step <= 1:
if is_lr_scheduler_disabled(args.optimizer):
optimizer.eval()
validation.run_validations(validation_type="base_model", step=0)
if is_lr_scheduler_disabled(args.optimizer):
optimizer.train()

# Only show the progress bar once on each machine.
show_progress_bar = True
Expand Down Expand Up @@ -2051,8 +2072,12 @@ def main():
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
try:
lr_scheduler.step(**scheduler_kwargs)
lr = lr_scheduler.get_last_lr()[0]
if is_schedulefree:
# hackjob method of retrieving LR from accelerated optims
lr = StateTracker.get_last_lr()
else:
lr_scheduler.step(**scheduler_kwargs)
lr = lr_scheduler.get_last_lr()[0]
except Exception as e:
logger.error(
f"Failed to get the last learning rate from the scheduler. Error: {e}"
Expand Down Expand Up @@ -2164,7 +2189,12 @@ def main():
args.output_dir, f"checkpoint-{global_step}"
)
print("\n")
# schedulefree optim needs the optimizer to be in eval mode to save the state (and then back to train after)
if is_schedulefree:
optimizer.eval()
accelerator.save_state(save_path)
if is_schedulefree:
optimizer.train()
for _, backend in StateTracker.get_data_backends().items():
if "sampler" in backend:
logger.debug(f"Backend: {backend}")
Expand All @@ -2185,7 +2215,11 @@ def main():
"lr": lr,
}
progress_bar.set_postfix(**logs)
if is_schedulefree:
optimizer.eval()
validation.run_validations(validation_type="intermediary", step=step)
if is_schedulefree:
optimizer.train()
if (
args.push_to_hub
and args.push_checkpoints_to_hub
Expand Down Expand Up @@ -2220,6 +2254,8 @@ def main():
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
if is_schedulefree:
optimizer.eval()
validation_images = validation.run_validations(
validation_type="final",
step=global_step,
Expand Down
Loading
Loading