-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor LoRA #3778
Refactor LoRA #3778
Conversation
The documentation is not available anymore as the PR was closed or merged. |
fde2bf8
to
ff70b2c
Compare
@takuma104 would be super helpful if you could take a look here :) I tried requesting you as reviewer but it wouldn't let me |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is good start and I can see how it removes some anti-patterns in the existing codebase and introduces flexibility especially how you are dealing with PatchedLoraLinear
.
Left some comments. Let me know if anything is unclear.
@williamberman Looks great for me! This is an excellent refactoring job. I learned a lot from it. Just to be sure, I applied Kohya-lora and tried generating images, but there were absolutely no issues. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to take a deeper look here later today
Awesome super helpful! Any other LoRA side cases I should try? |
ff70b2c
to
8b149ac
Compare
I added some changes which remove the use of AttnProcLayers within the lora dreambooth script. It's not clear to me how to make it compatible with the T5 patching and I think we can probably handle saving and loading a subset of the weights by just being more explicit about which keys in the state dict we're storing |
if isinstance(model, type(accelerator.unwrap_model(unet))): | ||
unet_lora_layers_to_save = model.attn_processors_state_dict | ||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): | ||
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model) | ||
else: | ||
raise ValueError(f"unexpected save model: {model.__class__}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we pass the whole model to accelerator.prepare, we now get the whole model in the hook. Additionally, I think this is a bit clearer to understand that we just map the model to a subset of its state dict
8b149ac
to
1828f82
Compare
Looks very nice to me! Would prefer to add this additional property to the unet though as I think it's a one, two liner in Python to do it with existing methods (no?) Apart from this the PR looks very nice to me - thanks a lot for the refactoring! Can we maybe run one quick LoRA + text encoder training as well to make sure nothing broke? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks amazing!
If someone wanted to have support for another text encoder LoRA fine-tuning, I think the sequence of modifications needed is quite clear from this PR.
I have run the tests too (inference and DreamBooth training with SD) and they look super!
🤟
] | ||
images = [] | ||
for _ in range(args.num_validation_images): | ||
with torch.cuda.amp.autocast(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why these changes? We try to avoid running pipelines in autocast if possible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I copy and pasted from the regular dreambooth training script which runs the validation inference under autocast. Will remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I might have added this because when we load the rest of the model in fp16, we keep the lora weights in fp32 and needed autocast for them to work together
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, looking through the commit history, that's why I added it. Is that ok?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For running inference validation in intervals, I think keeping autocast is okay as it helps keep things simple.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why didn't we need it before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok so I was wrong that we needed it because of the difference in dtype in the lora weights. The lora layer casts manually internally.
The issue was the dtype of the output of the unet being passed to the vae.
The difference is that in the main
version of the script, the unet is not wrapped in amp, so the output of the unet during validation is the same dtype as the unet, fp16.
In the branch, the full unet is wrapped in amp so even though the unet is loaded in fp16, the output is fp32 and then there's an error when the fp32 ending latents are passed to the fp16 vae.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alternative to wrapping this section in amp is to put a check either in the pipeline or the beginning of the vae for the dtype of the initial latents and manually cast them if necessary. I like to avoid manual casts like that in case the caller expects the execution in the dtype of their input (which is a reasonable assumption imo). So I would prefer to leave the amp decorator in this case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still some open questions here we need to double check:
- Does autocasting the whole unet not increase memory usage? If it does not increase memory usage, ok to change for me, but we need to test
- Why add
torch.autocast
to the inference calls of the pipeline? - Docstrings are missing
- We need to make sure
_remove_monkey_patch
is correctly deprecated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR looks pretty nice to me. I ran the inference tests and they look alright: https://colab.research.google.com/gist/sayakpaul/6c835c5c42bad3776697692f4691a53b/scratchpad.ipynb.
But would be cool to do a memory usage comparison (maybe for a small number of steps) to check if passing entire models to accelerate.prepare()
does anything unexpected.
Sorry still unclear on #3778 (comment) - I want to make sure that memory usage of training before this PR is <= memory usage of training after this PR - can we test this quickly? We can run this code: https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#training-with-low-rank-adaptation-of-large-language-models-lora with different GPU levels (6GB, 8GB, 10GB) to see when it errors out. You can limit your GPU memory with https://pytorch.org/docs/stable/generated/torch.cuda.set_per_process_memory_fraction.html#torch-cuda-set-per-process-memory-fraction Then we can do the same with the new code in this PR. Just want to be sure we don't pay in GPU memory requirement here. Also this still needs to be added no? |
instantiate the lora linear layer on the same device as the regular linear layer get lora rank from state dict tests fmt can create lora layer in float32 even when rest of model is float16 fix loading model hook remove load_lora_weights_ and T5 dispatching remove Unet#attn_processors_state_dict docstrings
5268ef3
to
99e35de
Compare
* refactor to support patching LoRA into T5 instantiate the lora linear layer on the same device as the regular linear layer get lora rank from state dict tests fmt can create lora layer in float32 even when rest of model is float16 fix loading model hook remove load_lora_weights_ and T5 dispatching remove Unet#attn_processors_state_dict docstrings * text encoder monkeypatch class method * fix test --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* refactor to support patching LoRA into T5 instantiate the lora linear layer on the same device as the regular linear layer get lora rank from state dict tests fmt can create lora layer in float32 even when rest of model is float16 fix loading model hook remove load_lora_weights_ and T5 dispatching remove Unet#attn_processors_state_dict docstrings * text encoder monkeypatch class method * fix test --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
this refactors the text encoder LoRA code to support monkey patching T5 as well as clip
tl;dr
We mainly simplify the
_modify_text_encoder
method to create the new lora layers itself instead of taking them as inputs. It also now returns a list of the new parameters so we don't have to construct a separate object with the parameters to be passed to the optimizer in the training script.We also remove the use of the
AttnProcsLayers
class from the training script. We can instead directly pass the parameters to the optimizer and rely on the standard state dict loading functions. This lets us not have to update its state dict renaming code to be compatible with T5We also update the checkpoint saving and loading code to directly call into the mixin without instantiating dummy pipelines. Most of the functions called don't require any access to instance state and can instead just be moved to class methods