Skip to content

Commit

Permalink
hotfix : patching
Browse files Browse the repository at this point in the history
  • Loading branch information
cloneofsimo committed Dec 17, 2022
1 parent 540894d commit b6b9986
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,13 +302,16 @@ def load_learned_embed_in_clip(
def patch_pipe(
pipe,
unet_path,
token,
alpha: float = 1.0,
token: str,
r: int = 4,
patch_unet=True,
patch_text=False,
patch_ti=False,
idempotent_token=True,
):
assert (
len(token) > 0
), "Token cannot be empty. Input token non-empty token like <s>."

ti_path = _ti_lora_path(unet_path)
text_path = _text_lora_path(unet_path)
Expand All @@ -323,13 +326,16 @@ def patch_pipe(
for _module in pipe.text_encoder.modules():
if _module.__class__.__name__ == "LoraInjectedLinear":
text_encoder_has_lora = True
if patch_unet:
print("LoRA : Patching Unet")

if not unet_has_lora:
monkeypatch_lora(pipe.unet, torch.load(unet_path), r=r)
else:
monkeypatch_replace_lora(pipe.unet, torch.load(unet_path), r=r)
if not unet_has_lora:
monkeypatch_lora(pipe.unet, torch.load(unet_path), r=r)
else:
monkeypatch_replace_lora(pipe.unet, torch.load(unet_path), r=r)

if patch_text:
print("LoRA : Patching text encoder")
if not text_encoder_has_lora:
monkeypatch_lora(
pipe.text_encoder,
Expand All @@ -346,6 +352,7 @@ def patch_pipe(
r=r,
)
if patch_ti:
print("LoRA : Patching token input")
token = load_learned_embed_in_clip(
ti_path,
pipe.text_encoder,
Expand Down

0 comments on commit b6b9986

Please sign in to comment.