Skip to content

Commit

Permalink
[Kernel] Build flash-attn from source (vllm-project#8245)
Browse files Browse the repository at this point in the history
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
  • Loading branch information
ProExpertProg authored and sumitd2 committed Nov 14, 2024
1 parent e78b351 commit 9d244e4
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 41 deletions.
1 change: 1 addition & 0 deletions .github/workflows/scripts/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ $python_executable -m pip install -r requirements-cuda.txt
export MAX_JOBS=1
# Make sure release wheels are built for the following architectures
export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
export VLLM_FA_CMAKE_GPU_ARCHES="80-real;90-real"
# Build
$python_executable setup.py bdist_wheel --dist-dir=dist
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# vllm commit id, generated by setup.py
vllm/commit_id.py

# vllm-flash-attn built from source
vllm/vllm_flash_attn/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand All @@ -12,6 +15,8 @@ __pycache__/
# Distribution / packaging
.Python
build/
cmake-build-*/
CMakeUserPresets.json
develop-eggs/
dist/
downloads/
Expand Down
98 changes: 73 additions & 25 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
cmake_minimum_required(VERSION 3.26)

# When building directly using CMake, make sure you run the install step
# (it places the .so files in the correct location).
#
# Example:
# mkdir build && cd build
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_INSTALL_PREFIX=.. ..
# cmake --build . --target install
#
# If you want to only build one target, make sure to install it manually:
# cmake --build . --target _C
# cmake --install . --component _C
project(vllm_extensions LANGUAGES CXX)

# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
Expand All @@ -13,6 +24,9 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
# Suppress potential warnings about unused manually-specified variables
set(ignoreMe "${VLLM_PYTHON_PATH}")

# Prevent installation of dependencies (cutlass) by default.
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)

#
# Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py.
Expand Down Expand Up @@ -70,19 +84,6 @@ endif()
find_package(Torch REQUIRED)

#
# Add the `default` target which detects which extensions should be
# built based on platform/architecture. This is the same logic that
# setup.py uses to select which extensions should be built and should
# be kept in sync.
#
# The `default` target makes direct use of cmake easier since knowledge
# of which extensions are supported has been factored in, e.g.
#
# mkdir build && cd build
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
# cmake --build . --target default
#
add_custom_target(default)
message(STATUS "Enabling core extension.")

# Define _core_C extension
Expand All @@ -100,8 +101,6 @@ define_gpu_extension_target(
USE_SABI 3
WITH_SOABI)

add_dependencies(default _core_C)

#
# Forward the non-CUDA device extensions to external CMake scripts.
#
Expand Down Expand Up @@ -167,6 +166,8 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
endif()

include(FetchContent)

#
# Define other extension targets
#
Expand All @@ -190,7 +191,6 @@ set(VLLM_EXT_SRC
"csrc/torch_bindings.cpp")

if(VLLM_GPU_LANG STREQUAL "CUDA")
include(FetchContent)
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
FetchContent_Declare(
cutlass
Expand Down Expand Up @@ -283,6 +283,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
csrc/quantization/machete/machete_pytorch.cu)
endif()

message(STATUS "Enabling C extension.")
define_gpu_extension_target(
_C
DESTINATION vllm
Expand Down Expand Up @@ -313,6 +314,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/moe/marlin_moe_ops.cu")
endif()

message(STATUS "Enabling moe extension.")
define_gpu_extension_target(
_moe_C
DESTINATION vllm
Expand All @@ -323,7 +325,6 @@ define_gpu_extension_target(
USE_SABI 3
WITH_SOABI)


if(VLLM_GPU_LANG STREQUAL "HIP")
#
# _rocm_C extension
Expand All @@ -343,16 +344,63 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
WITH_SOABI)
endif()

# vllm-flash-attn currently only supported on CUDA
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda")
return()
endif ()

if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling C extension.")
add_dependencies(default _C)
#
# Build vLLM flash attention from source
#
# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM.
# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs.
# They should be identical but if they aren't, this is a massive footgun.
#
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# To only install vllm-flash-attn, use --component vllm_flash_attn_c.
# If no component is specified, vllm-flash-attn is still installed.

message(STATUS "Enabling moe extension.")
add_dependencies(default _moe_C)
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
# This is to enable local development of vllm-flash-attn within vLLM.
# It can be set as an environment variable or passed as a cmake argument.
# The environment variable takes precedence.
if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR})
endif()

if(VLLM_GPU_LANG STREQUAL "HIP")
message(STATUS "Enabling rocm extension.")
add_dependencies(default _rocm_C)
if(VLLM_FLASH_ATTN_SRC_DIR)
FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR})
else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 013f0c4fc47e6574060879d9734c1df8c5c273bd
GIT_PROGRESS TRUE
)
endif()

# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization.
set(VLLM_PARENT_BUILD ON)

# Make sure vllm-flash-attn install rules are nested under vllm/
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c)
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c)

# Fetch the vllm-flash-attn library
FetchContent_MakeAvailable(vllm-flash-attn)
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")

# Restore the install prefix
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c)
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c)

# Copy over the vllm-flash-attn python files
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm/vllm_flash_attn
COMPONENT vllm_flash_attn_c
FILES_MATCHING PATTERN "*.py"
)

# Nothing after vllm-flash-attn, see comment about macros above
3 changes: 3 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \
# see https://github.com/pytorch/pytorch/pull/123243
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
# Override the arch list for flash-attn to reduce the binary size
ARG vllm_fa_cmake_gpu_arches='80-real;90-real'
ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches}
#################### BASE BUILD IMAGE ####################

#################### WHEEL BUILD IMAGE ####################
Expand Down
2 changes: 1 addition & 1 deletion cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -364,5 +364,5 @@ function (define_gpu_extension_target GPU_MOD_NAME)
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
endif()

install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION})
install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME})
endfunction()
1 change: 0 additions & 1 deletion requirements-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@ torch == 2.4.0
# These must be updated alongside torch
torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
xformers == 0.0.27.post2; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0
vllm-flash-attn == 2.6.1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0
38 changes: 30 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import subprocess
import sys
import warnings
from pathlib import Path
from shutil import which
from typing import Dict, List

Expand Down Expand Up @@ -152,15 +153,8 @@ def configure(self, ext: CMakeExtension) -> None:
default_cfg = "Debug" if self.debug else "RelWithDebInfo"
cfg = envs.CMAKE_BUILD_TYPE or default_cfg

# where .so files will be written, should be the same for all extensions
# that use the same CMakeLists.txt.
outdir = os.path.abspath(
os.path.dirname(self.get_ext_fullpath(ext.name)))

cmake_args = [
'-DCMAKE_BUILD_TYPE={}'.format(cfg),
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={}'.format(outdir),
'-DCMAKE_ARCHIVE_OUTPUT_DIRECTORY={}'.format(self.build_temp),
'-DVLLM_TARGET_DEVICE={}'.format(VLLM_TARGET_DEVICE),
]

Expand Down Expand Up @@ -224,10 +218,12 @@ def build_extensions(self) -> None:
os.makedirs(self.build_temp)

targets = []
target_name = lambda s: remove_prefix(remove_prefix(s, "vllm."),
"vllm_flash_attn.")
# Build all the extensions
for ext in self.extensions:
self.configure(ext)
targets.append(remove_prefix(ext.name, "vllm."))
targets.append(target_name(ext.name))

num_jobs, _ = self.compute_num_jobs()

Expand All @@ -240,6 +236,28 @@ def build_extensions(self) -> None:

subprocess.check_call(["cmake", *build_args], cwd=self.build_temp)

# Install the libraries
for ext in self.extensions:
# Install the extension into the proper location
outdir = Path(self.get_ext_fullpath(ext.name)).parent.absolute()

# Skip if the install directory is the same as the build directory
if outdir == self.build_temp:
continue

# CMake appends the extension prefix to the install path,
# and outdir already contains that prefix, so we need to remove it.
prefix = outdir
for i in range(ext.name.count('.')):
prefix = prefix.parent

# prefix here should actually be the same for all components
install_args = [
"cmake", "--install", ".", "--prefix", prefix, "--component",
target_name(ext.name)
]
subprocess.check_call(install_args, cwd=self.build_temp)


def _no_device() -> bool:
return VLLM_TARGET_DEVICE == "empty"
Expand Down Expand Up @@ -467,6 +485,10 @@ def _read_requirements(filename: str) -> List[str]:
if _is_hip():
ext_modules.append(CMakeExtension(name="vllm._rocm_C"))

if _is_cuda():
ext_modules.append(
CMakeExtension(name="vllm.vllm_flash_attn.vllm_flash_attn_c"))

if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C"))

Expand Down
9 changes: 7 additions & 2 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)

from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache
# yapf: disable
from vllm.vllm_flash_attn import (
flash_attn_varlen_func as _flash_attn_varlen_func)
from vllm.vllm_flash_attn import (
flash_attn_with_kvcache as _flash_attn_with_kvcache)

# yapf: enable


@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[])
Expand Down
8 changes: 4 additions & 4 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,7 @@ def which_attn_to_use(
# FlashAttn is valid for the model, checking if the package is installed.
if selected_backend == _Backend.FLASH_ATTN:
try:
import vllm_flash_attn # noqa: F401

import vllm.vllm_flash_attn # noqa: F401
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)

Expand All @@ -258,8 +257,9 @@ def which_attn_to_use(
except ImportError:
logger.info(
"Cannot use FlashAttention-2 backend because the "
"vllm_flash_attn package is not found. "
"`pip install vllm-flash-attn` for better performance.")
"vllm.vllm_flash_attn package is not found. "
"Make sure that vllm_flash_attn was built and installed "
"(on by default).")
selected_backend = _Backend.XFORMERS

return selected_backend
Expand Down

0 comments on commit 9d244e4

Please sign in to comment.