From 65f7edc86779eab4bb1fc6e529fbadc393fbbcb5 Mon Sep 17 00:00:00 2001 From: Bellk17 Date: Wed, 10 Apr 2024 17:13:32 -0700 Subject: [PATCH 1/2] Fix triton compilation issue Solves the following compilation error with triton flash attention backend enabled. .../vllm/attention/ops/triton_flash_attention.py: ... triton.compiler.errors.UnsupportedLanguageConstruct: at 120:14: # + offs_m # We store inf to LSE, not -inf because in the bwd pass, # we subtract this # from qk which makes it -inf, such that exp(qk - inf) = 0 # for these masked blocks. # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) # tl.store(l_ptrs, l) # TODO: Should dropout and return encoded softmax be handled here? return is_mqa = hq != hk off_h_k = off_h_q % hk if is_mqa else off_h_q ^ Triton does not support `if` expressions (ternary operators) with dynamic conditions, use `if` statements instead ... --- vllm/attention/ops/triton_flash_attention.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 87cf30cbef79a..ca7d217d44eff 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -415,7 +415,11 @@ def attn_fwd( return is_mqa = hq != hk - off_h_k = off_h_q % hk if is_mqa else off_h_q + if is_mqa: + off_h_k = off_h_q % hk + else: + off_h_k = off_h_q + n_extra_tokens = 0 if seqlen_k < BLOCK_N: n_extra_tokens = BLOCK_N - seqlen_k From fae4f82cb5d8027f47da7ffa0dba67c5516d85b6 Mon Sep 17 00:00:00 2001 From: Bellk17 Date: Thu, 11 Apr 2024 16:16:06 -0700 Subject: [PATCH 2/2] Update vllm/attention/ops/triton_flash_attention.py Co-authored-by: Woosuk Kwon --- vllm/attention/ops/triton_flash_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index ca7d217d44eff..e160411859f0b 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -415,7 +415,7 @@ def attn_fwd( return is_mqa = hq != hk - if is_mqa: + if is_mqa: # noqa: SIM108 off_h_k = off_h_q % hk else: off_h_k = off_h_q