Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change dtype of output after passing through lora_A #1172

Closed
4 tasks
huseyinatahaninan opened this issue Nov 23, 2023 · 8 comments
Closed
4 tasks

change dtype of output after passing through lora_A #1172

huseyinatahaninan opened this issue Nov 23, 2023 · 8 comments

Comments

@huseyinatahaninan
Copy link

System Info

peft 0.6.2

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

I am not sure if this is really a bug but my question is that after passing the input x through lora_A, should we cast it to the lora_B.weight.dtype like we do it for lora_A.weight.dtype in the first place?

I am talking about this line:

output = lora_B(lora_A(dropout(x)))

instead of output = lora_B(lora_A(dropout(x))) I was thinking if the following should be done output = lora_B(lora_A(dropout(x)).to(lora_B.weight.dtype)) because otherwise for instance in mixed precision training x becomes fp32 but then after passing through lora_A, it becomes bf16 as the input to lora_B. So I was thinking whether we should cast it back to fp32.

Thanks very much for your help in advance!

Expected behavior

na

@khalil-Hennara
Copy link

khalil-Hennara commented Nov 30, 2023

I face the same thing, what ever the model dtype is it change to fb32. if you change the dtype of the model after you create it would only change the dtype of the LoRA parameter, not the base model.

`model_id = "BAAI/bge-small-en-v1.5"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id,torch_dtype=torch.bfloat16)

def create_peft_config(model):
from peft import (
    get_peft_model,
    LoraConfig,
    TaskType,
    prepare_model_for_kbit_training,)

peft_config = LoraConfig(
    task_type=TaskType.FEATURE_EXTRACTION,
    inference_mode=False,
    r=128,
    lora_alpha=12,
    lora_dropout=0.03,
    target_modules = ["query", "value"]
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
return model, peft_config
model, lora_config = create_peft_config(model)

print(model.dtype) # the output is torch.float32 `

if we change the dtype of the model like this

model.dtype = torch.bfloat16
the base model still has dtype float32

@khalil-Hennara
Copy link

I would like to work on this issue

@baoleai
Copy link

baoleai commented Dec 8, 2023

There is a fix #1010

@huseyinatahaninan
Copy link
Author

Thanks @baoleai for pointing this out! Unfortunately it does not resolve my issue because I still see that result += lora_B(lora_A(dropout(x))) * scaling and after passing x through lora_A, it becomes bf16 as the input to lora_B. On the other hand, I'd like to make sure that the output of lora_A to be cast to the lora_B.weight.dtype.

@BenjaminBossan
Copy link
Member

@huseyinatahaninan For us to reproduce, do you have a small code snippet that shows this issue cropping up?

@huseyinatahaninan
Copy link
Author

Thanks @BenjaminBossan, actually I am not 100% sure if this is an issue really. Maybe it's just the way it's supposed to be. Consider this minimal example:

import torch
from peft import get_peft_model, LoraConfig
from transformers import AutoModel

if __name__ == "__main__":
    model = AutoModel.from_pretrained("distilgpt2")
    lora_cfg = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.05, bias="lora_only", task_type="lm")
    model = get_peft_model(model, lora_cfg)
    model = model.cuda()

    # prepare dummy inputs
    input_ids, labels = torch.randint(0, 1000, (2, 4)).cuda(), torch.randint(0, 1000, (4,)).cuda()

    # forward pass
    with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
        outputs = model(input_ids)
        loss, logits = outputs[:2]

If you add a breakpoint to

result += lora_B(lora_A(dropout(x))) * scaling
you'll see that lora_A(dropout(x)) gives dtype=torch.bfloat16 and that's the input to lora_B. I was wondering if we should also modify the dtype after passing through lora_A similar to what happens a line before x = x.to(lora_A.weight.dtype).

Essentially something like the following:

y = lora_A(dropout(x))
y = y.to(lora_B.weight.dtype)
result += lora_B(y) * scaling

The reason that I am asking this has to do with differentially private training and I can explain why I'd like to give input to lora_B as torch.bfloat32 instead of torch.bfloat16 if you'd like but that's a bit out of scope :) I have different ways to handle this so it's all good but just wanted to quickly double check with you if this behavior is intended because I do see that x.dtype is modified for lora_A layer so it makes me question why it's not modified for lora_B layer.

@BenjaminBossan
Copy link
Member

Thanks for explaining @huseyinatahaninan. I think that if we're inside an autocast context, it is not totally unexpected that the input dtype would be the one indicated in the context. Although I agree that this could be considered inconsistent when comparing lora_A and lora_B, I'm not sure if your use case is common enough to justify making the suggested change.

I have different ways to handle this so it's all good

Please let us know if that no longer works with PEFT, otherwise I'd say let's keep it as is.

differentially private training

Nice, happy to see more work being done on this with LoRA.

@huseyinatahaninan
Copy link
Author

That sounds great, thanks very much @BenjaminBossan, appreciate the discussion and this awesome repo :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants