diff --git a/contents/lora_diff_lrs.jpg b/contents/lora_diff_lrs.jpg new file mode 100644 index 0000000..1ab1b79 Binary files /dev/null and b/contents/lora_diff_lrs.jpg differ diff --git a/contents/lora_diff_lrs_0.6.jpg b/contents/lora_diff_lrs_0.6.jpg new file mode 100644 index 0000000..6d625a3 Binary files /dev/null and b/contents/lora_diff_lrs_0.6.jpg differ diff --git a/lora_diffusion/cli_lora_add.py b/lora_diffusion/cli_lora_add.py index 240f930..3a416af 100644 --- a/lora_diffusion/cli_lora_add.py +++ b/lora_diffusion/cli_lora_add.py @@ -26,15 +26,23 @@ def add( ] = "lpl", with_text_lora: bool = False, ): + print("Lora Add, mode " + mode) if mode == "lpl": - assert output_path.endswith(".pt"), "Only .pt files are supported" - - for _path_1, _path_2 in ( - [(path_1, path_2)] + [(_text_lora_path(path_1), _text_lora_path(path_2))] + for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + ( + [(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")] if with_text_lora else [] ): + print("Loading", _path_1, _path_2) out_list = [] + if opt == "text_encoder": + if not os.path.exists(_path_1): + print(f"No text encoder found in {_path_1}, skipping...") + continue + if not os.path.exists(_path_2): + print(f"No text encoder found in {_path_1}, skipping...") + continue + l1 = torch.load(_path_1) l2 = torch.load(_path_2) @@ -42,18 +50,24 @@ def add( l2pairs = zip(l2[::2], l2[1::2]) for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs): + # print("Merging", x1.shape, y1.shape, x2.shape, y2.shape) x1.data = alpha * x1.data + (1 - alpha) * x2.data y1.data = alpha * y1.data + (1 - alpha) * y2.data out_list.append(x1) out_list.append(y1) - torch.save(out_list, output_path) - if with_text_lora: - torch.save( - out_list, - _text_lora_path(output_path), - ) + if opt == "unet": + + print("Saving merged UNET to", output_path) + torch.save(out_list, output_path) + + elif opt == "text_encoder": + print("Saving merged text encoder to", _text_lora_path(output_path)) + torch.save( + out_list, + _text_lora_path(output_path), + ) elif mode == "upl": @@ -96,6 +110,7 @@ def add( shutil.rmtree(_tmp_output) else: + print("Unknown mode", mode) raise ValueError(f"Unknown mode {mode}") diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index d23b9b9..33deda4 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -13,9 +13,9 @@ class LoraInjectedLinear(nn.Module): def __init__(self, in_features, out_features, bias=False, r=4): super().__init__() - if r >= min(in_features, out_features): + if r > min(in_features, out_features): raise ValueError( - f"LoRA rank {r} must be less than {min(in_features, out_features)}" + f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" ) self.linear = nn.Linear(in_features, out_features, bias) @@ -138,7 +138,7 @@ def weight_apply_lora( def monkeypatch_lora( - model, loras, target_replace_module=["CrossAttention", "Attention"] + model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4 ): for _module in model.modules(): if _module.__class__.__name__ in target_replace_module: @@ -151,6 +151,7 @@ def monkeypatch_lora( _child_module.in_features, _child_module.out_features, _child_module.bias is not None, + r=r, ) _tmp.linear.weight = weight @@ -174,7 +175,7 @@ def monkeypatch_lora( def monkeypatch_replace_lora( - model, loras, target_replace_module=["CrossAttention", "Attention"] + model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4 ): for _module in model.modules(): if _module.__class__.__name__ in target_replace_module: @@ -187,6 +188,7 @@ def monkeypatch_replace_lora( _child_module.linear.in_features, _child_module.linear.out_features, _child_module.linear.bias is not None, + r=r, ) _tmp.linear.weight = weight diff --git a/scripts/lora_lr_effects.ipynb b/scripts/lora_lr_effects.ipynb new file mode 100644 index 0000000..518f7d4 --- /dev/null +++ b/scripts/lora_lr_effects.ipynb @@ -0,0 +1,853 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c18c3f7f367542a9ac893940cbd50f98", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 15 files: 0%| | 0/15 [00:00" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from diffusers import StableDiffusionPipeline, EulerAncestralDiscreteScheduler\n", + "import torch\n", + "\n", + "model_id = \"runwayml/stable-diffusion-v1-5\"\n", + "\n", + "pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(\n", + " \"cuda:1\"\n", + ")\n", + "pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)\n", + "\n", + "prompt = \"female 3d game character bnha, Skill magic geek inside matrix deepdream radiating a glowing aura stuff loot legends stylized digital illustration video game icon artstation lois van baarle, ilya kuvshinov, rossdraws\"\n", + "prompt = \"portrait of female 3d game character bnha, impressionist style from the 19th century, claude monet, oil painting\"\n", + "\n", + "torch.manual_seed(0)\n", + "image = pipe(prompt, num_inference_steps=50, guidance_scale=4.5).images[0]\n", + "\n", + "from lora_diffusion import monkeypatch_replace_lora, monkeypatch_lora, tune_lora_scale\n", + "\n", + "monkeypatch_lora(pipe.unet, torch.load(\"../output_example_text_1e-4/lora_weight.pt\"))\n", + "monkeypatch_lora(pipe.text_encoder, torch.load(\"../output_example_text_1e-4/lora_weight.text_encoder.pt\"))\n", + "\n", + "tune_lora_scale(pipe.unet, alpha=0.0)\n", + "tune_lora_scale(pipe.text_encoder, alpha=0.0)\n", + "image # nice, but that's the base model." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9d2ae94e54854a2aa43cc095fb6f7b76", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/50 [00:00" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from lora_diffusion import monkeypatch_replace_lora, monkeypatch_lora, tune_lora_scale\n", + "\n", + "tune_lora_scale(pipe.unet, 0.8)\n", + "tune_lora_scale(pipe.text_encoder, 0.8)\n", + "torch.manual_seed(0)\n", + "pipe(prompt, num_inference_steps=50, guidance_scale=4.5).images[0]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e2b9e67cafbb46caa632553b27f6904a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/50 [00:00" + "" ] }, - "execution_count": 4, + "execution_count": 1, "metadata": {}, "output_type": "execute_result" } @@ -69,25 +77,25 @@ "\n", "prompt = \"style of sks, robotic horse with rocket launcher\"\n", "torch.manual_seed(1)\n", - "image = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images[0]\n", + "image = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images[0]\n", "\n", "image\n" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "3adc01e02d1c4ac2baeda43dad883e48", + "model_id": "6cbeaa8a136b44978d1809ea19028c90", "version_major": 2, "version_minor": 0 }, "text/plain": [ - " 0%| | 0/38 [00:00" + "" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -113,7 +121,7 @@ "tune_lora_scale(pipe.unet, 1.00)\n", "\n", "torch.manual_seed(1)\n", - "image = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images[0]\n", + "image = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images[0]\n", "\n", "image\n" ] @@ -127,18 +135,18 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "28e6a7adc2a442b99d4c2c203c146f0c", + "model_id": "503505131f9d4a839f2defd296489a56", "version_major": 2, "version_minor": 0 }, "text/plain": [ - " 0%| | 0/38 [00:00" + "" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -160,25 +168,25 @@ "tune_lora_scale(pipe.unet, 1.5)\n", "\n", "torch.manual_seed(1)\n", - "image = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images[0]\n", + "image = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images[0]\n", "\n", "image" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "5c30582799304315af94ecc47caab4e2", + "model_id": "e805ec0652c8490db615dce85737af2b", "version_major": 2, "version_minor": 0 }, "text/plain": [ - " 0%| | 0/38 [00:00" + "" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -200,7 +208,7 @@ "tune_lora_scale(pipe.unet, 0.5)\n", "\n", "torch.manual_seed(1)\n", - "image = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images[0]\n", + "image = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images[0]\n", "\n", "image" ] diff --git a/setup.py b/setup.py index e4c36b8..5544615 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="lora_diffusion", py_modules=["lora_diffusion"], - version="0.0.4", + version="0.0.5", description="Low Rank Adaptation for Diffusion Models. Works with Stable Diffusion out-of-the-box.", author="Simo Ryu", packages=find_packages(),