Skip to content

Commit

Permalink
feat: specify gemm backend (#648)
Browse files Browse the repository at this point in the history
Add optional `backend` api at gemm initialization.

Usage:
```python
# this will load cutlass_segment_gemm_sm90 kernel
backend="sm90" 
segment_gemm = flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer, backend) 
```
Supported values: `sm90`, `sm80`, `auto`;
Default: `auto`.

---------

Co-authored-by: Zihao Ye <zihaoye.cs@gmail.com>
  • Loading branch information
xslingcn and yzh119 authored Dec 5, 2024
1 parent 553ace5 commit 0cc1a51
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 72 deletions.
2 changes: 0 additions & 2 deletions include/flashinfer/gemm/group_gemm_sm90.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
#ifndef FLASHINFER_GEMM_GROUP_GEMM_SM90_CUH_
#define FLASHINFER_GEMM_GROUP_GEMM_SM90_CUH_

#include <sstream>

#include "../allocator.h"
#include "../cutlass_utils.cuh"
#include "../utils.cuh"
Expand Down
140 changes: 75 additions & 65 deletions python/flashinfer/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops
from .utils import (
_get_cache_buf,
get_compute_capability,
determine_gemm_backend,
get_cuda_stream,
get_indptr,
register_custom_op,
Expand Down Expand Up @@ -480,7 +480,9 @@ class SegmentGEMMWrapper:
True
"""

def __init__(self, float_workspace_buffer: torch.Tensor) -> None:
def __init__(
self, float_workspace_buffer: torch.Tensor, backend: str = "auto"
) -> None:
r"""Initialize the wrapper.
Parameters
Expand All @@ -493,6 +495,7 @@ def __init__(self, float_workspace_buffer: torch.Tensor) -> None:
(1024 * 1024,), dtype=torch.int8, device=float_workspace_buffer.device
)
self._float_workspace_buffer = float_workspace_buffer
self.backend = backend

def reset_workspace_buffer(
self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor
Expand Down Expand Up @@ -584,75 +587,82 @@ def run(
if weight_indices is None:
# create an empty CPU tensor as placeholder
weight_indices = torch.empty(0, dtype=torch.int64)
major, _ = get_compute_capability(x.device)
cumulative_batch_size = x.size(0)
d_out = weights.size(1) if weight_column_major else weights.size(2)
y = torch.zeros((cumulative_batch_size, d_out), dtype=x.dtype, device=x.device)
empty_x_data = torch.empty(0, dtype=x.dtype, device=x.device)

if major >= 9:
(
all_problems,
x_data,
w_data,
y_data,
x_stride_data,
w_stride_data,
y_stride_data,
) = launch_compute_sm90_group_gemm_args(
x,
weights,
y,
weight_column_major,
batch_size,
seg_indptr,
weight_indices,
)
get_gemm_sm90_module().cutlass_segment_gemm_sm90(
self._float_workspace_buffer,
self._int_workspace_buffer,
all_problems,
x_data,
w_data,
y_data,
x_stride_data,
w_stride_data,
y_stride_data,
y, # for torch compile mutates_args
empty_x_data, # for kernel type dispatch
weight_column_major,
)
if self.backend == "auto":
backend = determine_gemm_backend(x.device)
else:
(
all_problems,
x_data,
w_data,
y_data,
x_ld_data,
w_ld_data,
y_ld_data,
) = launch_compute_sm80_group_gemm_args(
x,
weights,
y,
weight_column_major,
batch_size,
seg_indptr,
weight_indices,
)
get_gemm_module().cutlass_segment_gemm(
self._int_workspace_buffer,
all_problems,
x_data,
w_data,
y_data,
x_ld_data,
w_ld_data,
y_ld_data,
y,
empty_x_data,
weight_column_major,
)
backend = self.backend

match backend:
case "sm90":
(
all_problems,
x_data,
w_data,
y_data,
x_stride_data,
w_stride_data,
y_stride_data,
) = launch_compute_sm90_group_gemm_args(
x,
weights,
y,
weight_column_major,
batch_size,
seg_indptr,
weight_indices,
)
get_gemm_sm90_module().cutlass_segment_gemm_sm90(
self._float_workspace_buffer,
self._int_workspace_buffer,
all_problems,
x_data,
w_data,
y_data,
x_stride_data,
w_stride_data,
y_stride_data,
y, # for torch compile mutates_args
empty_x_data, # for kernel type dispatch
weight_column_major,
)
case "sm80":
(
all_problems,
x_data,
w_data,
y_data,
x_ld_data,
w_ld_data,
y_ld_data,
) = launch_compute_sm80_group_gemm_args(
x,
weights,
y,
weight_column_major,
batch_size,
seg_indptr,
weight_indices,
)
get_gemm_module().cutlass_segment_gemm(
self._int_workspace_buffer,
all_problems,
x_data,
w_data,
y_data,
x_ld_data,
w_ld_data,
y_ld_data,
y,
empty_x_data,
weight_column_major,
)
case _:
raise ValueError(f"Unsupported gemm backend: {backend}")
return y

forward = run
Expand Down
8 changes: 8 additions & 0 deletions python/flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,11 @@ def register_fake_op(

def get_cuda_stream(device: torch.device) -> int:
return torch.cuda.current_stream(device).cuda_stream


def determine_gemm_backend(device: torch.device) -> str:
major, _ = get_compute_capability(device)
if major >= 9:
return "sm90"
else:
return "sm80"
16 changes: 11 additions & 5 deletions tests/test_group_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

import flashinfer
from flashinfer.utils import determine_gemm_backend

DTYPES = [torch.float16]
CUDA_DEVICES = ["cuda:0"]
Expand All @@ -31,6 +32,7 @@
@pytest.mark.parametrize("column_major", [False, True])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("backend", ["auto", "sm90", "sm80"])
def test_segment_gemm(
batch_size,
num_rows_per_batch,
Expand All @@ -40,12 +42,16 @@ def test_segment_gemm(
column_major,
dtype,
device,
backend,
):
if batch_size * num_rows_per_batch > 8192:
pytest.skip("batch_size * num_rows_per_batch too large for test.")
latest_supported_backend = determine_gemm_backend(torch.device(device))
if backend == "sm90" and latest_supported_backend == "sm80":
pytest.skip("sm90 backend not supported on this device.")
torch.manual_seed(42)
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(device)
segment_gemm = flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer)
segment_gemm = flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer, backend=backend)
x = torch.randn(batch_size * num_rows_per_batch, d_in, dtype=dtype).to(device)
if use_weight_indices:
num_weights = 1024
Expand Down Expand Up @@ -99,7 +105,7 @@ def test_segment_gemm(


if __name__ == "__main__":
test_segment_gemm(199, 17, 128, 1024, False, False, torch.float16, "cuda:0")
test_segment_gemm(199, 17, 128, 1024, False, True, torch.float16, "cuda:0")
test_segment_gemm(199, 17, 128, 1024, True, False, torch.float16, "cuda:0")
test_segment_gemm(199, 17, 128, 1024, True, True, torch.float16, "cuda:0")
test_segment_gemm(199, 17, 128, 1024, False, False, torch.float16, "cuda:0", "auto")
test_segment_gemm(199, 17, 128, 1024, False, True, torch.float16, "cuda:0", "auto")
test_segment_gemm(199, 17, 128, 1024, True, False, torch.float16, "cuda:0", "auto")
test_segment_gemm(199, 17, 128, 1024, True, True, torch.float16, "cuda:0", "auto")

0 comments on commit 0cc1a51

Please sign in to comment.