From edd83ea9181c860c09a0d6cd5a4775c430dba7ec Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Fri, 26 Jan 2024 17:03:53 +0100 Subject: [PATCH] fix sorting method for LoRA keys - support _out_0 - sort _in before _out - avoid false positives by only considering suffixes --- src/refiners/foundationals/latent_diffusion/lora.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index eb2c4917b..f96720ff7 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -44,7 +44,7 @@ def add_loras( loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)} # if no key contains "unet" or "text", assume all keys are for the unet - if all(["unet" not in key and "text" not in key for key in loras.keys()]): + if all("unet" not in key and "text" not in key for key in loras.keys()): loras = {f"unet_{key}": value for key, value in loras.items()} self.add_loras_to_unet(loras) @@ -141,13 +141,12 @@ def pad(input: str, /, padding_length: int = 2) -> str: @staticmethod def sort_keys(key: str, /) -> tuple[str, int]: - # out0 happens sometimes as an alias for out ; this dict might not be exhaustive - key_char_order = {"q": 1, "k": 2, "v": 3, "out": 4, "out0": 4} + # this dict might not be exhaustive + key_char_order = {"_q": 1, "_k": 2, "_v": 3, "_in": 3, "_out": 4, "_out0": 4, "_out_0": 4} - for i, s in enumerate(key.split("_")): - if s in key_char_order: - prefix = SDLoraManager.pad("_".join(key.split("_")[:i])) - return (prefix, key_char_order[s]) + for s in key_char_order: + if key.endswith(s): + return (SDLoraManager.pad(key.removesuffix(s)), key_char_order[s]) return (SDLoraManager.pad(key), 5)