Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support sm90 cutlass group gemm #509

Merged
merged 8 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion flashinfer-aot/csrc_aot/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>

void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
Expand Down
129 changes: 0 additions & 129 deletions flashinfer-aot/csrc_aot/flashinfer_ops_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,143 +13,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <torch/extension.h>

void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value,
torch::Tensor append_indptr, std::optional<torch::Tensor> paged_kv_cache,
std::optional<torch::Tensor> paged_k_cache,
std::optional<torch::Tensor> paged_v_cache, torch::Tensor kv_indices,
torch::Tensor kv_indptr, torch::Tensor kv_last_page_len,
unsigned int layout);

std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, torch::Tensor v_b,
torch::Tensor s_b);

void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_other,
torch::Tensor s_other, std::optional<torch::Tensor> mask = std::nullopt);

std::vector<torch::Tensor> merge_states(torch::Tensor v, torch::Tensor s);

torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples,
bool deterministic);

std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
std::optional<torch::Tensor> maybe_top_p_arr,
double top_p_val, bool deterministic);

std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
std::optional<torch::Tensor> maybe_top_k_arr,
unsigned int top_k_val, bool deterministic);

std::vector<torch::Tensor> min_p_sampling_from_probs(torch::Tensor probs,
torch::Tensor uniform_samples,
std::optional<torch::Tensor> maybe_min_p_arr,
double min_p_val, bool deterministic);

std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(
torch::Tensor probs, torch::Tensor uniform_samples,
std::optional<torch::Tensor> maybe_top_k_arr, double top_k_val,
std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic);

torch::Tensor top_p_renorm_probs(torch::Tensor probs, std::optional<torch::Tensor> maybe_top_p_arr,
double top_p_val);

torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional<torch::Tensor> maybe_top_k_arr,
unsigned int top_k_val);

torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tensor> maybe_top_k_arr,
unsigned int top_k_val);

std::vector<torch::Tensor> chain_speculative_sampling(
torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples,
torch::Tensor target_probs, std::optional<torch::Tensor> maybe_output_accepted_token_num,
std::optional<torch::Tensor> maybe_output_emitted_token_num, bool deterministic);

void rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps);

void fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight,
double eps);

void gemma_rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps);

void gemma_fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight,
double eps);

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta);

void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta, float low_freq_factor, float high_freq_factor,
float old_context_length);

std::vector<torch::Tensor> apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta);

std::vector<torch::Tensor> apply_llama31_rope(torch::Tensor q, torch::Tensor k,
torch::Tensor indptr, torch::Tensor offsets,
bool interleave, float rope_scale, float rope_theta,
float low_freq_factor, float high_freq_factor,
float old_context_length);

torch::Tensor packbits(torch::Tensor x, const std::string& bitorder);

torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,
torch::Tensor output_indptr, const std::string& bitorder);

void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D,
torch::Tensor& A_scale, torch::Tensor& B_scale);

torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr,
torch::Tensor weight_indices, torch::Tensor x,
torch::Tensor weight, unsigned int batch_size,
bool weight_column_major);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator");
m.def("merge_state", &merge_state, "Merge two self-attention states");
m.def("merge_state_in_place", &merge_state_in_place,
"Merge another self-attention state in-place.");
m.def("merge_states", &merge_states, "Merge multiple self-attention states");
m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities");
m.def("top_k_sampling_from_probs", &top_k_sampling_from_probs,
"Top-k sampling from probabilities");
m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs,
"Min-p sampling from probabilities");
m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs,
"Top-p sampling from probabilities");
m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs,
"Top-k and top-p sampling from probabilities");
m.def("top_k_renorm_probs", &top_k_renorm_probs, "Renormalize probabilities by top-k mask");
m.def("top_p_renorm_probs", &top_p_renorm_probs, "Renormalize probabilities by top-p mask");
m.def("top_k_mask_logits", &top_k_mask_logits, "Mask logits by top-k mask");
m.def("chain_speculative_sampling", &chain_speculative_sampling,
"Speculative sampling from sequence of probabilities");
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization");
m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization");
m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm,
"Gemma Fused add root mean square normalization");
m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul");
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul");
m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul");
m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place");
m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace,
"Apply Llama 3.1 style RoPE in-place");
m.def("apply_rope", &apply_rope, "Apply RoPE");
m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE");
m.def("packbits", &packbits, "GPU packbits operator");
m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator");
m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90");
m.def("bmm_fp8", &bmm_fp8, "BMM FP8");
zhyncs marked this conversation as resolved.
Show resolved Hide resolved
}
30 changes: 17 additions & 13 deletions flashinfer-aot/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,18 +336,6 @@ def remove_unwanted_pytorch_nvcc_flags():
except ValueError:
pass

