Skip to content

Commit

Permalink
fix for latest flash-attn function signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
tmm1 committed Aug 11, 2023
1 parent fc48943 commit e654647
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def _flash_fwd(
out_padded,
softmax_lse,
p,
rng_state,
) = _C_flashattention.varlen_fwd(
query,
key,
Expand All @@ -83,7 +84,7 @@ def _flash_fwd(
return_softmax,
None,
)
return out, softmax_lse, None
return out, softmax_lse, rng_state

def _flash_bwd(
grad,
Expand Down Expand Up @@ -123,6 +124,7 @@ def _flash_bwd(
False, # zero_tensors
causal,
None,
rng_state,
)
return dq

Expand Down

0 comments on commit e654647

Please sign in to comment.