Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support sm90 cutlass group gemm #509

Merged
merged 8 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion flashinfer-aot/csrc_aot/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>

void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
Expand Down
26 changes: 26 additions & 0 deletions flashinfer-aot/csrc_aot/flashinfer_ops_sm90.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>


torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr,
torch::Tensor weight_indices, torch::Tensor x,
torch::Tensor weight, unsigned int batch_size,
bool weight_column_major);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90");
}
18 changes: 17 additions & 1 deletion flashinfer-aot/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@ def __init__(self, *args, **kwargs) -> None:
include_dirs = [
str(root.resolve() / "include"),
str(root.resolve() / "3rdparty" / "cutlass" / "include"), # for group gemm
str(root.resolve() / "3rdparty" / "cutlass" / "tools" / "util" / "include"),
]
extra_compile_args = {
"cxx": [
Expand All @@ -371,6 +372,10 @@ def __init__(self, *args, **kwargs) -> None:
"-use_fast_math",
],
}
extra_compile_args_sm90 = extra_compile_args.copy()
extra_compile_args_sm90["nvcc"].extend(
"-gencode arch=compute_90a,code=sm_90a".split()
)
ext_modules = []
ext_modules.append(
torch_cpp_ext.CUDAExtension(
Expand All @@ -385,12 +390,23 @@ def __init__(self, *args, **kwargs) -> None:
"csrc/quantization.cu",
"csrc/group_gemm.cu",
"csrc/bmm_fp8.cu",
"csrc_aot/flashinfer_ops.cu",
"csrc_aot/flashinfer_ops.cu"
],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args,
)
)
ext_modules.append(
torch_cpp_ext.CUDAExtension(
name="flashinfer._kernels_sm90",
sources=[
"csrc/group_gemm_sm90.cu",
"csrc_aot/flashinfer_ops_sm90.cu",
],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args_sm90,
)
)
ext_modules.append(
torch_cpp_ext.CUDAExtension(
name="flashinfer._decode_kernels",
Expand Down
4 changes: 2 additions & 2 deletions include/flashinfer/gemm/group_gemm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe

// NOTE(Zihao): I didn't successfully launch the kernel with cudaLaunchKernel API,
// so I just use the kernel function directly, need to investigate more.
auto compute_args_kernel = compute_cutlass_group_gemm_args<DType>;
auto compute_args_kernel = compute_sm80_cutlass_group_gemm_args<DType, DType>;
compute_args_kernel<<<batch_size, 1, 0, stream>>>(
problem_sizes_device, x_data, w_data, y_data, ld_x, ld_w, ld_y, (DType*)x, (DType*)w,
(DType*)y, xy_indptr_d, w_indices_d, d_in, d_out, weight_column_major);
Expand Down Expand Up @@ -116,4 +116,4 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe

} // namespace flashinfer

#endif // FLASHINFER_GEMM_GROUP_GEMM_CUH_
#endif // FLASHINFER_GEMM_GROUP_GEMM_CUH_
57 changes: 45 additions & 12 deletions include/flashinfer/gemm/group_gemm_cutlass.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@
#ifndef FLASHINFER_GROUP_GEMM_CUTLASS_CUH_
#define FLASHINFER_GROUP_GEMM_CUTLASS_CUH_

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_grouped.h"
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/numeric_types.h"
#include "cutlass/util/packed_stride.hpp"

namespace flashinfer {

Expand All @@ -41,21 +46,49 @@ struct cutlass_dtype<nv_bfloat16> {
using type = cutlass::bfloat16_t;
};

template <typename T>
__global__ void compute_cutlass_group_gemm_args(cutlass::gemm::GemmCoord* all_problems, T** ptr_x,
T** ptr_w, T** ptr_y, int64_t* ld_x, int64_t* ld_w,
int64_t* ld_y, T* x, T* w, T* y, int64_t* xy_indptr,
int64_t* w_indices, size_t d_in, size_t d_out,
bool w_column_major) {
template <>
struct cutlass_dtype<__nv_fp8_e4m3> {
using type = cutlass::float_e4m3_t;
};

template <>
struct cutlass_dtype<__nv_fp8_e5m2> {
using type = cutlass::float_e5m2_t;
};

template <typename DTypeIn, typename DTypeOut>
__global__ void compute_sm80_cutlass_group_gemm_args(
cutlass::gemm::GemmCoord* all_problems, DTypeIn** x_ptr, DTypeIn** w_ptr, DTypeOut** y_ptr,
int64_t* x_ld, int64_t* w_ld, int64_t* y_ld, DTypeIn* x, DTypeIn* w, DTypeOut* y,
int64_t* xy_indptr, int64_t* w_indices, size_t d_in, size_t d_out, bool w_column_major) {
int i = blockIdx.x;
int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out;
all_problems[i] = cutlass::gemm::GemmCoord(m, n, k);
ptr_w[i] = w + (w_indices == nullptr ? i : w_indices[i]) * d_in * d_out;
ptr_x[i] = x + xy_indptr[i] * d_in;
ptr_y[i] = y + xy_indptr[i] * d_out;
ld_x[i] = k; // m * k
ld_w[i] = w_column_major ? k : n; // k * n if column major, n * k if row major
ld_y[i] = n; // m * n
w_ptr[i] = w + (w_indices == nullptr ? i : w_indices[i]) * k * n;
x_ptr[i] = x + xy_indptr[i] * k;
y_ptr[i] = y + xy_indptr[i] * n;
x_ld[i] = k; // m * k
w_ld[i] = w_column_major ? k : n; // k * n if column major, n * k if row major
y_ld[i] = n; // m * n
}

template <typename DTypeIn, typename DTypeOut, typename ProblemShape, typename StrideA,
typename StrideB, typename StrideCD>
__global__ void compute_sm90_cutlass_group_gemm_args(
ProblemShape* all_problems, DTypeIn** x_ptr, DTypeIn** w_ptr, DTypeOut** y_ptr,
StrideA* x_stride, StrideB* w_stride, StrideCD* y_stride, DTypeIn* x, DTypeIn* w, DTypeOut* y,
int64_t* xy_indptr, int64_t* w_indices, size_t d_in, size_t d_out, bool w_column_major) {
int i = blockIdx.x;
int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out;
all_problems[i] = ProblemShape(m, n, k);
w_ptr[i] = w + (w_indices == nullptr ? i : w_indices[i]) * k * n;
x_ptr[i] = x + xy_indptr[i] * k;
y_ptr[i] = y + xy_indptr[i] * n;

x_stride[i] = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1});
w_stride[i] = w_column_major ? cutlass::make_cute_packed_stride(StrideB{}, {k, n, 1})
: cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1});
y_stride[i] = cutlass::make_cute_packed_stride(StrideCD{}, {m, n, 1});
}

} // namespace group_gemm
Expand Down
Loading