Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump flash-attn to v2.0.4 #816

Merged
merged 4 commits into from
Aug 11, 2023
Merged

Conversation

tmm1
Copy link
Contributor

@tmm1 tmm1 commented Aug 3, 2023

What does this PR do?

Fixes #805

see https://github.com/Dao-AILab/flash-attention/commits/main for recent fixes

cc #712 Dao-AILab/flash-attention#359 Dao-AILab/flash-attention#334
cc #795 @danthe3rd

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 3, 2023
This was referenced Aug 3, 2023
@codecov-commenter
Copy link

codecov-commenter commented Aug 4, 2023

Codecov Report

Patch coverage: 96.15% and project coverage change: +0.12% 🎉

Comparison is base (f525106) 81.73% compared to head (e115d8e) 81.85%.
Report is 3 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #816      +/-   ##
==========================================
+ Coverage   81.73%   81.85%   +0.12%     
==========================================
  Files          96       96              
  Lines        6401     6427      +26     
==========================================
+ Hits         5232     5261      +29     
+ Misses       1169     1166       -3     
Flag Coverage Δ
Python 81.85% <96.15%> (+0.12%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Changed Coverage Δ
xformers/ops/fmha/flash.py 48.46% <80.00%> (-1.17%) ⬇️
xformers/ops/fmha/triton.py 70.37% <100.00%> (+14.67%) ⬆️

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ghunkins
Copy link

ghunkins commented Aug 8, 2023

Would love to see this added!

Copy link
Contributor

@danthe3rd danthe3rd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,
Thanks for opening this PR!
Happy to merge once you have reverted the changes to nvcc_flags in setup.py

setup.py Outdated
Comment on lines 255 to 270
]
extra_compile_args["nvcc"] = nvcc_flags

ext_modules += get_flash_attention_extensions(
cuda_version=cuda_version, extra_compile_args=extra_compile_args
)

# NOTE: This should not be applied to Flash-Attention
# see https://github.com/Dao-AILab/flash-attention/issues/359
extra_compile_args["nvcc"] += [
nvcc_flags += [
# Workaround for a regression with nvcc > 11.6
# See https://github.com/facebookresearch/xformers/issues/712
"--ptxas-options=-O2",
"--ptxas-options=-allow-expensive-optimizations=true",
]
extra_compile_args["nvcc"] = nvcc_flags

ext_modules += get_flash_attention_extensions(
cuda_version=cuda_version, extra_compile_args=extra_compile_args
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we revert these changes? I don't think we need to use O2 for Flash-attention (and also might make performance worse)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, i pulled out that commit. thanks for the feedback.

Copy link
Contributor

@danthe3rd danthe3rd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Thank you for your contribution

@danthe3rd danthe3rd merged commit eadc8c6 into facebookresearch:main Aug 11, 2023
1 check passed
@fmassa
Copy link
Contributor

fmassa commented Aug 11, 2023

Looks like this PR gives wrong results for flashattn backend. We need to investigate this, but in the meantime we might revert it

@tmm1
Copy link
Contributor Author

tmm1 commented Aug 11, 2023

Is there a failing test somewhere, or what exactly is the issue?

@fmassa
Copy link
Contributor

fmassa commented Aug 11, 2023

There are some failing tests (which run on internal infra), but not as many as I originally thought.

Here they are
image

@danthe3rd
Copy link
Contributor

Here is a simple repro (I get the failures on A100 at least):

$ python -m pytest tests/test_mem_eff_attention.py -k "test_backward[flshattBv2-cuda-torch.float16-BlockDiagonalCausalMask"

===================================================================================== short test summary info ======================================================================================
FAILED tests/test_mem_eff_attention.py::test_backward[flshattBv2-cuda-torch.float16-BlockDiagonalCausalMask-1-256-2-1-32-32-False-BMHK] - AssertionError: cutlassF+flshattBv2:query: out=nan and ref=0.0 (diff=nan > 0) at (0, 174, 0, 0) of shape (1, 256, 1, 32) / atol=0.09, rtol=0.02/ total failing elements: 0, percentage=0.0
FAILED tests/test_mem_eff_attention.py::test_backward[flshattBv2-cuda-torch.float16-BlockDiagonalCausalMask-1-256-2-1-32-32-True-BMHK] - AssertionError: cutlassF+flshattBv2:query: out=nan and ref=0.0 (diff=nan > 0) at (0, 174, 0, 0) of shape (1, 256, 1, 32) / atol=0.09, rtol=0.02/ total failing elements: 0, percentage=0.0
FAILED tests/test_mem_eff_attention.py::test_backward[flshattBv2-cuda-torch.float16-BlockDiagonalCausalMask-1-256-15-1-32-32-False-BMHK] - AssertionError: flshattFv2+flshattBv2:query: out=nan and ref=0.0 (diff=nan > 0) at (0, 245, 0, 0) of shape (1, 256, 1, 32) / atol=0.09, rtol=0.02/ total failing elements: 0, percentage=0.0
FAILED tests/test_mem_eff_attention.py::test_backward[flshattBv2-cuda-torch.float16-BlockDiagonalCausalMask-1-256-15-1-32-32-True-BMHK] - AssertionError: flshattFv2+flshattBv2:query: out=nan and ref=0.0 (diff=nan > 0) at (0, 245, 0, 0) of shape (1, 256, 1, 32) / atol=0.09, rtol=0.02/ total failing elements: 0, percentage=0.0
==================================================================== 4 failed, 120 passed, 8 skipped, 14123 deselected in 8.11s ====================================================================

I'll investigate this and send an MVP to Tri if it's a bug in Flash-Attention

@danthe3rd
Copy link
Contributor

This is indeed a bug in Flash-Attention - I opened an issue in Dao-AILab/flash-attention#443

@tmm1
Copy link
Contributor Author

tmm1 commented Aug 11, 2023

@danthe3rd in your bug repro it still fails with bfloat16 for me, but the tests for bfloat16 pass here?

pytest tests/test_mem_eff_attention.py -vk "test_backward and flshattBv2-cuda and torch.bfloat16 and BlockDiagonalCausalMask"

@tmm1
Copy link
Contributor Author

tmm1 commented Aug 11, 2023

Tests pass again if we revert 698532d as well.

@danthe3rd
Copy link
Contributor

Yes indeed. I was hopping that Tri might have some insight on what is causing this bug if there is a narrower condition maybe. Might also be specific to variable sequence length (BlockDiagonalCausalMask)

@jinqiua
Copy link

jinqiua commented Aug 17, 2023

I try to change xops.fmha.cutlass.FwOp() to xops.fmha.flash.FwOp(), but it doesn`t get any speedup result.
(I use Xformer 0.0.21+ba5b449.d20230817), @tmm1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Rebuild latest wheels on main for FlashAttention 2
7 participants