-
Notifications
You must be signed in to change notification settings - Fork 617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use cutlass for memory-efficient attention #362
Changes from 56 commits
f526651
b0a6c91
5de87c4
267cc4e
21ab567
921c637
7078b1e
9079d7e
01c8edc
0282b39
3d0e645
f6e0c8c
4080f90
3426755
f8cb6d9
0b46be0
9611baa
f698a5e
1f26b59
cbfef46
fd424e3
9fb88bd
f79c017
fe5f615
216fa27
4fbe4e9
b1cd83c
0d05f69
feae957
1907b68
c32053f
3602c06
f187e25
c8d488e
93a75b7
256f2d4
1cce7fc
19d1cce
a04fae4
2568a84
71c2eab
573ed14
36cf435
daac694
ccf7d15
db0b9a7
7d11238
579eace
8b61b0b
ff52718
bb616fa
67ecf34
4bc3588
0de2f12
4ef8439
3946ab8
b1dd378
c601866
1e17161
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ omit = | |
xformers/benchmarks/* | ||
xformers/triton/k_* | ||
stubs/* | ||
third_party/* |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,3 +52,5 @@ examples/data | |
# Hydra default output dir | ||
multirun | ||
outputs | ||
|
||
.benchmarks |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
[submodule "third_party/flash-attention"] | ||
path = third_party/flash-attention | ||
url = https://github.com/HazyResearch/flash-attention.git | ||
[submodule "third_party/cutlass"] | ||
path = third_party/cutlass | ||
url = https://github.com/fmassa/cutlass.git | ||
branch = updates_for_mha |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
[settings] | ||
known_third_party =fvcore,hydra,input_pipeline,matplotlib,numpy,omegaconf,pandas,pl_bolts,pyre_extensions,pytest,pytorch_lightning,ragged_inference,recommonmark,seaborn,setuptools,sklearn,submitit,tensorflow,timm,torch,torchmetrics,torchvision,tqdm,triton,typing_extensions | ||
skip_glob=third_party/* |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,6 +53,9 @@ There are two ways you can install xFormers locally: | |
|
||
```bash | ||
git clone git@github.com:facebookresearch/xformers.git | ||
git submodule update --init --recursive | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice, I was chacking that when seeing that xformers now has two submodules, perfect. Thanks |
||
conda create --name xformer_env python=3.8 | ||
conda activate xformer_env | ||
cd xformers | ||
pip install -r requirements.txt | ||
pip install -e . | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,9 @@ | |
import os | ||
import re | ||
import shutil | ||
import subprocess | ||
import sys | ||
from pathlib import Path | ||
|
||
import setuptools | ||
import torch | ||
|
@@ -44,6 +46,84 @@ def find_version(version_file_path): | |
raise RuntimeError("Unable to find version string.") | ||
|
||
|
||
def get_cuda_version(cuda_dir) -> int: | ||
nvcc_bin = "nvcc" if cuda_dir is None else cuda_dir + "/bin/nvcc" | ||
raw_output = subprocess.check_output([nvcc_bin, "-V"], universal_newlines=True) | ||
output = raw_output.split() | ||
release_idx = output.index("release") + 1 | ||
release = output[release_idx].split(".") | ||
bare_metal_major = int(release[0]) | ||
bare_metal_minor = int(release[1][0]) | ||
|
||
assert bare_metal_minor < 100 | ||
return bare_metal_major * 100 + bare_metal_minor | ||
|
||
|
||
def get_flash_attention_extensions(cuda_version: int, extra_compile_args): | ||
# Figure out default archs to target | ||
DEFAULT_ARCHS_LIST = "" | ||
if cuda_version > 1100: | ||
DEFAULT_ARCHS_LIST = "7.5;8.0;8.6" | ||
elif cuda_version >= 1100: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit, but cuda_version == 1100 in that case, right ? |
||
DEFAULT_ARCHS_LIST = "7.5;8.0" | ||
else: | ||
return [] | ||
|
||
archs_list = os.environ.get("TORCH_CUDA_ARCH_LIST", DEFAULT_ARCHS_LIST) | ||
nvcc_archs_flags = [] | ||
for arch in archs_list.split(";"): | ||
assert len(arch) >= 3, f"Invalid sm version: {arch}" | ||
|
||
num = 10 * int(arch[0]) + int(arch[2]) | ||
# Need at least 7.0 | ||
fmassa marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if num < 75: | ||
continue | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we printout some warnings here (or in the main setup), to recap what's being built and possibly why ? I feel like there could be a lot of issues raised around that with the build process silently skipping flashattention because of an old cuda version and users not seeing it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea, I'll add some log messages But in general, we need to improve on the packaging of xformers, specially now that a lot of hardware-specific kernels are being used. @bottler might look into improving this |
||
nvcc_archs_flags.append(f"-gencode=arch=compute_{num},code=sm_{num}") | ||
if arch.endswith("+PTX"): | ||
nvcc_archs_flags.append(f"-gencode=arch=compute_{num},code=compute_{num}") | ||
if not nvcc_archs_flags: | ||
return [] | ||
|
||
this_dir = os.path.dirname(os.path.abspath(__file__)) | ||
flash_root = os.path.join(this_dir, "third_party", "flash-attention") | ||
return [ | ||
CUDAExtension( | ||
name="xformers._C_flashattention", | ||
sources=[ | ||
os.path.join(this_dir, "third_party", "flash-attention", path) | ||
for path in [ | ||
"csrc/flash_attn/fmha_api.cpp", | ||
"csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu", | ||
"csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu", | ||
"csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu", | ||
"csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu", | ||
] | ||
], | ||
extra_compile_args={ | ||
**extra_compile_args, | ||
"nvcc": extra_compile_args.get("nvcc", []) | ||
+ [ | ||
"-O3", | ||
"-U__CUDA_NO_HALF_OPERATORS__", | ||
"-U__CUDA_NO_HALF_CONVERSIONS__", | ||
"--expt-relaxed-constexpr", | ||
"--expt-extended-lambda", | ||
"--use_fast_math", | ||
"--ptxas-options=-v", | ||
"-lineinfo", | ||
] | ||
+ nvcc_archs_flags, | ||
}, | ||
include_dirs=[ | ||
Path(flash_root) / "csrc" / "flash_attn", | ||
Path(flash_root) / "csrc" / "flash_attn" / "src", | ||
# Path(flash_root) / 'csrc' / 'flash_attn' / 'cutlass' / 'include', | ||
Path(this_dir) / "third_party" / "cutlass" / "include", | ||
], | ||
) | ||
] | ||
|
||
|
||
def get_extensions(): | ||
this_dir = os.path.dirname(os.path.abspath(__file__)) | ||
extensions_dir = os.path.join( | ||
|
@@ -57,9 +137,11 @@ def get_extensions(): | |
) | ||
|
||
sources = main_file + source_cpu | ||
|
||
source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) | ||
|
||
sputnik_dir = os.path.join(this_dir, "third_party", "sputnik") | ||
cutlass_dir = os.path.join(this_dir, "third_party", "cutlass", "include") | ||
|
||
extension = CppExtension | ||
|
||
|
@@ -73,31 +155,42 @@ def get_extensions(): | |
extra_compile_args["cxx"].append("-fopenmp") | ||
|
||
include_dirs = [extensions_dir] | ||
ext_modules = [] | ||
|
||
if (torch.cuda.is_available() and ((CUDA_HOME is not None))) or os.getenv( | ||
"FORCE_CUDA", "0" | ||
) == "1": | ||
extension = CUDAExtension | ||
sources += source_cuda | ||
include_dirs += [sputnik_dir] | ||
include_dirs += [sputnik_dir, cutlass_dir] | ||
nvcc_flags = os.getenv("NVCC_FLAGS", "") | ||
if nvcc_flags == "": | ||
nvcc_flags = [] | ||
else: | ||
nvcc_flags = nvcc_flags.split(" ") | ||
cuda_version = get_cuda_version(CUDA_HOME) | ||
if cuda_version >= 1102: | ||
nvcc_flags += ["--threads", "4", "--ptxas-options=-v"] | ||
extra_compile_args["nvcc"] = nvcc_flags | ||
if ( | ||
cuda_version >= 1100 | ||
and os.getenv("XFORMERS_DISABLE_FLASH_ATTN", "0") == "0" | ||
): | ||
ext_modules += get_flash_attention_extensions( | ||
cuda_version=cuda_version, extra_compile_args=extra_compile_args | ||
) | ||
|
||
sources = [os.path.join(extensions_dir, s) for s in sources] | ||
|
||
ext_modules = [ | ||
ext_modules.append( | ||
extension( | ||
"xformers._C", | ||
sorted(sources), | ||
include_dirs=include_dirs, | ||
define_macros=define_macros, | ||
extra_compile_args=extra_compile_args, | ||
) | ||
] | ||
) | ||
|
||
return ext_modules | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for all the CI changes, LGTM and pretty useful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi, @lucidrains , do you meet such precision loss situation during your training? I meet it when train SWIN-T.