Skip to content

Bugfixes/improvements in `memory_efficient_attention`

Compare
Choose a tag to compare
@danthe3rd danthe3rd released this 06 Dec 16:05
· 226 commits to main since this release

Pre-built binary wheels require PyTorch 2.1.1

Fixed

  • fMHA: Fixed a bug in cutlass backend forward pass where the logsumexp was not correctly calculated, resulting in wrong results in the BW pass. This would happen with MQA when one sequence has a query with length%64 == 1
  • fMHA: Updated Flash-Attention to v2.3.6 - this fixes a performance regression in causal backward passes, and now supports BlockDiagonalCausalWithOffsetPaddedKeysMask

Added

  • fMHA: Added LocalAttentionFromBottomRightMask (local)
  • fMHA: Added LowerTriangularFromBottomRightMask (causal)
  • fMHA: Added LowerTriangularFromBottomRightLocalAttentionMask (local + causal)

Removed

  • Removed xformers.triton.sum_strided