Skip to content

Commit

Permalink
Add a watch dog thread (#1816)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Oct 27, 2024
1 parent 1be853e commit 86fc0d7
Show file tree
Hide file tree
Showing 34 changed files with 96 additions and 53 deletions.
2 changes: 1 addition & 1 deletion python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,4 +550,4 @@ def main(server_args, bench_args):
except Exception as e:
raise e
finally:
kill_child_process(os.getpid(), including_parent=False)
kill_child_process()
5 changes: 2 additions & 3 deletions python/sglang/bench_server_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import itertools
import json
import multiprocessing
import os
import time
from typing import Tuple

Expand Down Expand Up @@ -70,7 +69,7 @@ def launch_server_internal(server_args):
except Exception as e:
raise e
finally:
kill_child_process(os.getpid(), including_parent=False)
kill_child_process()


def launch_server_process(server_args: ServerArgs):
Expand Down Expand Up @@ -176,7 +175,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
)
finally:
if proc:
kill_child_process(proc.pid)
kill_child_process(proc.pid, include_self=True)

print(f"\nResults are saved to {bench_args.result_filename}")

Expand Down
2 changes: 1 addition & 1 deletion python/sglang/launch_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
except Exception as e:
raise e
finally:
kill_child_process(os.getpid(), including_parent=False)
kill_child_process()
38 changes: 33 additions & 5 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import logging
import os
import threading
import time
import warnings
from collections import deque
Expand Down Expand Up @@ -222,10 +223,11 @@ def __init__(
self.waiting_queue: List[Req] = []
self.running_batch: Optional[ScheduleBatch] = None
self.cur_batch: Optional[ScheduleBatch] = None
self.decode_forward_ct = 0
self.stream_interval = server_args.stream_interval
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
self.stream_interval = server_args.stream_interval

# Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
Expand Down Expand Up @@ -272,6 +274,11 @@ def __init__(

self.batch_is_full = False

# Init watchdog thread
self.watchdog_timeout = server_args.watchdog_timeout
t = threading.Thread(target=self.watchdog_thread, daemon=True)
t.start()

# Init profiler
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
self.profiler = None
Expand All @@ -289,6 +296,23 @@ def __init__(
with_stack=True,
)

def watchdog_thread(self):
self.watchdog_last_forward_ct = 0
self.watchdog_last_time = time.time()

while True:
if self.cur_batch is not None:
if self.watchdog_last_forward_ct == self.forward_ct:
if time.time() > self.watchdog_last_time + self.watchdog_timeout:
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
break
else:
self.watchdog_last_forward_ct = self.forward_ct
self.watchdog_last_time = time.time()
time.sleep(self.watchdog_timeout / 2)

kill_parent_process()

@torch.inference_mode()
def event_loop_normal(self):
"""A normal blocking scheduler loop."""
Expand All @@ -299,6 +323,7 @@ def event_loop_normal(self):
self.process_input_requests(recv_reqs)

batch = self.get_next_batch_to_run()
self.cur_batch = batch

if batch:
result = self.run_batch(batch)
Expand Down Expand Up @@ -746,6 +771,8 @@ def update_running_batch(self):

def run_batch(self, batch: ScheduleBatch):
"""Run a batch."""
self.forward_ct += 1

if self.is_generation:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
model_worker_batch = batch.get_model_worker_batch()
Expand Down Expand Up @@ -778,6 +805,7 @@ def process_batch_result(self, batch: ScheduleBatch, result):
self.process_batch_result_prefill(batch, result)

def process_batch_result_prefill(self, batch: ScheduleBatch, result):

if self.is_generation:
logits_output, next_token_ids, bid = result

Expand Down Expand Up @@ -890,8 +918,8 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result):

self.token_to_kv_pool.free_group_end()

self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
if self.tp_rank == 0 and self.forward_ct_decode % 40 == 0:
self.print_decode_stats()

def add_logprob_return_values(
Expand Down Expand Up @@ -984,7 +1012,7 @@ def stream_output(self, reqs: List[Req]):
else: # embedding or reward model
output_embeddings = []

is_stream_iter = self.decode_forward_ct % self.stream_interval == 0
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0

for req in reqs:
if req.finished() or (
Expand Down
12 changes: 6 additions & 6 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def launch_server(

# Send a warmup request
t = threading.Thread(
target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
)
t.start()

Expand Down Expand Up @@ -496,7 +496,7 @@ def _set_envs_and_config(server_args: ServerArgs):
mp.set_start_method("spawn", force=True)


def _wait_and_warmup(server_args, pipe_finish_writer, pid):
def _wait_and_warmup(server_args, pipe_finish_writer):
headers = {}
url = server_args.url()
if server_args.api_key:
Expand All @@ -519,7 +519,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(pid, including_parent=False)
kill_child_process(include_self=True)
return

model_info = res.json()
Expand Down Expand Up @@ -551,7 +551,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
if pipe_finish_writer is not None:
pipe_finish_writer.send(last_traceback)
logger.error(f"Initialization failed. warmup error: {last_traceback}")
kill_child_process(pid, including_parent=False)
kill_child_process(include_self=True)
return

# logger.info(f"{res.json()=}")
Expand Down Expand Up @@ -617,7 +617,7 @@ def __init__(

def shutdown(self):
if self.pid is not None:
kill_child_process(self.pid)
kill_child_process(self.pid, include_self=True)
self.pid = None

def cache_prefix(self, prefix: str):
Expand Down Expand Up @@ -834,7 +834,7 @@ async def generator_wrapper():
return ret

def shutdown(self):
kill_child_process(os.getpid(), including_parent=False)
kill_child_process(include_self=True)

def get_tokenizer(self):
global tokenizer_manager
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class ServerArgs:
api_key: Optional[str] = None
file_storage_pth: str = "SGLang_storage"
enable_cache_report: bool = False
watchdog_timeout: float = 600

# Data parallelism
dp_size: int = 1
Expand Down Expand Up @@ -429,6 +430,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
action="store_true",
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
)
parser.add_argument(
"--watchdog-timeout",
type=float,
default=ServerArgs.watchdog_timeout,
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
)

# Data parallelism
parser.add_argument(
Expand Down
21 changes: 15 additions & 6 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,17 +398,26 @@ def kill_parent_process():
"""Kill the parent process and all children of the parent process."""
current_process = psutil.Process()
parent_process = current_process.parent()
kill_child_process(parent_process.pid, skip_pid=current_process.pid)
kill_child_process(
parent_process.pid, include_self=True, skip_pid=current_process.pid
)
try:
current_process.kill()
except psutil.NoSuchProcess:
pass


def kill_child_process(pid, including_parent=True, skip_pid=None):
def kill_child_process(pid=None, include_self=False, skip_pid=None):
"""Kill the process and all its children process."""
if pid is None:
pid = os.getpid()

try:
parent = psutil.Process(pid)
itself = psutil.Process(pid)
except psutil.NoSuchProcess:
return

children = parent.children(recursive=True)
children = itself.children(recursive=True)
for child in children:
if child.pid == skip_pid:
continue
Expand All @@ -417,9 +426,9 @@ def kill_child_process(pid, including_parent=True, skip_pid=None):
except psutil.NoSuchProcess:
pass

if including_parent:
if include_self:
try:
parent.kill()
itself.kill()
except psutil.NoSuchProcess:
pass

Expand Down
10 changes: 5 additions & 5 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def run_one_file(filename):
)
assert ret_code == 0
except TimeoutError:
kill_child_process(process.pid)
kill_child_process(process.pid, include_self=True)
time.sleep(5)
print(
f"\nTimeout after {timeout_per_file} seconds when running {filename}\n",
Expand Down Expand Up @@ -563,7 +563,7 @@ def run_bench_serving(
try:
res = run_benchmark(args)
finally:
kill_child_process(process.pid)
kill_child_process(process.pid, include_self=True)

assert res["completed"] == num_prompts
return res
Expand Down Expand Up @@ -596,7 +596,7 @@ def run_bench_latency(model, other_args):
lastline = output.split("\n")[-3]
output_throughput = float(lastline.split(" ")[-2])
finally:
kill_child_process(process.pid)
kill_child_process(process.pid, include_self=True)

return output_throughput

Expand Down Expand Up @@ -707,8 +707,8 @@ def run_mmlu_test(
pass

# Clean up everything
kill_child_process(process.pid)
kill_child_process(process.pid)
kill_child_process(process.pid, include_self=True)
kill_child_process(process.pid, include_self=True)
stdout.close()
stderr.close()
if os.path.exists(STDOUT_FILENAME):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def setUpClass(cls):

@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)

def run_decode(
self,
Expand Down
2 changes: 1 addition & 1 deletion test/srt/test_cache_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def setUpClass(cls):

@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)

def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
response = requests.post(
Expand Down
2 changes: 1 addition & 1 deletion test/srt/test_data_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def setUpClass(cls):

@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)

def test_mmlu(self):
args = SimpleNamespace(
Expand Down
2 changes: 1 addition & 1 deletion test/srt/test_double_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def setUpClass(cls):

@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)

def test_mmlu(self):
args = SimpleNamespace(
Expand Down
2 changes: 1 addition & 1 deletion test/srt/test_embedding_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def setUpClass(cls):

@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)

def run_embedding(self, use_list_input, token_input):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
Expand Down
2 changes: 1 addition & 1 deletion test/srt/test_eval_accuracy_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def setUpClass(cls):

@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)

def test_mmlu(self):
args = SimpleNamespace(
Expand Down
2 changes: 1 addition & 1 deletion test/srt/test_eval_accuracy_large_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def setUpClass(cls):

@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)

def test_mmlu(self):
args = SimpleNamespace(
Expand Down
2 changes: 1 addition & 1 deletion test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def setUpClass(cls):

@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)

def test_mmlu(self):
args = SimpleNamespace(
Expand Down
2 changes: 1 addition & 1 deletion test/srt/test_eval_accuracy_mini.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def setUpClass(cls):

@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)

def test_mmlu(self):
args = SimpleNamespace(
Expand Down
2 changes: 1 addition & 1 deletion test/srt/test_json_constrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def setUpClass(cls):

@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)

def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1):
response = requests.post(
Expand Down
2 changes: 1 addition & 1 deletion test/srt/test_large_max_new_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def setUpClass(cls):

@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
kill_child_process(cls.process.pid, include_self=True)
cls.stdout.close()
cls.stderr.close()
os.remove("stdout.txt")
Expand Down
Loading

0 comments on commit 86fc0d7

Please sign in to comment.