Skip to content

Commit

Permalink
Improve multi-node stability (sgl-project#1171)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored and chenxu02 committed Sep 11, 2024
1 parent cd10654 commit bba322e
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 77 deletions.
9 changes: 8 additions & 1 deletion python/sglang/launch_server.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
"""Launch the inference server."""

import argparse
import os

from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_child_process

if __name__ == "__main__":
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)

launch_server(server_args)
try:
launch_server(server_args)
except Exception as e:
raise e
finally:
kill_child_process(os.getpid(), including_parent=False)
16 changes: 11 additions & 5 deletions python/sglang/srt/hf_transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ def __init__(self, tokenizer_path):
}
assert tok_dict["word_split"] == "V1"

default_allowed_special = None

kwargs = {
"name": name,
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
Expand All @@ -246,14 +248,18 @@ def __init__(self, tokenizer_path):
for bytes_list in tok_dict["default_allowed_special"]
]
)
else:
default_allowed_special = None
if "vocab_size" in tok_dict:
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]

PAD = "<|pad|>"
EOS = "<|eos|>"
SEP = "<|separator|>"

DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}

tokenizer = tiktoken.Encoding(**kwargs)
tokenizer._default_allowed_special = default_allowed_special or set()
tokenizer._default_allowed_special |= {"<|separator|>"}
tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS

