diff --git a/README.md b/README.md index da4061d819..95c2457d54 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ pip install "sglang[all]" pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ ``` -**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.** +Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. ### Method 2: From source ``` @@ -75,7 +75,7 @@ pip install -e "python[all]" pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ ``` -**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.** +Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. ### Method 3: Using docker The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker). diff --git a/docs/en/install.md b/docs/en/install.md index b118a92894..161958e0ca 100644 --- a/docs/en/install.md +++ b/docs/en/install.md @@ -11,7 +11,7 @@ pip install "sglang[all]" pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ ``` -**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.** +Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. ### Method 2: From source ``` @@ -26,7 +26,7 @@ pip install -e "python[all]" pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ ``` -**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.** +Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. ### Method 3: Using docker The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker). diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index f0c55af625..eda2c7738d 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -33,17 +33,17 @@ class LogitsProcessorOutput: # The logits of the next tokens. shape: [#seq, vocab_size] next_token_logits: torch.Tensor # The logprobs of the next tokens. shape: [#seq, vocab_size] - next_token_logprobs: torch.Tensor + next_token_logprobs: torch.Tensor = None # The normlaized logprobs of prompts. shape: [#seq] - normalized_prompt_logprobs: torch.Tensor + normalized_prompt_logprobs: torch.Tensor = None # The logprobs of input tokens. shape: [#token, vocab_size] - input_token_logprobs: torch.Tensor + input_token_logprobs: torch.Tensor = None # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id) - input_top_logprobs: List + input_top_logprobs: List = None # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id) - output_top_logprobs: List + output_top_logprobs: List = None @dataclasses.dataclass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e9bf7be8ee..55b05f8469 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -833,6 +833,7 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): if self.enable_overlap: logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid) + next_token_logprobs = logits_output.next_token_logprobs else: # Move next_token_ids and logprobs to cpu if batch.return_logprob: diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 8032915e7b..9200612e87 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -103,6 +103,8 @@ def forward_thread_func_(self): while True: self.has_inflight_batch = False model_worker_batch, future_token_ids_ct = self.input_queue.get() + if not model_worker_batch: + break self.has_inflight_batch = True self.launch_event = threading.Event() @@ -122,19 +124,48 @@ def forward_thread_func_(self): ] = next_token_ids # Copy results to the CPU + if model_worker_batch.return_logprob: + logits_output.next_token_logprobs = logits_output.next_token_logprobs[ + torch.arange(len(next_token_ids), device=self.device), + next_token_ids, + ].to("cpu", non_blocking=True) + if logits_output.input_token_logprobs is not None: + logits_output.input_token_logprobs = ( + logits_output.input_token_logprobs.to("cpu", non_blocking=True) + ) + logits_output.normalized_prompt_logprobs = ( + logits_output.normalized_prompt_logprobs.to( + "cpu", non_blocking=True + ) + ) next_token_ids = next_token_ids.to("cpu", non_blocking=True) copy_event = torch.cuda.Event(blocking=True) copy_event.record() self.launch_event.set() - self.copy_queue.put((copy_event, next_token_ids)) + self.copy_queue.put((copy_event, logits_output, next_token_ids)) def copy_thread_func(self): while True: - copy_event, next_token_ids = self.copy_queue.get() + copy_event, logits_output, next_token_ids = self.copy_queue.get() + if not copy_event: + break while not copy_event.query(): time.sleep(1e-5) - self.output_queue.put((None, next_token_ids.tolist())) + + if logits_output.next_token_logprobs is not None: + logits_output.next_token_logprobs = ( + logits_output.next_token_logprobs.tolist() + ) + if logits_output.input_token_logprobs is not None: + logits_output.input_token_logprobs = ( + logits_output.input_token_logprobs.tolist() + ) + logits_output.normalized_prompt_logprobs = ( + logits_output.normalized_prompt_logprobs.tolist() + ) + + self.output_queue.put((logits_output, next_token_ids.tolist())) def resulve_batch_result(self, bid: int): logits_output, next_token_ids = self.output_queue.get() @@ -172,3 +203,7 @@ def update_weights(self, recv_req: UpdateWeightReqInput): recv_req.model_path, recv_req.load_format ) return success, message + + def __delete__(self): + self.input_queue.put((None, None)) + self.copy_queue.put((None, None, None)) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 52378f566a..8f9553b5af 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -263,7 +263,8 @@ def run_once(): positions=clamp_position(seq_lens), mrope_positions=mrope_positions, ) - return forward(input_ids, forward_batch.positions, forward_batch) + logits_output = forward(input_ids, forward_batch.positions, forward_batch) + return logits_output.next_token_logits for _ in range(2): torch.cuda.synchronize() @@ -318,23 +319,16 @@ def replay(self, forward_batch: ForwardBatch): # Replay self.graphs[bs].replay() - logits_output = self.output_buffers[bs] - - # Unpad - if bs != raw_bs: - logits_output = LogitsProcessorOutput( - next_token_logits=logits_output.next_token_logits[:raw_bs], - next_token_logprobs=None, - normalized_prompt_logprobs=None, - input_token_logprobs=None, - input_top_logprobs=None, - output_top_logprobs=None, - ) + next_token_logits = self.output_buffers[bs][:raw_bs] # Extract logprobs if forward_batch.return_logprob: - logits_output.next_token_logprobs = torch.nn.functional.log_softmax( - logits_output.next_token_logits, dim=-1 + next_token_logprobs = torch.nn.functional.log_softmax( + next_token_logits, dim=-1 + ) + logits_output = LogitsProcessorOutput( + next_token_logits=next_token_logits, + next_token_logprobs=next_token_logprobs, ) return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums) if return_top_logprob: @@ -343,7 +337,11 @@ def replay(self, forward_batch: ForwardBatch): top_logprobs_nums=forward_batch.top_logprobs_nums, ) logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs( - logits_output.next_token_logprobs, logits_metadata + next_token_logprobs, logits_metadata )[1] + else: + logits_output = LogitsProcessorOutput( + next_token_logits=next_token_logits, + ) return logits_output diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 9246526225..3f8a1fecb1 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -19,7 +19,6 @@ "test_openai_server.py", "test_overlap_schedule.py", "test_pytorch_sampling_backend.py", - "test_radix_attention.py", "test_retract_decode.py", "test_server_args.py", "test_skip_tokenizer_init.py",