Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA レイヤー別学習率の実装、state_dict読み込みの際のdevice指定削除、typo修正 #355

Merged
merged 11 commits into from
Apr 2, 2023
2 changes: 1 addition & 1 deletion library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def extract_dreambooth_params(name: str) -> Tuple[int, str]:
try:
n_repeats = int(tokens[0])
except ValueError as e:
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
return 0, ""
caption_by_folder = '_'.join(tokens[1:])
return n_repeats, caption_by_folder
Expand Down
143 changes: 125 additions & 18 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from typing import List
import numpy as np
import torch
import re

from library import train_util

RE_UPDOWN = re.compile(r'(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_')

class LoRAModule(torch.nn.Module):
"""
Expand Down Expand Up @@ -177,7 +179,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
else:
conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
"""
"""

network = LoRANetwork(
text_encoder,
Expand All @@ -188,6 +190,20 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
)

up_lr_weight=None
if 'up_lr_weight' in kwargs:
up_lr_weight = kwargs.get('up_lr_weight',None)
if "," in up_lr_weight:
up_lr_weight = [float(s) for s in up_lr_weight.split(",") if s]
down_lr_weight=None
if 'down_lr_weight' in kwargs:
down_lr_weight = kwargs.get('down_lr_weight',None)
if "," in down_lr_weight:
down_lr_weight = [float(s) for s in down_lr_weight.split(",") if s]
mid_lr_weight=float(kwargs.get('mid_lr_weight', 1.0)) if 'mid_lr_weight' in kwargs else None
network.set_stratified_lr_weight(up_lr_weight,mid_lr_weight,down_lr_weight,float(kwargs.get('stratified_zero_threshold', 0.0)))

return network


Expand Down Expand Up @@ -318,6 +334,11 @@ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules)
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)

self.up_lr_weight:list[float] = None
self.down_lr_weight:list[float] = None
self.mid_lr_weight:float = None
self.stratified_lr = False

def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
Expand Down Expand Up @@ -366,9 +387,17 @@ def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None)
else:
self.unet_loras = []

skipped = []
for lora in self.text_encoder_loras + self.unet_loras:
if self.get_stratified_lr_weight(lora) == 0:
skipped.append(lora.lora_name)
continue
lora.apply_to()
self.add_module(lora.lora_name, lora)
if len(skipped)>0:
print(f"stratified_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:")
for name in skipped:
print(f"\t{name}")
Comment on lines +397 to +400
Copy link
Owner

@kohya-ss kohya-ss Apr 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ここで text_encoder_loras および unet_loras から該当の LoRA を削除しておいても良さそうですね。

いろいろ考慮が必要そうです。


if self.weights_sd:
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
Expand Down Expand Up @@ -404,34 +433,112 @@ def merge_to(self, text_encoder, unet, dtype, device):
lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged")

def enable_gradient_checkpointing(self):
# not supported
pass

def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
def enumerate_params(loras):
params = []
for lora in loras:
params.extend(lora.parameters())
return params

# 層別学習率用に層ごとの学習率に対する倍率を定義する
def set_stratified_lr_weight(self, up_lr_weight:list[float]|str=None, mid_lr_weight:float=None, down_lr_weight:list[float]|str=None, zero_threshold:float=0.0):
max_len=12 # フルモデル相当でのup,downの層の数
def get_list(name) -> list[float]:
import math
if name=="cosine":
return [math.sin(math.pi*(i/(max_len-1))/2) for i in reversed(range(max_len))]
elif name=="sine":
return [math.sin(math.pi*(i/(max_len-1))/2) for i in range(max_len)]
elif name=="linear":
return [i/(max_len-1) for i in range(max_len)]
elif name=="reverse_linear":
return [i/(max_len-1) for i in reversed(range(max_len))]
elif name=="zeros":
return [0.0] * max_len
else:
print("不明なlr_weightの引数 %s が使われました。\n\t有効な引数: cosine, sine, linear, reverse_linear, zeros"%(name))
return None

if type(down_lr_weight)==str:
down_lr_weight=get_list(down_lr_weight)
if type(up_lr_weight)==str:
up_lr_weight=get_list(up_lr_weight)

if (up_lr_weight != None and len(up_lr_weight)>max_len) or (down_lr_weight != None and len(down_lr_weight)>max_len):
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。"%max_len)
if (up_lr_weight != None and len(up_lr_weight)<max_len) or (down_lr_weight != None and len(down_lr_weight)<max_len):
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。"%max_len)
if down_lr_weight != None and len(down_lr_weight)<max_len:
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
if up_lr_weight != None and len(up_lr_weight)<max_len:
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
print("層別学習率を適用します。")
self.stratified_lr = True
if (down_lr_weight != None):
self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight[:max_len]]
print("down_lr_weight(浅い層->深い層):",self.down_lr_weight)
if (mid_lr_weight != None):
self.mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
print("mid_lr_weight:",self.mid_lr_weight)
if (up_lr_weight != None):
self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight[:max_len]]
print("up_lr_weight(深い層->浅い層):",self.up_lr_weight)
return

def get_stratified_lr_weight(self, lora:LoRAModule) -> float:
m = RE_UPDOWN.search(lora.lora_name)
if m:
g = m.groups()
i = int(g[1])
j = int(g[3])
if g[2]=="resnets":
idx=3*i + j
elif g[2]=="attentions":
idx=3*i + j
elif g[2]=="upsamplers" or g[2]=="downsamplers":
idx=3*i + 2

if (g[0]=="down") and (self.down_lr_weight != None):
return self.down_lr_weight[idx+1]
elif (g[0]=="up") and (self.up_lr_weight != None):
return self.up_lr_weight[idx]
elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12
return self.mid_lr_weight
return 1

def prepare_optimizer_params(self, text_encoder_lr, unet_lr , default_lr):
self.requires_grad_(True)
all_params = []

if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
params = []
for lora in self.text_encoder_loras:
params.extend(lora.parameters())
param_data = {'params': params}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
param_data['lr'] = text_encoder_lr
all_params.append(param_data)

if self.unet_loras:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
if self.stratified_lr:
for lora in self.unet_loras:
param_data = {'params': lora.parameters()}
if unet_lr is not None:
param_data['lr'] = unet_lr * self.get_stratified_lr_weight(lora)
elif default_lr is not None:
param_data['lr'] = default_lr * self.get_stratified_lr_weight(lora)
if ('lr' in param_data) and (param_data['lr']==0):
continue
all_params.append(param_data)
else:
params = []
for lora in self.unet_loras:
params.extend(lora.parameters())
param_data = {'params': params}
if unet_lr is not None:
param_data['lr'] = unet_lr
all_params.append(param_data)

return all_params

def enable_gradient_checkpointing(self):
# not supported
pass

def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)

Expand Down
2 changes: 1 addition & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def train(args):
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")

trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)

# dataloaderを準備する
Expand Down