Skip to content

Commit

Permalink
fix sorting method for LoRA keys
Browse files Browse the repository at this point in the history
- support _out_0
- sort _in before _out
- avoid false positives by only considering suffixes
  • Loading branch information
catwell committed Jan 26, 2024
1 parent 7bb531e commit edd83ea
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/refiners/foundationals/latent_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def add_loras(
loras = {key: loras[key] for key in sorted(loras.keys(), key=SDLoraManager.sort_keys)}

# if no key contains "unet" or "text", assume all keys are for the unet
if all(["unet" not in key and "text" not in key for key in loras.keys()]):
if all("unet" not in key and "text" not in key for key in loras.keys()):
loras = {f"unet_{key}": value for key, value in loras.items()}

self.add_loras_to_unet(loras)
Expand Down Expand Up @@ -141,13 +141,12 @@ def pad(input: str, /, padding_length: int = 2) -> str:

@staticmethod
def sort_keys(key: str, /) -> tuple[str, int]:
# out0 happens sometimes as an alias for out ; this dict might not be exhaustive
key_char_order = {"q": 1, "k": 2, "v": 3, "out": 4, "out0": 4}
# this dict might not be exhaustive
key_char_order = {"_q": 1, "_k": 2, "_v": 3, "_in": 3, "_out": 4, "_out0": 4, "_out_0": 4}

for i, s in enumerate(key.split("_")):
if s in key_char_order:
prefix = SDLoraManager.pad("_".join(key.split("_")[:i]))
return (prefix, key_char_order[s])
for s in key_char_order:
if key.endswith(s):
return (SDLoraManager.pad(key.removesuffix(s)), key_char_order[s])

return (SDLoraManager.pad(key), 5)

Expand Down

0 comments on commit edd83ea

Please sign in to comment.