diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 85ca560a92..f99f3377e5 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -221,7 +221,7 @@ def __init__( self.prefix_indices = [] self.extend_input_len = 0 self.last_node = None - self.is_inflight_req = 0 + self.is_being_chunked = 0 # Logprobs (arguments) self.return_logprob = False @@ -888,15 +888,14 @@ def prepare_for_decode(self, enable_overlap: bool = False): def filter_batch( self, - current_inflight_req: Optional[Req] = None, + being_chunked_req: Optional[Req] = None, keep_indices: Optional[List[int]] = None, ): if keep_indices is None: keep_indices = [ i for i in range(len(self.reqs)) - if not self.reqs[i].finished() - and self.reqs[i] is not current_inflight_req + if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req ] if keep_indices is None or len(keep_indices) == 0: diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 6ea6ff194d..4388928958 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -138,7 +138,7 @@ def __init__( self.req_states = None self.can_run_list = [] - self.new_inflight_req = None + self.new_chunked_req = None self.log_hit_tokens = 0 self.log_input_tokens = 0 @@ -178,7 +178,7 @@ def _prefill_one_req( self.log_hit_tokens += prefix_len self.log_input_tokens += extend_input_len - def add_inflight_req(self, req: Req): + def add_being_chunked_req(self, req: Req): truncated = req.extend_input_len > self.rem_chunk_tokens req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] @@ -194,8 +194,13 @@ def add_inflight_req(self, req: Req): ), ) - # Return if chunked prefill not finished - return req if truncated else None + if truncated: + # Continue to chunk the request + assert req.is_being_chunked + self.new_chunked_req = req + else: + # Release the being chunked status + req.is_being_chunked -= 1 @contextmanager def _lock_node(self, last_node: TreeNode): @@ -264,11 +269,14 @@ def add_req_state(r, insert_sort=False): ) else: # Chunked prefill + assert self.new_chunked_req is None + trunc_len = self.rem_chunk_tokens req.extend_input_len = trunc_len + req.is_being_chunked += 1 req.fill_ids = req.fill_ids[:trunc_len] self.can_run_list.append(req) - self.new_inflight_req = req + self.new_chunked_req = req self._prefill_one_req(0, trunc_len, 0) return self.budget_state() @@ -310,15 +318,18 @@ def add_one_req(self, req: Req): ), ) else: - # Chunked prefill trunc_len = self.rem_chunk_tokens if trunc_len == 0: return AddReqResult.OTHER + # Chunked prefill + assert self.new_chunked_req is None + req.extend_input_len = trunc_len req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] + req.is_being_chunked += 1 self.can_run_list.append(req) - self.new_inflight_req = req + self.new_chunked_req = req self.tree_cache.inc_lock_ref(req.last_node) self._prefill_one_req(prefix_len, trunc_len, 0) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 7c7780a64d..6f6264b2ad 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -231,7 +231,7 @@ def __init__( # Init chunked prefill self.chunked_prefill_size = server_args.chunked_prefill_size - self.current_inflight_req = None + self.being_chunked_req = None self.is_mixed_chunk = ( self.chunked_prefill_size is not None and server_args.enable_mixed_chunk ) @@ -544,20 +544,18 @@ def check_memory(self): ) exit(1) if crash_on_warning else None - def get_next_batch_to_run(self): + def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: # Merge the prefill batch into the running batch if ( self.last_batch and not self.last_batch.forward_mode.is_decode() and not self.last_batch.is_empty() ): - if self.current_inflight_req: - self.last_batch.filter_batch( - current_inflight_req=self.current_inflight_req - ) - self.tree_cache.cache_unfinished_req(self.current_inflight_req) - # Inflight request keeps its rid but will get a new req_pool_idx. - self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx) + if self.being_chunked_req: + self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req) + self.tree_cache.cache_unfinished_req(self.being_chunked_req) + # Being chunked request keeps its rid but will get a new req_pool_idx. + self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx) self.batch_is_full = False if not self.last_batch.is_empty(): if self.running_batch is None: @@ -588,7 +586,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: # Handle the cases where prefill is not allowed if ( self.batch_is_full or len(self.waiting_queue) == 0 - ) and self.current_inflight_req is None: + ) and self.being_chunked_req is None: return None running_bs = len(self.running_batch.reqs) if self.running_batch else 0 @@ -611,15 +609,6 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: num_mixed_running, ) - has_inflight = self.current_inflight_req is not None - if has_inflight: - self.current_inflight_req.init_next_round_input( - None if prefix_computed else self.tree_cache - ) - self.current_inflight_req = adder.add_inflight_req( - self.current_inflight_req - ) - if self.lora_paths: lora_set = ( set([req.lora_path for req in self.running_batch.reqs]) @@ -627,6 +616,13 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: else set([]) ) + # NOTE: if there is request being chunked, we always add it first + has_being_chunked = self.being_chunked_req is not None + if has_being_chunked: + # NOTE: the prefix_indices of being-chunked prefill should align with the last prefill result + self.being_chunked_req.init_next_round_input() + adder.add_being_chunked_req(self.being_chunked_req) + # Get requests from the waiting queue to a new prefill batch for req in self.waiting_queue: if ( @@ -660,12 +656,8 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: x for x in self.waiting_queue if x not in set(can_run_list) ] - if adder.new_inflight_req is not None: - assert self.current_inflight_req is None - self.current_inflight_req = adder.new_inflight_req - - if self.current_inflight_req: - self.current_inflight_req.is_inflight_req += 1 + # Update new round being chunked request + self.being_chunked_req = adder.new_chunked_req # Print stats if self.tp_rank == 0: @@ -694,7 +686,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: f"#cached-token: {adder.log_hit_tokens}, " f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " - f"#queue-req: {len(self.waiting_queue) + has_inflight}" + f"#queue-req: {len(self.waiting_queue) + has_being_chunked}" ) else: logger.info( @@ -705,7 +697,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"#running-req: {running_bs}, " - f"#queue-req: {len(self.waiting_queue) + has_inflight}" + f"#queue-req: {len(self.waiting_queue) + has_being_chunked}" ) # Create a new batch @@ -833,10 +825,8 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): # Check finish conditions logprob_pt = 0 for i, req in enumerate(batch.reqs): - if req.is_inflight_req > 0: - req.is_inflight_req -= 1 - else: - # Inflight reqs' prefill is not finished + if not req.is_being_chunked: + # Being chunked reqs' prefill is not finished req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_ids[i]) req.check_finished() @@ -860,10 +850,8 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): # Check finish conditions for i, req in enumerate(batch.reqs): req.embedding = embeddings[i] - if req.is_inflight_req > 0: - req.is_inflight_req -= 1 - else: - # Inflight reqs' prefill is not finished + if not req.is_being_chunked: + # Being chunked reqs' prefill is not finished # dummy output token for embedding models req.output_ids.append(0) req.check_finished() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 1237df7095..f7277f03da 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -19,6 +19,7 @@ "test_openai_server.py", "test_overlap_schedule.py", "test_pytorch_sampling_backend.py", + "test_radix_attention.py", "test_retract_decode.py", "test_server_args.py", "test_skip_tokenizer_init.py", diff --git a/test/srt/test_radix_attention.py b/test/srt/test_radix_attention.py new file mode 100644 index 0000000000..292a7b454d --- /dev/null +++ b/test/srt/test_radix_attention.py @@ -0,0 +1,112 @@ +import os +import random +import unittest + +import requests + +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + kill_child_process, + popen_launch_server, +) + + +def gen_radix_tree(num_nodes=400, chunk_len=256): + num0 = num_nodes // 2 + num1 = num_nodes - num0 + nodes = [{"input_ids": [37] * 117, "decode_len": 217}] + for _ in range(num0): + parent = random.choice(nodes) + unique_len = random.randint(0, chunk_len) + decode_len = random.randint(0, chunk_len) + token_id = random.randint(0, 32000) + child = { + "input_ids": parent["input_ids"] + [token_id] * unique_len, + "decode_len": decode_len, + } + nodes.append(child) + + while num1 > 0: + num_branch = random.randint(1, min(num1, 10)) + parent = random.choice(nodes) + for _ in range(num_branch): + unique_len = random.randint(0, chunk_len) + decode_len = random.randint(0, chunk_len) + token_id = random.randint(0, 32000) + child = { + "input_ids": parent["input_ids"] + [token_id] * unique_len, + "decode_len": decode_len, + } + nodes.append(child) + + num1 -= num_branch + + random.shuffle(nodes) + return nodes + + +def run_test(base_url, nodes): + data = { + "input_ids": [node["input_ids"] for node in nodes], + "sampling_params": [ + {"max_new_tokens": node["decode_len"], "temperature": 0} for node in nodes + ], + } + + res = requests.post(base_url + "/generate", json=data) + assert res.status_code == 200 + + +class TestRadixCacheFCFS(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--chunked-prefill-size", + "128", + "--max-total-tokens", + "20000", + "--schedule-policy", + "fcfs", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_radix_attention(self): + nodes = gen_radix_tree() + run_test(self.base_url, nodes) + + +class TestRadixCacheLPM(TestRadixCacheFCFS): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--chunked-prefill-size", + "128", + "--max-total-tokens", + "20000", + "--schedule-policy", + "lpm", + ], + ) + + +if __name__ == "__main__": + os.environ["SGLANG_TEST_RETRACT"] = "true" + unittest.main()