diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md index 907f93d573a0..615af55ef5b5 100644 --- a/docs/source/en/tutorials/using_peft_for_inference.md +++ b/docs/source/en/tutorials/using_peft_for_inference.md @@ -75,6 +75,12 @@ image ![pixel-art](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_12_1.png) + + +By default, if the most up-to-date versions of PEFT and Transformers are detected, `low_cpu_mem_usage` is set to `True` to speed up the loading time of LoRA checkpoints. + + + ## Merge adapters You can also merge different adapter checkpoints for inference to blend their styles together. diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 8c8f2dfa84f8..2037bd787433 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -25,8 +25,11 @@ deprecate, get_adapter_name, get_peft_kwargs, + is_peft_available, is_peft_version, + is_torch_version, is_transformers_available, + is_transformers_version, logging, scale_lora_layers, ) @@ -39,6 +42,17 @@ ) +_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False +if is_torch_version(">=", "1.9.0"): + if ( + is_peft_available() + and is_peft_version(">=", "0.13.1") + and is_transformers_available() + and is_transformers_version(">", "4.45.2") + ): + _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True + + if is_transformers_available(): from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules @@ -83,15 +97,24 @@ def load_lora_weights( Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() @@ -109,6 +132,7 @@ def load_lora_weights( unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) self.load_lora_into_text_encoder( state_dict, @@ -119,6 +143,7 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) @classmethod @@ -237,7 +262,9 @@ def lora_state_dict( return state_dict, network_alphas @classmethod - def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): + def load_lora_into_unet( + cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -255,10 +282,16 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # their prefixes. @@ -268,7 +301,11 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None # Load the layers corresponding to UNet. logger.info(f"Loading {cls.unet_name}.") unet.load_attn_procs( - state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, ) @classmethod @@ -281,6 +318,7 @@ def load_lora_into_text_encoder( lora_scale=1.0, adapter_name=None, _pipeline=None, + low_cpu_mem_usage=False, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -303,10 +341,25 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + peft_kwargs = {} + if low_cpu_mem_usage: + if not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + if not is_transformers_version(">", "4.45.2"): + # Note from sayakpaul: It's not in `transformers` stable yet. + # https://github.com/huggingface/transformers/pull/33725/ + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." + ) + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + from peft import LoraConfig # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), @@ -377,6 +430,7 @@ def load_lora_into_text_encoder( adapter_name=adapter_name, adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config, + **peft_kwargs, ) # scale LoRA layers with `lora_scale` @@ -547,12 +601,19 @@ def load_lora_weights( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL # pipeline. @@ -573,7 +634,12 @@ def load_lora_weights( raise ValueError("Invalid LoRA checkpoint.") self.load_lora_into_unet( - state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self + state_dict, + network_alphas=network_alphas, + unet=self.unet, + adapter_name=adapter_name, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: @@ -585,6 +651,7 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} @@ -597,6 +664,7 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) @classmethod @@ -717,7 +785,9 @@ def lora_state_dict( @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet - def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None): + def load_lora_into_unet( + cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): """ This will load the LoRA layers specified in `state_dict` into `unet`. @@ -735,10 +805,16 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # their prefixes. @@ -748,7 +824,11 @@ def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None # Load the layers corresponding to UNet. logger.info(f"Loading {cls.unet_name}.") unet.load_attn_procs( - state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline + state_dict, + network_alphas=network_alphas, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, ) @classmethod @@ -762,6 +842,7 @@ def load_lora_into_text_encoder( lora_scale=1.0, adapter_name=None, _pipeline=None, + low_cpu_mem_usage=False, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -784,10 +865,25 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + peft_kwargs = {} + if low_cpu_mem_usage: + if not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + if not is_transformers_version(">", "4.45.2"): + # Note from sayakpaul: It's not in `transformers` stable yet. + # https://github.com/huggingface/transformers/pull/33725/ + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." + ) + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + from peft import LoraConfig # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), @@ -858,6 +954,7 @@ def load_lora_into_text_encoder( adapter_name=adapter_name, adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config, + **peft_kwargs, ) # scale LoRA layers with `lora_scale` @@ -1126,15 +1223,22 @@ def load_lora_weights( Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() @@ -1151,6 +1255,7 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} @@ -1163,6 +1268,7 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} @@ -1175,10 +1281,13 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) @classmethod - def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1192,7 +1301,13 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict keys = list(state_dict.keys()) @@ -1236,8 +1351,12 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, # otherwise loading LoRA weights will lead to an error is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + peft_kwargs = {} + if is_peft_version(">=", "0.13.1"): + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) if incompatible_keys is not None: # check only for unexpected keys @@ -1266,6 +1385,7 @@ def load_lora_into_text_encoder( lora_scale=1.0, adapter_name=None, _pipeline=None, + low_cpu_mem_usage=False, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1288,10 +1408,25 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + peft_kwargs = {} + if low_cpu_mem_usage: + if not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + if not is_transformers_version(">", "4.45.2"): + # Note from sayakpaul: It's not in `transformers` stable yet. + # https://github.com/huggingface/transformers/pull/33725/ + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." + ) + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + from peft import LoraConfig # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), @@ -1362,6 +1497,7 @@ def load_lora_into_text_encoder( adapter_name=adapter_name, adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config, + **peft_kwargs, ) # scale LoRA layers with `lora_scale` @@ -1667,10 +1803,17 @@ def load_lora_weights( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() @@ -1690,6 +1833,7 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} @@ -1702,10 +1846,13 @@ def load_lora_weights( lora_scale=self.lora_scale, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) @classmethod - def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None): + def load_lora_into_transformer( + cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -1723,7 +1870,13 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: """ + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict keys = list(state_dict.keys()) @@ -1772,8 +1925,12 @@ def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, ada # otherwise loading LoRA weights will lead to an error is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + peft_kwargs = {} + if is_peft_version(">=", "0.13.1"): + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) if incompatible_keys is not None: # check only for unexpected keys @@ -1802,6 +1959,7 @@ def load_lora_into_text_encoder( lora_scale=1.0, adapter_name=None, _pipeline=None, + low_cpu_mem_usage=False, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1824,10 +1982,25 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + peft_kwargs = {} + if low_cpu_mem_usage: + if not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + if not is_transformers_version(">", "4.45.2"): + # Note from sayakpaul: It's not in `transformers` stable yet. + # https://github.com/huggingface/transformers/pull/33725/ + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." + ) + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + from peft import LoraConfig # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), @@ -1898,6 +2071,7 @@ def load_lora_into_text_encoder( adapter_name=adapter_name, adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config, + **peft_kwargs, ) # scale LoRA layers with `lora_scale` @@ -2132,6 +2306,7 @@ def load_lora_into_text_encoder( lora_scale=1.0, adapter_name=None, _pipeline=None, + low_cpu_mem_usage=False, ): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -2154,10 +2329,25 @@ def load_lora_into_text_encoder( adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + peft_kwargs = {} + if low_cpu_mem_usage: + if not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + if not is_transformers_version(">", "4.45.2"): + # Note from sayakpaul: It's not in `transformers` stable yet. + # https://github.com/huggingface/transformers/pull/33725/ + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." + ) + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + from peft import LoraConfig # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), @@ -2228,6 +2418,7 @@ def load_lora_into_text_encoder( adapter_name=adapter_name, adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config, + **peft_kwargs, ) # scale LoRA layers with `lora_scale` @@ -2416,15 +2607,22 @@ def load_lora_weights( Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() @@ -2441,11 +2639,14 @@ def load_lora_weights( transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer - def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None): + def load_lora_into_transformer( + cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False + ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. @@ -2459,7 +2660,13 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict keys = list(state_dict.keys()) @@ -2503,8 +2710,12 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, # otherwise loading LoRA weights will lead to an error is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) - incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + peft_kwargs = {} + if is_peft_version(">=", "0.13.1"): + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs) if incompatible_keys is not None: # check only for unexpected keys diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 32ace77b6224..eaac52df6202 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -115,6 +115,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict `default_{i}` where i is the total number of adapters being loaded. weight_name (`str`, *optional*, defaults to None): Name of the serialized state dict file. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. Example: @@ -142,8 +145,14 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict adapter_name = kwargs.pop("adapter_name", None) _pipeline = kwargs.pop("_pipeline", None) network_alphas = kwargs.pop("network_alphas", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) allow_pickle = False + if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + if use_safetensors is None: use_safetensors = True allow_pickle = True @@ -209,6 +218,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, ) else: raise ValueError( @@ -268,7 +278,9 @@ def _process_custom_diffusion(self, state_dict): return attn_processors - def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline): + def _process_lora( + self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, low_cpu_mem_usage + ): # This method does the following things: # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy # format. For legacy format no filtering is applied. @@ -335,9 +347,12 @@ def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks # otherwise loading LoRA weights will lead to an error is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline) + peft_kwargs = {} + if is_peft_version(">=", "0.13.1"): + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - inject_adapter_in_model(lora_config, self, adapter_name=adapter_name) - incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name) + inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs) + incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs) if incompatible_keys is not None: # check only for unexpected keys diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 7dc3f414d55c..a2f283d0c4f5 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -388,6 +388,24 @@ def decorator(test_case): return decorator +def require_transformers_version_greater(transformers_version): + """ + Decorator marking a test that requires transformers with a specific version, this would require some specific + versions of PEFT and transformers. + """ + + def decorator(test_case): + correct_transformers_version = is_transformers_available() and version.parse( + version.parse(importlib.metadata.version("transformers")).base_version + ) > version.parse(transformers_version) + return unittest.skipUnless( + correct_transformers_version, + f"test requires transformers with the version greater than {transformers_version}", + )(test_case) + + return decorator + + def require_accelerate_version_greater(accelerate_version): def decorator(test_case): correct_accelerate_version = is_peft_available() and version.parse( diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 5def867324f4..9c982e8de37f 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -32,13 +32,14 @@ floats_tensor, require_peft_backend, require_peft_version_greater, + require_transformers_version_greater, skip_mps, torch_device, ) if is_peft_available(): - from peft import LoraConfig + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer from peft.utils import get_peft_model_state_dict @@ -65,6 +66,12 @@ def check_if_lora_correctly_set(model) -> bool: return False +def initialize_dummy_state_dict(state_dict): + if not all(v.device.type == "meta" for _, v in state_dict.items()): + raise ValueError("`state_dict` has non-meta values.") + return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()} + + @require_peft_backend class PeftLoraLoaderMixinTests: pipeline_class = None @@ -272,6 +279,136 @@ def test_simple_inference_with_text_lora(self): not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" ) + @require_peft_version_greater("0.13.1") + def test_low_cpu_mem_usage_with_injection(self): + """Tests if we can inject LoRA state dict with low_cpu_mem_usage.""" + for scheduler_cls in self.scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder." + ) + self.assertTrue( + "meta" in {p.device.type for p in pipe.text_encoder.parameters()}, + "The LoRA params should be on 'meta' device.", + ) + + te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder)) + set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True) + self.assertTrue( + "meta" not in {p.device.type for p in pipe.text_encoder.parameters()}, + "No param should be on 'meta' device.", + ) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + self.assertTrue( + "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device." + ) + + denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser)) + set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True) + self.assertTrue( + "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device." + ) + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + self.assertTrue( + "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()}, + "The LoRA params should be on 'meta' device.", + ) + + te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2)) + set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True) + self.assertTrue( + "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()}, + "No param should be on 'meta' device.", + ) + + _, _, inputs = self.get_dummy_inputs() + output_lora = pipe(**inputs)[0] + self.assertTrue(output_lora.shape == self.output_shape) + + @require_peft_version_greater("0.13.1") + @require_transformers_version_greater("4.45.1") + def test_low_cpu_mem_usage_with_loading(self): + """Tests if we can load LoRA state dict with low_cpu_mem_usage.""" + + for scheduler_cls in self.scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(output_no_lora.shape == self.output_shape) + + if "text_encoder" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config) + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + if self.has_two_text_encoders or self.has_three_text_encoders: + if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + + with tempfile.TemporaryDirectory() as tmpdirname: + modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) + lora_state_dicts = self._get_lora_state_dicts(modules_to_save) + self.pipeline_class.save_lora_weights( + save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts + ) + + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False) + + for module_name, module in modules_to_save.items(): + self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") + + images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue( + np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), + "Loading from saved checkpoints should give same results.", + ) + + # Now, check for `low_cpu_mem_usage.` + pipe.unload_lora_weights() + pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True) + + for module_name, module in modules_to_save.items(): + self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") + + images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue( + np.allclose( + images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3 + ), + "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.", + ) + def test_simple_inference_with_text_lora_and_scale(self): """ Tests a simple inference with lora attached on the text encoder + scale argument