From 8d1a8e5eecc17171acc93e656e360cd7446d904c Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Thu, 3 Aug 2023 17:55:36 +0000 Subject: [PATCH 1/4] Bump flash-attn to v2.0.4 --- third_party/flash-attention | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/flash-attention b/third_party/flash-attention index 9ee0ff1d9..d30f2e1cd 160000 --- a/third_party/flash-attention +++ b/third_party/flash-attention @@ -1 +1 @@ -Subproject commit 9ee0ff1d9b6a99630e2a6868b9291dfa32d35abd +Subproject commit d30f2e1cd50185c98ed88c0684b4a603f15bee37 From 698532d184c0d404e0fc8d25b901cc082d52c01d Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Thu, 3 Aug 2023 18:37:29 +0000 Subject: [PATCH 2/4] flash-attn v2.0.3 fixed this issue --- xformers/ops/fmha/flash.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 87e13d24a..3e54b6705 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -302,18 +302,6 @@ class BwOp(AttentionBwOpBase): MAX_HEADDIM_SM8x = 192 - @classmethod - def shape_not_supported_reasons( - cls, Mq: int, Mkv: int, K: int, Kv: int - ) -> List[str]: - reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv) - if (Mq % 128) or (Mkv % 128): - reasons.append( - "flashv2 beta: BW is incorrect when seqlen is not aligned on 128 " - "(https://github.com/Dao-AILab/flash-attention/issues/334)" - ) - return reasons - @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons = super(BwOp, cls).not_supported_reasons(d) From fc48943e0e890577fda325ae16be3b077881a948 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Thu, 3 Aug 2023 19:12:50 +0000 Subject: [PATCH 3/4] prompt user to update submodules again if they forgot to use --recursive --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index e009cb7ff..8b26f62fd 100644 --- a/setup.py +++ b/setup.py @@ -143,13 +143,13 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args): nvcc_windows_flags = ["-Xcompiler", "/permissive-"] flash_root = os.path.join(this_dir, "third_party", "flash-attention") - if not os.path.exists(flash_root): + cutlass_inc = os.path.join(flash_root, "csrc", "cutlass", "include") + if not os.path.exists(flash_root) or not os.path.exists(cutlass_inc): raise RuntimeError( "flashattention submodule not found. Did you forget " "to run `git submodule update --init --recursive` ?" ) - flash_root = os.path.join(this_dir, "third_party", "flash-attention") sources = ["csrc/flash_attn/flash_api.cpp"] for f in glob.glob(os.path.join(flash_root, "csrc", "flash_attn", "src", "*.cu")): sources.append(str(Path(f).relative_to(flash_root))) From e654647a81b9f54c2cfe2c18215a5ef73e507526 Mon Sep 17 00:00:00 2001 From: Aman Karmani Date: Fri, 4 Aug 2023 18:10:26 +0000 Subject: [PATCH 4/4] fix for latest flash-attn function signatures --- xformers/ops/fmha/flash.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 3e54b6705..ee3d8257c 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -67,6 +67,7 @@ def _flash_fwd( out_padded, softmax_lse, p, + rng_state, ) = _C_flashattention.varlen_fwd( query, key, @@ -83,7 +84,7 @@ def _flash_fwd( return_softmax, None, ) - return out, softmax_lse, None + return out, softmax_lse, rng_state def _flash_bwd( grad, @@ -123,6 +124,7 @@ def _flash_bwd( False, # zero_tensors causal, None, + rng_state, ) return dq