def get_gemm_src_files():
cuda_major, _ = get_cuda_version()
if cuda_major < 9:
return [
"csrc/group_gemm.cu",
"csrc_aot/flashinfer_ops.cu",
]
else:
return [
"csrc/group_gemm_sm90.cu",
"csrc_aot/flashinfer_ops_sm90.cu",
]

class NinjaBuildExtension(torch_cpp_ext.BuildExtension):
def __init__(self, *args, **kwargs) -> None:
Expand Down Expand Up @@ -384,6 +372,10 @@ def __init__(self, *args, **kwargs) -> None:
"-use_fast_math",
],
}
extra_compile_args_sm90 = extra_compile_args.copy()
extra_compile_args_sm90["nvcc"].extend(
"-gencode arch=compute_90a,code=sm_90a".split()
)
ext_modules = []
ext_modules.append(
torch_cpp_ext.CUDAExtension(
Expand All @@ -398,11 +390,23 @@ def __init__(self, *args, **kwargs) -> None:
"csrc/quantization.cu",
"csrc/group_gemm.cu",
"csrc/bmm_fp8.cu",
] + get_gemm_src_files(),
"csrc_aot/flashinfer_ops.cu"
],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args,
)
)
ext_modules.append(
torch_cpp_ext.CUDAExtension(
name="flashinfer._kernels_sm90",
sources=[
"csrc/group_gemm_sm90.cu",
"csrc_aot/flashinfer_ops_sm90.cu",
],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args_sm90,
)
)
ext_modules.append(
torch_cpp_ext.CUDAExtension(
name="flashinfer._decode_kernels",
Expand Down
4 changes: 0 additions & 4 deletions python/csrc/flashinfer_gemm_ops_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
*/
#include <torch/extension.h>

void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D,
torch::Tensor& A_scale, torch::Tensor& B_scale);


// (... Tensor x_arr, Tensor w_arr, Tensor y_arr, Tensor x_stride, Tensor weight_stride, Tensor y_stride, Tensor problem_shape ...)
torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr,
Expand All @@ -27,5 +24,4 @@ torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90");
m.def("bmm_fp8", &bmm_fp8, "BMM FP8");
}
33 changes: 28 additions & 5 deletions python/flashinfer/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@

import torch

from .utils import get_indptr
from .jit import get_gemm_src_files, load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops, is_sm90_capable
from .utils import get_indptr, get_compute_capability
from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops
from typing import Optional


_gemm_module = None
_gemm_module_sm90 = None


def get_gemm_module():
Expand All @@ -38,10 +39,31 @@ def get_gemm_module():
"gemm",
[
FLASHINFER_CSRC_DIR / "bmm_fp8.cu",
] + get_gemm_src_files(),
FLASHINFER_CSRC_DIR / "group_gemm.cu",
FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops.cu",
],
)
return _gemm_module

def get_gemm_sm90_module():
print("get_gemm_sm90_module")
global _gemm_module_sm90
if _gemm_module_sm90 is None:
if has_prebuilt_ops:
from . import _kernels_sm90

_gemm_module_sm90 = _kernels_sm90
else:
_gemm_module_sm90 = load_cuda_ops(
"gemm_sm90",
[
FLASHINFER_CSRC_DIR / "group_gemm_sm90.cu",
FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops_sm90.cu",
],
extra_cuda_cflags=["-gencode", "arch=compute_90a,code=sm_90a"],
)
return _gemm_module_sm90


class SegmentGEMMWrapper:
r"""Wrapper for segment GEMM kernels.
Expand Down Expand Up @@ -198,8 +220,9 @@ def run(
if weight_indices is None:
# create an empty CPU tensor as placeholder
weight_indices = torch.empty(0, dtype=torch.int64)
if is_sm90_capable:
return get_gemm_module().cutlass_segment_gemm_sm90(
major, _ = get_compute_capability(x.device)
if major >= 9:
return get_gemm_sm90_module().cutlass_segment_gemm_sm90(
self._float_workspace_buffer,
self._int_workspace_buffer,
seg_indptr,
Expand Down
45 changes: 13 additions & 32 deletions python/flashinfer/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,6 @@ def check_cuda_arch():
if arch < 75:
raise RuntimeError("FlashInfer requires sm75+")

def get_cuda_version() -> Tuple[int, int]:
if torch_cpp_ext.CUDA_HOME is None:
nvcc = "nvcc"
else:
nvcc = os.path.join(torch_cpp_ext.CUDA_HOME, "bin/nvcc")
txt = subprocess.check_output([nvcc, "--version"], text=True)
major, minor = map(int, re.findall(r"release (\d+)\.(\d+),", txt)[0])
return major, minor

is_sm90_capable = get_cuda_version() >= (9, 0)

def clear_cache_dir():
if os.path.exists(FLASHINFER_JIT_DIR):
Expand All @@ -115,40 +105,31 @@ def remove_unwanted_pytorch_nvcc_flags():
except ValueError:
pass

def get_gemm_src_files():
if is_sm90_capable:
return [
FLASHINFER_CSRC_DIR / "group_gemm_sm90.cu",
FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops_sm90.cu",
]
else:
return [
FLASHINFER_CSRC_DIR / "group_gemm.cu",
FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops.cu",
]

remove_unwanted_pytorch_nvcc_flags()


def load_cuda_ops(
name: str,
sources: List[str],
extra_cflags: List[str] = ["-O3", "-Wno-switch-bool"],
extra_cuda_cflags: List[str] = [
extra_cflags: List[str] = [],
extra_cuda_cflags: List[str] = [],
extra_ldflags=None,
extra_include_paths=None,
verbose=False,
):
cflags = ["-O3", "-Wno-switch-bool"]
cuda_cflags = [
"-O3",
"-std=c++17",
"--threads",
"4",
# "-Xfatbin",
# "-compress-all",
"-use_fast_math",
"-DFLASHINFER_ENABLE_BF16",
"-DFLASHINFER_ENABLE_FP8",
],
extra_ldflags=None,
extra_include_paths=None,
verbose=False,
):
]
cflags += extra_cflags
cuda_cflags += extra_cuda_cflags
logger.info(f"Loading JIT ops: {name}")
check_cuda_arch()
build_directory = FLASHINFER_JIT_DIR / name
Expand All @@ -162,8 +143,8 @@ def load_cuda_ops(
return torch_cpp_ext.load(
name,
list(map(lambda _: str(_), sources)),
extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cuda_cflags,
extra_cflags=cflags,
extra_cuda_cflags=cuda_cflags,
extra_ldflags=extra_ldflags,
extra_include_paths=list(map(lambda _: str(_), extra_include_paths)),
build_directory=build_directory,
Expand Down
6 changes: 6 additions & 0 deletions python/flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,9 @@ def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:
raise TypeError(
"dtype must be a string or torch.dtype, got {}".format(type(dtype))
)


def get_compute_capability(device: torch.device) -> Tuple[int, int]:
if device.type != "cuda":
raise ValueError("device must be a cuda device")
return torch.cuda.get_device_capability(device.index)
1 change: 1 addition & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def clear_aot_config():

if __name__ == "__main__":
generate_build_meta()
clear_aot_config()
setuptools.setup(
name="flashinfer",
version=get_version(),
Expand Down