From fc374375de4fc9efd10eb598fdc166a4b6d0ad17 Mon Sep 17 00:00:00 2001 From: Cauldrath Date: Thu, 18 Apr 2024 23:29:01 -0400 Subject: [PATCH] Allow negative learning rate This can be used to train away from a group of images you don't want As this moves the model away from a point instead of towards it, the change in the model is unbounded So, don't set it too low. -4e-7 seemed to work well. --- sdxl_train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index 46d7860be..1e6cec1a4 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -272,7 +272,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # 学習を準備する:モデルを適切な状態にする if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - train_unet = args.learning_rate > 0 + train_unet = args.learning_rate != 0 train_text_encoder1 = False train_text_encoder2 = False @@ -284,8 +284,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder2.gradient_checkpointing_enable() lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train - train_text_encoder1 = lr_te1 > 0 - train_text_encoder2 = lr_te2 > 0 + train_text_encoder1 = lr_te1 != 0 + train_text_encoder2 = lr_te2 != 0 # caching one text encoder output is not supported if not train_text_encoder1: