Skip to content

Commit

Permalink
Crash the server correctly during error (#2231)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Nov 28, 2024
1 parent db674e3 commit d4fc1a7
Show file tree
Hide file tree
Showing 46 changed files with 147 additions and 139 deletions.
9 changes: 3 additions & 6 deletions python/sglang/bench_one_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import json
import logging
import multiprocessing
import os
import time
from typing import Tuple

Expand All @@ -62,11 +63,7 @@
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server import _set_envs_and_config
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
configure_logger,
kill_child_process,
suppress_other_loggers,
)
from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers


@dataclasses.dataclass
Expand Down Expand Up @@ -468,4 +465,4 @@ def main(server_args, bench_args):
main(server_args, bench_args)
finally:
if server_args.tp_size != 1:
kill_child_process()
kill_process_tree(os.getpid(), include_parent=False)
7 changes: 4 additions & 3 deletions python/sglang/bench_one_batch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import itertools
import json
import multiprocessing
import os
import time
from typing import Tuple

Expand All @@ -23,7 +24,7 @@

from sglang.srt.server import launch_server
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_child_process
from sglang.srt.utils import kill_process_tree


@dataclasses.dataclass
Expand Down Expand Up @@ -69,7 +70,7 @@ def launch_server_internal(server_args):
except Exception as e:
raise e
finally:
kill_child_process()
kill_process_tree(os.getpid(), include_parent=False)


def launch_server_process(server_args: ServerArgs):
Expand Down Expand Up @@ -175,7 +176,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
)
finally:
if proc:
kill_child_process(proc.pid, include_self=True)
kill_process_tree(proc.pid)

print(f"\nResults are saved to {bench_args.result_filename}")

Expand Down
4 changes: 2 additions & 2 deletions python/sglang/launch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

from sglang.srt.server import launch_server
from sglang.srt.server_args import prepare_server_args
from sglang.srt.utils import kill_child_process
from sglang.srt.utils import kill_process_tree

if __name__ == "__main__":
server_args = prepare_server_args(sys.argv[1:])

try:
launch_server(server_args)
finally:
kill_child_process()
kill_process_tree(os.getpid(), include_parent=False)
18 changes: 7 additions & 11 deletions python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@

import logging
import multiprocessing as mp
import signal
import threading
from enum import Enum, auto

import psutil
import zmq

from sglang.srt.managers.io_struct import (
Expand All @@ -26,13 +28,7 @@
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
bind_port,
configure_logger,
get_zmq_socket,
kill_parent_process,
suppress_other_loggers,
)
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -235,7 +231,7 @@ def run_data_parallel_controller_process(
pipe_writer,
):
configure_logger(server_args)
suppress_other_loggers()
parent_process = psutil.Process().parent()

try:
controller = DataParallelController(server_args, port_args)
Expand All @@ -244,6 +240,6 @@ def run_data_parallel_controller_process(
)
controller.event_loop()
except Exception:
msg = get_exception_traceback()
logger.error(msg)
kill_parent_process()
traceback = get_exception_traceback()
logger.error(f"DataParallelController hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT)
11 changes: 7 additions & 4 deletions python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@

import dataclasses
import logging
import signal
from collections import OrderedDict
from typing import List, Union

import psutil
import zmq

from sglang.srt.hf_transformers_utils import get_tokenizer
Expand All @@ -28,7 +30,7 @@
)
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, get_zmq_socket, kill_parent_process
from sglang.srt.utils import configure_logger, get_zmq_socket
from sglang.utils import find_printable_text, get_exception_traceback

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -193,11 +195,12 @@ def run_detokenizer_process(
port_args: PortArgs,
):
configure_logger(server_args)
parent_process = psutil.Process().parent()

