Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor flash attention implementation in transformers #31446

Merged
merged 62 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
b66fdb0
dumb commit
ArthurZucker May 31, 2024
029ee11
nit
ArthurZucker Jun 17, 2024
9a7885d
Merge branch 'main' into backend-compatible
ArthurZucker Jun 17, 2024
ac3e5b5
update
ArthurZucker Jun 17, 2024
a7c48bd
something like this
ArthurZucker Jun 17, 2024
682f221
unpack in modeling utils
ArthurZucker Jun 17, 2024
2201178
safe import
ArthurZucker Jun 17, 2024
55a3503
oups
ArthurZucker Jun 21, 2024
b5cbaef
update
ArthurZucker Jun 26, 2024
7c6fdd7
nits
ArthurZucker Jun 26, 2024
08d7e1e
diff convert gemma
ArthurZucker Jun 26, 2024
27044da
update
ArthurZucker Jun 26, 2024
ca316a0
start propagating
ArthurZucker Jun 26, 2024
ea93267
udpate other modeling code as well
ArthurZucker Jun 26, 2024
4b67223
update for sliding window models
ArthurZucker Jun 26, 2024
d59ac0c
nits
ArthurZucker Jun 26, 2024
a1d3866
more init cleanups
ArthurZucker Jun 26, 2024
aea7f03
styling
ArthurZucker Jun 26, 2024
f1bedd0
fixup
ArthurZucker Jun 26, 2024
86e2edc
noice
ArthurZucker Jun 26, 2024
e90a944
pass fixup
ArthurZucker Jun 26, 2024
093fbf5
typo typing_extension -> typing_extensions
ArthurZucker Jun 26, 2024
a1c56d2
torch.nn.functionnal -> torch.nn.functional
ArthurZucker Jun 26, 2024
1aad4a2
add to import structure
ArthurZucker Jun 26, 2024
10bc1fa
unpack
ArthurZucker Jun 26, 2024
9f08ddb
simplify a bit more for this first version
ArthurZucker Jun 26, 2024
2e65e57
nut
ArthurZucker Jun 26, 2024
f8622e6
update
ArthurZucker Jun 26, 2024
2bb4347
update
ArthurZucker Jun 26, 2024
9be7579
nit
ArthurZucker Jun 26, 2024
889cbf8
ease the import of `Unpack`
ArthurZucker Jun 26, 2024
070af2d
remove useless `use_sliding_window`
ArthurZucker Jun 26, 2024
80057a0
no qua please
ArthurZucker Jun 26, 2024
c0b024d
protect import?
ArthurZucker Jun 26, 2024
8f7d1c1
style
ArthurZucker Jun 26, 2024
46b77f9
[run-slow]
ArthurZucker Jun 26, 2024
4a98ee7
[run slow] llama,gemma,mistral,mixtral
ArthurZucker Jun 26, 2024
25b2c10
remove extra kwargs
ArthurZucker Jun 26, 2024
8c3780d
Merge branch 'main' of github.com:huggingface/transformers into backe…
ArthurZucker Jun 26, 2024
1d38dab
fix llama
ArthurZucker Jun 26, 2024
f64864a
address review comments
fxmarty Jul 1, 2024
565c5dc
apply diff_model_converter to modeling_gemma.py
fxmarty Jul 1, 2024
2403ce5
Merge branch 'main' into backend-compatible
fxmarty Jul 1, 2024
c89571d
remove cache_position 1
fxmarty Jul 2, 2024
32c2df8
remove cache_position 2
fxmarty Jul 2, 2024
54a9fb0
some cleaning
fxmarty Jul 2, 2024
206731e
refactor gemma2 as well
fxmarty Jul 2, 2024
7c65fc7
Merge branch 'main' into backend-compatible
fxmarty Jul 2, 2024
1be8c31
apply review comments
fxmarty Jul 3, 2024
8d181ea
rename file to modeling_flash_attention_utils.py
fxmarty Jul 3, 2024
3a30cb6
Merge branch 'main' into backend-compatible
fxmarty Jul 8, 2024
c92028a
siglip refactor
fxmarty Jul 8, 2024
7243993
remove dead code
fxmarty Jul 8, 2024
8b077d8
is the hub down?
fxmarty Jul 8, 2024
a9796bc
still down?
fxmarty Jul 9, 2024
6752a9c
fix siglip
fxmarty Jul 10, 2024
3a9cf1b
Merge branch 'main' into backend-compatible
fxmarty Jul 11, 2024
b4d1df5
fix gemma2
fxmarty Jul 11, 2024
1e1bc2f
fatal: Could not read from remote repository.
fxmarty Jul 11, 2024
c79ca83
fix typo in softcap implem
fxmarty Jul 11, 2024
30dc123
flacky
fxmarty Jul 11, 2024
fae6843
Failed: Timeout >120.0s
fxmarty Jul 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,7 @@
"WhisperTimeStampLogitsProcessor",
]
)
_import_structure["modeling_flash_attention_utils"]: []
_import_structure["modeling_outputs"] = []
_import_structure["modeling_utils"] = ["PreTrainedModel"]

Expand Down
211 changes: 211 additions & 0 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# coding=utf-8
# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Optional, Tuple

import torch
import torch.nn.functional as F

from .utils import is_flash_attn_2_available


if is_flash_attn_2_available():
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from flash_attn import flash_attn_func, flash_attn_varlen_func

_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)


def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""
Retrieves indexing data required to repad unpadded (ragged) tensors.

Arguments:
attention_mask (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.

Return:
indices (`torch.Tensor):
The indices of non-masked tokens from the flattened input sequence.
cu_seqlens (`torch.Tensor`):
The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
max_seqlen_in_batch (`int`):
Maximum sequence length in batch.
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)


def _upad_input(
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor,
attention_mask: torch.Tensor,
query_length: int,
):
"""
Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.

This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
tensors for query, key, value tensors.

Arguments:
query_layer (`torch.Tensor`):
Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
key_layer (`torch.Tensor`):
Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
value_layer (`torch.Tensor`):
Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
attention_mask (`torch.Tensor`):
Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
query_length (`int`):
Target length.

Return:
query_layer (`torch.Tensor):
Query state without padding. Shape: (total_target_length, num_heads, head_dim).
key_layer (`torch.Tensor`):
Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
value_layer (`torch.Tensor`):
Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
indices_q (`torch.Tensor`):
The indices of non-masked tokens from the flattened input target sequence.
(cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
(max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
"""
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)


def _flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: torch.Tensor,
query_length: int,
is_causal: bool,
dropout: float = 0.0,
softmax_scale: Optional[float] = None,
sliding_window: Optional[int] = None,
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.

Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
use_top_left_mask (`bool`, defaults to `False`):
flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
softcap (`float`, *optional*):
Softcap for the attention logits, used e.g. in gemma2.
"""
if not use_top_left_mask:
causal = is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
causal = is_causal and query_length != 1

# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
use_sliding_windows = (
_flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
)
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}

if softcap is not None:
flash_kwargs["softcap"] = softcap

# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
)

return attn_output
126 changes: 11 additions & 115 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@


if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from ...modeling_flash_attention_utils import _flash_attention_forward


logger = logging.get_logger(__name__)
Expand All @@ -65,19 +64,6 @@
_CONFIG_FOR_DOC = "BarkConfig"


# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)


class BarkSelfAttention(nn.Module):
# adapted from GPTNeoSelfAttention and Bark code
# BarkSelfAttention can have two attention type, i.e full attention or causal attention
Expand Down Expand Up @@ -270,7 +256,16 @@ def forward(
else:
present = None

attn_output = self._flash_attention_forward(query, key, value, attention_mask, query_len, dropout=self.dropout)
attn_output = _flash_attention_forward(
query,
key,
value,
attention_mask,
query_len,
dropout=self.dropout,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
)

attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.out_proj(attn_output)
Expand All @@ -283,105 +278,6 @@ def forward(

return outputs

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.

Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1

# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)

attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)

return attn_output

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)


BARK_ATTENTION_CLASSES = {
"eager": BarkSelfAttention,
Expand Down
Loading
Loading