From eadc8c63dc3db42ebd95b8d4f44ca6244039cfd7 Mon Sep 17 00:00:00 2001 From: Aman Gupta Karmani Date: Fri, 11 Aug 2023 02:00:55 -0700 Subject: [PATCH] Bump flash-attn to v2.0.4 (#816) * Bump flash-attn to v2.0.4 * flash-attn v2.0.3 fixed this issue * prompt user to update submodules again if they forgot to use --recursive * fix for latest flash-attn function signatures --- setup.py | 4 ++-- third_party/flash-attention | 2 +- xformers/ops/fmha/flash.py | 16 +++------------- 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/setup.py b/setup.py index e009cb7ff0..8b26f62fdd 100644 --- a/setup.py +++ b/setup.py @@ -143,13 +143,13 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args): nvcc_windows_flags = ["-Xcompiler", "/permissive-"] flash_root = os.path.join(this_dir, "third_party", "flash-attention") - if not os.path.exists(flash_root): + cutlass_inc = os.path.join(flash_root, "csrc", "cutlass", "include") + if not os.path.exists(flash_root) or not os.path.exists(cutlass_inc): raise RuntimeError( "flashattention submodule not found. Did you forget " "to run `git submodule update --init --recursive` ?" ) - flash_root = os.path.join(this_dir, "third_party", "flash-attention") sources = ["csrc/flash_attn/flash_api.cpp"] for f in glob.glob(os.path.join(flash_root, "csrc", "flash_attn", "src", "*.cu")): sources.append(str(Path(f).relative_to(flash_root))) diff --git a/third_party/flash-attention b/third_party/flash-attention index 9ee0ff1d9b..d30f2e1cd5 160000 --- a/third_party/flash-attention +++ b/third_party/flash-attention @@ -1 +1 @@ -Subproject commit 9ee0ff1d9b6a99630e2a6868b9291dfa32d35abd +Subproject commit d30f2e1cd50185c98ed88c0684b4a603f15bee37 diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 87e13d24a5..ee3d8257cf 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -67,6 +67,7 @@ def _flash_fwd( out_padded, softmax_lse, p, + rng_state, ) = _C_flashattention.varlen_fwd( query, key, @@ -83,7 +84,7 @@ def _flash_fwd( return_softmax, None, ) - return out, softmax_lse, None + return out, softmax_lse, rng_state def _flash_bwd( grad, @@ -123,6 +124,7 @@ def _flash_bwd( False, # zero_tensors causal, None, + rng_state, ) return dq @@ -302,18 +304,6 @@ class BwOp(AttentionBwOpBase): MAX_HEADDIM_SM8x = 192 - @classmethod - def shape_not_supported_reasons( - cls, Mq: int, Mkv: int, K: int, Kv: int - ) -> List[str]: - reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv) - if (Mq % 128) or (Mkv % 128): - reasons.append( - "flashv2 beta: BW is incorrect when seqlen is not aligned on 128 " - "(https://github.com/Dao-AILab/flash-attention/issues/334)" - ) - return reasons - @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons = super(BwOp, cls).not_supported_reasons(d)