Skip to content

Commit

Permalink
Optimize the memory usage of logits processor (#420)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored May 11, 2024
1 parent 33b242d commit 09deb20
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadat
all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size]

all_logprobs = torch.log(torch.softmax(all_logits.float(), dim=-1) + 1e-6)
all_logprobs = all_logits.float()
all_logits = None
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)

prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
all_logprobs, input_metadata
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/router/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def handle_finished_requests(self, batch: Batch):
+ len(req.output_ids)
- req.prompt_tokens,
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": req.finish_reason,
"finish_reason": str(req.finish_reason),
"hit_stop_str": req.hit_stop_str,
}
if req.return_logprob:
Expand Down

0 comments on commit 09deb20

Please sign in to comment.