From c96696acd0d98d20959500f21c1c756e690e2648 Mon Sep 17 00:00:00 2001 From: Thomas Bauwens Date: Thu, 19 Dec 2024 05:57:48 +0100 Subject: [PATCH] Fix hardcoded dtypes in DeBERTa model causing range mismatches. --- .../models/deberta/modeling_deberta.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index c9a85bcad1bd6f..6d4fedf7246d0e 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -62,7 +62,7 @@ def __init__(self, size, eps=1e-12): def forward(self, hidden_states): input_type = hidden_states.dtype - hidden_states = hidden_states.float() + hidden_states = hidden_states.float() # TODO: Even when working in bfloat16? mean = hidden_states.mean(-1, keepdim=True) variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) hidden_states = (hidden_states - mean) / torch.sqrt(variance + self.variance_epsilon) @@ -134,7 +134,7 @@ def pos_dynamic_expand(pos_index, p2c_att, key_layer): # Full credits to @Szustarol @torch.jit.script def scaled_size_sqrt(query_layer: torch.Tensor, scale_factor: int): - return torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor) + return torch.sqrt(torch.tensor(query_layer.size(-1), dtype=query_layer.dtype) * scale_factor) @torch.jit.script @@ -184,8 +184,8 @@ def __init__(self, config): self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.in_proj = nn.Linear(config.hidden_size, self.all_head_size * 3, bias=False) - self.q_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) - self.v_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) + self.q_bias = nn.Parameter(torch.zeros((self.all_head_size))) + self.v_bias = nn.Parameter(torch.zeros((self.all_head_size))) self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] self.relative_attention = getattr(config, "relative_attention", False) @@ -271,8 +271,7 @@ def forward( rel_att: int = 0 # Take the dot product between "query" and "key" to get the raw attention scores. scale_factor = 1 + len(self.pos_att_type) - scale = scaled_size_sqrt(query_layer, scale_factor) - query_layer = query_layer / scale.to(dtype=query_layer.dtype) + query_layer = query_layer / scaled_size_sqrt(query_layer, scale_factor) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.relative_attention and rel_embeddings is not None and relative_pos is not None: @@ -287,7 +286,7 @@ def forward( attention_scores = self.head_logits_proj(attention_scores.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) attention_mask = attention_mask.bool() - attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min) + attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(attention_scores.dtype).min) # bsz x height x length x dimension attention_probs = nn.functional.softmax(attention_scores, dim=-1) @@ -1133,7 +1132,7 @@ def forward( ) labels = torch.gather(labels, 0, label_index.view(-1)) loss_fct = CrossEntropyLoss() - loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) + loss = loss_fct(labeled_logits.view(-1, self.num_labels).to(dtype=encoder_layer.dtype), labels.view(-1)) else: loss = torch.tensor(0).to(logits) else: