Skip to content

Commit

Permalink
Flash-attn performance: remove cuda sync during inference (huggingfac…
Browse files Browse the repository at this point in the history
…e#33570)

Switch conditions to use short-circuit during inference
  • Loading branch information
Cyrilvallez authored and BernardZach committed Dec 5, 2024
1 parent 08e87c0 commit 803e2e1
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ def _flash_attention_forward(
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
elif position_ids is not None and not (torch.diff(position_ids, dim=-1) >= 0).all() and query_length != 1:
# Note: the `torch.diff(...)` condition is last to use short-circuit and avoid the cuda synchronization it incurs during inference (query_length == 1 always)
elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
batch_size = query_states.size(0)
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
Expand Down

0 comments on commit 803e2e1

Please sign in to comment.