Skip to content

Commit

Permalink
Support d*lr for ProdigyPlus optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
rockerBOO committed Nov 20, 2024
1 parent 0dbb0d9 commit ef9a9c1
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self):

# TODO 他のスクリプトと共通化する
def generate_step_logs(
self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None
self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, optimizer, keys_scaled=None, mean_norm=None, maximum_norm=None
):
logs = {"loss/current": current_loss, "loss/average": avr_loss}

Expand All @@ -79,6 +79,12 @@ def generate_step_logs(
logs["lr/d*lr"] = (
lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"]
)
if (
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower())
): # tracking d*lr value of unet.
logs["lr/d*lr"] = (
optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
)
else:
idx = 0
if not args.network_train_unet_only:
Expand All @@ -91,6 +97,12 @@ def generate_step_logs(
logs[f"lr/d*lr/group{i}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
)
if (
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower())
):
logs[f"lr/d*lr/group{i}"] = (
optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
)

return logs

Expand Down Expand Up @@ -965,7 +977,7 @@ def remove_model(old_ckpt_name):
progress_bar.set_postfix(**{**max_mean_logs, **logs})

if args.logging_dir is not None:
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, optimizer, keys_scaled, mean_norm, maximum_norm)
accelerator.log(logs, step=global_step)

if global_step >= args.max_train_steps:
Expand Down

0 comments on commit ef9a9c1

Please sign in to comment.