From 57fa5ead92b612863d8904f3827ddd4ef873257a Mon Sep 17 00:00:00 2001 From: Maru-mee Date: Mon, 9 Sep 2024 17:57:53 +0900 Subject: [PATCH 1/4] Support : OFT merge to base model --- networks/sdxl_merge_lora.py | 197 +++++++++++++++++++++++++++--------- 1 file changed, 149 insertions(+), 48 deletions(-) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index 3383a80de..49cf0081b 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -8,11 +8,15 @@ from library import sai_model_spec, sdxl_model_util, train_util import library.model_util as model_util import lora +import oft from library.utils import setup_logging setup_logging() import logging logger = logging.getLogger(__name__) +import concurrent.futures + + def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": sd = load_file(file_name) @@ -39,82 +43,179 @@ def save_to_file(file_name, model, state_dict, dtype, metadata): else: torch.save(model, file_name) +def detect_method_from_training_model(models, dtype): + for model in models: + lora_sd, _ = load_state_dict(model, dtype) + for key in tqdm(lora_sd.keys()): + if 'lora_up' in key or 'lora_down' in key: + return 'LoRA' + elif "oft_blocks" in key: + return 'OFT' def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype): text_encoder1.to(merge_dtype) text_encoder1.to(merge_dtype) unet.to(merge_dtype) + + # 方式を判定。OFT or LoRA_module + method = detect_method_from_training_model(models, merge_dtype) + + + print(f"method:{method}") # create module map name_to_module = {} for i, root_module in enumerate([text_encoder1, text_encoder2, unet]): - if i <= 1: - if i == 0: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 + if method == 'LoRA': + if i <= 1: + if i == 0: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 + else: + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 + target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE else: - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 - target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE - else: - prefix = lora.LoRANetwork.LORA_PREFIX_UNET - target_replace_modules = ( + prefix = lora.LoRANetwork.LORA_PREFIX_UNET + target_replace_modules = ( lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 + ) + elif method == 'OFT': + prefix = oft.OFTNetwork.OFT_PREFIX_UNET + target_replace_modules = ( + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 ) for name, module in root_module.named_modules(): if module.__class__.__name__ in target_replace_modules: for child_name, child_module in module.named_modules(): - if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") - name_to_module[lora_name] = child_module - + if method == 'LoRA': + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + name_to_module[lora_name] = child_module + elif method == 'OFT': + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "conv2d_1x1" or child_module.__class__.__name__ == "conv2d": + oft_name = prefix + "." + name + "." + child_name + oft_name = oft_name.replace(".", "_") + name_to_module[oft_name] = child_module + + for model, ratio in zip(models, ratios): logger.info(f"loading: {model}") lora_sd, _ = load_state_dict(model, merge_dtype) logger.info(f"merging...") - for key in tqdm(lora_sd.keys()): - if "lora_down" in key: - up_key = key.replace("lora_down", "lora_up") - alpha_key = key[: key.index("lora_down")] + "alpha" - # find original module for this lora - module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if method == 'LoRA': + for key in tqdm(lora_sd.keys()): + if "lora_down" in key: + up_key = key.replace("lora_down", "lora_up") + alpha_key = key[: key.index("lora_down")] + "alpha" + + # find original module for this lora + module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" + if module_name not in name_to_module: + logger.info(f"no module found for LoRA weight: {key}") + continue + module = name_to_module[module_name] + # logger.info(f"apply {key} to {module}") + + down_weight = lora_sd[key] + up_weight = lora_sd[up_key] + + dim = down_weight.size()[0] + alpha = lora_sd.get(alpha_key, dim) + scale = alpha / dim + + # W <- W + U * D + weight = module.weight + # logger.info(module_name, down_weight.size(), up_weight.size()) + if len(weight.size()) == 2: + # linear + weight = weight + ratio * (up_weight @ down_weight) * scale + elif down_weight.size()[2:4] == (1, 1): + # conv2d 1x1 + weight = ( + weight + + ratio + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + * scale + ) + else: + # conv2d 3x3 + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) + # logger.info(conved.size(), weight.size(), module.stride, module.padding) + weight = weight + ratio * conved * scale + + module.weight = torch.nn.Parameter(weight) + + + elif method == 'OFT': + + multiplier=1.0 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + for key in tqdm(lora_sd.keys()): + if "oft_blocks" in key: + oft_blocks = lora_sd[key] + dim = oft_blocks.shape[0] + break + for key in tqdm(lora_sd.keys()): + if "alpha" in key: + oft_blocks = lora_sd[key] + alpha = oft_blocks.item() + break + + def merge_to(key): + if "alpha" in key: + return + + # find original module for this OFT + module_name = ".".join(key.split(".")[:-1]) if module_name not in name_to_module: - logger.info(f"no module found for LoRA weight: {key}") - continue + return module = name_to_module[module_name] - # logger.info(f"apply {key} to {module}") - down_weight = lora_sd[key] - up_weight = lora_sd[up_key] - - dim = down_weight.size()[0] - alpha = lora_sd.get(alpha_key, dim) - scale = alpha / dim - - # W <- W + U * D - weight = module.weight - # logger.info(module_name, down_weight.size(), up_weight.size()) - if len(weight.size()) == 2: - # linear - weight = weight + ratio * (up_weight @ down_weight) * scale - elif down_weight.size()[2:4] == (1, 1): - # conv2d 1x1 - weight = ( - weight - + ratio - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) - * scale - ) + # logger.info(f"apply {key} to {module}") + + oft_blocks = lora_sd[key] + + if isinstance(module, torch.nn.Linear): + out_dim = module.out_features + elif isinstance(module, torch.nn.Conv2d): + out_dim = module.out_channels + + num_blocks = dim + block_size = out_dim // dim + constraint = (0 if alpha is None else alpha) * out_dim + eye = torch.eye(block_size, device=oft_blocks.device) + + block_Q = oft_blocks - oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + I = torch.eye(block_size, device=oft_blocks.device).unsqueeze(0).repeat(num_blocks, 1, 1) + block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + block_R_weighted = multiplier * block_R + (1 - multiplier) * I + R = torch.block_diag(*block_R_weighted) + + # get org weight + org_sd = module.state_dict() + org_weight = org_sd["weight"].to(device) + + R = R.to(org_weight.device, dtype=org_weight.dtype) + + if org_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", org_weight, R) else: - # conv2d 3x3 - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) - # logger.info(conved.size(), weight.size(), module.stride, module.padding) - weight = weight + ratio * conved * scale - + weight = torch.einsum("oi, op -> pi", org_weight, R) + + weight = weight.contiguous() # Tensorを連続化。ThreadPoolExecutor実行したので必要。 + module.weight = torch.nn.Parameter(weight) + with concurrent.futures.ThreadPoolExecutor() as executor: + list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) + def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False): base_alphas = {} # alpha for merged model From 4c85bc896bcbbe4a551e4066c0174e6c53d4a078 Mon Sep 17 00:00:00 2001 From: Maru-mee <151493593+Maru-mee@users.noreply.github.com> Date: Tue, 10 Sep 2024 14:42:15 +0900 Subject: [PATCH 2/4] Fix typo --- networks/sdxl_merge_lora.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index 49cf0081b..51e9b6835 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -13,10 +13,8 @@ setup_logging() import logging logger = logging.getLogger(__name__) - import concurrent.futures - def load_state_dict(file_name, dtype): if os.path.splitext(file_name)[1] == ".safetensors": sd = load_file(file_name) @@ -57,11 +55,9 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ text_encoder1.to(merge_dtype) unet.to(merge_dtype) - # 方式を判定。OFT or LoRA_module + # detect the method。OFT or LoRA_module method = detect_method_from_training_model(models, merge_dtype) - - - print(f"method:{method}") + logger.info(f"method:{method}") # create module map name_to_module = {} @@ -93,7 +89,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ lora_name = lora_name.replace(".", "_") name_to_module[lora_name] = child_module elif method == 'OFT': - if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "conv2d_1x1" or child_module.__class__.__name__ == "conv2d": + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": oft_name = prefix + "." + name + "." + child_name oft_name = oft_name.replace(".", "_") name_to_module[oft_name] = child_module @@ -209,7 +205,7 @@ def merge_to(key): else: weight = torch.einsum("oi, op -> pi", org_weight, R) - weight = weight.contiguous() # Tensorを連続化。ThreadPoolExecutor実行したので必要。 + weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor module.weight = torch.nn.Parameter(weight) From 3bf09b0266bb099e46176be5b2c7cf9e14ab56ac Mon Sep 17 00:00:00 2001 From: Maru-mee <151493593+Maru-mee@users.noreply.github.com> Date: Tue, 10 Sep 2024 14:47:41 +0900 Subject: [PATCH 3/4] Fix typo_2 --- networks/sdxl_merge_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index 51e9b6835..e75d4bdfd 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -55,7 +55,7 @@ def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_ text_encoder1.to(merge_dtype) unet.to(merge_dtype) - # detect the method。OFT or LoRA_module + # detect the method: OFT or LoRA_module method = detect_method_from_training_model(models, merge_dtype) logger.info(f"method:{method}") From 8d32cbd0f2f84ebf8e6abb88c7def6e3dc4779a7 Mon Sep 17 00:00:00 2001 From: Maru-mee <151493593+Maru-mee@users.noreply.github.com> Date: Thu, 12 Sep 2024 20:35:30 +0900 Subject: [PATCH 4/4] Delete unused parameter 'eye' --- networks/sdxl_merge_lora.py | 1 - 1 file changed, 1 deletion(-) diff --git a/networks/sdxl_merge_lora.py b/networks/sdxl_merge_lora.py index e75d4bdfd..2c998c8cb 100644 --- a/networks/sdxl_merge_lora.py +++ b/networks/sdxl_merge_lora.py @@ -183,7 +183,6 @@ def merge_to(key): num_blocks = dim block_size = out_dim // dim constraint = (0 if alpha is None else alpha) * out_dim - eye = torch.eye(block_size, device=oft_blocks.device) block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten())