Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

What's up with autocast in PEFT? #971

Closed
4 tasks
MFajcik opened this issue Sep 27, 2023 · 8 comments
Closed
4 tasks

What's up with autocast in PEFT? #971

MFajcik opened this issue Sep 27, 2023 · 8 comments

Comments

@MFajcik
Copy link
Contributor

MFajcik commented Sep 27, 2023

Hi all,

I have noticed that the attention function of my GPT-2 (_attn in modeling_gpt2.py) received float32 despite using autocast to bf16 in the context manager. Everything works alright when turning off Lora and PEFT.

Library versions:

peft: '0.4.0'
transformers '4.28.1'
torch '2.0.1+cu117'

Am I doing something wrong?

Cheers,
Martin

Who can help?

@pac

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

Minimum working code:

import torch
from peft import get_peft_model, LoraConfig
from transformers import AutoModel

if __name__ == "__main__":
    model = AutoModel.from_pretrained("gpt2").cuda().bfloat16()
    lora_cfg = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.05, bias="lora_only", task_type="lm")
    model = get_peft_model(model, lora_cfg)

    # prepare dummy inputs
    input_ids, labels = torch.randint(0, 1000, (2, 10)).cuda(), torch.randint(0, 1000, (10,)).cuda()

    # forward pass
    with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
        outputs = model(input_ids)
        loss, logits = outputs[:2]

Check the dtype when entering _attn function

def _attn(self, query, key, value, attention_mask=None, head_mask=None):
      attn_weights = torch.matmul(query, key.transpose(-1, -2))
...
>>> key.dtype # with peft
torch.float32

>>> key.dtype # without peft
torch.bfloat16

Expected behavior

autocast works

@MFajcik
Copy link
Contributor Author

MFajcik commented Oct 2, 2023

So the problem is in peft.tuners.lora's Linear class, forward method.

    def forward(self, x: torch.Tensor):
        previous_dtype = x.dtype
        if self.active_adapter not in self.lora_A.keys():
            return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
        if self.disable_adapters:
            if self.r[self.active_adapter] > 0 and self.merged:
                self.unmerge()
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
        elif self.r[self.active_adapter] > 0 and not self.merged:
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

            x = x.to(self.lora_A[self.active_adapter].weight.dtype)

            result += (
                self.lora_B[self.active_adapter](
                    self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
                )
                * self.scaling[self.active_adapter]
            )
        else:
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

        result = result.to(previous_dtype)

The result is assumed to be of the same type as the input type. However, in general, this is not True.

For instance, in GPT-2, the inputs to linear are layernormed hidden states, which are retained in fp32 (inputs to layernorm are bf16, layernorm params are in bf16, but outputs are in fp32!). But the Conv1D linear operator applied on the fp32 operand, with bf16 parameters, returns bf16 dtype (Fp32xBF16 -> BF16) again! However, the lora code above, which replaces the Conv1D operator in GPT-2, will return again fp32(Fp32xBF16 -> FP32) (so the subsequent operations are then also computed in fp32, and everything turns slower...)!

So the solution is to retain the data type of F.linear in this case. But won't this break anything in general (e.g., when the model is quantized)?

@BenjaminBossan
Copy link
Member

Interesting, thanks a lot for digging deeper. I'm not sure if we can make a general statement about the expected dtype of the output, would that require us to basically know the dtype that would have been produced without applying PEFT?

So the solution is to retain the data type of F.linear in this case.

Do you mean the dtype of the weight?

But won't this break anything in general (e.g., when the model is quantized)?

Note that for quantized layers, we use different classes (see bnb.py and gptq.py), so changes here should not affect those classes.

@MFajcik
Copy link
Contributor Author

MFajcik commented Oct 3, 2023

Do you mean the dtype of the weight?

No, here I meant the output type of F.linear, so the dtype of result assignment after the F.linear call. The dtype of the weight does not seem to guarantee anything.

would that require us to basically know the dtype that would have been produced without applying PEFT?

yes exactly. It is certainly possible here (F.linear is called anyway), but in general, I don't know how to predetermine the function output dtype without calling the function itself.

@BenjaminBossan
Copy link
Member

Would the issue be solved if, instead of converting to previous_dtype, we convert to the dtype of the result of F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)?

@MFajcik
Copy link
Contributor Author

MFajcik commented Oct 5, 2023

@BenjaminBossan , yes. That is what I meant by my first comment in my previous response.

@BenjaminBossan
Copy link
Member

Could you please check if making that change would solve your initial problem and report back? If it does, we can try if this change works in general (not sure about that yet) and implement it.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@github-actions github-actions bot closed this as completed Nov 6, 2023
@MFajcik
Copy link
Contributor Author

MFajcik commented Dec 6, 2023

This is being resolved in PR #1010.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants