From 09deb20deef8181a23f66c933ea74b86fee47366 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 11 May 2024 16:56:42 -0700 Subject: [PATCH] Optimize the memory usage of logits processor (#420) --- python/sglang/srt/layers/logits_processor.py | 4 +++- python/sglang/srt/managers/router/model_rpc.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index f95c307862..668cd33902 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -98,7 +98,9 @@ def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadat all_logits = tensor_model_parallel_all_gather(all_logits) all_logits = all_logits[:, : self.config.vocab_size] - all_logprobs = torch.log(torch.softmax(all_logits.float(), dim=-1) + 1e-6) + all_logprobs = all_logits.float() + all_logits = None + all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs( all_logprobs, input_metadata diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index f283635c36..55bd9e80ca 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -589,7 +589,7 @@ def handle_finished_requests(self, batch: Batch): + len(req.output_ids) - req.prompt_tokens, "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, - "finish_reason": req.finish_reason, + "finish_reason": str(req.finish_reason), "hit_stop_str": req.hit_stop_str, } if req.return_logprob: