Skip to content

Commit

Permalink
te lr changes
Browse files Browse the repository at this point in the history
  • Loading branch information
linoytsaban committed Oct 1, 2024
1 parent 4faa6cf commit 571e49c
Showing 1 changed file with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1717,7 +1717,7 @@ def load_model_hook(models, input_dir):
"weight_decay": args.adam_weight_decay_text_encoder
if args.adam_weight_decay_text_encoder
else args.adam_weight_decay,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
"lr": args.text_encoder_lr,
}
if not args.enable_t5_ti:
# pure textual inversion - only clip
Expand All @@ -1739,7 +1739,7 @@ def load_model_hook(models, input_dir):
"weight_decay": args.adam_weight_decay_text_encoder
if args.adam_weight_decay_text_encoder
else args.adam_weight_decay,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
"lr": args.text_encoder_lr,
}
# pure textual inversion - only clip & t5
if pure_textual_inversion:
Expand Down Expand Up @@ -1783,7 +1783,6 @@ def load_model_hook(models, input_dir):
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW

optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
Expand All @@ -1803,16 +1802,17 @@ def load_model_hook(models, input_dir):
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
if args.train_text_encoder and args.text_encoder_lr:
if not freeze_text_encoder and args.text_encoder_lr:
logger.warning(
f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
)
# changes the learning rate of text_encoder_parameters_one to be
# changes the learning rate of text_encoder_parameters to be
# --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate

params_to_optimize[te_idx]["lr"] = args.learning_rate
params_to_optimize[-1]["lr"] = args.learning_rate
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
Expand Down

0 comments on commit 571e49c

Please sign in to comment.