Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support new optimizer Schedule free #1250

Merged
merged 9 commits into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,18 +255,31 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree"):
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]
else:
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree"):
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader
)
else:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("schedulefree"):
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
Expand Down Expand Up @@ -324,6 +337,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
m.train()

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(*training_models):
with torch.no_grad():
Expand Down Expand Up @@ -390,9 +405,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
17 changes: 16 additions & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3087,7 +3087,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
)
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument(
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする"
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
)
parser.add_argument(
"--gradient_accumulation_steps",
Expand Down Expand Up @@ -4087,6 +4087,21 @@ def get_optimizer(args, trainable_params):
logger.info(f"use AdamW optimizer | {optimizer_kwargs}")
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

elif optimizer_type.endswith("schedulefree".lower()):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
if optimizer_type == "AdamWScheduleFree".lower():
optimizer_class = sf.AdamWScheduleFree
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
elif optimizer_type == "SGDScheduleFree".lower():
optimizer_class = sf.SGDScheduleFree
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
else:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

if optimizer is None:
# 任意のoptimizerを使う
Expand Down
24 changes: 19 additions & 5 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,9 +415,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder2=text_encoder2 if train_text_encoder2 else None,
)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree"):
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]

else:
Expand All @@ -428,7 +433,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("schedulefree"):
optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader)
else:
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
Expand Down Expand Up @@ -503,6 +511,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
m.train()

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
Expand Down Expand Up @@ -626,9 +636,13 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
18 changes: 16 additions & 2 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,19 @@ def train(args):
unet.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("schedulefree"):
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)

if args.gradient_checkpointing:
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる

else:
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()
unet.eval()

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
Expand Down Expand Up @@ -390,6 +398,8 @@ def remove_model(old_ckpt_name):
current_epoch.value = epoch + 1

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(unet):
with torch.no_grad():
Expand Down Expand Up @@ -481,9 +491,13 @@ def remove_model(old_ckpt_name):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
23 changes: 19 additions & 4 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,15 +254,24 @@ def train(args):
network.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree"):
unet, network, optimizer, train_dataloader = accelerator.prepare(
unet, network, optimizer, train_dataloader
)
else:
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
network: control_net_lllite.ControlNetLLLite

if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
else:
unet.eval()
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

network.prepare_grad_etc()

Expand Down Expand Up @@ -357,6 +366,8 @@ def remove_model(old_ckpt_name):
network.on_epoch_start() # train()

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(network):
with torch.no_grad():
Expand Down Expand Up @@ -449,9 +460,13 @@ def remove_model(old_ckpt_name):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
16 changes: 13 additions & 3 deletions train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,14 @@ def train(args):
controlnet.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
controlnet, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree"):
controlnet, optimizer, train_dataloader = accelerator.prepare(
controlnet, optimizer, train_dataloader
)
else:
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
controlnet, optimizer, train_dataloader, lr_scheduler
)

unet.requires_grad_(False)
text_encoder.requires_grad_(False)
Expand Down Expand Up @@ -393,6 +398,8 @@ def remove_model(old_ckpt_name):
current_epoch.value = epoch + 1

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
with accelerator.accumulate(controlnet):
with torch.no_grad():
Expand Down Expand Up @@ -472,6 +479,9 @@ def remove_model(old_ckpt_name):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
35 changes: 27 additions & 8 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,19 +229,32 @@ def train(args):
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree"):
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
else:
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]

else:
if train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
if args.optimizer_type.lower().endswith("schedulefree"):
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader
)
else:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
training_models = [unet, text_encoder]
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
if args.optimizer_type.lower().endswith("schedulefree"):
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
training_models = [unet]

if not train_text_encoder:
Expand Down Expand Up @@ -307,6 +320,8 @@ def train(args):
text_encoder.train()

for step, batch in enumerate(train_dataloader):
if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.train()
current_step.value = global_step
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:
Expand Down Expand Up @@ -384,9 +399,13 @@ def train(args):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

if (args.optimizer_type.lower().endswith("schedulefree")):
optimizer.eval()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
Loading