From 5fecd342b00b3d9f2ec58aaf7949d3b8a11b66d3 Mon Sep 17 00:00:00 2001 From: xiaobo Date: Thu, 29 Aug 2024 23:13:29 +0000 Subject: [PATCH 1/3] opt creating flashinfer kv_indices, reduce the cpu time at high concurrency --- .../srt/model_executor/forward_batch_info.py | 111 ++++++++++++++---- 1 file changed, 88 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 3d40c9d755..88c7bb39e2 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -22,6 +22,8 @@ import numpy as np import torch +import triton +import triton.language as tl from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool @@ -262,6 +264,37 @@ def init_flashinfer_handlers( ) +@triton.jit +def create_kv_indices_triton( + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices_ptr, + page_kernel_lens_ptr, + kv_indptr, + kv_start_idx, + max_context_len, + kv_indices_ptr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + req_pool_index = tl.load(req_pool_indices_ptr + pid) + kv_indices_offset = tl.load(kv_indptr + pid) + + kv_start = 0 + kv_end = 0 + if kv_start_idx: + kv_start = tl.load(kv_start_idx + pid) + kv_end = kv_start + kv_end += tl.load(page_kernel_lens_ptr + pid) + + req_to_token_ptr += req_pool_index * max_context_len + kv_indices_ptr += kv_indices_offset + for start_off in range(kv_start, kv_end, BLOCK_SIZE): + offset = start_off + tl.arange(0, BLOCK_SIZE) + mask = offset < kv_end + data = tl.load(req_to_token_ptr + offset, mask=mask) + tl.store(kv_indices_ptr + offset, data, mask=mask) + + def update_flashinfer_indices( forward_mode, model_runner, @@ -285,17 +318,32 @@ def update_flashinfer_indices( kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - req_pool_indices_cpu = req_pool_indices.cpu().numpy() - paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() - kv_indices = torch.cat( - [ - model_runner.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i] - ] - for i in range(batch_size) - ], - dim=0, - ).contiguous() + + if True: + kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") + create_kv_indices_triton[(batch_size,)]( + model_runner.req_to_token_pool.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + None, + model_runner.req_to_token_pool.req_to_token.size(1), + kv_indices, + BLOCK_SIZE=512, + ) + else: + req_pool_indices_cpu = req_pool_indices.cpu().numpy() + paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() + kv_indices = torch.cat( + [ + model_runner.req_to_token_pool.req_to_token[ + req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i] + ] + for i in range(batch_size) + ], + dim=0, + ).contiguous() + kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") if forward_mode == ForwardMode.DECODE: @@ -365,18 +413,35 @@ def update_flashinfer_indices( kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - req_pool_indices_cpu = req_pool_indices.cpu().numpy() - paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() - kv_indices = torch.cat( - [ - model_runner.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[i], - kv_start_idx[i] : kv_start_idx[i] + paged_kernel_lens_cpu[i], - ] - for i in range(batch_size) - ], - dim=0, - ).contiguous() + + if True: + kv_indices = torch.empty( + kv_indptr[-1], dtype=torch.int32, device="cuda" + ) + create_kv_indices_triton[(batch_size,)]( + model_runner.req_to_token_pool.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + kv_start_idx, + model_runner.req_to_token_pool.req_to_token.size(1), + kv_indices, + BLOCK_SIZE=512, + ) + else: + req_pool_indices_cpu = req_pool_indices.cpu().numpy() + paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() + kv_indices = torch.cat( + [ + model_runner.req_to_token_pool.req_to_token[ + req_pool_indices_cpu[i], + kv_start_idx[i] : kv_start_idx[i] + + paged_kernel_lens_cpu[i], + ] + for i in range(batch_size) + ], + dim=0, + ).contiguous() if forward_mode == ForwardMode.DECODE: # CUDA graph uses different flashinfer_decode_wrapper From f798e25a70a91479b256cedb6011cd02374a078c Mon Sep 17 00:00:00 2001 From: xiaobo Date: Fri, 30 Aug 2024 07:04:45 +0000 Subject: [PATCH 2/3] fix a logic bug and avoid a triton=3.0.0 error --- .../srt/model_executor/forward_batch_info.py | 95 +++++++------------ 1 file changed, 34 insertions(+), 61 deletions(-) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 88c7bb39e2..a443b113d4 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -265,7 +265,7 @@ def init_flashinfer_handlers( @triton.jit -def create_kv_indices_triton( +def create_flashinfer_kv_indices_triton( req_to_token_ptr, # [max_batch, max_context_len] req_pool_indices_ptr, page_kernel_lens_ptr, @@ -273,8 +273,8 @@ def create_kv_indices_triton( kv_start_idx, max_context_len, kv_indices_ptr, - BLOCK_SIZE: tl.constexpr, ): + BLOCK_SIZE: tl.constexpr = 512 pid = tl.program_id(axis=0) req_pool_index = tl.load(req_pool_indices_ptr + pid) kv_indices_offset = tl.load(kv_indptr + pid) @@ -282,17 +282,22 @@ def create_kv_indices_triton( kv_start = 0 kv_end = 0 if kv_start_idx: - kv_start = tl.load(kv_start_idx + pid) + kv_start = tl.load(kv_start_idx + pid).to(tl.int32) kv_end = kv_start - kv_end += tl.load(page_kernel_lens_ptr + pid) + kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32) req_to_token_ptr += req_pool_index * max_context_len kv_indices_ptr += kv_indices_offset - for start_off in range(kv_start, kv_end, BLOCK_SIZE): - offset = start_off + tl.arange(0, BLOCK_SIZE) - mask = offset < kv_end - data = tl.load(req_to_token_ptr + offset, mask=mask) - tl.store(kv_indices_ptr + offset, data, mask=mask) + + ld_offset = kv_start + tl.arange(0, BLOCK_SIZE) + st_offset = tl.arange(0, BLOCK_SIZE) + num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) + for _ in range(num_loop): + mask = ld_offset < kv_end + data = tl.load(req_to_token_ptr + ld_offset, mask=mask) + tl.store(kv_indices_ptr + st_offset, data, mask=mask) + ld_offset += BLOCK_SIZE + st_offset += BLOCK_SIZE def update_flashinfer_indices( @@ -319,30 +324,16 @@ def update_flashinfer_indices( kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - if True: - kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") - create_kv_indices_triton[(batch_size,)]( - model_runner.req_to_token_pool.req_to_token, - req_pool_indices, - paged_kernel_lens, - kv_indptr, - None, - model_runner.req_to_token_pool.req_to_token.size(1), - kv_indices, - BLOCK_SIZE=512, - ) - else: - req_pool_indices_cpu = req_pool_indices.cpu().numpy() - paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() - kv_indices = torch.cat( - [ - model_runner.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i] - ] - for i in range(batch_size) - ], - dim=0, - ).contiguous() + kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") + create_flashinfer_kv_indices_triton[(batch_size,)]( + model_runner.req_to_token_pool.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + None, + model_runner.req_to_token_pool.req_to_token.size(1), + kv_indices, + ) kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") @@ -414,34 +405,16 @@ def update_flashinfer_indices( kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) - if True: - kv_indices = torch.empty( - kv_indptr[-1], dtype=torch.int32, device="cuda" - ) - create_kv_indices_triton[(batch_size,)]( - model_runner.req_to_token_pool.req_to_token, - req_pool_indices, - paged_kernel_lens, - kv_indptr, - kv_start_idx, - model_runner.req_to_token_pool.req_to_token.size(1), - kv_indices, - BLOCK_SIZE=512, - ) - else: - req_pool_indices_cpu = req_pool_indices.cpu().numpy() - paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() - kv_indices = torch.cat( - [ - model_runner.req_to_token_pool.req_to_token[ - req_pool_indices_cpu[i], - kv_start_idx[i] : kv_start_idx[i] - + paged_kernel_lens_cpu[i], - ] - for i in range(batch_size) - ], - dim=0, - ).contiguous() + kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") + create_flashinfer_kv_indices_triton[(batch_size,)]( + model_runner.req_to_token_pool.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + kv_start_idx, + model_runner.req_to_token_pool.req_to_token.size(1), + kv_indices, + ) if forward_mode == ForwardMode.DECODE: # CUDA graph uses different flashinfer_decode_wrapper From 83ca9123b297505314a6392e8ba2c5c2e9b40d50 Mon Sep 17 00:00:00 2001 From: xiaobo Date: Fri, 30 Aug 2024 18:44:54 +0000 Subject: [PATCH 3/3] add a unit test --- test/srt/test_create_kvindices.py | 76 +++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 test/srt/test_create_kvindices.py diff --git a/test/srt/test_create_kvindices.py b/test/srt/test_create_kvindices.py new file mode 100644 index 0000000000..230302f264 --- /dev/null +++ b/test/srt/test_create_kvindices.py @@ -0,0 +1,76 @@ +import itertools +import unittest + +import numpy as np +import torch + +from sglang.srt.model_executor.forward_batch_info import ( + create_flashinfer_kv_indices_triton, +) + + +class TestCreateKvIndices(unittest.TestCase): + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _run_test(self, batch, max_batch, max_context_len): + req_to_token = torch.arange( + max_batch * max_context_len, dtype=torch.int32, device="cuda" + ).reshape((max_batch, max_context_len)) + req_pool_indices = torch.tensor( + torch.from_numpy( + np.random.choice(range(max_batch), size=batch, replace=False) + ), + dtype=torch.int32, + device="cuda", + ) + paged_kernel_lens = torch.tensor( + torch.from_numpy( + np.random.choice(range(max_context_len), size=batch, replace=False) + ), + dtype=torch.int32, + device="cuda", + ) + + kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + # ref + req_pool_indices_cpu = req_pool_indices.cpu().numpy() + paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy() + kv_indices_ref = torch.cat( + [ + req_to_token[req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]] + for i in range(batch) + ], + dim=0, + ).contiguous() + + # triton + kv_indices_triton = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") + create_flashinfer_kv_indices_triton[(batch,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + None, + req_to_token.size(1), + kv_indices_triton, + ) + + # Check + self.assertTrue(torch.equal(kv_indices_ref, kv_indices_triton)) + + def test_create_kvindices(self): + BATCH = [1, 37, 1786] + MAX_BATCH = 4096 + MAX_CONTEXT_LEN = 4096 + for batch in BATCH: + self._run_test(batch, MAX_BATCH, MAX_CONTEXT_LEN) + + +if __name__ == "__main__": + unittest.main()