Skip to content

Commit

Permalink
[Kernel] Add w8a8 CUTLASS kernels (#4749)
Browse files Browse the repository at this point in the history
  • Loading branch information
tlrmchlsmth authored May 16, 2024
1 parent 8435b20 commit 2060e93
Show file tree
Hide file tree
Showing 10 changed files with 1,197 additions and 2 deletions.
27 changes: 26 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,38 @@ set(VLLM_EXT_SRC
"csrc/pybind.cpp")

if(VLLM_GPU_LANG STREQUAL "CUDA")
include(FetchContent)
SET(CUTLASS_ENABLE_HEADERS_ONLY=ON)
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
# CUTLASS 3.5.0
GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc
)
FetchContent_MakeAvailable(cutlass)

list(APPEND VLLM_EXT_SRC
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/custom_all_reduce.cu")
"csrc/custom_all_reduce.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu")

#
# The CUTLASS kernels for Hopper require sm90a to be enabled.
# This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
set_source_files_properties(
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu"
PROPERTIES
COMPILE_FLAGS
"-gencode arch=compute_90a,code=sm_90a")

endif()

define_gpu_extension_target(
Expand All @@ -190,6 +214,7 @@ define_gpu_extension_target(
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
WITH_SOABI)

#
Expand Down
8 changes: 8 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,14 @@ torch::Tensor gptq_marlin_repack(
int64_t size_k,
int64_t size_n,
int64_t num_bits);

int cutlass_scaled_mm_dq(
torch::Tensor& out,
torch::Tensor const &a,
torch::Tensor const &b,
torch::Tensor const &a_scales,
torch::Tensor const &b_scales);

#endif

void squeezellm_gemm(
Expand Down
1 change: 1 addition & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ");
ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column quantization.");
#endif

ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
Expand Down
12 changes: 12 additions & 0 deletions csrc/quantization/cutlass_w8a8/common.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include "cutlass/cutlass.h"

/**
* Helper function for checking CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
TORCH_CHECK(status == cutlass::Status::kSuccess, \
cutlassGetStatusString(status)) \
}
Loading

0 comments on commit 2060e93

Please sign in to comment.