Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable ability to resize lora dim based off sv ratios #243

Merged
merged 5 commits into from
Mar 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 186 additions & 49 deletions networks/resize_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
# Thanks to cloneofsimo and kohya

import argparse
import os
import torch
from safetensors.torch import load_file, save_file, safe_open
from tqdm import tqdm
from library import train_util, model_util
import numpy as np

MIN_SV = 1e-6

def load_state_dict(file_name, dtype):
if model_util.is_safetensors(file_name):
Expand Down Expand Up @@ -38,12 +39,149 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
torch.save(model, file_name)


def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
def index_sv_cumulative(S, target):
original_sum = float(torch.sum(S))
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
index = int(torch.searchsorted(cumulative_sums, target)) + 1
if index >= len(S):
index = len(S) - 1

return index


def index_sv_fro(S, target):
S_squared = S.pow(2)
s_fro_sq = float(torch.sum(S_squared))
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
if index >= len(S):
index = len(S) - 1

return index


# Modified from Kohaku-blueleaf's extract/merge functions
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
out_size, in_size, kernel_size, _ = weight.size()
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))

param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
lora_rank = param_dict["new_rank"]

U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]

param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
del U, S, Vh, weight
return param_dict


def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
out_size, in_size = weight.size()

U, S, Vh = torch.linalg.svd(weight.to(device))

param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
lora_rank = param_dict["new_rank"]

U = U[:, :lora_rank]
S = S[:lora_rank]
U = U @ torch.diag(S)
Vh = Vh[:lora_rank, :]

param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
del U, S, Vh, weight
return param_dict


def merge_conv(lora_down, lora_up, device):
in_rank, in_size, kernel_size, k_ = lora_down.shape
out_size, out_rank, _, _ = lora_up.shape
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"

lora_down = lora_down.to(device)
lora_up = lora_up.to(device)

merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
del lora_up, lora_down
return weight


def merge_linear(lora_down, lora_up, device):
in_rank, in_size = lora_down.shape
out_size, out_rank = lora_up.shape
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"

lora_down = lora_down.to(device)
lora_up = lora_up.to(device)

weight = lora_up @ lora_down
del lora_up, lora_down
return weight


def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
param_dict = {}

if dynamic_method=="sv_ratio":
# Calculate new dim and alpha based off ratio
max_sv = S[0]
min_sv = max_sv/dynamic_param
new_rank = max(torch.sum(S > min_sv).item(),1)
new_alpha = float(scale*new_rank)

elif dynamic_method=="sv_cumulative":
# Calculate new dim and alpha based off cumulative sum
new_rank = index_sv_cumulative(S, dynamic_param)
new_rank = max(new_rank, 1)
new_alpha = float(scale*new_rank)

elif dynamic_method=="sv_fro":
# Calculate new dim and alpha based off sqrt sum of squares
new_rank = index_sv_fro(S, dynamic_param)
new_rank = min(max(new_rank, 1), len(S)-1)
new_alpha = float(scale*new_rank)
else:
new_rank = rank
new_alpha = float(scale*new_rank)


if S[0] <= MIN_SV: # Zero matrix, set dim to 1
new_rank = 1
new_alpha = float(scale*new_rank)
elif new_rank > rank: # cap max rank at rank
new_rank = rank
new_alpha = float(scale*new_rank)


# Calculate resize info
s_sum = torch.sum(torch.abs(S))
s_rank = torch.sum(torch.abs(S[:new_rank]))

S_squared = S.pow(2)
s_fro = torch.sqrt(torch.sum(S_squared))
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
fro_percent = float(s_red_fro/s_fro)

param_dict["new_rank"] = new_rank
param_dict["new_alpha"] = new_alpha
param_dict["sum_retained"] = (s_rank)/s_sum
param_dict["fro_retained"] = fro_percent
param_dict["max_ratio"] = S[0]/S[new_rank]

return param_dict


def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
network_alpha = None
network_dim = None
verbose_str = "\n"

CLAMP_QUANTILE = 0.99
fro_list = []

# Extract loaded lora dim and alpha
for key, value in lora_sd.items():
Expand All @@ -57,9 +195,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
network_alpha = network_dim

scale = network_alpha/network_dim
new_alpha = float(scale*new_rank) # calculate new alpha from scale

print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
if dynamic_method:
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")

lora_down_weight = None
lora_up_weight = None
Expand All @@ -68,7 +206,6 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
block_down_name = None
block_up_name = None

print("resizing lora...")
with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
if 'lora_down' in key:
Expand All @@ -85,57 +222,43 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
conv2d = (len(lora_down_weight.size()) == 4)

if conv2d:
lora_down_weight = lora_down_weight.squeeze()
lora_up_weight = lora_up_weight.squeeze()

if device:
org_device = lora_up_weight.device
lora_up_weight = lora_up_weight.to(args.device)
lora_down_weight = lora_down_weight.to(args.device)

full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)

U, S, Vh = torch.linalg.svd(full_weight_matrix)
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
else:
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)

if verbose:
s_sum = torch.sum(torch.abs(S))
s_rank = torch.sum(torch.abs(S[:new_rank]))
verbose_str+=f"{block_down_name:76} | "
verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n"
max_ratio = param_dict['max_ratio']
sum_retained = param_dict['sum_retained']
fro_retained = param_dict['fro_retained']
if not np.isnan(fro_retained):
fro_list.append(float(fro_retained))

U = U[:, :new_rank]
S = S[:new_rank]
U = U @ torch.diag(S)
verbose_str+=f"{block_down_name:75} | "
verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"

Vh = Vh[:new_rank, :]
if verbose and dynamic_method:
verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
else:
verbose_str+=f"\n"

dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val

U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)

if conv2d:
U = U.unsqueeze(2).unsqueeze(3)
Vh = Vh.unsqueeze(2).unsqueeze(3)

if device:
U = U.to(org_device)
Vh = Vh.to(org_device)

o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
new_alpha = param_dict['new_alpha']
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)

block_down_name = None
block_up_name = None
lora_down_weight = None
lora_up_weight = None
weights_loaded = False
del param_dict

if verbose:
print(verbose_str)

print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
print("resizing complete")
return o_lora_sd, network_dim, new_alpha

Expand All @@ -151,6 +274,9 @@ def str_to_dtype(p):
return torch.bfloat16
return None

if args.dynamic_method and not args.dynamic_param:
raise Exception("If using dynamic_method, then dynamic_param is required")

merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
Expand All @@ -159,17 +285,23 @@ def str_to_dtype(p):
print("loading Model...")
lora_sd, metadata = load_state_dict(args.model, merge_dtype)

print("resizing rank...")
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
print("Resizing Lora...")
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)

# update metadata
if metadata is None:
metadata = {}

comment = metadata.get("ss_training_comment", "")
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
metadata["ss_network_dim"] = str(args.new_rank)
metadata["ss_network_alpha"] = str(new_alpha)

if not args.dynamic_method:
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
metadata["ss_network_dim"] = str(args.new_rank)
metadata["ss_network_alpha"] = str(new_alpha)
else:
metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
metadata["ss_network_dim"] = 'Dynamic'
metadata["ss_network_alpha"] = 'Dynamic'

model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
Expand All @@ -193,6 +325,11 @@ def str_to_dtype(p):
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
parser.add_argument("--verbose", action="store_true",
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
parser.add_argument("--dynamic_param", type=float, default=None,
help="Specify target for dynamic reduction")


args = parser.parse_args()
resize(args)