From 609a1767e8ba367350abf3c553d40b68607987e5 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Thu, 15 Feb 2024 00:55:48 +0100 Subject: [PATCH] [`CLeanup`] Revert SDPA attention changes that got in the static kv cache PR (#29027) * revert unrelated changes that got in * style --- .../models/mistral/modeling_mistral.py | 27 ++++++++----------- .../models/mixtral/modeling_mixtral.py | 27 ++++++++----------- .../models/qwen2/modeling_qwen2.py | 27 ++++++++----------- 3 files changed, 33 insertions(+), 48 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index cf8c0329b673d6..f4251b98304c4e 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -659,34 +659,28 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] - past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - past_seen_tokens = kv_seq_len - key_states.shape[-2] - new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - if ( - attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1 - ): # user defined causal mask - causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]] - # this one liner is equivalent to the pad_unpad function - causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None]) - else: - causal_mask = None + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: + if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() @@ -695,9 +689,10 @@ def forward( query_states, key_states, value_states, - attn_mask=causal_mask, + attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=causal_mask is None and q_len > 1, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 7a3870c333e5cf..674ace5f236039 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -736,34 +736,28 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] - past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - past_seen_tokens = kv_seq_len - key_states.shape[-2] - new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - if ( - attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1 - ): # user defined causal mask - causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]] - # this one liner is equivalent to the pad_unpad function - causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None]) - else: - causal_mask = None + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: + if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() @@ -772,9 +766,10 @@ def forward( query_states, key_states, value_states, - attn_mask=causal_mask, + attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=causal_mask is None and q_len > 1, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index fd6447e46b80d3..da0c9b8567752a 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -669,34 +669,28 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] - past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - past_seen_tokens = kv_seq_len - key_states.shape[-2] - new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - if ( - attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1 - ): # user defined causal mask - causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]] - # this one liner is equivalent to the pad_unpad function - causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None]) - else: - causal_mask = None + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: + if query_states.device.type == "cuda" and attention_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() @@ -705,9 +699,10 @@ def forward( query_states, key_states, value_states, - attn_mask=causal_mask, + attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=causal_mask is None and q_len > 1, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous()