-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Lora] Seperate logic #5809
[Lora] Seperate logic #5809
Conversation
@@ -57,25 +66,6 @@ def text_encoder_mlp_modules(text_encoder): | |||
return mlp_modules | |||
|
|||
|
|||
def text_encoder_lora_state_dict(text_encoder): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is only used in training scripts, let's remove it from here. It'll make it easier to later completely remove once the training scripts are refactored.
@@ -66,6 +67,35 @@ | |||
LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future." | |||
|
|||
|
|||
def text_encoder_attn_modules(text_encoder): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need those for PEFT for now, so to be able to easily delete src/diffusers/models/lora.py
let's move it here.
@@ -41,7 +41,7 @@ | |||
import diffusers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing example refactor from this PR: #5331
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All in for sweeping off 🧹
Thanks!
state_dict = {} | ||
|
||
def text_encoder_attn_modules(text_encoder): | ||
from transformers import CLIPTextModel, CLIPTextModelWithProjection |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed? We can directly import it at the top.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So that we can remove text_encoder_attn_modules
from src/diffusers
@@ -1339,7 +1368,7 @@ def process_weights(adapter_names, weights): | |||
) | |||
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights) | |||
|
|||
def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): | |||
def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): # noqa: F821 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the noqa?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some linters complain about PreTrainedModel
not being present
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, looks like a good step in the direction of separating the logic. Some minor comments, but no blockers from my side.
src/diffusers/loaders/__init__.py
Outdated
@@ -8,7 +8,7 @@ def text_encoder_lora_state_dict(text_encoder): | |||
deprecate( | |||
"text_encoder_load_state_dict in `models`", | |||
"0.27.0", | |||
"`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.", | |||
"`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights with PEFT.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add the info how the weights can be retrieved with PEFT (or a reference).
# Parse the attention module. | ||
attn_module = unet | ||
for n in attn_processor_name.split(".")[:-1]: | ||
attn_module = getattr(attn_module, n) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
operator.attrgetter
can be used to replace this pattern, but looping explicitly is also fine.
* [Lora] Seperate logic * [Lora] Seperate logic * [Lora] Seperate logic * add comments to explain the code better * add comments to explain the code better
* [Lora] Seperate logic * [Lora] Seperate logic * [Lora] Seperate logic * add comments to explain the code better * add comments to explain the code better
* [Lora] Seperate logic * [Lora] Seperate logic * [Lora] Seperate logic * add comments to explain the code better * add comments to explain the code better
What does this PR do?
This PR separates old and new (PEFT) lora logic better to make it easier to remove the old LoRA logic.