From b1dffe8d9ae1c02a06e8871a844c42d6729623ce Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 31 Mar 2023 00:11:11 +0900 Subject: [PATCH 1/9] =?UTF-8?q?=E3=83=95=E3=82=A1=E3=82=A4=E3=83=AB?= =?UTF-8?q?=E3=83=AD=E3=83=BC=E3=83=89=E3=81=8C=E3=81=A7=E3=81=8D=E3=81=AA?= =?UTF-8?q?=E3=81=84=E3=83=90=E3=82=B0=E4=BF=AE=E6=AD=A3(Exception:=20devi?= =?UTF-8?q?ce=20cuda=20is=20invalid)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- library/model_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/model_util.py b/library/model_util.py index e227ced83..f3f236af1 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -831,7 +831,7 @@ def is_safetensors(path): return os.path.splitext(path)[1].lower() == '.safetensors' -def load_checkpoint_with_text_encoder_conversion(ckpt_path, device): +def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): # text encoderの格納形式が違うモデルに対応する ('text_model'がない) TEXT_ENCODER_KEY_REPLACEMENTS = [ ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'), @@ -866,7 +866,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device): # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device='cpu', dtype=None): - _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) + _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) # no need to specify device in loading state_dict # Convert the UNet2DConditionModel model. unet_config = create_unet_diffusers_config(v2) From 4dacc52bde623e7a562d054585250ba8a7737c0a Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 31 Mar 2023 00:39:35 +0900 Subject: [PATCH 2/9] implement stratified_lr --- networks/lora.py | 145 ++++++++++++++++++++++++++++++++++++++++------- train_network.py | 2 +- 2 files changed, 127 insertions(+), 20 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index 2bf785118..4dbf79f95 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -8,9 +8,11 @@ from typing import List import numpy as np import torch +import re from library import train_util +RE_UPDOWN = re.compile(r'(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_') class LoRAModule(torch.nn.Module): """ @@ -177,7 +179,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un else: conv_block_alphas = [int(a) for a in conv_block_alphas(',')] assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" - """ + """ network = LoRANetwork( text_encoder, @@ -188,6 +190,20 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un conv_lora_dim=conv_dim, conv_alpha=conv_alpha, ) + + up_weight=None + if 'up_weight' in kwargs: + up_weight = kwargs.get('up_weight',None) + if "," in up_weight: + up_weight = [float(s) for s in up_weight.split(",") if s] + down_weight=None + if 'down_weight' in kwargs: + down_weight = kwargs.get('down_weight',None) + if "," in down_weight: + down_weight = [float(s) for s in down_weight.split(",") if s] + + network.set_stratified_lr_weight(up_weight,float(kwargs.get('mid_weight', 1.0)),down_weight,float(kwargs.get('lr_weight_threshold', 0.0))) + return network @@ -318,6 +334,10 @@ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) + self.up_weight:list[float] = None + self.down_weight:list[float] = None + self.mid_weight:float = None + def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: @@ -366,9 +386,17 @@ def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None) else: self.unet_loras = [] + skipped = [] for lora in self.text_encoder_loras + self.unet_loras: + if self.get_stratified_lr_weight(lora) == 0: + skipped.append(lora.lora_name) + continue lora.apply_to() self.add_module(lora.lora_name, lora) + if len(skipped)>0: + print(f"stratified_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:") + for name in skipped: + print(f"\t{name}") if self.weights_sd: # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros) @@ -404,34 +432,113 @@ def merge_to(self, text_encoder, unet, dtype, device): lora.merge_to(sd_for_lora, dtype, device) print(f"weights are merged") - def enable_gradient_checkpointing(self): - # not supported - pass - - def prepare_optimizer_params(self, text_encoder_lr, unet_lr): - def enumerate_params(loras): - params = [] - for lora in loras: - params.extend(lora.parameters()) - return params - + # 層別学習率用に層ごとの学習率に対する倍率を定義する + def set_stratified_lr_weight(self, up_weight:list[float]|str=None, mid_weight:float=None, down_weight:list[float]|str=None, zero_threshold:float=0.0): + max_len = 3 # attentions -> attentions -> attentions で3個のModuleに対して定義 + if self.apply_to_conv2d_3x3: + max_len = 10 # (resnets -> {up,down}sampler -> attentions) x3 -> resnets で10個のModuleに対して定義 + + def get_list(name) -> list[float]: + import math + if name=="cosine": + return [math.cos(math.pi*(i/(max_len-1))/2) for i in range(max_len)] + elif name=="sine": + return [math.sin(math.pi*(i/(max_len-1))/2) for i in range(max_len)] + elif name=="linear": + return [i/(max_len-1) for i in range(max_len)] + elif name=="reverse_linear": + return [i/(max_len-1) for i in reversed(range(max_len))] + elif name=="zeros": + return [0.0] * max_len + else: + print("不明なweightの引数 %s が使われました。\n\t有効な引数: cosine, sine, linear, reverse_linear, zeros"%(name)) + return None + + if type(down_weight)==str: + down_weight=get_list(down_weight) + if type(up_weight)==str: + up_weight=get_list(up_weight) + + if (up_weight != None and len(up_weight)>max_len) or (down_weight != None and len(down_weight)>max_len): + print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。"%max_len) + if (up_weight != None and len(up_weight) zero_threshold else 0 for w in down_weight[:max_len]] + print("down_weight(浅い層->深い層):",self.down_weight) + if (mid_weight != None): + self.mid_weight = mid_weight if mid_weight > zero_threshold else 0 + print("mid_weight:",self.mid_weight) + if (up_weight != None): + self.up_weight = [w if w > zero_threshold else 0 for w in up_weight[:max_len]] + print("up_weight(深い層->浅い層):",self.up_weight) + return + + def get_stratified_lr_weight(self, lora:LoRAModule) -> float: + m = RE_UPDOWN.search(lora.lora_name) + if m: + idx = 0 + g = m.groups() + i = int(g[1]) + if self.apply_to_conv2d_3x3: + if g[2]=="resnets": + idx=3*i + elif g[2]=="attentions": + if g[0]=="down": + idx=3*i + 2 + else: + idx=3*i - 1 + elif g[2]=="upsamplers" or g[2]=="downsamplers": + idx=3*i + 1 + else: + idx=i + if g[0]=="up": + idx=i-1 + + if (g[0]=="up") and (self.up_weight != None): + return self.up_weight[idx] + elif (g[0]=="down") and (self.down_weight != None): + return self.down_weight[idx] + elif ("mid_block_" in lora.lora_name) and (self.mid_weight != None): + return self.mid_weight + # print({'params': lora.parameters(), 'lr':alpha*lr}) + return 1 + + def prepare_optimizer_params(self, text_encoder_lr, unet_lr , default_lr): self.requires_grad_(True) all_params = [] if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} + params = [] + for lora in self.text_encoder_loras: + params.extend(lora.parameters()) + param_data = {'params': params} if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr + param_data['lr'] = text_encoder_lr all_params.append(param_data) if self.unet_loras: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) - + for lora in self.unet_loras: + param_data={} + if unet_lr is not None: + param_data = {'params': lora.parameters(), 'lr':self.get_stratified_lr_weight(lora)*unet_lr} + elif default_lr is not None: + param_data = {'params': lora.parameters(), 'lr':self.get_stratified_lr_weight(lora)*default_lr} + if param_data["lr"]==0: + continue + all_params.append(param_data) return all_params + def enable_gradient_checkpointing(self): + # not supported + pass + def prepare_grad_etc(self, text_encoder, unet): self.requires_grad_(True) diff --git a/train_network.py b/train_network.py index 200d8d84d..eb5301e20 100644 --- a/train_network.py +++ b/train_network.py @@ -191,7 +191,7 @@ def train(args): # 学習に必要なクラスを準備する print("prepare optimizer, data loader etc.") - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する From dade23a4149494e8eb9342463aba13a6f6e04b98 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 31 Mar 2023 01:14:03 +0900 Subject: [PATCH 3/9] =?UTF-8?q?stratified=5Fzero=5Fthreshold=E3=81=AB?= =?UTF-8?q?=E5=A4=89=E6=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- networks/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/lora.py b/networks/lora.py index 4dbf79f95..ad8331c81 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -202,7 +202,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un if "," in down_weight: down_weight = [float(s) for s in down_weight.split(",") if s] - network.set_stratified_lr_weight(up_weight,float(kwargs.get('mid_weight', 1.0)),down_weight,float(kwargs.get('lr_weight_threshold', 0.0))) + network.set_stratified_lr_weight(up_weight,float(kwargs.get('mid_weight', 1.0)),down_weight,float(kwargs.get('stratified_zero_threshold', 0.0))) return network From 1b75dbd4f2553bdc09fdfc1d10fa007926a907b5 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 31 Mar 2023 01:40:29 +0900 Subject: [PATCH 4/9] =?UTF-8?q?=E5=BC=95=E6=95=B0=E5=90=8D=E3=81=AB=5Flr?= =?UTF-8?q?=E3=82=92=E8=BF=BD=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- networks/lora.py | 86 ++++++++++++++++++++++++------------------------ 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index ad8331c81..f60789f83 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -191,18 +191,18 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un conv_alpha=conv_alpha, ) - up_weight=None - if 'up_weight' in kwargs: - up_weight = kwargs.get('up_weight',None) - if "," in up_weight: - up_weight = [float(s) for s in up_weight.split(",") if s] - down_weight=None - if 'down_weight' in kwargs: - down_weight = kwargs.get('down_weight',None) - if "," in down_weight: - down_weight = [float(s) for s in down_weight.split(",") if s] - - network.set_stratified_lr_weight(up_weight,float(kwargs.get('mid_weight', 1.0)),down_weight,float(kwargs.get('stratified_zero_threshold', 0.0))) + up_lr_weight=None + if 'up_lr_weight' in kwargs: + up_lr_weight = kwargs.get('up_lr_weight',None) + if "," in up_lr_weight: + up_lr_weight = [float(s) for s in up_lr_weight.split(",") if s] + down_lr_weight=None + if 'down_lr_weight' in kwargs: + down_lr_weight = kwargs.get('down_lr_weight',None) + if "," in down_lr_weight: + down_lr_weight = [float(s) for s in down_lr_weight.split(",") if s] + mid_lr_weight=float(kwargs.get('mid_lr_weight', 1.0)) if 'mid_lr_weight' in kwargs else None + network.set_stratified_lr_weight(up_lr_weight,mid_lr_weight,down_lr_weight,float(kwargs.get('stratified_zero_threshold', 0.0))) return network @@ -334,9 +334,9 @@ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) - self.up_weight:list[float] = None - self.down_weight:list[float] = None - self.mid_weight:float = None + self.up_lr_weight:list[float] = None + self.down_lr_weight:list[float] = None + self.mid_lr_weight:float = None def set_multiplier(self, multiplier): self.multiplier = multiplier @@ -433,7 +433,7 @@ def merge_to(self, text_encoder, unet, dtype, device): print(f"weights are merged") # 層別学習率用に層ごとの学習率に対する倍率を定義する - def set_stratified_lr_weight(self, up_weight:list[float]|str=None, mid_weight:float=None, down_weight:list[float]|str=None, zero_threshold:float=0.0): + def set_stratified_lr_weight(self, up_lr_weight:list[float]|str=None, mid_lr_weight:float=None, down_lr_weight:list[float]|str=None, zero_threshold:float=0.0): max_len = 3 # attentions -> attentions -> attentions で3個のModuleに対して定義 if self.apply_to_conv2d_3x3: max_len = 10 # (resnets -> {up,down}sampler -> attentions) x3 -> resnets で10個のModuleに対して定義 @@ -451,33 +451,33 @@ def get_list(name) -> list[float]: elif name=="zeros": return [0.0] * max_len else: - print("不明なweightの引数 %s が使われました。\n\t有効な引数: cosine, sine, linear, reverse_linear, zeros"%(name)) + print("不明なlr_weightの引数 %s が使われました。\n\t有効な引数: cosine, sine, linear, reverse_linear, zeros"%(name)) return None - if type(down_weight)==str: - down_weight=get_list(down_weight) - if type(up_weight)==str: - up_weight=get_list(up_weight) + if type(down_lr_weight)==str: + down_lr_weight=get_list(down_lr_weight) + if type(up_lr_weight)==str: + up_lr_weight=get_list(up_lr_weight) - if (up_weight != None and len(up_weight)>max_len) or (down_weight != None and len(down_weight)>max_len): + if (up_lr_weight != None and len(up_lr_weight)>max_len) or (down_lr_weight != None and len(down_lr_weight)>max_len): print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。"%max_len) - if (up_weight != None and len(up_weight) zero_threshold else 0 for w in down_weight[:max_len]] - print("down_weight(浅い層->深い層):",self.down_weight) - if (mid_weight != None): - self.mid_weight = mid_weight if mid_weight > zero_threshold else 0 - print("mid_weight:",self.mid_weight) - if (up_weight != None): - self.up_weight = [w if w > zero_threshold else 0 for w in up_weight[:max_len]] - print("up_weight(深い層->浅い層):",self.up_weight) + if (down_lr_weight != None): + self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight[:max_len]] + print("down_lr_weight(浅い層->深い層):",self.down_lr_weight) + if (mid_lr_weight != None): + self.mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 + print("mid_lr_weight:",self.mid_lr_weight) + if (up_lr_weight != None): + self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight[:max_len]] + print("up_lr_weight(深い層->浅い層):",self.up_lr_weight) return def get_stratified_lr_weight(self, lora:LoRAModule) -> float: @@ -501,12 +501,12 @@ def get_stratified_lr_weight(self, lora:LoRAModule) -> float: if g[0]=="up": idx=i-1 - if (g[0]=="up") and (self.up_weight != None): - return self.up_weight[idx] - elif (g[0]=="down") and (self.down_weight != None): - return self.down_weight[idx] - elif ("mid_block_" in lora.lora_name) and (self.mid_weight != None): - return self.mid_weight + if (g[0]=="up") and (self.up_lr_weight != None): + return self.up_lr_weight[idx] + elif (g[0]=="down") and (self.down_lr_weight != None): + return self.down_lr_weight[idx] + elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): + return self.mid_lr_weight # print({'params': lora.parameters(), 'lr':alpha*lr}) return 1 From 3032a47af4ce9d08628561f1b759975951d3ddc3 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 31 Mar 2023 01:42:57 +0900 Subject: [PATCH 5/9] =?UTF-8?q?cosine=E3=82=92sine=E3=81=AEreversed?= =?UTF-8?q?=E3=81=AB=E5=A4=89=E6=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- networks/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/lora.py b/networks/lora.py index f60789f83..bb8f356e2 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -441,7 +441,7 @@ def set_stratified_lr_weight(self, up_lr_weight:list[float]|str=None, mid_lr_wei def get_list(name) -> list[float]: import math if name=="cosine": - return [math.cos(math.pi*(i/(max_len-1))/2) for i in range(max_len)] + return [math.sin(math.pi*(i/(max_len-1))/2) for i in reversed(range(max_len))] elif name=="sine": return [math.sin(math.pi*(i/(max_len-1))/2) for i in range(max_len)] elif name=="linear": From 94441fa7468b90571e3c4107758639f3e441ee13 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 31 Mar 2023 02:26:54 +0900 Subject: [PATCH 6/9] =?UTF-8?q?=E7=B9=B0=E3=82=8A=E8=BF=94=E3=81=97?= =?UTF-8?q?=E5=9B=9E=E6=95=B0=E3=81=AE=E3=81=AA=E3=81=84=E3=83=87=E3=82=A3?= =?UTF-8?q?=E3=83=AC=E3=82=AF=E3=83=88=E3=83=AA=E3=81=AE=E5=90=8D=E5=89=8D?= =?UTF-8?q?=E8=A1=A8=E7=A4=BA=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- library/config_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/config_util.py b/library/config_util.py index 97bbb4a8d..217646752 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -445,7 +445,7 @@ def extract_dreambooth_params(name: str) -> Tuple[int, str]: try: n_repeats = int(tokens[0]) except ValueError as e: - print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}") + print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") return 0, "" caption_by_folder = '_'.join(tokens[1:]) return n_repeats, caption_by_folder From 1e164b6ec37eff1034c213628dfc75105922b233 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 31 Mar 2023 12:52:39 +0900 Subject: [PATCH 7/9] specify device when loading state_dict --- library/model_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/model_util.py b/library/model_util.py index 9b4405ebd..32a9c87af 100644 --- a/library/model_util.py +++ b/library/model_util.py @@ -866,7 +866,7 @@ def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): # TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認 def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, device='cpu', dtype=None): - _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) # no need to specify device in loading state_dict + _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path, device) # Convert the UNet2DConditionModel model. unet_config = create_unet_diffusers_config(v2) From 058e442072582f16a591bf5fb5f395f953767501 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 2 Apr 2023 04:02:34 +0900 Subject: [PATCH 8/9] =?UTF-8?q?=E3=83=AC=E3=82=A4=E3=83=A4=E3=83=BC?= =?UTF-8?q?=E6=95=B0=E5=A4=89=E6=9B=B4(hako-mikan/sd-webui-lora-block-weig?= =?UTF-8?q?ht=E5=8F=82=E8=80=83)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- networks/lora.py | 54 +++++++++++++++++++++--------------------------- 1 file changed, 23 insertions(+), 31 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index bb8f356e2..cfc517ce4 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -337,6 +337,7 @@ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) self.up_lr_weight:list[float] = None self.down_lr_weight:list[float] = None self.mid_lr_weight:float = None + self.stratified_lr = False def set_multiplier(self, multiplier): self.multiplier = multiplier @@ -434,10 +435,7 @@ def merge_to(self, text_encoder, unet, dtype, device): # 層別学習率用に層ごとの学習率に対する倍率を定義する def set_stratified_lr_weight(self, up_lr_weight:list[float]|str=None, mid_lr_weight:float=None, down_lr_weight:list[float]|str=None, zero_threshold:float=0.0): - max_len = 3 # attentions -> attentions -> attentions で3個のModuleに対して定義 - if self.apply_to_conv2d_3x3: - max_len = 10 # (resnets -> {up,down}sampler -> attentions) x3 -> resnets で10個のModuleに対して定義 - + max_len=12 # フルモデル相当でのup,downの層の数 def get_list(name) -> list[float]: import math if name=="cosine": @@ -469,6 +467,7 @@ def get_list(name) -> list[float]: up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): print("層別学習率を適用します。") + self.stratified_lr = True if (down_lr_weight != None): self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight[:max_len]] print("down_lr_weight(浅い層->深い層):",self.down_lr_weight) @@ -483,31 +482,22 @@ def get_list(name) -> list[float]: def get_stratified_lr_weight(self, lora:LoRAModule) -> float: m = RE_UPDOWN.search(lora.lora_name) if m: - idx = 0 g = m.groups() i = int(g[1]) - if self.apply_to_conv2d_3x3: - if g[2]=="resnets": - idx=3*i - elif g[2]=="attentions": - if g[0]=="down": - idx=3*i + 2 - else: - idx=3*i - 1 - elif g[2]=="upsamplers" or g[2]=="downsamplers": - idx=3*i + 1 - else: - idx=i - if g[0]=="up": - idx=i-1 - - if (g[0]=="up") and (self.up_lr_weight != None): + j = int(g[3]) + if g[2]=="resnets": + idx=3*i + j + elif g[2]=="attentions": + idx=3*i + j + elif g[2]=="upsamplers" or g[2]=="downsamplers": + idx=3*i + 2 + + if (g[0]=="down") and (self.down_lr_weight != None): + return self.down_lr_weight[idx+1] + elif (g[0]=="up") and (self.up_lr_weight != None): return self.up_lr_weight[idx] - elif (g[0]=="down") and (self.down_lr_weight != None): - return self.down_lr_weight[idx] - elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): - return self.mid_lr_weight - # print({'params': lora.parameters(), 'lr':alpha*lr}) + elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12 + return self.mid_lr_weight return 1 def prepare_optimizer_params(self, text_encoder_lr, unet_lr , default_lr): @@ -525,13 +515,15 @@ def prepare_optimizer_params(self, text_encoder_lr, unet_lr , default_lr): if self.unet_loras: for lora in self.unet_loras: - param_data={} + param_data = {'params': lora.parameters()} if unet_lr is not None: - param_data = {'params': lora.parameters(), 'lr':self.get_stratified_lr_weight(lora)*unet_lr} + param_data['lr'] = unet_lr elif default_lr is not None: - param_data = {'params': lora.parameters(), 'lr':self.get_stratified_lr_weight(lora)*default_lr} - if param_data["lr"]==0: - continue + param_data['lr'] = default_lr + if self.stratified_lr and ('lr' in param_data): + param_data['lr'] = param_data['lr'] * self.get_stratified_lr_weight(lora) + if (param_data['lr']==0): + continue all_params.append(param_data) return all_params From 19340d82e6fb2a081cadb5fc4c6f38aa627ea81d Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 2 Apr 2023 12:57:55 +0900 Subject: [PATCH 9/9] =?UTF-8?q?=E5=B1=A4=E5=88=A5=E5=AD=A6=E7=BF=92?= =?UTF-8?q?=E7=8E=87=E3=82=92=E4=BD=BF=E3=82=8F=E3=81=AA=E3=81=84=E5=A0=B4?= =?UTF-8?q?=E5=90=88=E3=81=ABparams=E3=82=92=E3=81=BE=E3=81=A8=E3=82=81?= =?UTF-8?q?=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- networks/lora.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index cfc517ce4..6e860a03d 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -514,17 +514,25 @@ def prepare_optimizer_params(self, text_encoder_lr, unet_lr , default_lr): all_params.append(param_data) if self.unet_loras: - for lora in self.unet_loras: - param_data = {'params': lora.parameters()} + if self.stratified_lr: + for lora in self.unet_loras: + param_data = {'params': lora.parameters()} + if unet_lr is not None: + param_data['lr'] = unet_lr * self.get_stratified_lr_weight(lora) + elif default_lr is not None: + param_data['lr'] = default_lr * self.get_stratified_lr_weight(lora) + if ('lr' in param_data) and (param_data['lr']==0): + continue + all_params.append(param_data) + else: + params = [] + for lora in self.unet_loras: + params.extend(lora.parameters()) + param_data = {'params': params} if unet_lr is not None: param_data['lr'] = unet_lr - elif default_lr is not None: - param_data['lr'] = default_lr - if self.stratified_lr and ('lr' in param_data): - param_data['lr'] = param_data['lr'] * self.get_stratified_lr_weight(lora) - if (param_data['lr']==0): - continue all_params.append(param_data) + return all_params def enable_gradient_checkpointing(self):