DeBERTa's DisentangledSelfAttention
hardcodes float
dtype, which causes bfloat16
overflow error
#35332
Open
2 of 4 tasks
Labels
System Info
transformers: 4.47.0
Python: 3.10.5
PyTorch: 2.5.1+cu124
GPU: NVIDIA GTX 980 Ti
Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I'm training a
DebertaForMaskedLM
model with a broader experimental framework, but you can reproduce the bug with simple inference as follows: instantiate such a model with datatypebfloat16
, and send a batch through it.One of two errors is now thrown in
modeling_deberta.py
, both inDisentangledSelfAttention.forward()
(and they can both be traced back to the same issue):RuntimeError: expected m1 and m2 to have the same dtype, but got: float != struct c10::BFloat16
RuntimeError: value cannot be converted to type at::BFloat16 without overflow
Here's where they come from: two fields in DeBERTa's
DisentangledSelfAttention
are constructed by explicitly declaring theirdtype
astorch.float
:transformers/src/transformers/models/deberta/modeling_deberta.py
Lines 187 to 188 in 9613933
Then, in
forward()
, we create the two tensorsquery_layer
andkey_layer
that start out with thedtype
of the hidden states, which have thedtype
of the model, namelybfloat16
:transformers/src/transformers/models/deberta/modeling_deberta.py
Lines 258 to 259 in 9613933
But then, one of these tensors,
query_layer
, is modified by addingself.q_bias
into it. The resulting tensor inherits thetorch.float
data type:transformers/src/transformers/models/deberta/modeling_deberta.py
Line 268 in 9613933
The first RuntimeError can occur on the following line, when
query_layer
(nowtorch.float
) andkey_layer
(stilltorch.bfloat16
) are multiplied. I've had this line crash on one machine and work on another, so perhaps this kind of mixed precision sometimes works.transformers/src/transformers/models/deberta/modeling_deberta.py
Line 276 in 9613933
The second RuntimeError occurs even when mixed precision is supported. It happens on the following line:
transformers/src/transformers/models/deberta/modeling_deberta.py
Line 290 in 9613933
attention_scores
is of typebfloat16
. You then ask to fill it with the minimal value for the data type ofquery_layer
, not the data type ofattention_scores
. Becausequery_layer.dtype
istorch.float
, that minimal value (-3.40282e+38) is more negative than the most negativetorch.bfloat16
(-3.38953e+38). Hence, the overflow.Expected behavior
The
dtype
ofself.q_bias
andself.v_bias
should be set like the rest of the modules/tensors in the model, rather than being hardcoded. That would keep everythingbfloat16
.The text was updated successfully, but these errors were encountered: