Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Llama FA2] Re-add _expand_attention_mask and clean a couple things #27074

Merged
merged 12 commits into from
Oct 26, 2023
17 changes: 10 additions & 7 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def to_4d(
self,
attention_mask_2d: torch.Tensor,
query_length: int,
key_value_length: int,
key_value_length: Optional[int] = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Expand All @@ -157,12 +157,16 @@ def to_4d(
causal, a causal mask will be added.
"""
input_shape = (attention_mask_2d.shape[0], query_length)
past_key_values_length = key_value_length - query_length

# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
if key_value_length is None:
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
)

past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
Expand All @@ -182,8 +186,8 @@ def to_4d(

return expanded_4d_mask

@staticmethod
def _make_causal_mask(
self,
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
Expand Down Expand Up @@ -212,7 +216,8 @@ def _make_causal_mask(

return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
Expand Down Expand Up @@ -837,7 +842,7 @@ def _flash_attention_forward(
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
)

return attn_output
Expand Down Expand Up @@ -1154,8 +1159,6 @@ def __init__(self, config: FalconConfig):
# Embedding + LN Embedding
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)

# create attention mask cache that trickles down to each attention layer
# so that the attention_mask cache can be shared among layers
self.attn_mask_converter = AttnMaskConverter(is_causal=True)

# Transformer blocks
Expand Down
42 changes: 32 additions & 10 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,24 @@ def _get_unpad_data(attention_mask):
)


def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
warnings.warn(
"Calling `transformers.models.llama.modeling_llama._expand_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttnMaskConverter._expand_mask"
)
return AttnMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
Comment on lines +67 to +71
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! We should probably do the same for falcon and mistral as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok 👍🏻



def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
warnings.warn(
"Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.llama.modeling_llama.AttnMaskConverter._make_causal_mask"
)
return AttnMaskConverter._make_causal_mask(
input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
)


class AttnMaskConverter:
"""
A utility attention mask class that allows:
Expand Down Expand Up @@ -122,7 +140,7 @@ def to_4d(
self,
attention_mask_2d: torch.Tensor,
query_length: int,
key_value_length: int,
key_value_length: Optional[int] = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Expand All @@ -131,12 +149,16 @@ def to_4d(
causal, a causal mask will be added.
"""
input_shape = (attention_mask_2d.shape[0], query_length)
past_key_values_length = key_value_length - query_length

# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
if key_value_length is None:
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
)

past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
Expand All @@ -156,8 +178,8 @@ def to_4d(

return expanded_4d_mask

@staticmethod
def _make_causal_mask(
self,
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
Expand Down Expand Up @@ -186,7 +208,8 @@ def _make_causal_mask(

return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
Expand Down Expand Up @@ -555,7 +578,7 @@ def forward(
value_states = self.v_proj(hidden_states)

# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dime x hidden_dim
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
Expand Down Expand Up @@ -669,7 +692,7 @@ def _flash_attention_forward(
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
)

return attn_output
Expand Down Expand Up @@ -739,8 +762,9 @@ def forward(
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
Expand Down Expand Up @@ -914,8 +938,6 @@ def __init__(self, config: LlamaConfig):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

# create attention mask cache that trickles down to each attention layer
# so that the attention_mask cache can be shared among layers
self.attn_mask_converter = AttnMaskConverter(is_causal=True)

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
Expand Down
24 changes: 16 additions & 8 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def to_4d(
self,
attention_mask_2d: torch.Tensor,
query_length: int,
key_value_length: int,
key_value_length: Optional[int] = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Expand All @@ -122,12 +122,16 @@ def to_4d(
causal, a causal mask will be added.
"""
input_shape = (attention_mask_2d.shape[0], query_length)
past_key_values_length = key_value_length - query_length

# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
if key_value_length is None:
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
)

past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
Expand All @@ -147,8 +151,8 @@ def to_4d(

return expanded_4d_mask

@staticmethod
def _make_causal_mask(
self,
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
Expand Down Expand Up @@ -177,7 +181,8 @@ def _make_causal_mask(

return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
Expand Down Expand Up @@ -645,7 +650,12 @@ def _flash_attention_forward(
else:
if not use_sliding_windows:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=self.is_causal,
)
else:
attn_output = flash_attn_func(
Expand All @@ -654,7 +664,7 @@ def _flash_attention_forward(
value_states,
dropout,
softmax_scale=softmax_scale,
causal=True,
causal=self.is_causal,
window_size=(self.config.sliding_window, self.config.sliding_window),
)

Expand Down Expand Up @@ -903,8 +913,6 @@ def __init__(self, config: MistralConfig):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

# create attention mask cache that trickles down to each attention layer
# so that the attention_mask cache can be shared among layers
self.attn_mask_converter = AttnMaskConverter(is_causal=True, sliding_window=config.sliding_window)

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
Expand Down
Loading