Skip to content

Commit

Permalink
v 0.0.5 (#42)
Browse files Browse the repository at this point in the history
* Add parameter to control rank of decomposition (#28)

* ENH: allow controlling rank of approximation

* Training script accepts lora_rank

* feat : statefully monkeypatch different loras + example ipynb + readme

* Fix lora inject, added weight self apply lora (#39)

* Develop (#34)

* Add parameter to control rank of decomposition (#28)

* ENH: allow controlling rank of approximation

* Training script accepts lora_rank

* feat : statefully monkeypatch different loras + example ipynb + readme

Co-authored-by: brian6091 <brian6091@gmail.com>

* release : version 0.0.4, now able to tune rank, now add loras dynamically

* readme : add brain6091's discussions

* fix:inject lora in to_out module list

* feat: added weight self apply lora

* chore: add import copy

* fix: readded r

Co-authored-by: Simo Ryu <35953539+cloneofsimo@users.noreply.github.com>
Co-authored-by: brian6091 <brian6091@gmail.com>
Co-authored-by: SimoRyu <cloneofsimo@korea.ac.kr>

* Revert "Fix lora inject, added weight self apply lora (#39)" (#40)

This reverts commit fececf3.

* fix : rank bug in monkeypatch

* fix cli fix

* visualizatio on effect of LR

Co-authored-by: brian6091 <brian6091@gmail.com>
Co-authored-by: Davide Paglieri <paglieridavide@gmail.com>
  • Loading branch information
3 people authored Dec 15, 2022
1 parent 9b46b77 commit b64b1d4
Show file tree
Hide file tree
Showing 7 changed files with 918 additions and 40 deletions.
Binary file added contents/lora_diff_lrs.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added contents/lora_diff_lrs_0.6.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
35 changes: 25 additions & 10 deletions lora_diffusion/cli_lora_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,48 @@ def add(
] = "lpl",
with_text_lora: bool = False,
):
print("Lora Add, mode " + mode)
if mode == "lpl":
assert output_path.endswith(".pt"), "Only .pt files are supported"

for _path_1, _path_2 in (
[(path_1, path_2)] + [(_text_lora_path(path_1), _text_lora_path(path_2))]
for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + (
[(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")]
if with_text_lora
else []
):
print("Loading", _path_1, _path_2)
out_list = []
if opt == "text_encoder":
if not os.path.exists(_path_1):
print(f"No text encoder found in {_path_1}, skipping...")
continue
if not os.path.exists(_path_2):
print(f"No text encoder found in {_path_1}, skipping...")
continue

l1 = torch.load(_path_1)
l2 = torch.load(_path_2)

l1pairs = zip(l1[::2], l1[1::2])
l2pairs = zip(l2[::2], l2[1::2])

for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs):
# print("Merging", x1.shape, y1.shape, x2.shape, y2.shape)
x1.data = alpha * x1.data + (1 - alpha) * x2.data
y1.data = alpha * y1.data + (1 - alpha) * y2.data

out_list.append(x1)
out_list.append(y1)

torch.save(out_list, output_path)
if with_text_lora:
torch.save(
out_list,
_text_lora_path(output_path),
)
if opt == "unet":

print("Saving merged UNET to", output_path)
torch.save(out_list, output_path)

elif opt == "text_encoder":
print("Saving merged text encoder to", _text_lora_path(output_path))
torch.save(
out_list,
_text_lora_path(output_path),
)

elif mode == "upl":

Expand Down Expand Up @@ -96,6 +110,7 @@ def add(
shutil.rmtree(_tmp_output)

else:
print("Unknown mode", mode)
raise ValueError(f"Unknown mode {mode}")


Expand Down
10 changes: 6 additions & 4 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class LoraInjectedLinear(nn.Module):
def __init__(self, in_features, out_features, bias=False, r=4):
super().__init__()

if r >= min(in_features, out_features):
if r > min(in_features, out_features):
raise ValueError(
f"LoRA rank {r} must be less than {min(in_features, out_features)}"
f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
)

self.linear = nn.Linear(in_features, out_features, bias)
Expand Down Expand Up @@ -138,7 +138,7 @@ def weight_apply_lora(


def monkeypatch_lora(
model, loras, target_replace_module=["CrossAttention", "Attention"]
model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4
):
for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:
Expand All @@ -151,6 +151,7 @@ def monkeypatch_lora(
_child_module.in_features,
_child_module.out_features,
_child_module.bias is not None,
r=r,
)
_tmp.linear.weight = weight

Expand All @@ -174,7 +175,7 @@ def monkeypatch_lora(


def monkeypatch_replace_lora(
model, loras, target_replace_module=["CrossAttention", "Attention"]
model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4
):
for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:
Expand All @@ -187,6 +188,7 @@ def monkeypatch_replace_lora(
_child_module.linear.in_features,
_child_module.linear.out_features,
_child_module.linear.bias is not None,
r=r,
)
_tmp.linear.weight = weight

Expand Down
853 changes: 853 additions & 0 deletions scripts/lora_lr_effects.ipynb

Large diffs are not rendered by default.

58 changes: 33 additions & 25 deletions scripts/run_img2img.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name="lora_diffusion",
py_modules=["lora_diffusion"],
version="0.0.4",
version="0.0.5",
description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.",
author="Simo Ryu",
packages=find_packages(),
Expand Down

0 comments on commit b64b1d4

Please sign in to comment.