diff --git a/library/config_util.py b/library/config_util.py index 97bbb4a8d..217646752 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -445,7 +445,7 @@ def extract_dreambooth_params(name: str) -> Tuple[int, str]: try: n_repeats = int(tokens[0]) except ValueError as e: - print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}") + print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}") return 0, "" caption_by_folder = '_'.join(tokens[1:]) return n_repeats, caption_by_folder diff --git a/networks/lora.py b/networks/lora.py index 2bf785118..6e860a03d 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -8,9 +8,11 @@ from typing import List import numpy as np import torch +import re from library import train_util +RE_UPDOWN = re.compile(r'(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_') class LoRAModule(torch.nn.Module): """ @@ -177,7 +179,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un else: conv_block_alphas = [int(a) for a in conv_block_alphas(',')] assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}" - """ + """ network = LoRANetwork( text_encoder, @@ -188,6 +190,20 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un conv_lora_dim=conv_dim, conv_alpha=conv_alpha, ) + + up_lr_weight=None + if 'up_lr_weight' in kwargs: + up_lr_weight = kwargs.get('up_lr_weight',None) + if "," in up_lr_weight: + up_lr_weight = [float(s) for s in up_lr_weight.split(",") if s] + down_lr_weight=None + if 'down_lr_weight' in kwargs: + down_lr_weight = kwargs.get('down_lr_weight',None) + if "," in down_lr_weight: + down_lr_weight = [float(s) for s in down_lr_weight.split(",") if s] + mid_lr_weight=float(kwargs.get('mid_lr_weight', 1.0)) if 'mid_lr_weight' in kwargs else None + network.set_stratified_lr_weight(up_lr_weight,mid_lr_weight,down_lr_weight,float(kwargs.get('stratified_zero_threshold', 0.0))) + return network @@ -318,6 +334,11 @@ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) + self.up_lr_weight:list[float] = None + self.down_lr_weight:list[float] = None + self.mid_lr_weight:float = None + self.stratified_lr = False + def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: @@ -366,9 +387,17 @@ def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None) else: self.unet_loras = [] + skipped = [] for lora in self.text_encoder_loras + self.unet_loras: + if self.get_stratified_lr_weight(lora) == 0: + skipped.append(lora.lora_name) + continue lora.apply_to() self.add_module(lora.lora_name, lora) + if len(skipped)>0: + print(f"stratified_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:") + for name in skipped: + print(f"\t{name}") if self.weights_sd: # if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros) @@ -404,34 +433,112 @@ def merge_to(self, text_encoder, unet, dtype, device): lora.merge_to(sd_for_lora, dtype, device) print(f"weights are merged") - def enable_gradient_checkpointing(self): - # not supported - pass - - def prepare_optimizer_params(self, text_encoder_lr, unet_lr): - def enumerate_params(loras): - params = [] - for lora in loras: - params.extend(lora.parameters()) - return params - + # 層別学習率用に層ごとの学習率に対する倍率を定義する + def set_stratified_lr_weight(self, up_lr_weight:list[float]|str=None, mid_lr_weight:float=None, down_lr_weight:list[float]|str=None, zero_threshold:float=0.0): + max_len=12 # フルモデル相当でのup,downの層の数 + def get_list(name) -> list[float]: + import math + if name=="cosine": + return [math.sin(math.pi*(i/(max_len-1))/2) for i in reversed(range(max_len))] + elif name=="sine": + return [math.sin(math.pi*(i/(max_len-1))/2) for i in range(max_len)] + elif name=="linear": + return [i/(max_len-1) for i in range(max_len)] + elif name=="reverse_linear": + return [i/(max_len-1) for i in reversed(range(max_len))] + elif name=="zeros": + return [0.0] * max_len + else: + print("不明なlr_weightの引数 %s が使われました。\n\t有効な引数: cosine, sine, linear, reverse_linear, zeros"%(name)) + return None + + if type(down_lr_weight)==str: + down_lr_weight=get_list(down_lr_weight) + if type(up_lr_weight)==str: + up_lr_weight=get_list(up_lr_weight) + + if (up_lr_weight != None and len(up_lr_weight)>max_len) or (down_lr_weight != None and len(down_lr_weight)>max_len): + print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。"%max_len) + if (up_lr_weight != None and len(up_lr_weight) zero_threshold else 0 for w in down_lr_weight[:max_len]] + print("down_lr_weight(浅い層->深い層):",self.down_lr_weight) + if (mid_lr_weight != None): + self.mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 + print("mid_lr_weight:",self.mid_lr_weight) + if (up_lr_weight != None): + self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight[:max_len]] + print("up_lr_weight(深い層->浅い層):",self.up_lr_weight) + return + + def get_stratified_lr_weight(self, lora:LoRAModule) -> float: + m = RE_UPDOWN.search(lora.lora_name) + if m: + g = m.groups() + i = int(g[1]) + j = int(g[3]) + if g[2]=="resnets": + idx=3*i + j + elif g[2]=="attentions": + idx=3*i + j + elif g[2]=="upsamplers" or g[2]=="downsamplers": + idx=3*i + 2 + + if (g[0]=="down") and (self.down_lr_weight != None): + return self.down_lr_weight[idx+1] + elif (g[0]=="up") and (self.up_lr_weight != None): + return self.up_lr_weight[idx] + elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12 + return self.mid_lr_weight + return 1 + + def prepare_optimizer_params(self, text_encoder_lr, unet_lr , default_lr): self.requires_grad_(True) all_params = [] if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} + params = [] + for lora in self.text_encoder_loras: + params.extend(lora.parameters()) + param_data = {'params': params} if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr + param_data['lr'] = text_encoder_lr all_params.append(param_data) if self.unet_loras: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + if self.stratified_lr: + for lora in self.unet_loras: + param_data = {'params': lora.parameters()} + if unet_lr is not None: + param_data['lr'] = unet_lr * self.get_stratified_lr_weight(lora) + elif default_lr is not None: + param_data['lr'] = default_lr * self.get_stratified_lr_weight(lora) + if ('lr' in param_data) and (param_data['lr']==0): + continue + all_params.append(param_data) + else: + params = [] + for lora in self.unet_loras: + params.extend(lora.parameters()) + param_data = {'params': params} + if unet_lr is not None: + param_data['lr'] = unet_lr + all_params.append(param_data) return all_params + def enable_gradient_checkpointing(self): + # not supported + pass + def prepare_grad_etc(self, text_encoder, unet): self.requires_grad_(True) diff --git a/train_network.py b/train_network.py index 476f76dfc..2b824018f 100644 --- a/train_network.py +++ b/train_network.py @@ -196,7 +196,7 @@ def train(args): # 学習に必要なクラスを準備する print("prepare optimizer, data loader etc.") - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する