Skip to content

Commit

Permalink
Merge pull request kohya-ss#1277 from Cauldrath/negative_learning
Browse files Browse the repository at this point in the history
Allow negative learning rate
  • Loading branch information
kohya-ss authored May 19, 2024
2 parents e4d9e3c + fc37437 commit 38e4c60
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
# 学習を準備する:モデルを適切な状態にする
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
train_unet = args.learning_rate > 0
train_unet = args.learning_rate != 0
train_text_encoder1 = False
train_text_encoder2 = False

Expand All @@ -284,8 +284,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder2.gradient_checkpointing_enable()
lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train
lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train
train_text_encoder1 = lr_te1 > 0
train_text_encoder2 = lr_te2 > 0
train_text_encoder1 = lr_te1 != 0
train_text_encoder2 = lr_te2 != 0

# caching one text encoder output is not supported
if not train_text_encoder1:
Expand Down

0 comments on commit 38e4c60

Please sign in to comment.