diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 5e7290edc5..e845266984 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -15,7 +15,7 @@ def __init__(self): # Runtime constants: New generation token ratio estimation self.init_new_token_ratio = 0.7 - self.base_min_new_token_ratio = 0.1 + self.min_new_token_ratio = 0.1 self.new_token_ratio_decay = 0.001 # Runtime constants: others @@ -32,5 +32,15 @@ def __init__(self): self.enable_precache_with_tracing = True self.enable_parallel_encoding = True + def adjust_new_token_ratio(self, schedule_conservativeness=1): + assert schedule_conservativeness >= 0, "Invalid schedule_conservativeness" + min_new_token_ratio = min( + self.min_new_token_ratio * schedule_conservativeness, + 1.0, + ) + init_new_token_ratio = max(self.init_new_token_ratio, min_new_token_ratio) + + return min_new_token_ratio, init_new_token_ratio + global_config = GlobalConfig() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index fcd06d8cc9..39fc1e558f 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -222,7 +222,7 @@ def __init__( self.prefix_indices = [] self.extend_input_len = 0 self.last_node = None - self.is_inflight_req = 0 + self.is_being_chunked = False # Logprobs (arguments) self.return_logprob = False @@ -906,15 +906,14 @@ def prepare_for_decode(self, enable_overlap: bool = False): def filter_batch( self, - current_inflight_req: Optional[Req] = None, + being_chunked_req: Optional[Req] = None, keep_indices: Optional[List[int]] = None, ): if keep_indices is None: keep_indices = [ i for i in range(len(self.reqs)) - if not self.reqs[i].finished() - and self.reqs[i] is not current_inflight_req + if not self.reqs[i].finished() and self.reqs[i] is not being_chunked_req ] if keep_indices is None or len(keep_indices) == 0: diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 45c9be37a6..a5362ff7ca 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -136,7 +136,7 @@ def __init__( self.req_states = None self.can_run_list = [] - self.new_inflight_req = None + self.new_chunked_req = None self.log_hit_tokens = 0 self.log_input_tokens = 0 @@ -176,7 +176,7 @@ def _prefill_one_req( self.log_hit_tokens += prefix_len self.log_input_tokens += extend_input_len - def add_inflight_req(self, req: Req): + def add_being_chunked_req(self, req: Req): truncated = req.extend_input_len > self.rem_chunk_tokens req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens) req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len] @@ -192,8 +192,13 @@ def add_inflight_req(self, req: Req): ), ) - # Return if chunked prefill not finished - return req if truncated else None + if truncated: + # Continue to chunk the request + assert req.is_being_chunked + self.new_chunked_req = req + else: + # Release the being chunked status + req.is_being_chunked = False @contextmanager def _lock_node(self, last_node: TreeNode): @@ -262,11 +267,14 @@ def add_req_state(r, insert_sort=False): ) else: # Chunked prefill + assert self.new_chunked_req is None + trunc_len = self.rem_chunk_tokens req.extend_input_len = trunc_len + req.is_being_chunked = True req.fill_ids = req.fill_ids[:trunc_len] self.can_run_list.append(req) - self.new_inflight_req = req + self.new_chunked_req = req self._prefill_one_req(0, trunc_len, 0) return self.budget_state() @@ -305,15 +313,18 @@ def add_one_req(self, req: Req): min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS), ) else: - # Chunked prefill trunc_len = self.rem_chunk_tokens if trunc_len == 0: return AddReqResult.OTHER + # Chunked prefill + assert self.new_chunked_req is None + req.extend_input_len = trunc_len req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] + req.is_being_chunked = True self.can_run_list.append(req) - self.new_inflight_req = req + self.new_chunked_req = req self.tree_cache.inc_lock_ref(req.last_node) self._prefill_one_req(prefix_len, trunc_len, 0) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e9bf7be8ee..76e3be0733 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -219,35 +219,28 @@ def __init__( # Init chunked prefill self.chunked_prefill_size = server_args.chunked_prefill_size - self.current_inflight_req = None + self.being_chunked_req = None self.is_mixed_chunk = ( self.chunked_prefill_size is not None and server_args.enable_mixed_chunk ) # Init the FSM cache for constrained generation - if not server_args.skip_tokenizer_init: - self.regex_fsm_cache = FSMCache( - server_args.tokenizer_path, - { - "tokenizer_mode": server_args.tokenizer_mode, - "trust_remote_code": server_args.trust_remote_code, - }, - skip_tokenizer_init=server_args.skip_tokenizer_init, - constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern, - ) + self.regex_fsm_cache = FSMCache( + server_args.tokenizer_path, + { + "tokenizer_mode": server_args.tokenizer_mode, + "trust_remote_code": server_args.trust_remote_code, + }, + skip_tokenizer_init=server_args.skip_tokenizer_init, + constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern, + ) self.jump_forward_cache = JumpForwardCache() # Init new token estimation - assert ( - server_args.schedule_conservativeness >= 0 - ), "Invalid schedule_conservativeness" - self.min_new_token_ratio = min( - global_config.base_min_new_token_ratio - * server_args.schedule_conservativeness, - 1.0, + self.min_new_token_ratio, self.init_new_token_ratio = ( + global_config.adjust_new_token_ratio(server_args.schedule_conservativeness) ) - self.new_token_ratio = self.min_new_token_ratio - self.new_token_ratio_decay = global_config.new_token_ratio_decay + self.new_token_ratio = self.init_new_token_ratio self.batch_is_full = False # Init profiler @@ -294,7 +287,7 @@ def event_loop_normal(self): self.process_batch_result(batch, result) else: self.check_memory() - self.new_token_ratio = global_config.init_new_token_ratio + self.new_token_ratio = self.init_new_token_ratio self.last_batch = batch @@ -321,7 +314,7 @@ def event_loop_overlap(self): self.process_batch_result(tmp_batch, tmp_result) elif batch is None: self.check_memory() - self.new_token_ratio = global_config.init_new_token_ratio + self.new_token_ratio = self.init_new_token_ratio self.last_batch = batch @@ -499,20 +492,18 @@ def check_memory(self): ) exit(1) if crash_on_warning else None - def get_next_batch_to_run(self): + def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: # Merge the prefill batch into the running batch if ( self.last_batch and not self.last_batch.forward_mode.is_decode() and not self.last_batch.is_empty() ): - if self.current_inflight_req: - self.last_batch.filter_batch( - current_inflight_req=self.current_inflight_req - ) - self.tree_cache.cache_unfinished_req(self.current_inflight_req) - # Inflight request keeps its rid but will get a new req_pool_idx. - self.req_to_token_pool.free(self.current_inflight_req.req_pool_idx) + if self.being_chunked_req: + self.last_batch.filter_batch(being_chunked_req=self.being_chunked_req) + self.tree_cache.cache_unfinished_req(self.being_chunked_req) + # Being chunked request keeps its rid but will get a new req_pool_idx. + self.req_to_token_pool.free(self.being_chunked_req.req_pool_idx) self.batch_is_full = False if not self.last_batch.is_empty(): if self.running_batch is None: @@ -543,7 +534,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: # Handle the cases where prefill is not allowed if ( self.batch_is_full or len(self.waiting_queue) == 0 - ) and self.current_inflight_req is None: + ) and self.being_chunked_req is None: return None running_bs = len(self.running_batch.reqs) if self.running_batch else 0 @@ -566,15 +557,6 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: num_mixed_running, ) - has_inflight = self.current_inflight_req is not None - if has_inflight: - self.current_inflight_req.init_next_round_input( - None if prefix_computed else self.tree_cache - ) - self.current_inflight_req = adder.add_inflight_req( - self.current_inflight_req - ) - if self.lora_paths: lora_set = ( set([req.lora_path for req in self.running_batch.reqs]) @@ -582,6 +564,13 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: else set([]) ) + # NOTE: if there is request being chunked, we always add it first + has_being_chunked = self.being_chunked_req is not None + if has_being_chunked: + # NOTE: the prefix_indices of being-chunked prefill should align with the last prefill result + self.being_chunked_req.init_next_round_input() + adder.add_being_chunked_req(self.being_chunked_req) + # Get requests from the waiting queue to a new prefill batch for req in self.waiting_queue: if ( @@ -615,12 +604,8 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: x for x in self.waiting_queue if x not in set(can_run_list) ] - if adder.new_inflight_req is not None: - assert self.current_inflight_req is None - self.current_inflight_req = adder.new_inflight_req - - if self.current_inflight_req: - self.current_inflight_req.is_inflight_req += 1 + # Update new round being chunked request + self.being_chunked_req = adder.new_chunked_req # Print stats if self.tp_rank == 0: @@ -649,7 +634,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: f"#cached-token: {adder.log_hit_tokens}, " f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " - f"#queue-req: {len(self.waiting_queue) + has_inflight}" + f"#queue-req: {len(self.waiting_queue) + has_being_chunked}" ) else: logger.info( @@ -660,7 +645,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"#running-req: {running_bs}, " - f"#queue-req: {len(self.waiting_queue) + has_inflight}" + f"#queue-req: {len(self.waiting_queue) + has_being_chunked}" ) # Create a new batch @@ -709,7 +694,7 @@ def update_running_batch(self): self.waiting_queue.extend(retracted_reqs) else: self.new_token_ratio = max( - self.new_token_ratio - self.new_token_ratio_decay, + self.new_token_ratio - global_config.new_token_ratio_decay, self.min_new_token_ratio, ) @@ -783,10 +768,8 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): # Check finish conditions logprob_pt = 0 for i, req in enumerate(batch.reqs): - if req.is_inflight_req > 0: - req.is_inflight_req -= 1 - else: - # Inflight reqs' prefill is not finished + if not req.is_being_chunked: + # Being chunked reqs' prefill is not finished req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_ids[i]) req.check_finished() @@ -812,10 +795,8 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): # Check finish conditions for i, req in enumerate(batch.reqs): req.embedding = embeddings[i] - if req.is_inflight_req > 0: - req.is_inflight_req -= 1 - else: - # Inflight reqs' prefill is not finished + if not req.is_being_chunked: + # Being chunked reqs' prefill is not finished # dummy output token for embedding models req.output_ids.append(0) req.check_finished() diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 20fc9d52da..baea2fa520 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -658,6 +658,7 @@ def run_mmlu_test( chunked_prefill_size=32, ): other_args = ["--chunked-prefill-size", str(chunked_prefill_size)] + other_args += ["--mem-fraction-static", "0.85"] if disable_radix_cache: other_args += ["--disable-radix-cache"] if enable_mixed_chunk: diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 3f8a1fecb1..2b1be4ed76 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -5,6 +5,7 @@ suites = { "minimal": [ + "test_radix_attention.py", "models/test_embedding_models.py", "models/test_generation_models.py", "models/test_lora.py", diff --git a/test/srt/test_radix_attention.py b/test/srt/test_radix_attention.py new file mode 100644 index 0000000000..292a7b454d --- /dev/null +++ b/test/srt/test_radix_attention.py @@ -0,0 +1,112 @@ +import os +import random +import unittest + +import requests + +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + kill_child_process, + popen_launch_server, +) + + +def gen_radix_tree(num_nodes=400, chunk_len=256): + num0 = num_nodes // 2 + num1 = num_nodes - num0 + nodes = [{"input_ids": [37] * 117, "decode_len": 217}] + for _ in range(num0): + parent = random.choice(nodes) + unique_len = random.randint(0, chunk_len) + decode_len = random.randint(0, chunk_len) + token_id = random.randint(0, 32000) + child = { + "input_ids": parent["input_ids"] + [token_id] * unique_len, + "decode_len": decode_len, + } + nodes.append(child) + + while num1 > 0: + num_branch = random.randint(1, min(num1, 10)) + parent = random.choice(nodes) + for _ in range(num_branch): + unique_len = random.randint(0, chunk_len) + decode_len = random.randint(0, chunk_len) + token_id = random.randint(0, 32000) + child = { + "input_ids": parent["input_ids"] + [token_id] * unique_len, + "decode_len": decode_len, + } + nodes.append(child) + + num1 -= num_branch + + random.shuffle(nodes) + return nodes + + +def run_test(base_url, nodes): + data = { + "input_ids": [node["input_ids"] for node in nodes], + "sampling_params": [ + {"max_new_tokens": node["decode_len"], "temperature": 0} for node in nodes + ], + } + + res = requests.post(base_url + "/generate", json=data) + assert res.status_code == 200 + + +class TestRadixCacheFCFS(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--chunked-prefill-size", + "128", + "--max-total-tokens", + "20000", + "--schedule-policy", + "fcfs", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_radix_attention(self): + nodes = gen_radix_tree() + run_test(self.base_url, nodes) + + +class TestRadixCacheLPM(TestRadixCacheFCFS): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--chunked-prefill-size", + "128", + "--max-total-tokens", + "20000", + "--schedule-policy", + "lpm", + ], + ) + + +if __name__ == "__main__": + os.environ["SGLANG_TEST_RETRACT"] = "true" + unittest.main()