diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 791509cbca..bd0602a2e9 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -36,7 +36,7 @@ class GenerateReqInput: # See also python/sglang/srt/utils.py:load_image. image_data: Optional[Union[List[str], str]] = None # The sampling_params. See descriptions below. - sampling_params: Union[List[Dict], Dict] = None + sampling_params: Optional[Union[List[Dict], Dict]] = None # The request id. rid: Optional[Union[List[str], str]] = None # Whether to return logprobs. @@ -55,28 +55,47 @@ class GenerateReqInput: # LoRA related lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None - # Whether it is a single request or a batch request - is_single: bool = True - def post_init(self): if (self.text is None and self.input_ids is None) or ( self.text is not None and self.input_ids is not None ): raise ValueError("Either text or input_ids should be provided.") - if ( - isinstance(self.sampling_params, dict) - and self.sampling_params.get("n", 1) != 1 - ): - is_single = False + self.is_single = False + if self.text is not None: + if isinstance(self.text, str): + self.is_single = True + self.batch_size = 1 + else: + self.batch_size = len(self.text) else: - if self.text is not None: - is_single = isinstance(self.text, str) + if isinstance(self.input_ids[0], int): + self.is_single = True + self.batch_size = 1 else: - is_single = isinstance(self.input_ids[0], int) - self.is_single = is_single + self.batch_size = len(self.input_ids) + + if self.sampling_params is None: + self.parallel_sample_num = 1 + if isinstance(self.sampling_params, dict): + self.parallel_sample_num = self.sampling_params.get("n", 1) + else: # isinstance(self.sampling_params, list): + self.parallel_sample_num = self.sampling_params[0].get("n", 1) + for sp in self.sampling_params: + # TODO cope with the case that the parallel_sample_num is different for different samples + assert self.parallel_sample_num == sp.get( + "n", 1 + ), "The parallel_sample_num should be the same for all samples in sample params." + + if self.parallel_sample_num > 1: + if self.is_single: + self.is_single = False + if self.text is not None: + self.text = [self.text] + if self.input_ids is not None: + self.input_ids = [self.input_ids] - if is_single: + if self.is_single: if self.sampling_params is None: self.sampling_params = {} if self.rid is None: @@ -88,79 +107,54 @@ def post_init(self): if self.top_logprobs_num is None: self.top_logprobs_num = 0 else: - parallel_sample_num_list = [] - if isinstance(self.sampling_params, dict): - parallel_sample_num = self.sampling_params.get("n", 1) - elif isinstance(self.sampling_params, list): - for sp in self.sampling_params: - parallel_sample_num = sp.get("n", 1) - parallel_sample_num_list.append(parallel_sample_num) - parallel_sample_num = max(parallel_sample_num_list) - all_equal = all( - element == parallel_sample_num - for element in parallel_sample_num_list - ) - if parallel_sample_num > 1 and (not all_equal): - # TODO cope with the case that the parallel_sample_num is different for different samples - raise ValueError( - "The parallel_sample_num should be the same for all samples in sample params." - ) + if self.parallel_sample_num == 1: + num = self.batch_size else: - parallel_sample_num = 1 - self.parallel_sample_num = parallel_sample_num - - if parallel_sample_num != 1: - # parallel sampling +1 represents the original prefill stage - num = parallel_sample_num + 1 - if isinstance(self.text, list): - # suppot batch operation - self.batch_size = len(self.text) - num = num * len(self.text) - elif isinstance(self.input_ids, list) and isinstance( - self.input_ids[0], list - ): - self.batch_size = len(self.input_ids) - num = num * len(self.input_ids) - else: - self.batch_size = 1 - else: - # support select operation - num = len(self.text) if self.text is not None else len(self.input_ids) - self.batch_size = num + # FIXME support cascade inference + # first bs samples are used for caching the prefix for parallel sampling + num = self.batch_size + self.parallel_sample_num * self.batch_size if self.image_data is None: self.image_data = [None] * num elif not isinstance(self.image_data, list): self.image_data = [self.image_data] * num elif isinstance(self.image_data, list): - # multi-image with n > 1 + # FIXME incorrect order for duplication self.image_data = self.image_data * num if self.sampling_params is None: self.sampling_params = [{}] * num elif not isinstance(self.sampling_params, list): self.sampling_params = [self.sampling_params] * num + else: + assert self.parallel_sample_num == 1 if self.rid is None: self.rid = [uuid.uuid4().hex for _ in range(num)] else: - if not isinstance(self.rid, list): - raise ValueError("The rid should be a list.") + assert isinstance(self.rid, list), "The rid should be a list." + assert self.parallel_sample_num == 1 if self.return_logprob is None: self.return_logprob = [False] * num elif not isinstance(self.return_logprob, list): self.return_logprob = [self.return_logprob] * num + else: + assert self.parallel_sample_num == 1 if self.logprob_start_len is None: self.logprob_start_len = [-1] * num elif not isinstance(self.logprob_start_len, list): self.logprob_start_len = [self.logprob_start_len] * num + else: + assert self.parallel_sample_num == 1 if self.top_logprobs_num is None: self.top_logprobs_num = [0] * num elif not isinstance(self.top_logprobs_num, list): self.top_logprobs_num = [self.top_logprobs_num] * num + else: + assert self.parallel_sample_num == 1 @dataclass @@ -199,8 +193,6 @@ class EmbeddingReqInput: # Dummy sampling params for compatibility sampling_params: Union[List[Dict], Dict] = None - is_single: bool = True - def post_init(self): if (self.text is None and self.input_ids is None) or ( self.text is not None and self.input_ids is not None @@ -255,8 +247,6 @@ class RewardReqInput: # Dummy sampling params for compatibility sampling_params: Union[List[Dict], Dict] = None - is_single: bool = True - def post_init(self): self.is_single = isinstance(self.conv[0], dict) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 3103faec83..00c75fcda7 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -159,58 +159,72 @@ async def generate_request( async for response in self._handle_batch_request(obj, request): yield response - async def _handle_single_request( + async def _send_single_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput], - request: Optional[fastapi.Request] = None, index: Optional[int] = None, + input_id_index: Optional[int] = None, is_cache_for_prefill: Optional[bool] = False, ): if not is_cache_for_prefill: # The normal case with a single prompt - not_use_index = index is None - - rid = obj.rid if not_use_index else obj.rid[index] - input_text = obj.text if not_use_index else obj.text[index] - if hasattr(obj, "conv"): - # reward model - assert self.tokenizer is not None - conv = obj.conv if not_use_index else obj.conv[index] - input_text = self.tokenizer.apply_chat_template(conv, tokenize=False) - input_ids = self.tokenizer.encode(input_text) - elif obj.input_ids is None: - assert self.tokenizer is not None - input_ids = self.tokenizer.encode(input_text) + if index is None: + rid = obj.rid + if hasattr(obj, "conv"): + # reward model + conv = obj.conv + input_text = self.tokenizer.apply_chat_template( + conv, tokenize=False + ) + input_ids = self.tokenizer.encode(input_text) + elif obj.input_ids is None: + input_text = obj.text + input_ids = self.tokenizer.encode(input_text) + else: + input_text = obj.text if obj.text is not None else None + input_ids = obj.input_ids + + sampling_params = self._get_sampling_params(obj.sampling_params) + if self.is_generation: + image_inputs = await self.image_processor.process_images_async( + obj.image_data, obj + ) + return_logprob = obj.return_logprob + logprob_start_len = obj.logprob_start_len + top_logprobs_num = obj.top_logprobs_num else: - input_ids = obj.input_ids if not_use_index else obj.input_ids[index] + rid = obj.rid[index] + if hasattr(obj, "conv"): + # reward model + conv = obj.conv[index] + input_text = self.tokenizer.apply_chat_template( + conv, tokenize=False + ) + input_ids = self.tokenizer.encode(input_text) + elif obj.input_ids is None: + input_text = obj.text[input_id_index] + input_ids = self.tokenizer.encode(input_text) + else: + input_text = ( + obj.text[input_id_index] if obj.text is not None else None + ) + input_ids = obj.input_ids[input_id_index] - self._validate_input_length(input_ids) + sampling_params = self._get_sampling_params(obj.sampling_params[index]) + if self.is_generation: + image_inputs = await self.image_processor.process_images_async( + obj.image_data[index], obj + ) + return_logprob = obj.return_logprob[index] + logprob_start_len = obj.logprob_start_len[index] + top_logprobs_num = obj.top_logprobs_num[index] - sampling_params = self._get_sampling_params( - obj.sampling_params if not_use_index else obj.sampling_params[index] - ) + self._validate_input_length(input_ids) - if self.is_generation: - image_inputs = await self.image_processor.process_images_async( - obj.image_data if not_use_index else obj.image_data[index], obj - ) - return_logprob = ( - obj.return_logprob if not_use_index else obj.return_logprob[index] - ) - logprob_start_len = ( - obj.logprob_start_len - if not_use_index - else obj.logprob_start_len[index] - ) - top_logprobs_num = ( - obj.top_logprobs_num - if not_use_index - else obj.top_logprobs_num[index] - ) else: # A prefill request to cache the common prompt for parallel sampling assert self.is_generation if obj.text is not None: if isinstance(obj.text, list): - input_text = obj.text[index] + input_text = obj.text[input_id_index] rid = obj.rid[index] else: input_text = obj.text @@ -224,7 +238,7 @@ async def _handle_single_request( obj.input_ids[0], list ): # when obj["input_ids"] is List[List[int]] - input_ids = obj.input_ids[index] + input_ids = obj.input_ids[input_id_index] rid = obj.rid[index] else: input_ids = obj.input_ids @@ -235,7 +249,7 @@ async def _handle_single_request( obj.input_ids[0], list ): # when obj["input_ids"] is List[List[int]] - input_ids = obj.input_ids[index] + input_ids = obj.input_ids[input_id_index] rid = obj.rid[index] else: input_ids = obj.input_ids @@ -263,7 +277,7 @@ async def _handle_single_request( top_logprobs_num, obj.stream, ( - obj.lora_path[index] + obj.lora_path[input_id_index] if isinstance(obj.lora_path, list) else obj.lora_path ), @@ -283,12 +297,30 @@ async def _handle_single_request( input_ids, sampling_params, ) + self.send_to_scheduler.send_pyobj(tokenized_obj) + return rid, input_ids + + async def _handle_single_request( + self, + obj: Union[GenerateReqInput, EmbeddingReqInput, RewardReqInput], + request: Optional[fastapi.Request] = None, + index: Optional[int] = None, + input_id_index: Optional[int] = None, + is_cache_for_prefill: Optional[bool] = False, + ): + rid, input_ids = await self._send_single_request( + obj, + index, + input_id_index=input_id_index, + is_cache_for_prefill=is_cache_for_prefill, + ) # Recv results event = asyncio.Event() state = ReqState([], False, event) self.rid_to_state[rid] = state + if not is_cache_for_prefill: async for response in self._wait_for_response(state, obj, rid, request): yield response @@ -312,14 +344,16 @@ async def _handle_batch_request( input_id_result = [] if obj.input_ids is None else None for i in range(batch_size): async for input_id in self._handle_single_request( - obj, request, index=i, is_cache_for_prefill=True + obj, + request, + index=i, + input_id_index=i, + is_cache_for_prefill=True, ): if input_id_result is not None: input_id_result.append(input_id) - if input_id_result is not None and len(input_id_result) > 1: + if input_id_result is not None: obj.input_ids = input_id_result - elif input_id_result is not None: - obj.input_ids = input_id_result[0] else: parallel_sample_num = 1 @@ -333,69 +367,10 @@ async def _handle_batch_request( if parallel_sample_num != 1: # Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1 index += batch_size - 1 - i - rid = obj.rid[index] - if parallel_sample_num == 1: - ## select operation - if hasattr(obj, "conv"): - # reward model - conv = obj.conv[i] - input_text = self.tokenizer.apply_chat_template( - conv, tokenize=False - ) - input_ids = self.tokenizer.encode(input_text) - elif obj.input_ids is None: - input_text = obj.text[i] - input_ids = self.tokenizer.encode(input_text) - else: - input_text = None - input_ids = obj.input_ids[i] - else: - assert obj.input_ids is not None - if batch_size == 1: - input_text = None - input_ids = obj.input_ids - else: - input_text = None - input_ids = obj.input_ids[i] - sampling_params = self._get_sampling_params(obj.sampling_params[index]) - - if self.is_generation: - image_inputs = await self.image_processor.process_images_async( - obj.image_data[index], obj - ) - tokenized_obj = TokenizedGenerateReqInput( - rid, - input_text, - input_ids, - image_inputs, - sampling_params, - obj.return_logprob[index], - obj.logprob_start_len[index], - obj.top_logprobs_num[index], - obj.stream, - ( - obj.lora_path[index] - if isinstance(obj.lora_path, list) - else obj.lora_path - ), - ) - elif isinstance(obj, EmbeddingReqInput): - tokenized_obj = TokenizedEmbeddingReqInput( - rid, - input_text, - input_ids, - sampling_params, - ) - else: - assert isinstance(obj, RewardReqInput) - tokenized_obj = TokenizedRewardReqInput( - rid, - input_text, - input_ids, - sampling_params, - ) - self.send_to_scheduler.send_pyobj(tokenized_obj) + rid, _ = await self._send_single_request( + obj, index, input_id_index=i, is_cache_for_prefill=False + ) event = asyncio.Event() state = ReqState([], False, event) @@ -418,7 +393,7 @@ async def _handle_batch_request( tasks = [asyncio.create_task(gen.__anext__()) for gen in generators] output_list = [None] * len(tasks) - # Recv results + # Fetch results while tasks: done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)