Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sm90 dispatch change #4

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 36 additions & 29 deletions sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm90.h

#pragma once

#include <cudaTypedefs.h>
#include <chrono>

#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
Expand Down Expand Up @@ -33,9 +32,9 @@

#include "utils.hpp"


using namespace cute;

#if defined CUDA_VERSION && CUDA_VERSION >= 12040
template <typename ElementType, typename OutElementType, typename AccumElementType, typename CtaShape,
typename WarpShape, int Stages, bool WithBias,
typename FP8MathOperator = cutlass::arch::OpMultiplyAdd,
Expand Down Expand Up @@ -240,7 +239,8 @@ void sm89_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);

if (mp2 <= 2) {
if (mp2 <= 1) {
// m == 1
if (np2 <= 8192) {
return sm89_dispatch_bias<OutType, cutlass::gemm::GemmShape<16, 64, 128>, cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias);
} else if (np2 <= 16384) {
Expand Down Expand Up @@ -304,9 +304,7 @@ void sm89_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch
}
}
}
#endif

#if defined CUDA_VERSION && CUDA_VERSION >= 12000
template <typename ElementType, typename OutElementType, typename AccumElementType, typename CTAShape,
typename ClusterShape, typename MainloopScheduleType, typename EpilogueScheduleType,
typename TileSchedulerType = void, bool WithBias = false>
Expand Down Expand Up @@ -403,7 +401,7 @@ struct DeviceGemmFp8RowwiseSm90
using FastPongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;

using SlowAccum = DefaultSchedule;
using FastAccum = FastPongSchedule; // Default apply Pingpong
using FastAccum = FastDefaultSchedule;
using MainLoopSchedule = cute::conditional_t<FAST_ACCUM, FastAccum, SlowAccum>;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<ArchTag, OperatorClass, ElementA,
Expand Down Expand Up @@ -496,6 +494,7 @@ void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const
TORCH_CHECK(can_implement == cutlass::Status::kSuccess)

auto status = gemm_op.run(args, workspace.data_ptr(), stream);

TORCH_CHECK(status == cutlass::Status::kSuccess)
}

Expand Down Expand Up @@ -526,25 +525,40 @@ void sm90_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2

if (mp2 <= 64) {
// m in [1, 64]
return sm90_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _8, _1>>(out, a, b, scales_a, scales_b, bias);
uint32_t const mp2 = next_pow_2(m);

if (mp2 <= 1) {
// m == 1
return sm90_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _4, _1>>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 16) {
// m in [2, 16]
return sm90_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _4, _1>>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 64) {
// m in (16, 64]
return sm90_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _4, _1>>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 128) {
// m in (64, 128]
return sm90_dispatch_bias<OutType, Shape<_64, _128, _128>, Shape<_2, _1, _1>>(out, a, b, scales_a, scales_b, bias);
return sm90_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _1, _1>>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 256) {
// m in (128, 256]
return sm90_dispatch_bias<OutType, Shape<_64, _64, _128>, Shape<_1, _1, _1>>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 512) {
// m in (256, 512]
return sm90_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_1, _1, _1>>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 1024) {
// m in (512, 1024]
return sm90_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_1, _1, _1>>(out, a, b, scales_a, scales_b, bias);
} else {
// m in (128, inf)
// m in (1024, inf)
return sm90_dispatch_bias<OutType, Shape<_128, _128, _128>, Shape<_2, _1, _1>>(out, a, b, scales_a, scales_b, bias);
}
}
#endif

torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias) {


TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor");
TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor");
TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor");
Expand Down Expand Up @@ -576,29 +590,22 @@ torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat
TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment");

auto sm_version = getSMVersion();

#if defined CUDA_VERSION && CUDA_VERSION >= 12000
if (sm_version >= 90) {
if (out_dtype == torch::kBFloat16) {
sm90_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
sm90_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
return out;
}
#endif

#if defined CUDA_VERSION && CUDA_VERSION >= 12040
if (sm_version == 89) {
} else if (sm_version == 89) {
if (out_dtype == torch::kBFloat16) {
sm89_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
} else {
sm89_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
}
return out;
}
#endif

} else {
TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented fp8_scaled_mm for current compute capability: ", sm_version);
}

}

return out;
}
Loading