Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor LoRA #3778

Merged
merged 4 commits into from
Jul 9, 2023
Merged

Conversation

williamberman
Copy link
Contributor

@williamberman williamberman commented Jun 13, 2023

this refactors the text encoder LoRA code to support monkey patching T5 as well as clip

tl;dr

We mainly simplify the _modify_text_encoder method to create the new lora layers itself instead of taking them as inputs. It also now returns a list of the new parameters so we don't have to construct a separate object with the parameters to be passed to the optimizer in the training script.

We also remove the use of the AttnProcsLayers class from the training script. We can instead directly pass the parameters to the optimizer and rely on the standard state dict loading functions. This lets us not have to update its state dict renaming code to be compatible with T5

We also update the checkpoint saving and loading code to directly call into the mixin without instantiating dummy pipelines. Most of the functions called don't require any access to instance state and can instead just be moved to class methods

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 13, 2023

The documentation is not available anymore as the PR was closed or merged.

@williamberman
Copy link
Contributor Author

@takuma104 would be super helpful if you could take a look here :) I tried requesting you as reviewer but it wouldn't let me

src/diffusers/loaders.py Outdated Show resolved Hide resolved
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@williamberman

I think this is good start and I can see how it removes some anti-patterns in the existing codebase and introduces flexibility especially how you are dealing with PatchedLoraLinear.

Left some comments. Let me know if anything is unclear.

@takuma104
Copy link
Contributor

@williamberman Looks great for me! This is an excellent refactoring job. I learned a lot from it. Just to be sure, I applied Kohya-lora and tried generating images, but there were absolutely no issues.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to take a deeper look here later today

@williamberman
Copy link
Contributor Author

@williamberman Looks great for me! This is an excellent refactoring job. I learned a lot from it. Just to be sure, I applied Kohya-lora and tried generating images, but there were absolutely no issues.

Awesome super helpful! Any other LoRA side cases I should try?

@williamberman
Copy link
Contributor Author

I added some changes which remove the use of AttnProcLayers within the lora dreambooth script. It's not clear to me how to make it compatible with the T5 patching and I think we can probably handle saving and loading a subset of the weights by just being more explicit about which keys in the state dict we're storing

Comment on lines 896 to 893
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = model.attn_processors_state_dict
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we pass the whole model to accelerator.prepare, we now get the whole model in the hook. Additionally, I think this is a bit clearer to understand that we just map the model to a subset of its state dict

@patrickvonplaten patrickvonplaten changed the title refactor to support patching LoRA into T5 Refactor LoRA Jun 19, 2023
@patrickvonplaten
Copy link
Contributor

Looks very nice to me! Would prefer to add this additional property to the unet though as I think it's a one, two liner in Python to do it with existing methods (no?)

Apart from this the PR looks very nice to me - thanks a lot for the refactoring! Can we maybe run one quick LoRA + text encoder training as well to make sure nothing broke?

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks amazing!

If someone wanted to have support for another text encoder LoRA fine-tuning, I think the sequence of modifications needed is quite clear from this PR.

I have run the tests too (inference and DreamBooth training with SD) and they look super!

🤟

]
images = []
for _ in range(args.num_validation_images):
with torch.cuda.amp.autocast():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why these changes? We try to avoid running pipelines in autocast if possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I copy and pasted from the regular dreambooth training script which runs the validation inference under autocast. Will remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I might have added this because when we load the rest of the model in fp16, we keep the lora weights in fp32 and needed autocast for them to work together

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, looking through the commit history, that's why I added it. Is that ok?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For running inference validation in intervals, I think keeping autocast is okay as it helps keep things simple.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why didn't we need it before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok so I was wrong that we needed it because of the difference in dtype in the lora weights. The lora layer casts manually internally.

The issue was the dtype of the output of the unet being passed to the vae.

The difference is that in the main version of the script, the unet is not wrapped in amp, so the output of the unet during validation is the same dtype as the unet, fp16.

In the branch, the full unet is wrapped in amp so even though the unet is loaded in fp16, the output is fp32 and then there's an error when the fp32 ending latents are passed to the fp16 vae.

Copy link
Contributor Author

@williamberman williamberman Jul 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The alternative to wrapping this section in amp is to put a check either in the pipeline or the beginning of the vae for the dtype of the initial latents and manually cast them if necessary. I like to avoid manual casts like that in case the caller expects the execution in the dtype of their input (which is a reasonable assumption imo). So I would prefer to leave the amp decorator in this case

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still some open questions here we need to double check:

  • Does autocasting the whole unet not increase memory usage? If it does not increase memory usage, ok to change for me, but we need to test
  • Why add torch.autocast to the inference calls of the pipeline?
  • Docstrings are missing
  • We need to make sure _remove_monkey_patch is correctly deprecated

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR looks pretty nice to me. I ran the inference tests and they look alright: https://colab.research.google.com/gist/sayakpaul/6c835c5c42bad3776697692f4691a53b/scratchpad.ipynb.

But would be cool to do a memory usage comparison (maybe for a small number of steps) to check if passing entire models to accelerate.prepare() does anything unexpected.

@patrickvonplaten
Copy link
Contributor

Sorry still unclear on #3778 (comment) - I want to make sure that memory usage of training before this PR is <= memory usage of training after this PR - can we test this quickly?

We can run this code: https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#training-with-low-rank-adaptation-of-large-language-models-lora with different GPU levels (6GB, 8GB, 10GB) to see when it errors out.

You can limit your GPU memory with https://pytorch.org/docs/stable/generated/torch.cuda.set_per_process_memory_fraction.html#torch-cuda-set-per-process-memory-fraction

Then we can do the same with the new code in this PR. Just want to be sure we don't pay in GPU memory requirement here.

Also this still needs to be added no?

instantiate the lora linear layer on the same device as the regular linear layer

get lora rank from state dict

tests

fmt

can create lora layer in float32 even when rest of model is float16

fix loading model hook

remove load_lora_weights_ and T5 dispatching

remove Unet#attn_processors_state_dict

docstrings
@patrickvonplaten patrickvonplaten merged commit c2a28c3 into huggingface:main Jul 9, 2023
@okotaku okotaku mentioned this pull request Jul 15, 2023
6 tasks
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* refactor to support patching LoRA into T5

instantiate the lora linear layer on the same device as the regular linear layer

get lora rank from state dict

tests

fmt

can create lora layer in float32 even when rest of model is float16

fix loading model hook

remove load_lora_weights_ and T5 dispatching

remove Unet#attn_processors_state_dict

docstrings

* text encoder monkeypatch class method

* fix test

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* refactor to support patching LoRA into T5

instantiate the lora linear layer on the same device as the regular linear layer

get lora rank from state dict

tests

fmt

can create lora layer in float32 even when rest of model is float16

fix loading model hook

remove load_lora_weights_ and T5 dispatching

remove Unet#attn_processors_state_dict

docstrings

* text encoder monkeypatch class method

* fix test

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants