Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify the event loop and expose --num-continuous-decode-steps as an argument #1652

Merged
merged 6 commits into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion python/sglang/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def __init__(self):
self.new_token_ratio_decay = 0.001

# Runtime constants: others
self.num_continue_decode_steps = 10
self.retract_decode_steps = 20
self.flashinfer_workspace_size = os.environ.get(
"FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
Expand Down
16 changes: 16 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,22 @@ def get_model_worker_batch(self):
sampling_info=self.sampling_info,
)

def copy(self):
return ScheduleBatch(
reqs=self.reqs,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool,
tree_cache=self.tree_cache,
forward_mode=self.forward_mode,
output_token_ids=self.output_token_ids,
)

def __str__(self):
return (
f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
f"#req={(len(self.reqs))})"
)


@dataclass
class ModelWorkerBatch:
Expand Down
119 changes: 59 additions & 60 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import time
import warnings
from types import SimpleNamespace
from typing import List, Optional, Union

import torch
Expand Down Expand Up @@ -106,7 +107,8 @@ def __init__(
self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
else:
self.recv_from_tokenizer = self.send_to_detokenizer = None
self.recv_from_tokenizer = None
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)

# Init tokenizer
self.model_config = ModelConfig(
Expand Down Expand Up @@ -190,7 +192,6 @@ def __init__(
# Init running status
self.waiting_queue: List[Req] = []
self.running_batch: ScheduleBatch = None
self.out_pyobjs = []
self.decode_forward_ct = 0
self.stream_interval = server_args.stream_interval
self.num_generated_tokens = 0
Expand Down Expand Up @@ -247,13 +248,30 @@ def __init__(

@torch.inference_mode()
def event_loop(self):
self.last_batch = None

while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)

self.run_step()
batch = self.get_next_batch_to_run()

if batch:
result = self.run_batch(batch)
self.process_batch_result(batch, result)

self.send_results()
# Decode multiple steps to reduce the overhead
if batch.forward_mode.is_decode():
for _ in range(self.server_args.num_continuous_decode_steps - 1):
if not self.running_batch:
break
self.update_running_batch()
if not self.running_batch:
break
result = self.run_batch(batch)
self.process_batch_result(batch, result)

self.last_batch = batch

def recv_requests(self):
if self.tp_rank == 0:
Expand Down Expand Up @@ -286,7 +304,9 @@ def process_input_requests(self, recv_reqs: List):
self.abort_request(recv_req)
elif isinstance(recv_req, UpdateWeightReqInput):
success, message = self.update_weights(recv_req)
self.out_pyobjs.append(UpdateWeightReqOutput(success, message))
self.send_to_detokenizer.send_pyobj(
UpdateWeightReqOutput(success, message)
)
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
self.start_profile()
Expand Down Expand Up @@ -384,12 +404,6 @@ def handle_embedding_request(

self.waiting_queue.append(req)

def send_results(self):
if self.tp_rank == 0:
for obj in self.out_pyobjs:
self.send_to_detokenizer.send_pyobj(obj)
self.out_pyobjs = []

def print_decode_stats(self):
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
Expand Down Expand Up @@ -427,44 +441,32 @@ def check_memory(self):
)
exit(1) if crash_on_warning else None

def run_step(self):
def get_next_batch_to_run(self):
# Merge prefill to the running batch
if (
self.last_batch
and not self.last_batch.forward_mode.is_decode()
and not self.last_batch.is_empty()
):
if self.running_batch is None:
self.running_batch = self.last_batch
else:
self.running_batch.merge_batch(self.last_batch)

# Prefill first
new_batch = self.get_new_batch_prefill()
if new_batch is not None:
# Run a new prefill batch
# replace run_batch with the uncommented line to use pytorch profiler
# result = pytorch_profile(
# "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs)
# )
result = self.run_batch(new_batch)
self.process_batch_result(new_batch, result)
else:
if self.running_batch is not None:
# Run a few decode batches continuously for reducing overhead
for _ in range(global_config.num_continue_decode_steps):
batch = self.get_new_batch_decode()

if batch:
# replace run_batch with the uncommented line to use pytorch profiler
# result = pytorch_profile(
# "profile_decode_step",
# self.run_batch,
# batch,
# data_size=len(batch.reqs),
# )
result = self.run_batch(batch)
self.process_batch_result(batch, result)

if self.running_batch.is_empty():
self.running_batch = None
return new_batch

if self.running_batch is None:
break

if self.out_pyobjs and self.running_batch.has_stream:
break
else:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
# Run decode
if self.running_batch is not None:
self.update_running_batch()
if not self.running_batch:
return None
return self.running_batch
else:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio

def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Handle the cases where prefill is not allowed
Expand Down Expand Up @@ -607,7 +609,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:

return new_batch

def get_new_batch_decode(self) -> Optional[ScheduleBatch]:
def update_running_batch(self):
batch = self.running_batch

# Check if decode out of memory
Expand Down Expand Up @@ -636,11 +638,11 @@ def get_new_batch_decode(self) -> Optional[ScheduleBatch]:
if jump_forward_reqs:
self.batch_is_full = False
if batch.is_empty():
return None
self.running_batch = None
return

# Update batch tensors
batch.prepare_for_decode()
return batch

def run_batch(self, batch: ScheduleBatch):
if self.is_generation:
Expand All @@ -657,16 +659,19 @@ def run_batch(self, batch: ScheduleBatch):
)
else:
next_token_ids = torch.full((batch.batch_size(),), 0)
return logits_output, next_token_ids
ret = logits_output, next_token_ids
else: # embedding or reward model
assert batch.extend_num_tokens != 0
model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
return embeddings
ret = embeddings
return ret

def process_batch_result(self, batch: ScheduleBatch, result):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
if batch.is_empty():
self.running_batch = None
else:
self.process_batch_result_prefill(batch, result)

Expand Down Expand Up @@ -728,7 +733,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):
)
else: # embedding or reward model
assert batch.extend_num_tokens != 0
embeddings = result
embeddings = result.tolist()

# Check finish conditions
for i, req in enumerate(batch.reqs):
Expand All @@ -750,12 +755,6 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result):

self.handle_finished_requests(batch)

if not batch.is_empty():
if self.running_batch is None:
self.running_batch = batch
else:
self.running_batch.merge_batch(batch)

def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids = result
if batch.sampling_info.penalizer_orchestrator:
Expand Down Expand Up @@ -951,7 +950,7 @@ def handle_finished_requests(self, batch: ScheduleBatch):
# Send to detokenizer
if output_rids:
if self.is_generation:
self.out_pyobjs.append(
self.send_to_detokenizer.send_pyobj(
BatchTokenIDOut(
output_rids,
output_vids,
Expand All @@ -965,7 +964,7 @@ def handle_finished_requests(self, batch: ScheduleBatch):
)
)
else: # embedding or reward model
self.out_pyobjs.append(
self.send_to_detokenizer.send_pyobj(
BatchEmbeddingOut(
output_rids,
output_embeddings,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
embeddings = logits_output.embeddings.tolist()
embeddings = logits_output.embeddings
return embeddings

def update_weights(self, recv_req: UpdateWeightReqInput):
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class ServerArgs:
torchao_config: str = ""
enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False
num_continuous_decode_steps: int = 1

def __post_init__(self):
# Set missing default values
Expand Down Expand Up @@ -559,6 +560,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels.",
)
parser.add_argument(
"--num-continuous-decode-steps",
type=int,
default=ServerArgs.num_continuous_decode_steps,
help="Run multiple continuous decoding steps to reduce scheduling overhead. "
"This can potentially increase throughput but may also increase time-to-first-token latency. "
"The default value is 1, meaning only run one decoding step at a time.",
)

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
Expand Down
Loading