diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index f1a2b73ff962b..8f2aa9425a029 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -234,15 +234,15 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, } template typename Epilogue, int32_t M> -struct sm90_fp8_config { + template typename Epilogue> +struct sm90_fp8_config_default { + // M in (128, inf) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_128, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = cutlass_3x_gemm; @@ -250,14 +250,14 @@ struct sm90_fp8_config { template typename Epilogue> -struct sm90_fp8_config { +struct sm90_fp8_config_M128 { + // M in (64, 128] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; using TileShape = Shape<_64, _128, _128>; using ClusterShape = Shape<_2, _1, _1>; - using Cutlass3xGemm = cutlass_3x_gemm; @@ -265,7 +265,8 @@ struct sm90_fp8_config { template typename Epilogue> -struct sm90_fp8_config { +struct sm90_fp8_config_M64 { + // M in [1, 64] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; @@ -278,6 +279,78 @@ struct sm90_fp8_config { KernelSchedule, EpilogueSchedule>; }; +template typename Epilogue> +struct sm90_int8_config_default { + // For M > 128 and any N + static_assert(std::is_same()); + using KernelSchedule = + typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M128 { + // For M in (64, 128] and any N + static_assert(std::is_same()); + using KernelSchedule = + typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M64 { + // For M in (32, 64] and any N + static_assert(std::is_same()); + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _64, _256>; + using ClusterShape = Shape<_1, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M32_NBig { + // For M in [1, 32] and N >= 8192 + static_assert(std::is_same()); + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _128, _256>; + using ClusterShape = Shape<_1, _4, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + +template typename Epilogue> +struct sm90_int8_config_M32_NSmall { + // For M in [1, 32] and N < 8192 + static_assert(std::is_same()); + using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized; + using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; + using TileShape = Shape<_64, _64, _256>; + using ClusterShape = Shape<_1, _8, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm; +}; + } // namespace template ::Cutlass3xGemm; + typename sm90_fp8_config_default::Cutlass3xGemm; using Cutlass3xGemmM64 = - typename sm90_fp8_config::Cutlass3xGemm; + typename sm90_fp8_config_M64::Cutlass3xGemm; using Cutlass3xGemmM128 = - typename sm90_fp8_config::Cutlass3xGemm; + typename sm90_fp8_config_M128::Cutlass3xGemm; uint32_t const m = a.size(0); uint32_t const mp2 = @@ -316,6 +390,61 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a, } } +template typename Epilogue, + typename... EpilogueArgs> +void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kInt8); + TORCH_CHECK(b.dtype() == torch::kInt8); + + using Cutlass3xGemmDefault = + typename sm90_int8_config_default::Cutlass3xGemm; + using Cutlass3xGemmM128 = + typename sm90_int8_config_M128::Cutlass3xGemm; + using Cutlass3xGemmM64 = + typename sm90_int8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM32NBig = + typename sm90_int8_config_M32_NBig::Cutlass3xGemm; + using Cutlass3xGemmM32NSmall = + typename sm90_int8_config_M32_NSmall::Cutlass3xGemm; + + uint32_t const n = out.size(1); + bool const is_small_n = n < 8192; + + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(32), next_pow_2(m)); // next power of 2 + + if (mp2 <= 32) { + // m in [1, 32] + if (is_small_n) { + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else { + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } + } else if (mp2 <= 64) { + // m in (32, 64] + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 128) { + // m in (64, 128] + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else { + // m in (128, inf) + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } +} + void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, @@ -326,22 +455,14 @@ void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a, if (a.dtype() == torch::kInt8) { TORCH_CHECK(b.dtype() == torch::kInt8); - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_1, _2, _1>; - using KernelSchedule = - typename cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized; - if (out.dtype() == torch::kBFloat16) { - return cutlass_gemm_caller>(out, a, b, a_scales, b_scales); + return cutlass_gemm_sm90_int8_dispatch( + out, a, b, a_scales, b_scales); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - - return cutlass_gemm_caller< - cutlass_3x_gemm>( + return cutlass_gemm_sm90_int8_dispatch( out, a, b, a_scales, b_scales); } } else {