From f99fe281cbb6519b7b5f1199c570d496ad4df474 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 1 Apr 2024 15:38:26 -0400 Subject: [PATCH] Add LoRA+ support --- library/train_util.py | 2 ++ networks/dylora.py | 45 ++++++++++++++++++++++++++---------- networks/lora.py | 54 ++++++++++++++++++++++++++++--------------- train_network.py | 2 +- 4 files changed, 71 insertions(+), 32 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index d2b69edb5..4e5ab7370 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2789,6 +2789,8 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) + parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") + parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): diff --git a/networks/dylora.py b/networks/dylora.py index 637f33450..a73ade8bd 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -406,27 +406,48 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): logger.info(f"weights are merged") """ - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): self.requires_grad_(True) all_params = [] - def enumerate_params(loras): - params = [] + def assemble_params(loras, lr, lora_plus_ratio): + param_groups = {"lora": {}, "plus": {}} for lora in loras: - params.extend(lora.parameters()) + for name, param in lora.named_parameters(): + if lora_plus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + # assigned_param_groups = "" + # for group in param_groups: + # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" + # logger.info(assigned_param_groups) + + params = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + if lr is not None: + if key == "plus": + param_data["lr"] = lr * lora_plus_ratio + else: + param_data["lr"] = lr + + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + + params.append(param_data) + return params if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) + params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + all_params.extend(params) if self.unet_loras: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + all_params.extend(params) return all_params diff --git a/networks/lora.py b/networks/lora.py index 948b30b0e..8d7619777 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1035,21 +1035,43 @@ def get_lr_weight(self, lora: LoRAModule) -> float: return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): self.requires_grad_(True) all_params = [] - def enumerate_params(loras): - params = [] + def assemble_params(loras, lr, lora_plus_ratio): + param_groups = {"lora": {}, "plus": {}} for lora in loras: - params.extend(lora.parameters()) + for name, param in lora.named_parameters(): + if lora_plus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + # assigned_param_groups = "" + # for group in param_groups: + # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" + # logger.info(assigned_param_groups) + + params = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + if lr is not None: + if key == "plus": + param_data["lr"] = lr * lora_plus_ratio + else: + param_data["lr"] = lr + + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + + params.append(param_data) + return params if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) + params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + all_params.extend(params) if self.unet_loras: if self.block_lr: @@ -1063,21 +1085,15 @@ def enumerate_params(loras): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - param_data = {"params": enumerate_params(block_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) + params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) elif default_lr is not None: - param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) - if ("lr" in param_data) and (param_data["lr"] == 0): - continue - all_params.append(param_data) + params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) + all_params.extend(params) else: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + all_params.extend(params) return all_params diff --git a/train_network.py b/train_network.py index e0fa69458..ba0c124d1 100644 --- a/train_network.py +++ b/train_network.py @@ -339,7 +339,7 @@ def train(self, args): # 後方互換性を確保するよ try: - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio) except TypeError: accelerator.print( "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"