try:
manager = DetokenizerManager(server_args, port_args)
manager.event_loop()
except Exception:
msg = get_exception_traceback()
logger.error(msg)
kill_parent_process()
traceback = get_exception_traceback()
logger.error(f"DetokenizerManager hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT)
13 changes: 8 additions & 5 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import logging
import os
import signal
import threading
import time
import warnings
Expand All @@ -23,6 +24,7 @@
from types import SimpleNamespace
from typing import List, Optional

import psutil
import torch
import zmq

Expand Down Expand Up @@ -73,7 +75,6 @@
crash_on_warnings,
get_bool_env_var,
get_zmq_socket,
kill_parent_process,
set_gpu_proc_affinity,
set_random_seed,
suppress_other_loggers,
Expand Down Expand Up @@ -316,6 +317,7 @@ def __init__(
self.watchdog_timeout = server_args.watchdog_timeout
t = threading.Thread(target=self.watchdog_thread, daemon=True)
t.start()
self.parent_process = psutil.Process().parent()

# Init profiler
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
Expand Down Expand Up @@ -359,7 +361,7 @@ def watchdog_thread(self):
self.watchdog_last_time = time.time()
time.sleep(self.watchdog_timeout / 2)

kill_parent_process()
self.parent_process.send_signal(signal.SIGQUIT)

@torch.no_grad()
def event_loop_normal(self):
Expand Down Expand Up @@ -1423,6 +1425,7 @@ def run_scheduler_process(
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")

suppress_other_loggers()
parent_process = psutil.Process().parent()

try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
Expand All @@ -1434,6 +1437,6 @@ def run_scheduler_process(
else:
scheduler.event_loop_normal()
except Exception:
msg = get_exception_traceback()
logger.error(msg)
kill_parent_process()
traceback = get_exception_traceback()
logger.error(f"Scheduler hit an exception: {traceback}")
parent_process.send_signal(signal.SIGQUIT)
4 changes: 2 additions & 2 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import get_zmq_socket, kill_child_process
from sglang.srt.utils import get_zmq_socket, kill_process_tree

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

Expand Down Expand Up @@ -532,7 +532,7 @@ async def sigterm_watchdog(self):
else:
break

kill_child_process(include_self=True)
kill_process_tree(os.getpid(), include_parent=True)
sys.exit(0)

async def handle_loop(self):
Expand Down
13 changes: 11 additions & 2 deletions python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,19 @@

import dataclasses
import logging
import signal
import threading
from queue import Queue
from typing import Optional

import psutil
import torch

from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -70,6 +73,7 @@ def __init__(
target=self.forward_thread_func,
)
self.forward_thread.start()
self.parent_process = psutil.Process().parent()

def get_worker_info(self):
return self.worker.get_worker_info()
Expand All @@ -87,8 +91,13 @@ def get_memory_pool(self):
)

def forward_thread_func(self):
with torch.cuda.stream(self.forward_stream):
self.forward_thread_func_()
try:
with torch.cuda.stream(self.forward_stream):
self.forward_thread_func_()
except Exception:
traceback = get_exception_traceback()
logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
self.parent_process.send_signal(signal.SIGQUIT)

@torch.no_grad()
def forward_thread_func_(self):
Expand Down
21 changes: 16 additions & 5 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import logging
import multiprocessing as mp
import os
import signal
import sys
import threading
import time
from http import HTTPStatus
Expand Down Expand Up @@ -79,7 +81,7 @@
configure_logger,
delete_directory,
is_port_available,
kill_child_process,
kill_process_tree,
maybe_set_triton_cache_manager,
prepare_model_and_tokenizer,
set_prometheus_multiproc_dir,
Expand Down Expand Up @@ -572,6 +574,15 @@ def _set_envs_and_config(server_args: ServerArgs):
"at https://docs.flashinfer.ai/installation.html.",
)

# Register the signal handler.
# The child processes will send SIGQUIT to this process when any error happens
# This process then clean up the whole process tree
def sigquit_handler(signum, frame):
kill_process_tree(os.getpid())

signal.signal(signal.SIGQUIT, sigquit_handler)

# Set mp start method
mp.set_start_method("spawn", force=True)


Expand All @@ -598,7 +609,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(include_self=True)
kill_process_tree(os.getpid())
return

model_info = res.json()
Expand Down Expand Up @@ -631,7 +642,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(include_self=True)
kill_process_tree(os.getpid())
return

# logger.info(f"{res.json()=}")
Expand Down Expand Up @@ -700,7 +711,7 @@ def __init__(

def shutdown(self):
if self.pid is not None:
kill_child_process(self.pid, include_self=True)
kill_process_tree(self.pid)
self.pid = None

def cache_prefix(self, prefix: str):
Expand Down Expand Up @@ -924,7 +935,7 @@ async def generator_wrapper():
return ret

def shutdown(self):
kill_child_process()
kill_process_tree(os.getpid(), include_parent=False)

def get_tokenizer(self):
global tokenizer_manager
Expand Down
28 changes: 8 additions & 20 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,26 +443,14 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
)


def kill_parent_process():
"""Kill the parent process and all children of the parent process."""
current_process = psutil.Process()
parent_process = current_process.parent()
kill_child_process(
parent_process.pid, include_self=True, skip_pid=current_process.pid
)
try:
current_process.kill()
except psutil.NoSuchProcess:
pass


def kill_child_process(pid=None, include_self=False, skip_pid=None):
"""Kill the process and all its children process."""
if pid is None:
pid = os.getpid()
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
"""Kill the process and all its child processes."""
if parent_pid is None:
parent_pid = os.getpid()
include_parent = False

try:
itself = psutil.Process(pid)
itself = psutil.Process(parent_pid)
except psutil.NoSuchProcess:
return

Expand All @@ -475,13 +463,13 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
except psutil.NoSuchProcess:
pass

if include_self:
if include_parent:
try:
itself.kill()

# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
# so we send an additional signal to kill them.
itself.send_signal(signal.SIGINT)
itself.send_signal(signal.SIGQUIT)
except psutil.NoSuchProcess:
pass

Expand Down
Loading

0 comments on commit d4fc1a7

Please sign in to comment.