Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Adding bias epilogue support for cutlass_scaled_mm #5560

Merged
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ cmake_minimum_required(VERSION 3.21)

project(vllm_extensions LANGUAGES CXX)

option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "cuda")
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you describe the difference between this and what was there before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accidentally committed, required to build locally, happy to put in a separate PR. The reason is that option only supports booleans, this is the equivalent syntax for a string variable.


message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
Expand Down
3 changes: 2 additions & 1 deletion csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,

void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);

#endif

Expand Down
224 changes: 167 additions & 57 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,31 +77,45 @@ struct enable_sm89_to_sm90 : Kernel {
};

/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.

A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).

So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.

ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
* This class provides the common ScaleA and ScaleB descriptors for the
* ScaledEpilogue and ScaledEpilogueBias classes.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue {
private:
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;

using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>;

using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
};

/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.

A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).

So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.

ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::ScaleA;
using ScaleB = typename SUPER::ScaleB;

using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
Expand Down Expand Up @@ -134,6 +148,53 @@ struct ScaledEpilogue {
}
};

template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::ScaleA;
using ScaleB = typename SUPER::ScaleB;

using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;

using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;

using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;

using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;

public:
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;

static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
using ScaleAArgs = typename ScaleA::Arguments;
using ScaleBArgs = typename ScaleB::Arguments;
using BiasArgs = typename Bias::Arguments;

ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
BiasArgs bias_args{bias.data_ptr<float>(), {}};

typename EVTCompute0::Arguments evt0_compute_args{b_args};

typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args,
bias_args};
return evt_compute_args;
}
};

template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_,
template <typename, typename> typename Epilogue_, typename TileShape,
Expand Down Expand Up @@ -168,13 +229,13 @@ struct cutlass_2x_gemm {
// clang-format off
using RowMajor = typename cutlass::layout::RowMajor;
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using KernelType =
using KernelType =
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
float, cutlass::layout::RowMajor, 4,
ElementAcc, float, cutlass::arch::OpClassTensorOp,
Arch,
ElementAcc, float, cutlass::arch::OpClassTensorOp,
Arch,
TileShape, WarpShape, InstructionShape,
EVTD,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
Expand Down Expand Up @@ -252,14 +313,13 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,

} // namespace

void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
ProExpertProg marked this conversation as resolved.
Show resolved Hide resolved

using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
Expand All @@ -268,25 +328,41 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, a_scales, b_scales);
Epilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, a_scales, b_scales);
Epilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}

void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
TORCH_CHECK(bias->dtype() == torch::kFloat32);
return cutlass_scaled_mm_sm75_epilogue<ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm75_epilogue<ScaledEpilogue>(out, a, b, a_scales,
b_scales);
}
}

template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);

using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
Expand All @@ -295,58 +371,92 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
Epilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
Epilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}

void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
TORCH_CHECK(bias->dtype() == torch::kFloat32);
return cutlass_scaled_mm_sm80_epilogue<ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm80_epilogue<ScaledEpilogue>(out, a, b, a_scales,
b_scales);
}
}

template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;

TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

if (a.dtype() == torch::kInt8) {
TORCH_CHECK(b.dtype() == torch::kInt8);

if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
Epilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
assert(out.dtype() == torch::kFloat16);
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
Epilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);

if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::bfloat16_t, ScaledEpilogue, TileShape, WarpShape,
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
return cutlass_gemm_caller<
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue,
TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::half_t, ScaledEpilogue, TileShape, WarpShape,
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
return cutlass_gemm_caller<
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
cutlass::float_e4m3_t, cutlass::half_t, Epilogue,
TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
}

void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
TORCH_CHECK(bias->dtype() == torch::kFloat32);
return cutlass_scaled_mm_sm89_epilogue<ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<ScaledEpilogue>(out, a, b, a_scales,
b_scales);
}
}
Loading
Loading