def encode_patched(
self,
Expand All @@ -270,14 +276,14 @@ def encode_patched(
self,
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
disallowed_special=(),
)

tokenizer.encode = functools.partial(encode_patched, tokenizer)

# Convert to HF interface
self.tokenizer = tokenizer
self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
self.eos_token_id = tokenizer._special_tokens[EOS]
self.vocab_size = tokenizer.n_vocab
self.chat_template = Template(
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
Expand Down
2 changes: 0 additions & 2 deletions python/sglang/srt/managers/controller_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,4 @@ def start_controller_process(
except Exception:
logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
finally:
for w in controller.workers:
os.kill(w.proc.pid, 9)
kill_parent_process()
2 changes: 0 additions & 2 deletions python/sglang/srt/managers/controller_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,4 @@ def start_controller_process(
except Exception:
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
finally:
for t in controller.tp_procs:
os.kill(t.pid, 9)
kill_parent_process()
20 changes: 5 additions & 15 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"""Meta data for requests and batches"""

import logging
import warnings
from dataclasses import dataclass
from typing import List, Optional, Union

Expand Down Expand Up @@ -270,7 +269,7 @@ def jump_forward_and_retokenize(self, jump_forward_str, next_state):

if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
# TODO(lsyin): fix token fusion
logging.warning(
logger.warning(
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
)
return False
Expand Down Expand Up @@ -753,7 +752,7 @@ def merge(self, other: "ScheduleBatch"):
)
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])

def sample(self, logits: torch.Tensor, is_multi_node_tp=False):
def sample(self, logits: torch.Tensor):
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# Post process logits
logits = logits.contiguous()
Expand Down Expand Up @@ -791,7 +790,7 @@ def sample(self, logits: torch.Tensor, is_multi_node_tp=False):
)

if not torch.all(success):
logging.warning("Sampling failed, fallback to top_k=1 strategy")
logger.warning(f"Sampling failed. Fallback to top_k=1 strategy. {logits=}")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(
Expand All @@ -808,16 +807,6 @@ def sample(self, logits: torch.Tensor, is_multi_node_tp=False):

self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)

if is_multi_node_tp:
# If the tensor parallelism spans across multiple nodes, there is some indeterminism
# that can cause the TP workers to generate different tokens, so we need to
# sync here
torch.distributed.all_reduce(
batch_next_token_ids,
op=dist.ReduceOp.MIN,
group=get_tensor_model_parallel_group().device_group,
)

return batch_next_token_ids


Expand All @@ -835,7 +824,8 @@ def top_k_top_p_sampling_from_probs_torch(
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
try:
sampled_index = torch.multinomial(probs_sort, num_samples=1)
except RuntimeError:
except RuntimeError as e:
logger.warning(f"Sampling error: {e}")
batch_next_token_ids = torch.zeros(
(probs_sort.shape[0],), dtype=torch.int32, device=probs.device
)
Expand Down
16 changes: 10 additions & 6 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ def __init__(
self.model_config.context_len - 1,
self.max_total_num_tokens - 1,
)

# Sync random seed
server_args.random_seed = broadcast_recv_input(
[server_args.random_seed],
self.tp_rank,
self.model_runner.tp_group.cpu_group,
)[0]
set_random_seed(server_args.random_seed)

# Print info
Expand Down Expand Up @@ -474,9 +481,7 @@ def forward_prefill_batch(self, batch: ScheduleBatch):
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids = batch.sample(
output.next_token_logits, self.model_runner.is_multi_node_tp
)
next_token_ids = batch.sample(output.next_token_logits)

# Move logprobs to cpu
if output.next_token_logprobs is not None:
Expand Down Expand Up @@ -636,9 +641,7 @@ def forward_decode_batch(self, batch: ScheduleBatch):

# Forward and sample the next tokens
output = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids = batch.sample(
output.next_token_logits, self.model_runner.is_multi_node_tp
)
next_token_ids = batch.sample(output.next_token_logits)

# Move logprobs to cpu
if output.next_token_logprobs is not None:
Expand Down Expand Up @@ -879,6 +882,7 @@ def broadcast_recv_input(

dist.broadcast(tensor_size, src=0, group=dist_group)
dist.broadcast(tensor_data, src=0, group=dist_group)
return data
else:
tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group)
Expand Down
14 changes: 12 additions & 2 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,20 @@ def set_torch_compile_config():


class CudaGraphRunner:
def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
def __init__(
self,
model_runner,
max_batch_size_to_capture: int,
use_torch_compile: bool,
disable_padding: bool,
):
self.model_runner = model_runner
self.graphs = {}
self.input_buffers = {}
self.output_buffers = {}
self.flashinfer_handlers = {}
self.graph_memory_pool = None
self.disable_padding = disable_padding

# Common inputs
self.max_bs = max_batch_size_to_capture
Expand Down Expand Up @@ -142,7 +149,10 @@ def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
set_torch_compile_config()

def can_run(self, batch_size):
return batch_size <= self.max_bs
if self.disable_padding:
return batch_size in self.graphs
else:
return batch_size <= self.max_bs

def capture(self, batch_size_list):
self.batch_size_list = batch_size_list
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ def init_cuda_graphs(self):
self,
max_batch_size_to_capture=max(batch_size_list),
use_torch_compile=self.server_args.enable_torch_compile,
disable_padding=self.server_args.disable_cuda_graph_padding,
)
try:
self.cuda_graph_runner.capture(batch_size_list)
Expand Down
90 changes: 52 additions & 38 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import logging
import multiprocessing as mp
import os
import sys
import threading
import time
from http import HTTPStatus
Expand Down Expand Up @@ -112,10 +111,25 @@ async def health(request: Request) -> Response:

@app.get("/health")
async def health() -> Response:
"""Health check."""
"""Check the health of the http server."""
return Response(status_code=200)


@app.get("/health_generate")
async def health_generate(request: Request) -> Response:
"""Check the health of the inference server by generating one token."""
gri = GenerateReqInput(
text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
)
try:
async for _ in tokenizer_manager.generate_request(gri, request):
break
return Response(status_code=200)
except Exception as e:
logger.exception(e)
return Response(status_code=503)


@app.get("/get_model_info")
async def get_model_info():
result = {
Expand Down Expand Up @@ -301,27 +315,29 @@ def launch_server(
server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)

# Launch processes for multi-node tensor parallelism
if server_args.nnodes > 1:
if server_args.node_rank != 0:
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [
i for _ in range(server_args.nnodes) for i in range(tp_size_local)
]
tp_rank_range = list(
range(
server_args.node_rank * tp_size_local,
(server_args.node_rank + 1) * tp_size_local,
)
)
procs = launch_tp_servers(
gpu_ids,
tp_rank_range,
server_args,
ports[3],
model_overide_args,
if server_args.nnodes > 1 and server_args.node_rank != 0:
tp_size_local = server_args.tp_size // server_args.nnodes
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
tp_rank_range = list(
range(
server_args.node_rank * tp_size_local,
(server_args.node_rank + 1) * tp_size_local,
)
while True:
pass
)
procs = launch_tp_servers(
gpu_ids,
tp_rank_range,
server_args,
ports[3],
model_overide_args,
)

try:
for p in procs:
p.join()
finally:
kill_child_process(os.getpid(), including_parent=False)
return

# Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
Expand Down Expand Up @@ -356,15 +372,11 @@ def launch_server(
if controller_init_state != "init ok" or detoken_init_state != "init ok":
proc_controller.kill()
proc_detoken.kill()
print(
f"Initialization failed. controller_init_state: {controller_init_state}",
flush=True,
raise RuntimeError(
"Initialization failed. "
f"controller_init_state: {controller_init_state}, "
f"detoken_init_state: {detoken_init_state}"
)
print(
f"Initialization failed. detoken_init_state: {detoken_init_state}",
flush=True,
)
sys.exit(1)
assert proc_controller.is_alive() and proc_detoken.is_alive()

# Add api key authorization
Expand All @@ -373,12 +385,12 @@ def launch_server(

# Send a warmup request
t = threading.Thread(
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
)
t.start()

# Listen for requests
try:
# Listen for requests
uvicorn.run(
app,
host=server_args.host,
Expand Down Expand Up @@ -426,7 +438,7 @@ def _set_envs_and_config(server_args: ServerArgs):
)


def _wait_and_warmup(server_args, pipe_finish_writer):
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
headers = {}
url = server_args.url()
if server_args.api_key:
Expand All @@ -449,8 +461,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
if not success:
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
sys.exit(1)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(pid, including_parent=False)
return

# Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode"
Expand All @@ -475,12 +488,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
timeout=600,
)
assert res.status_code == 200, f"{res}"
except Exception as e:
except Exception:
last_traceback = get_exception_traceback()
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
sys.exit(1)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(pid, including_parent=False)
return

logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None:
Expand Down
Loading

0 comments on commit bba322e

Please sign in to comment.