Skip to content

Commit

Permalink
[Minor] Fix code style (#2311)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Dec 2, 2024
1 parent c54bda3 commit 18108ab
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 274 deletions.
153 changes: 76 additions & 77 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from typing import Dict, List, Optional, Tuple, Union

import fastapi
import torch
import uvloop
import zmq
import zmq.asyncio
Expand Down Expand Up @@ -337,6 +336,12 @@ async def _handle_batch_request(
rids.append(tmp_obj.rid)
else:
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
if batch_size > 128:
logger.warning(
"Sending a single large batch with parallel sampling (n > 1) has not been well optimized. "
"The performance might be better if you just duplicate the requests n times or use "
"many threads to send them one by one with parallel sampling (n > 1)."
)

# Tokenize all requests
objs = [obj[i] for i in range(batch_size)]
Expand Down Expand Up @@ -494,9 +499,7 @@ async def update_weights_from_distributed(
result = await self.parameter_update_result
return result.success, result.message
else:
logger.error(
f"Another parameter update is in progress in tokenizer manager"
)
logger.error("Another parameter update is in progress in tokenizer manager")
return (
False,
"Another parameter update is in progress. Please try again later.",
Expand Down Expand Up @@ -597,21 +600,85 @@ async def handle_loop(self):
InitWeightsUpdateGroupReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj()

if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
continue

recv_obj.meta_info[i]["id"] = rid
if isinstance(recv_obj, BatchStrOut):
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
}
elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = {
"token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i],
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
"embedding": recv_obj.embeddings[i],
"meta_info": recv_obj.meta_info[i],
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reason[i] is not None
state.event.set()

if self.enable_metrics:
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]

if state.first_token_time is None:
state.first_token_time = time.time()
self.metrics_collector.observe_time_to_first_token(
state.first_token_time - state.created_time
)
else:
if completion_tokens >= 2:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.first_token_time)
/ (completion_tokens - 1)
)

if state.finished:
self.metrics_collector.inc_prompt_tokens(
recv_obj.meta_info[i]["prompt_tokens"]
)
self.metrics_collector.inc_generation_tokens(
completion_tokens
)
self.metrics_collector.observe_e2e_request_latency(
time.time() - state.created_time
)
if completion_tokens >= 1:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.created_time)
/ completion_tokens
)
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id
)
elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
if self.server_args.dp_size == 1:
self.model_update_result.set_result(recv_obj)
else: # self.server_args.dp_size > 1
self.model_update_tmp.append(recv_obj)
# set future if the all results are recevied
if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp)
continue
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
self.init_weights_update_group_result.set_result(recv_obj)
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for update weights from distributed"
self.parameter_update_result.set_result(recv_obj)
continue
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
if self.server_args.dp_size == 1:
self.get_weights_by_name_result.set_result(recv_obj)
Expand All @@ -621,76 +688,8 @@ async def handle_loop(self):
self.get_weights_by_name_result.set_result(
self.get_weights_by_name_tmp
)
continue
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
self.init_weights_update_group_result.set_result(recv_obj)
continue
elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id
)
continue

assert isinstance(
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
), f"Unexpected obj received: {type(recv_obj)}"

for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
continue

recv_obj.meta_info[i]["id"] = rid
if isinstance(recv_obj, BatchStrOut):
out_dict = {
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
}
elif isinstance(recv_obj, BatchTokenIDOut):
out_dict = {
"token_ids": recv_obj.output_ids[i],
"meta_info": recv_obj.meta_info[i],
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
"embedding": recv_obj.embeddings[i],
"meta_info": recv_obj.meta_info[i],
}
state.out_list.append(out_dict)
state.finished = recv_obj.finished_reason[i] is not None
state.event.set()

if self.enable_metrics:
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]

if state.first_token_time is None:
state.first_token_time = time.time()
self.metrics_collector.observe_time_to_first_token(
state.first_token_time - state.created_time
)
else:
if completion_tokens >= 2:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.first_token_time)
/ (completion_tokens - 1)
)

if state.finished:
self.metrics_collector.inc_prompt_tokens(
recv_obj.meta_info[i]["prompt_tokens"]
)
self.metrics_collector.inc_generation_tokens(completion_tokens)
self.metrics_collector.observe_e2e_request_latency(
time.time() - state.created_time
)
if completion_tokens >= 1:
self.metrics_collector.observe_time_per_output_token(
(time.time() - state.created_time) / completion_tokens
)
else:
raise ValueError(f"Invalid object: {recv_obj=}")

def convert_logprob_style(
self,
Expand Down
10 changes: 0 additions & 10 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,16 +218,6 @@ def init_torch_distributed(self):
)
self.tp_group = get_tp_group()

# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
# so we disable padding in cuda graph.
if self.device == "cuda" and not all(
in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
):
self.server_args.disable_cuda_graph_padding = True
logger.info(
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
)

# Check memory for tensor parallelism
if self.tp_size > 1:
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
Expand Down
Loading

0 comments on commit 18108ab

Please sign in to comment.