diff --git a/setup.py b/setup.py index e009cb7ff..8b26f62fd 100644 --- a/setup.py +++ b/setup.py @@ -143,13 +143,13 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args): nvcc_windows_flags = ["-Xcompiler", "/permissive-"] flash_root = os.path.join(this_dir, "third_party", "flash-attention") - if not os.path.exists(flash_root): + cutlass_inc = os.path.join(flash_root, "csrc", "cutlass", "include") + if not os.path.exists(flash_root) or not os.path.exists(cutlass_inc): raise RuntimeError( "flashattention submodule not found. Did you forget " "to run `git submodule update --init --recursive` ?" ) - flash_root = os.path.join(this_dir, "third_party", "flash-attention") sources = ["csrc/flash_attn/flash_api.cpp"] for f in glob.glob(os.path.join(flash_root, "csrc", "flash_attn", "src", "*.cu")): sources.append(str(Path(f).relative_to(flash_root))) diff --git a/third_party/flash-attention b/third_party/flash-attention index 9ee0ff1d9..d30f2e1cd 160000 --- a/third_party/flash-attention +++ b/third_party/flash-attention @@ -1 +1 @@ -Subproject commit 9ee0ff1d9b6a99630e2a6868b9291dfa32d35abd +Subproject commit d30f2e1cd50185c98ed88c0684b4a603f15bee37 diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 87e13d24a..ee3d8257c 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -67,6 +67,7 @@ def _flash_fwd( out_padded, softmax_lse, p, + rng_state, ) = _C_flashattention.varlen_fwd( query, key, @@ -83,7 +84,7 @@ def _flash_fwd( return_softmax, None, ) - return out, softmax_lse, None + return out, softmax_lse, rng_state def _flash_bwd( grad, @@ -123,6 +124,7 @@ def _flash_bwd( False, # zero_tensors causal, None, + rng_state, ) return dq @@ -302,18 +304,6 @@ class BwOp(AttentionBwOpBase): MAX_HEADDIM_SM8x = 192 - @classmethod - def shape_not_supported_reasons( - cls, Mq: int, Mkv: int, K: int, Kv: int - ) -> List[str]: - reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv) - if (Mq % 128) or (Mkv % 128): - reasons.append( - "flashv2 beta: BW is incorrect when seqlen is not aligned on 128 " - "(https://github.com/Dao-AILab/flash-attention/issues/334)" - ) - return reasons - @classmethod def not_supported_reasons(cls, d: Inputs) -> List[str]: reasons = super(BwOp, cls).not_supported_reasons(d)