Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Perf slump after updating flash-attn 2.7.0 (with torch.compile using) #1341

Open
Mnb66 opened this issue Nov 16, 2024 · 1 comment
Open

Comments

@Mnb66
Copy link

Mnb66 commented Nov 16, 2024

  • I was training a Mosaic BERT model with mosaicml composer, torch.compile, deepspeed and flash-attn 2.

  • For flash-attn 2.6.3, when using with torch.compile, the following warning was raised: lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:725: UserWarning: Graph break due to unsupported builtin flash_attn_2_cuda.PyCapsule.fwd. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.

  • Updating to the newest flash-attn 2.7.0.post2 resolved the warning and improved training speed by ~28% compared to version 2.6.3. However, after using flash-attn 2.7.0, the loss perf dropped significantly (see the loss curve below).

image

  • I'm using torch 2.5.1, CUDA 12.4, deepspeed 0.15.4 and composer 0.27.0.

  • Any advice will be greatly appreciated.

@tridao
Copy link
Contributor

tridao commented Nov 16, 2024

To isolate the issue, can you run flash-attn 2.7.0.post2 but without torch compile? If the quality is fine in eager then the issue is in the way torch compile interacts w flash-attn.

If the issue is with the kernel, can you help us with a script to reproduce the issue: construct some specific tensors (e.g. loading from disk) and run attn forward and backward and show that the output and gradient are drastically different from a reference implementation (standard attention) in fp32.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants