Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: specify gemm backend #648

Merged
merged 4 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 75 additions & 65 deletions python/flashinfer/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .jit import FLASHINFER_CSRC_DIR, has_prebuilt_ops, load_cuda_ops
from .utils import (
_get_cache_buf,
get_compute_capability,
determine_gemm_backend,
get_cuda_stream,
get_indptr,
register_custom_op,
Expand Down Expand Up @@ -480,7 +480,9 @@ class SegmentGEMMWrapper:
True
"""

def __init__(self, float_workspace_buffer: torch.Tensor) -> None:
def __init__(
self, float_workspace_buffer: torch.Tensor, backend: str = "auto"
) -> None:
r"""Initialize the wrapper.

Parameters
Expand All @@ -493,6 +495,7 @@ def __init__(self, float_workspace_buffer: torch.Tensor) -> None:
(1024 * 1024,), dtype=torch.int8, device=float_workspace_buffer.device
)
self._float_workspace_buffer = float_workspace_buffer
self.backend = backend

def reset_workspace_buffer(
self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor
Expand Down Expand Up @@ -584,75 +587,82 @@ def run(
if weight_indices is None:
# create an empty CPU tensor as placeholder
weight_indices = torch.empty(0, dtype=torch.int64)
major, _ = get_compute_capability(x.device)
cumulative_batch_size = x.size(0)
d_out = weights.size(1) if weight_column_major else weights.size(2)
y = torch.zeros((cumulative_batch_size, d_out), dtype=x.dtype, device=x.device)
empty_x_data = torch.empty(0, dtype=x.dtype, device=x.device)

if major >= 9:
(
all_problems,
x_data,
w_data,
y_data,
x_stride_data,
w_stride_data,
y_stride_data,
) = launch_compute_sm90_group_gemm_args(
x,
weights,
y,
weight_column_major,
batch_size,
seg_indptr,
weight_indices,
)
get_gemm_sm90_module().cutlass_segment_gemm_sm90(
self._float_workspace_buffer,
self._int_workspace_buffer,
all_problems,
x_data,
w_data,
y_data,
x_stride_data,
w_stride_data,
y_stride_data,
y, # for torch compile mutates_args
empty_x_data, # for kernel type dispatch
weight_column_major,
)
if self.backend == "auto":
backend = determine_gemm_backend(x.device)
else:
(
all_problems,
x_data,
w_data,
y_data,
x_ld_data,
w_ld_data,
y_ld_data,
) = launch_compute_sm80_group_gemm_args(
x,
weights,
y,
weight_column_major,
batch_size,
seg_indptr,
weight_indices,
)
get_gemm_module().cutlass_segment_gemm(
self._int_workspace_buffer,
all_problems,
x_data,
w_data,
y_data,
x_ld_data,
w_ld_data,
y_ld_data,
y,
empty_x_data,
weight_column_major,
)
backend = self.backend

match backend:
case "sm90":
(
all_problems,
x_data,
w_data,
y_data,
x_stride_data,
w_stride_data,
y_stride_data,
) = launch_compute_sm90_group_gemm_args(
x,
weights,
y,
weight_column_major,
batch_size,
seg_indptr,
weight_indices,
)
get_gemm_sm90_module().cutlass_segment_gemm_sm90(
self._float_workspace_buffer,
self._int_workspace_buffer,
all_problems,
x_data,
w_data,
y_data,
x_stride_data,
w_stride_data,
y_stride_data,
y, # for torch compile mutates_args
empty_x_data, # for kernel type dispatch
weight_column_major,
)
case "sm80":
(
all_problems,
x_data,
w_data,
y_data,
x_ld_data,
w_ld_data,
y_ld_data,
) = launch_compute_sm80_group_gemm_args(
x,
weights,
y,
weight_column_major,
batch_size,
seg_indptr,
weight_indices,
)
get_gemm_module().cutlass_segment_gemm(
self._int_workspace_buffer,
all_problems,
x_data,
w_data,
y_data,
x_ld_data,
w_ld_data,
y_ld_data,
y,
empty_x_data,
weight_column_major,
)
case _:
raise ValueError(f"Unsupported gemm backend: {backend}")
return y

forward = run
Expand Down
8 changes: 8 additions & 0 deletions python/flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,11 @@ def register_fake_op(

def get_cuda_stream(device: torch.device) -> int:
return torch.cuda.current_stream(device).cuda_stream


def determine_gemm_backend(device: torch.device) -> str:
major, _ = get_compute_capability(device)
if major >= 9:
return "sm90"
else:
return "sm80"
12 changes: 7 additions & 5 deletions tests/test_group_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
@pytest.mark.parametrize("column_major", [False, True])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("backend", ["auto", "sm90", "sm80"])
xslingcn marked this conversation as resolved.
Show resolved Hide resolved
def test_segment_gemm(
batch_size,
num_rows_per_batch,
Expand All @@ -40,12 +41,13 @@ def test_segment_gemm(
column_major,
dtype,
device,
backend,
):
if batch_size * num_rows_per_batch > 8192:
pytest.skip("batch_size * num_rows_per_batch too large for test.")
torch.manual_seed(42)
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(device)
segment_gemm = flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer)
segment_gemm = flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer, backend)
xslingcn marked this conversation as resolved.
Show resolved Hide resolved
x = torch.randn(batch_size * num_rows_per_batch, d_in, dtype=dtype).to(device)
if use_weight_indices:
num_weights = 1024
Expand Down Expand Up @@ -99,7 +101,7 @@ def test_segment_gemm(


if __name__ == "__main__":
test_segment_gemm(199, 17, 128, 1024, False, False, torch.float16, "cuda:0")
test_segment_gemm(199, 17, 128, 1024, False, True, torch.float16, "cuda:0")
test_segment_gemm(199, 17, 128, 1024, True, False, torch.float16, "cuda:0")
test_segment_gemm(199, 17, 128, 1024, True, True, torch.float16, "cuda:0")
test_segment_gemm(199, 17, 128, 1024, False, False, torch.float16, "cuda:0", "auto")
test_segment_gemm(199, 17, 128, 1024, False, True, torch.float16, "cuda:0", "auto")
test_segment_gemm(199, 17, 128, 1024, True, False, torch.float16, "cuda:0", "auto")
test_segment_gemm(199, 17, 128, 1024, True, True, torch.float16, "cuda:0", "auto")