From 803e2e1c11d3d38dc259efc08297a297e0bb7aa5 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 7 Oct 2024 09:52:19 +0200 Subject: [PATCH] Flash-attn performance: remove cuda sync during inference (#33570) Switch conditions to use short-circuit during inference --- src/transformers/modeling_flash_attention_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 44e61825dd9cd6..da961c6060e499 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -267,7 +267,8 @@ def _flash_attention_forward( # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach - elif position_ids is not None and not (torch.diff(position_ids, dim=-1) >= 0).all() and query_length != 1: + # Note: the `torch.diff(...)` condition is last to use short-circuit and avoid the cuda synchronization it incurs during inference (query_length == 1 always) + elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all(): batch_size = query_states.size(0) query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( query_states, key_states, value_states, position_ids