You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
For flash-attn 2.6.3, when using with torch.compile, the following warning was raised: lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:725: UserWarning: Graph break due to unsupported builtin flash_attn_2_cuda.PyCapsule.fwd. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.
Updating to the newest flash-attn 2.7.0.post2 resolved the warning and improved training speed by ~28% compared to version 2.6.3. However, after using flash-attn 2.7.0, the loss perf dropped significantly (see the loss curve below).
I'm using torch 2.5.1, CUDA 12.4, deepspeed 0.15.4 and composer 0.27.0.
Any advice will be greatly appreciated.
The text was updated successfully, but these errors were encountered:
To isolate the issue, can you run flash-attn 2.7.0.post2 but without torch compile? If the quality is fine in eager then the issue is in the way torch compile interacts w flash-attn.
If the issue is with the kernel, can you help us with a script to reproduce the issue: construct some specific tensors (e.g. loading from disk) and run attn forward and backward and show that the output and gradient are drastically different from a reference implementation (standard attention) in fp32.
I was training a Mosaic BERT model with mosaicml composer, torch.compile, deepspeed and flash-attn 2.
For flash-attn 2.6.3, when using with torch.compile, the following warning was raised:
lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:725: UserWarning: Graph break due to unsupported builtin flash_attn_2_cuda.PyCapsule.fwd. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.
Updating to the newest flash-attn 2.7.0.post2 resolved the warning and improved training speed by ~28% compared to version 2.6.3. However, after using flash-attn 2.7.0, the loss perf dropped significantly (see the loss curve below).
I'm using torch 2.5.1, CUDA 12.4, deepspeed 0.15.4 and composer 0.27.0.
Any advice will be greatly appreciated.
The text was updated successfully, but these errors were encountered: