Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading A LoRa into NF4 Quantized Flux Fill Pipeline Gives an Error #10612

Open
hamzaakyildiz opened this issue Jan 20, 2025 · 1 comment
Open
Labels
bug Something isn't working

Comments

@hamzaakyildiz
Copy link

Describe the bug

When i try to load a lora, such as alimama-creative/FLUX.1-Turbo-Alpha, into nf4 quantized flux fill pipeline it gives an error

Reproduction

from diffusers import FluxPipeline,FluxPriorReduxPipeline, FluxFillPipeline, FluxTransformer2DModel
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
import torch

dtype = torch.bfloat16

quant_config = DiffusersBitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=dtype
)

transformer = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/FLUX.1-Fill-dev",
    subfolder="transformer",
    quantization_config=quant_config,
    torch_dtype=dtype,
)

pipeline = FluxFillPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-Fill-dev",
    transformer=transformer,
    torch_dtype=dtype,
).to("cuda")

pipeline.load_lora_weights("alimama-creative/FLUX.1-Turbo-Alpha")

Logs

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[5], line 1
----> 1 pipeline.load_lora_weights("alimama-creative/FLUX.1-Turbo-Alpha", adapter_name=f"lora_")

File ~/.pyenv/versions/3.10.0/envs/jupyter/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py:1550, in FluxLoraLoaderMixin.load_lora_weights(self, pretrained_model_name_or_path_or_dict, adapter_name, **kwargs)
   1543 transformer_norm_state_dict = {
   1544     k: state_dict.pop(k)
   1545     for k in list(state_dict.keys())
   1546     if "transformer." in k and any(norm_key in k for norm_key in self._control_lora_supported_norm_keys)
   1547 }
   1549 transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
-> 1550 has_param_with_expanded_shape = self._maybe_expand_transformer_param_shape_or_error_(
   1551     transformer, transformer_lora_state_dict, transformer_norm_state_dict
   1552 )
   1554 if has_param_with_expanded_shape:
   1555     logger.info(
   1556         "The LoRA weights contain parameters that have different shapes that expected by the transformer. "
   1557         "As a result, the state_dict of the transformer has been expanded to match the LoRA parameter shapes. "
   1558         "To get a comprehensive list of parameter names that were modified, enable debug logging."
   1559     )

File ~/.pyenv/versions/3.10.0/envs/jupyter/lib/python3.10/site-packages/diffusers/loaders/lora_pipeline.py:2020, in FluxLoraLoaderMixin._maybe_expand_transformer_param_shape_or_error_(cls, transformer, lora_state_dict, norm_state_dict, prefix)
   2017 parent_module = transformer.get_submodule(parent_module_name)
   2019 with torch.device("meta"):
-> 2020     expanded_module = torch.nn.Linear(
   2021         in_features, out_features, bias=bias, dtype=module_weight.dtype
   2022     )
   2023 # Only weights are expanded and biases are not. This is because only the input dimensions
   2024 # are changed while the output dimensions remain the same. The shape of the weight tensor
   2025 # is (out_features, in_features), while the shape of bias tensor is (out_features,), which
   2026 # explains the reason why only weights are expanded.
   2027 new_weight = torch.zeros_like(
   2028     expanded_module.weight.data, device=module_weight.device, dtype=module_weight.dtype
   2029 )

File ~/.pyenv/versions/3.10.0/envs/jupyter/lib/python3.10/site-packages/torch/nn/modules/linear.py:105, in Linear.__init__(self, in_features, out_features, bias, device, dtype)
    103 self.in_features = in_features
    104 self.out_features = out_features
--> 105 self.weight = Parameter(
    106     torch.empty((out_features, in_features), **factory_kwargs)
    107 )
    108 if bias:
    109     self.bias = Parameter(torch.empty(out_features, **factory_kwargs))

File ~/.pyenv/versions/3.10.0/envs/jupyter/lib/python3.10/site-packages/torch/nn/parameter.py:46, in Parameter.__new__(cls, data, requires_grad)
     42     data = torch.empty(0)
     43 if type(data) is torch.Tensor or type(data) is Parameter:
     44     # For ease of BC maintenance, keep this path for standard Tensor.
     45     # Eventually (tm), we should change the behavior for standard Tensor to match.
---> 46     return torch.Tensor._make_subclass(cls, data, requires_grad)
     48 # Path for custom tensors: set a flag on the instance to indicate parameter-ness.
     49 t = data.detach().requires_grad_(requires_grad)

RuntimeError: Only Tensors of floating point and complex dtype can require gradients

System Info

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • 🤗 Diffusers version: 0.32.2
  • Platform: Linux-6.8.0-1019-aws-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.0
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.27.1
  • Transformers version: 4.47.1
  • Accelerate version: 1.2.1
  • PEFT version: 0.14.0
  • Bitsandbytes version: 0.45.0
  • Safetensors version: 0.5.2
  • xFormers version: not installed

Who can help?

@sayakpaul

@hamzaakyildiz hamzaakyildiz added the bug Something isn't working label Jan 20, 2025
@hamzaakyildiz hamzaakyildiz changed the title Loading A LoRa into NF4 Quantized Flux Fill model Gives an Error Loading A LoRa into NF4 Quantized Flux Fill Pipeline Gives an Error Jan 20, 2025
@sayakpaul
Copy link
Member

#10588

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants