Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize the update flashinfer indices #1262

Merged
merged 5 commits into from
Sep 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 61 additions & 23 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

import numpy as np
import torch
import triton
import triton.language as tl

from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
Expand Down Expand Up @@ -262,6 +264,42 @@ def init_flashinfer_handlers(
)


@triton.jit
def create_flashinfer_kv_indices_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_indptr,
kv_start_idx,
max_context_len,
kv_indices_ptr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(axis=0)
req_pool_index = tl.load(req_pool_indices_ptr + pid)
kv_indices_offset = tl.load(kv_indptr + pid)

kv_start = 0
kv_end = 0
if kv_start_idx:
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
kv_end = kv_start
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)

req_to_token_ptr += req_pool_index * max_context_len
kv_indices_ptr += kv_indices_offset

ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
st_offset = tl.arange(0, BLOCK_SIZE)
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for _ in range(num_loop):
mask = ld_offset < kv_end
data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
tl.store(kv_indices_ptr + st_offset, data, mask=mask)
ld_offset += BLOCK_SIZE
st_offset += BLOCK_SIZE


def update_flashinfer_indices(
forward_mode,
model_runner,
Expand All @@ -285,17 +323,18 @@ def update_flashinfer_indices(

kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
kv_indices = torch.cat(
[
model_runner.req_to_token_pool.req_to_token[
req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
]
for i in range(batch_size)
],
dim=0,
).contiguous()

kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
create_flashinfer_kv_indices_triton[(batch_size,)](
model_runner.req_to_token_pool.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
None,
model_runner.req_to_token_pool.req_to_token.size(1),
kv_indices,
)

kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")

if forward_mode == ForwardMode.DECODE:
Expand Down Expand Up @@ -365,18 +404,17 @@ def update_flashinfer_indices(

kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
kv_indices = torch.cat(
[
model_runner.req_to_token_pool.req_to_token[
req_pool_indices_cpu[i],
kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i],
]
for i in range(batch_size)
],
dim=0,
).contiguous()

kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
create_flashinfer_kv_indices_triton[(batch_size,)](
model_runner.req_to_token_pool.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
kv_start_idx,
model_runner.req_to_token_pool.req_to_token.size(1),
kv_indices,
)

if forward_mode == ForwardMode.DECODE:
# CUDA graph uses different flashinfer_decode_wrapper
Expand Down
76 changes: 76 additions & 0 deletions test/srt/test_create_kvindices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import itertools
import unittest

import numpy as np
import torch

from sglang.srt.model_executor.forward_batch_info import (
create_flashinfer_kv_indices_triton,
)


class TestCreateKvIndices(unittest.TestCase):
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")

def _run_test(self, batch, max_batch, max_context_len):
req_to_token = torch.arange(
max_batch * max_context_len, dtype=torch.int32, device="cuda"
).reshape((max_batch, max_context_len))
req_pool_indices = torch.tensor(
torch.from_numpy(
np.random.choice(range(max_batch), size=batch, replace=False)
),
dtype=torch.int32,
device="cuda",
)
paged_kernel_lens = torch.tensor(
torch.from_numpy(
np.random.choice(range(max_context_len), size=batch, replace=False)
),
dtype=torch.int32,
device="cuda",
)

kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)

# ref
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
kv_indices_ref = torch.cat(
[
req_to_token[req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]]
for i in range(batch)
],
dim=0,
).contiguous()

# triton
kv_indices_triton = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
create_flashinfer_kv_indices_triton[(batch,)](
req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
None,
req_to_token.size(1),
kv_indices_triton,
)

# Check
self.assertTrue(torch.equal(kv_indices_ref, kv_indices_triton))

def test_create_kvindices(self):
BATCH = [1, 37, 1786]
MAX_BATCH = 4096
MAX_CONTEXT_LEN = 4096
for batch in BATCH:
self._run_test(batch, MAX_BATCH, MAX_CONTEXT_LEN)


if __name__ == "__main__":
unittest.main()
Loading