Skip to content

Commit

Permalink
Support new optimizer Schedule free (kohya-ss#1250)
Browse files Browse the repository at this point in the history
* init

* use no schedule

* fix typo

* update for eval()

* fix typo
  • Loading branch information
rockerBOO committed Jun 2, 2024
1 parent e5bab69 commit a652c91
Show file tree
Hide file tree
Showing 10 changed files with 209 additions and 48 deletions.
35 changes: 27 additions & 8 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,18 +255,31 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree"):
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree"):
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader
)
else:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("schedulefree"):
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
Expand Down Expand Up @@ -328,6 +341,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
m.train()

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(*training_models):
with torch.no_grad():
Expand Down Expand Up @@ -400,9 +415,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
15 changes: 15 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4250,6 +4250,21 @@ def get_optimizer(args, trainable_params):
logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

elif optimizer_type.endswith("schedulefree".lower()):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
if optimizer_type == "AdamWScheduleFree".lower():
optimizer_class = sf.AdamWScheduleFree
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
elif optimizer_type == "SGDScheduleFree".lower():
optimizer_class = sf.SGDScheduleFree
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
else:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

if optimizer is None:
# 任意のoptimizerを使う
Expand Down
24 changes: 19 additions & 5 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,9 +466,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder2=text_encoder2 if train_text_encoder2 else None,
)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree"):
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]

else:
Expand All @@ -479,7 +484,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("schedulefree"):
optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader)
else:
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
Expand Down Expand Up @@ -605,6 +613,8 @@ def optimizer_hook(parameter: torch.Tensor):
m.train()

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step

if args.fused_optimizer_groups:
Expand Down Expand Up @@ -740,7 +750,8 @@ def optimizer_hook(parameter: torch.Tensor):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
Expand All @@ -749,6 +760,9 @@ def optimizer_hook(parameter: torch.Tensor):
for i in range(1, len(optimizers)):
lr_schedulers[i].step()

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
18 changes: 16 additions & 2 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,19 @@ def train(args):
unet.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("schedulefree"):
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

if args.gradient_checkpointing:
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる

else:
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()
unet.eval()

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
Expand Down Expand Up @@ -391,6 +399,8 @@ def remove_model(old_ckpt_name):
current_epoch.value = epoch + 1

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(unet):
with torch.no_grad():
Expand Down Expand Up @@ -486,9 +496,13 @@ def remove_model(old_ckpt_name):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
23 changes: 19 additions & 4 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,15 +254,24 @@ def train(args):
network.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree"):
unet, network, optimizer, train_dataloader = accelerator.prepare(
unet, network, optimizer, train_dataloader
)
else:
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
network: control_net_lllite.ControlNetLLLite

if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
else:
unet.eval()
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

network.prepare_grad_etc()

Expand Down Expand Up @@ -357,6 +366,8 @@ def remove_model(old_ckpt_name):
network.on_epoch_start() # train()

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(network):
with torch.no_grad():
Expand Down Expand Up @@ -449,9 +460,13 @@ def remove_model(old_ckpt_name):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
16 changes: 13 additions & 3 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,14 @@ def __contains__(self, name):
controlnet.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
controlnet, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree"):
controlnet, optimizer, train_dataloader = accelerator.prepare(
controlnet, optimizer, train_dataloader
)
else:
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
controlnet, optimizer, train_dataloader, lr_scheduler
)

unet.requires_grad_(False)
text_encoder.requires_grad_(False)
Expand Down Expand Up @@ -417,6 +422,8 @@ def remove_model(old_ckpt_name):
current_epoch.value = epoch + 1

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(controlnet):
with torch.no_grad():
Expand Down Expand Up @@ -500,6 +507,9 @@ def remove_model(old_ckpt_name):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
35 changes: 27 additions & 8 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,19 +229,32 @@ def train(args):
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree"):
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]

else:
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree"):
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader
)
else:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
training_models = [unet, text_encoder]
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("schedulefree"):
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
training_models = [unet]

if not train_text_encoder:
Expand Down Expand Up @@ -307,6 +320,8 @@ def train(args):
text_encoder.train()

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
Expand Down Expand Up @@ -384,9 +399,13 @@ def train(args):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
Loading

0 comments on commit a652c91

Please sign in to comment.