Skip to content

Commit

Permalink
Bump flash-attn to v2.0.4 (#816)
Browse files Browse the repository at this point in the history
* Bump flash-attn to v2.0.4

* flash-attn v2.0.3 fixed this issue

* prompt user to update submodules again if they forgot to use --recursive

* fix for latest flash-attn function signatures
  • Loading branch information
tmm1 committed Aug 11, 2023
1 parent 1c29213 commit eadc8c6
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 16 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,13 @@ def get_flash_attention_extensions(cuda_version: int, extra_compile_args):
nvcc_windows_flags = ["-Xcompiler", "/permissive-"]

flash_root = os.path.join(this_dir, "third_party", "flash-attention")
if not os.path.exists(flash_root):
cutlass_inc = os.path.join(flash_root, "csrc", "cutlass", "include")
if not os.path.exists(flash_root) or not os.path.exists(cutlass_inc):
raise RuntimeError(
"flashattention submodule not found. Did you forget "
"to run `git submodule update --init --recursive` ?"
)

flash_root = os.path.join(this_dir, "third_party", "flash-attention")
sources = ["csrc/flash_attn/flash_api.cpp"]
for f in glob.glob(os.path.join(flash_root, "csrc", "flash_attn", "src", "*.cu")):
sources.append(str(Path(f).relative_to(flash_root)))
Expand Down
16 changes: 3 additions & 13 deletions xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def _flash_fwd(
out_padded,
softmax_lse,
p,
rng_state,
) = _C_flashattention.varlen_fwd(
query,
key,
Expand All @@ -83,7 +84,7 @@ def _flash_fwd(
return_softmax,
None,
)
return out, softmax_lse, None
return out, softmax_lse, rng_state

def _flash_bwd(
grad,
Expand Down Expand Up @@ -123,6 +124,7 @@ def _flash_bwd(
False, # zero_tensors
causal,
None,
rng_state,
)
return dq

Expand Down Expand Up @@ -302,18 +304,6 @@ class BwOp(AttentionBwOpBase):

MAX_HEADDIM_SM8x = 192

@classmethod
def shape_not_supported_reasons(
cls, Mq: int, Mkv: int, K: int, Kv: int
) -> List[str]:
reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
if (Mq % 128) or (Mkv % 128):
reasons.append(
"flashv2 beta: BW is incorrect when seqlen is not aligned on 128 "
"(https://github.com/Dao-AILab/flash-attention/issues/334)"
)
return reasons

@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(BwOp, cls).not_supported_reasons(d)
Expand Down

0 comments on commit eadc8c6

Please sign in to comment.