From 5268ef3222d9fd139bc0893e41cec1a883098091 Mon Sep 17 00:00:00 2001 From: William Berman Date: Fri, 30 Jun 2023 17:31:28 -0700 Subject: [PATCH] docstrings --- src/diffusers/loaders.py | 50 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 12450322c076c..ee25a209d24c4 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -838,6 +838,26 @@ class LoraLoaderMixin: unet_name = UNET_NAME def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into self.unet and self.text_encoder. + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + + See [`~loaders.LoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into + `self.unet`. + + See [`~loaders.LoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded + into `self.text_encoder`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.LoraLoaderMixin.lora_state_dict`]. + + kwargs: + See [`~loaders.LoraLoaderMixin.lora_state_dict`]. + """ state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet) self.load_lora_into_text_encoder( @@ -990,6 +1010,20 @@ def lora_state_dict( @classmethod def load_lora_into_unet(cls, state_dict, network_alpha, unet): + """ + This will load the LoRA layers specified in `state_dict` into `unet` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + network_alpha (`float`): + See `LoRALinearLayer` for more details. + unet (`UNet2DConditionModel`): + The UNet model to load the LoRA layers into. + """ + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. @@ -1014,6 +1048,22 @@ def load_lora_into_unet(cls, state_dict, network_alpha, unet): @classmethod def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key shoult be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alpha (`float`): + See `LoRALinearLayer` for more details. + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + """ + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes.