-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
What's up with autocast in PEFT? #971
Comments
So the problem is in def forward(self, x: torch.Tensor):
previous_dtype = x.dtype
if self.active_adapter not in self.lora_A.keys():
return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
if self.disable_adapters:
if self.r[self.active_adapter] > 0 and self.merged:
self.unmerge()
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
elif self.r[self.active_adapter] > 0 and not self.merged:
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
x = x.to(self.lora_A[self.active_adapter].weight.dtype)
result += (
self.lora_B[self.active_adapter](
self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
)
* self.scaling[self.active_adapter]
)
else:
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
result = result.to(previous_dtype) The result is assumed to be of the same type as the input type. However, in general, this is not True. For instance, in GPT-2, the inputs to linear are layernormed hidden states, which are retained in fp32 (inputs to layernorm are bf16, layernorm params are in bf16, but outputs are in fp32!). But the Conv1D linear operator applied on the fp32 operand, with bf16 parameters, returns bf16 dtype (Fp32xBF16 -> BF16) again! However, the lora code above, which replaces the Conv1D operator in GPT-2, will return again fp32(Fp32xBF16 -> FP32) (so the subsequent operations are then also computed in fp32, and everything turns slower...)! So the solution is to retain the data type of F.linear in this case. But won't this break anything in general (e.g., when the model is quantized)? |
Interesting, thanks a lot for digging deeper. I'm not sure if we can make a general statement about the expected dtype of the output, would that require us to basically know the dtype that would have been produced without applying PEFT?
Do you mean the dtype of the weight?
Note that for quantized layers, we use different classes (see |
No, here I meant the output type of
yes exactly. It is certainly possible here (F.linear is called anyway), but in general, I don't know how to predetermine the function output dtype without calling the function itself. |
Would the issue be solved if, instead of converting to |
@BenjaminBossan , yes. That is what I meant by my first comment in my previous response. |
Could you please check if making that change would solve your initial problem and report back? If it does, we can try if this change works in general (not sure about that yet) and implement it. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
This is being resolved in PR #1010. |
Hi all,
I have noticed that the attention function of my GPT-2 (_attn in modeling_gpt2.py) received float32 despite using autocast to bf16 in the context manager. Everything works alright when turning off Lora and PEFT.
Library versions:
Am I doing something wrong?
Cheers,
Martin
Who can help?
@pac
Information
Tasks
examples
folderReproduction
Minimum working code:
Check the dtype when entering
_attn
functionExpected behavior
autocast works
The text was updated successfully, but these errors were encountered: