Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Adding bias epilogue support for cutlass_scaled_mm #5560

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ cmake_minimum_required(VERSION 3.21)

project(vllm_extensions LANGUAGES CXX)

option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda")
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you describe the difference between this and what was there before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accidentally committed, required to build locally, happy to put in a separate PR. The reason is that option only supports booleans, this is the equivalent syntax for a string variable.


message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
Expand Down
3 changes: 2 additions & 1 deletion csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,

void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias_azp);

#endif

Expand Down
163 changes: 126 additions & 37 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ struct enable_sm89_to_sm90 : Kernel {
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue {
private:
protected:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;

using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
Expand Down Expand Up @@ -134,6 +134,46 @@ struct ScaledEpilogue {
}
};

template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueAzp
: private ScaledEpilogue<ElementD, OutputTileThreadMap> {
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved
private:
using SUPER = ScaledEpilogue<ElementD, OutputTileThreadMap>;
using ScaleA = typename SUPER::ScaleA;
using ScaleB = typename SUPER::ScaleB;
using EVTCompute0 = typename SUPER::EVTCompute0;

using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;

using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;

public:
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;

static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias_azp) {
using ScaleAArgs = typename ScaleA::Arguments;
using ScaleBArgs = typename ScaleB::Arguments;
using BiasArgs = typename Bias::Arguments;

ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
BiasArgs bias_args{bias_azp.data_ptr<float>(), {}};

typename EVTCompute0::Arguments evt0_compute_args{b_args};

typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args,
bias_args};
return evt_compute_args;
}
};

template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_,
template <typename, typename> typename Epilogue_, typename TileShape,
Expand Down Expand Up @@ -168,13 +208,13 @@ struct cutlass_2x_gemm {
// clang-format off
using RowMajor = typename cutlass::layout::RowMajor;
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using KernelType =
using KernelType =
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
float, cutlass::layout::RowMajor, 4,
ElementAcc, float, cutlass::arch::OpClassTensorOp,
Arch,
ElementAcc, float, cutlass::arch::OpClassTensorOp,
Arch,
TileShape, WarpShape, InstructionShape,
EVTD,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
Expand Down Expand Up @@ -252,14 +292,13 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,

} // namespace

void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved

using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
Expand All @@ -268,25 +307,41 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, a_scales, b_scales);
Epilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, a_scales, b_scales);
Epilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}

void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias_azp) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias_azp) {
TORCH_CHECK(bias_azp->dtype() == torch::kFloat32);
return cutlass_scaled_mm_sm75_epilogue<ScaledEpilogueAzp>(
out, a, b, a_scales, b_scales, *bias_azp);
} else {
return cutlass_scaled_mm_sm75_epilogue<ScaledEpilogue>(out, a, b, a_scales,
b_scales);
}
}

template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);

using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
Expand All @@ -295,42 +350,57 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
Epilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
Epilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}

void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias_azp) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias_azp) {
TORCH_CHECK(bias_azp->dtype() == torch::kFloat32);
return cutlass_scaled_mm_sm80_epilogue<ScaledEpilogueAzp>(
out, a, b, a_scales, b_scales, *bias_azp);
} else {
return cutlass_scaled_mm_sm80_epilogue<ScaledEpilogue>(out, a, b, a_scales,
b_scales);
}
}

template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;

TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

if (a.dtype() == torch::kInt8) {
TORCH_CHECK(b.dtype() == torch::kInt8);

if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
Epilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
assert(out.dtype() == torch::kFloat16);
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
Epilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
Expand All @@ -339,14 +409,33 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::bfloat16_t, ScaledEpilogue, TileShape, WarpShape,
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
cutlass::bfloat16_t, Epilogue, TileShape, WarpShape,
InstructionShape, 5>>(out, a, b,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::half_t, ScaledEpilogue, TileShape, WarpShape,
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
cutlass::half_t, Epilogue, TileShape, WarpShape,
InstructionShape, 5>>(out, a, b,
std::forward<EpilogueArgs>(epilogue_args)...);
}
}
}

void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias_azp) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias_azp) {
TORCH_CHECK(bias_azp->dtype() == torch::kFloat32);
return cutlass_scaled_mm_sm89_epilogue<ScaledEpilogueAzp>(
out, a, b, a_scales, b_scales, *bias_azp);
} else {
return cutlass_scaled_mm_sm89_epilogue<ScaledEpilogue>(out, a, b, a_scales,
b_scales);
}
}
5 changes: 4 additions & 1 deletion csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,13 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias_azp) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

TORCH_CHECK(!bias_azp, "Bias AZP not supported on Hopper yet");

if (a.dtype() == torch::kInt8) {
TORCH_CHECK(b.dtype() == torch::kInt8);

Expand Down
27 changes: 16 additions & 11 deletions csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,33 @@
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias_azp);

void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias_azp);

void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias_azp);

#if defined CUDA_VERSION && CUDA_VERSION >= 12000
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias_azp);
#endif

void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias_azp) {
int32_t major_capability;
int32_t minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
Expand Down Expand Up @@ -57,19 +62,19 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,

// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias_azp);
#else
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias_azp);
#endif
} else if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias_azp);
} else if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias_azp);
} else {
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias_azp);
}
}
}
3 changes: 2 additions & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"cutlass_scaled_mm(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales) -> ()");
" Tensor b_scales, Tensor? bias_azp) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);

#endif

// Quantized GEMM for GPTQ.
Expand Down
Loading
Loading