Skip to content

Commit

Permalink
Merge pull request #355 from u-haru/feature/stratified_lr
Browse files Browse the repository at this point in the history
LoRA レイヤー別学習率の実装、state_dict読み込みの際のdevice指定削除、typo修正
  • Loading branch information
kohya-ss authored Apr 2, 2023
2 parents f037b09 + 19340d8 commit 36c8a4a
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 20 deletions.
2 changes: 1 addition & 1 deletion library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def extract_dreambooth_params(name: str) -> Tuple[int, str]:
try:
n_repeats = int(tokens[0])
except ValueError as e:
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}")
print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {name}")
return 0, ""
caption_by_folder = '_'.join(tokens[1:])
return n_repeats, caption_by_folder
Expand Down
143 changes: 125 additions & 18 deletions networks/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from typing import List
import numpy as np
import torch
import re

from library import train_util

RE_UPDOWN = re.compile(r'(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_')

class LoRAModule(torch.nn.Module):
"""
Expand Down Expand Up @@ -177,7 +179,7 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
else:
conv_block_alphas = [int(a) for a in conv_block_alphas(',')]
assert len(conv_block_alphas) == NUM_BLOCKS, f"Number of block alphas is not same to {NUM_BLOCKS}"
"""
"""

network = LoRANetwork(
text_encoder,
Expand All @@ -188,6 +190,20 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
)

up_lr_weight=None
if 'up_lr_weight' in kwargs:
up_lr_weight = kwargs.get('up_lr_weight',None)
if "," in up_lr_weight:
up_lr_weight = [float(s) for s in up_lr_weight.split(",") if s]
down_lr_weight=None
if 'down_lr_weight' in kwargs:
down_lr_weight = kwargs.get('down_lr_weight',None)
if "," in down_lr_weight:
down_lr_weight = [float(s) for s in down_lr_weight.split(",") if s]
mid_lr_weight=float(kwargs.get('mid_lr_weight', 1.0)) if 'mid_lr_weight' in kwargs else None
network.set_stratified_lr_weight(up_lr_weight,mid_lr_weight,down_lr_weight,float(kwargs.get('stratified_zero_threshold', 0.0)))

return network


Expand Down Expand Up @@ -318,6 +334,11 @@ def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules)
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)

self.up_lr_weight:list[float] = None
self.down_lr_weight:list[float] = None
self.mid_lr_weight:float = None
self.stratified_lr = False

def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
Expand Down Expand Up @@ -366,9 +387,17 @@ def apply_to(self, text_encoder, unet, apply_text_encoder=None, apply_unet=None)
else:
self.unet_loras = []

skipped = []
for lora in self.text_encoder_loras + self.unet_loras:
if self.get_stratified_lr_weight(lora) == 0:
skipped.append(lora.lora_name)
continue
lora.apply_to()
self.add_module(lora.lora_name, lora)
if len(skipped)>0:
print(f"stratified_lr_weightが0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:")
for name in skipped:
print(f"\t{name}")

if self.weights_sd:
# if some weights are not in state dict, it is ok because initial LoRA does nothing (lora_up is initialized by zeros)
Expand Down Expand Up @@ -404,34 +433,112 @@ def merge_to(self, text_encoder, unet, dtype, device):
lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged")

def enable_gradient_checkpointing(self):
# not supported
pass

def prepare_optimizer_params(self, text_encoder_lr, unet_lr):
def enumerate_params(loras):
params = []
for lora in loras:
params.extend(lora.parameters())
return params

# 層別学習率用に層ごとの学習率に対する倍率を定義する
def set_stratified_lr_weight(self, up_lr_weight:list[float]|str=None, mid_lr_weight:float=None, down_lr_weight:list[float]|str=None, zero_threshold:float=0.0):
max_len=12 # フルモデル相当でのup,downの層の数
def get_list(name) -> list[float]:
import math
if name=="cosine":
return [math.sin(math.pi*(i/(max_len-1))/2) for i in reversed(range(max_len))]
elif name=="sine":
return [math.sin(math.pi*(i/(max_len-1))/2) for i in range(max_len)]
elif name=="linear":
return [i/(max_len-1) for i in range(max_len)]
elif name=="reverse_linear":
return [i/(max_len-1) for i in reversed(range(max_len))]
elif name=="zeros":
return [0.0] * max_len
else:
print("不明なlr_weightの引数 %s が使われました。\n\t有効な引数: cosine, sine, linear, reverse_linear, zeros"%(name))
return None

if type(down_lr_weight)==str:
down_lr_weight=get_list(down_lr_weight)
if type(up_lr_weight)==str:
up_lr_weight=get_list(up_lr_weight)

if (up_lr_weight != None and len(up_lr_weight)>max_len) or (down_lr_weight != None and len(down_lr_weight)>max_len):
print("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。"%max_len)
if (up_lr_weight != None and len(up_lr_weight)<max_len) or (down_lr_weight != None and len(down_lr_weight)<max_len):
print("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。"%max_len)
if down_lr_weight != None and len(down_lr_weight)<max_len:
down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight))
if up_lr_weight != None and len(up_lr_weight)<max_len:
up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight))
if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None):
print("層別学習率を適用します。")
self.stratified_lr = True
if (down_lr_weight != None):
self.down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight[:max_len]]
print("down_lr_weight(浅い層->深い層):",self.down_lr_weight)
if (mid_lr_weight != None):
self.mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0
print("mid_lr_weight:",self.mid_lr_weight)
if (up_lr_weight != None):
self.up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight[:max_len]]
print("up_lr_weight(深い層->浅い層):",self.up_lr_weight)
return

def get_stratified_lr_weight(self, lora:LoRAModule) -> float:
m = RE_UPDOWN.search(lora.lora_name)
if m:
g = m.groups()
i = int(g[1])
j = int(g[3])
if g[2]=="resnets":
idx=3*i + j
elif g[2]=="attentions":
idx=3*i + j
elif g[2]=="upsamplers" or g[2]=="downsamplers":
idx=3*i + 2

if (g[0]=="down") and (self.down_lr_weight != None):
return self.down_lr_weight[idx+1]
elif (g[0]=="up") and (self.up_lr_weight != None):
return self.up_lr_weight[idx]
elif ("mid_block_" in lora.lora_name) and (self.mid_lr_weight != None): # idx=12
return self.mid_lr_weight
return 1

def prepare_optimizer_params(self, text_encoder_lr, unet_lr , default_lr):
self.requires_grad_(True)
all_params = []

if self.text_encoder_loras:
param_data = {"params": enumerate_params(self.text_encoder_loras)}
params = []
for lora in self.text_encoder_loras:
params.extend(lora.parameters())
param_data = {'params': params}
if text_encoder_lr is not None:
param_data["lr"] = text_encoder_lr
param_data['lr'] = text_encoder_lr
all_params.append(param_data)

if self.unet_loras:
param_data = {"params": enumerate_params(self.unet_loras)}
if unet_lr is not None:
param_data["lr"] = unet_lr
all_params.append(param_data)
if self.stratified_lr:
for lora in self.unet_loras:
param_data = {'params': lora.parameters()}
if unet_lr is not None:
param_data['lr'] = unet_lr * self.get_stratified_lr_weight(lora)
elif default_lr is not None:
param_data['lr'] = default_lr * self.get_stratified_lr_weight(lora)
if ('lr' in param_data) and (param_data['lr']==0):
continue
all_params.append(param_data)
else:
params = []
for lora in self.unet_loras:
params.extend(lora.parameters())
param_data = {'params': params}
if unet_lr is not None:
param_data['lr'] = unet_lr
all_params.append(param_data)

return all_params

def enable_gradient_checkpointing(self):
# not supported
pass

def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)

Expand Down
2 changes: 1 addition & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def train(args):
# 学習に必要なクラスを準備する
print("prepare optimizer, data loader etc.")

trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)

# dataloaderを準備する
Expand Down

0 comments on commit 36c8a4a

Please sign in to comment.