Skip to content

Commit

Permalink
docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
williamberman committed Jul 1, 2023
1 parent 651cc93 commit 5268ef3
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,26 @@ class LoraLoaderMixin:
unet_name = UNET_NAME

def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into self.unet and self.text_encoder.
All kwargs are forwarded to `self.lora_state_dict`.
See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
See [`~loaders.LoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into
`self.unet`.
See [`~loaders.LoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded
into `self.text_encoder`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
kwargs:
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
"""
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
self.load_lora_into_text_encoder(
Expand Down Expand Up @@ -990,6 +1010,20 @@ def lora_state_dict(

@classmethod
def load_lora_into_unet(cls, state_dict, network_alpha, unet):
"""
This will load the LoRA layers specified in `state_dict` into `unet`
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
network_alpha (`float`):
See `LoRALinearLayer` for more details.
unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into.
"""

# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
Expand All @@ -1014,6 +1048,22 @@ def load_lora_into_unet(cls, state_dict, network_alpha, unet):

@classmethod
def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The key shoult be prefixed with an
additional `text_encoder` to distinguish between unet lora layers.
network_alpha (`float`):
See `LoRALinearLayer` for more details.
text_encoder (`CLIPTextModel`):
The text encoder model to load the LoRA layers into.
lora_scale (`float`):
How much to scale the output of the lora linear layer before it is added with the output of the regular
lora layer.
"""

# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
Expand Down

0 comments on commit 5268ef3

Please sign in to comment.