From cd51758fade4119b3f6233444c3bfac91ed5eba9 Mon Sep 17 00:00:00 2001 From: HAI Date: Wed, 27 Nov 2024 21:18:51 -0800 Subject: [PATCH 1/8] Rename tuned MI300X config files for fused_moe_triton (#2228) --- ...=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} | 0 ...=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename python/sglang/srt/layers/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=float8.json => E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} (100%) rename python/sglang/srt/layers/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=float8.json => E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} (100%) diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=float8.json b/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=float8.json rename to python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=float8.json b/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=float8.json rename to python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json From b79fffdcb5c52ba8fdc72a9f18aabc3cd50bc7ff Mon Sep 17 00:00:00 2001 From: HAI Date: Wed, 27 Nov 2024 22:46:55 -0800 Subject: [PATCH 2/8] Update Install Method 2. From source (#2232) --- docker/Dockerfile.rocm | 2 +- docs/start/install.md | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 3a2c2761b0..2b9296d8f2 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -1,5 +1,5 @@ # Usage (to build SGLang ROCm docker image): -# docker build --build-arg SGL_BRANCH=v0.3.6.post2 -t testImage -f Dockerfile.rocm . +# docker build --build-arg SGL_BRANCH=v0.3.6.post2 -t v0.3.6.post2-rocm620 -f Dockerfile.rocm . # default base image ARG BASE_IMAGE="rocm/vllm-dev:20241022" diff --git a/docs/start/install.md b/docs/start/install.md index 220fc3c5b5..8debab0eb1 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -28,6 +28,17 @@ pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. +Note: To AMD ROCm system with Instinct/MI GPUs, do following instead: + +``` +# Use the last release branch +git clone -b v0.3.6.post2 https://github.com/sgl-project/sglang.git +cd sglang + +pip install --upgrade pip +pip install -e "python[all_hip]" +``` + ## Method 3: Using docker The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker). Replace `` below with your huggingface hub [token](https://huggingface.co/docs/hub/en/security-tokens). From 09798b36cd31f8f9787cc43a5aed9bca173ada40 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 27 Nov 2024 23:37:20 -0800 Subject: [PATCH 3/8] Fix chunked prefill size for bench_offline_throughput (#2234) --- python/sglang/srt/server_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b545e00c07..144ade58ea 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -144,7 +144,7 @@ def __post_init__(self): if self.served_model_name is None: self.served_model_name = self.model_path - if self.chunked_prefill_size <= 0: + if self.chunked_prefill_size is not None and self.chunked_prefill_size <= 0: # Disable chunked prefill self.chunked_prefill_size = None From fb915bd1a2e0f1425ecfd3ab47cace317abf1ddb Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 27 Nov 2024 23:44:33 -0800 Subject: [PATCH 4/8] Disable overlap scheduler for multimodal models (#2235) --- python/sglang/srt/managers/scheduler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 663d4c4f93..2563bb5596 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -170,6 +170,10 @@ def __init__( self.enable_overlap = False logger.info("Overlap scheduler is disabled for embedding models.") + if self.model_config.is_multimodal: + self.enable_overlap = False + logger.info("Overlap scheduler is disabled for multimodal models.") + if self.enable_overlap: self.disable_jump_forward = True From db674e3d24dd224df42aef37cad55be130062a6f Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Thu, 28 Nov 2024 10:15:20 +0200 Subject: [PATCH 5/8] Add OLMo2 model. (#2233) --- python/sglang/srt/models/olmo2.py | 392 ++++++++++++++++++++++ test/srt/models/test_generation_models.py | 1 + 2 files changed, 393 insertions(+) create mode 100755 python/sglang/srt/models/olmo2.py diff --git a/python/sglang/srt/models/olmo2.py b/python/sglang/srt/models/olmo2.py new file mode 100755 index 0000000000..d73a6d5a3d --- /dev/null +++ b/python/sglang/srt/models/olmo2.py @@ -0,0 +1,392 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/olmo2.py +"""Inference-only OLMo2 model compatible with HuggingFace weights.""" +from functools import partial +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, +) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.utils import make_layers + + +class Olmo2Attention(nn.Module): + """ + This is the attention block where the output is computed as + ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: PretrainedConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % tp_size == 0 + + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = self.config.num_key_value_heads + + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + # Attention input projection. Projects x -> (q, k, v) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + bias=config.attention_bias, + ) + self.tp_rank = get_tensor_model_parallel_rank() + + self.k_norm = RMSNorm( + self.total_num_kv_heads * self.head_dim, + eps=self.config.rms_norm_eps, + ) + self.q_norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + # Rotary embeddings. + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + + # Attention output projection. + self.o_proj = RowParallelLinear( + self.head_dim * self.total_num_heads, + self.hidden_size, + bias=config.attention_bias, + ) + + def _apply_qk_norm( + self, q: torch.Tensor, k: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.tp_size > 1: + q = tensor_model_parallel_all_gather(q.contiguous()) + k = tensor_model_parallel_all_gather(k.contiguous()) + q = self.q_norm.forward_native(q) + k = self.k_norm.forward_native(k) + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output + + +class Olmo2MLP(nn.Module): + """ + This is the MLP block where the output is computed as + ``MLP(x)`` in ``LN(MLP(x + LN(Attention(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + # Feed-forward input projection. + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + + # Activation function. + self.act_fn = SiluAndMul() + + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Olmo2DecoderLayer(nn.Module): + """ + This is a typical transformer block where the output is + computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: PretrainedConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + # Attention block. + self.self_attn = Olmo2Attention(config, layer_id, quant_config) + + # MLP block. + self.mlp = Olmo2MLP(config, quant_config) + + # RMSNorm + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + self.post_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + # Attention block. + residual = hidden_states + hidden_states = self.self_attn(positions, hidden_states, forward_batch) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = hidden_states + residual + + # MLP block. + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Olmo2Model(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) + self.layers = make_layers( + config.num_hidden_layers, + lambda idx, prefix: Olmo2DecoderLayer( + layer_id=idx, + config=config, + quant_config=quant_config, + ), + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + """ + # Get embeddings of input. + # shape: (batch_size, seq_len, d_model) + + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + # Apply blocks one-by-one. + for layer_id, decoder_layer in enumerate(self.layers): + # shape: (batch_size, seq_len, d_model) + hidden_states = decoder_layer( + positions, + hidden_states, + forward_batch, + ) + + # Apply final layer norm. + # shape: (batch_size, seq_len or 1, d_model) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class Olmo2ForCausalLM(nn.Module): + """ + Extremely barebones HF model wrapper. + """ + + def __init__( + self, + config: PretrainedConfig, + cache_config=None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.model = Olmo2Model(config, quant_config) + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + forward_batch=forward_batch, + input_embeds=input_embeds, + ) + return self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, forward_batch + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = Olmo2ForCausalLM diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index dbe35b0e7e..d9f1795341 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -56,6 +56,7 @@ class ModelCase: ModelCase("THUDM/glm-4-9b-chat"), ModelCase("openai-community/gpt2"), ModelCase("microsoft/Phi-3-small-8k-instruct"), + ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True), ] TORCH_DTYPES = [torch.float16] From d4fc1a70e3187c914043a1ffc619adbb0c3c6860 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 28 Nov 2024 00:22:39 -0800 Subject: [PATCH 6/8] Crash the server correctly during error (#2231) --- python/sglang/bench_one_batch.py | 9 ++---- python/sglang/bench_one_batch_server.py | 7 +++-- python/sglang/launch_server.py | 4 +-- .../srt/managers/data_parallel_controller.py | 18 +++++------- .../srt/managers/detokenizer_manager.py | 11 +++++--- python/sglang/srt/managers/scheduler.py | 13 +++++---- .../sglang/srt/managers/tokenizer_manager.py | 4 +-- .../srt/managers/tp_worker_overlap_thread.py | 13 +++++++-- python/sglang/srt/server.py | 21 ++++++++++---- python/sglang/srt/utils.py | 28 ++++++------------- python/sglang/test/test_utils.py | 12 ++++---- python/sglang/utils.py | 4 +-- rust/py_test/test_launch_server.py | 4 +-- .../test_srt_endpoint_with_penalizers.py | 4 +-- test/srt/test_cache_report.py | 4 +-- test/srt/test_data_parallelism.py | 4 +-- test/srt/test_double_sparsity.py | 4 +-- test/srt/test_dp_attention.py | 4 +-- test/srt/test_embedding_openai_server.py | 4 +-- test/srt/test_eval_accuracy_large.py | 4 +-- ...est_eval_accuracy_large_chunked_prefill.py | 4 +-- ...al_accuracy_large_mixed_chunked_prefill.py | 4 +-- test/srt/test_eval_accuracy_mini.py | 4 +-- test/srt/test_input_embeddings.py | 4 +-- test/srt/test_json_constrained.py | 4 +-- test/srt/test_large_max_new_tokens.py | 4 +-- test/srt/test_matched_stop.py | 4 +-- test/srt/test_metrics.py | 4 +-- test/srt/test_mla.py | 4 +-- test/srt/test_mla_fp8.py | 4 +-- test/srt/test_moe_eval_accuracy_large.py | 4 +-- test/srt/test_nightly_gsm8k_eval.py | 4 +-- test/srt/test_nightly_human_eval.py | 6 ++-- test/srt/test_openai_server.py | 4 +-- test/srt/test_pytorch_sampling_backend.py | 4 +-- test/srt/test_radix_attention.py | 4 +-- test/srt/test_retract_decode.py | 4 +-- test/srt/test_session_control.py | 6 ++-- test/srt/test_skip_tokenizer_init.py | 4 +-- test/srt/test_srt_endpoint.py | 4 +-- test/srt/test_torch_compile.py | 4 +-- test/srt/test_torch_compile_moe.py | 4 +-- test/srt/test_torchao.py | 4 +-- test/srt/test_triton_attention_backend.py | 4 +-- test/srt/test_update_weights.py | 4 +-- test/srt/test_vision_openai_server.py | 6 ++-- 46 files changed, 147 insertions(+), 139 deletions(-) diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 863bc58399..9bbe9b0f18 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -47,6 +47,7 @@ import json import logging import multiprocessing +import os import time from typing import Tuple @@ -62,11 +63,7 @@ from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server import _set_envs_and_config from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import ( - configure_logger, - kill_child_process, - suppress_other_loggers, -) +from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers @dataclasses.dataclass @@ -468,4 +465,4 @@ def main(server_args, bench_args): main(server_args, bench_args) finally: if server_args.tp_size != 1: - kill_child_process() + kill_process_tree(os.getpid(), include_parent=False) diff --git a/python/sglang/bench_one_batch_server.py b/python/sglang/bench_one_batch_server.py index 9d6048bc11..01cc561e1c 100644 --- a/python/sglang/bench_one_batch_server.py +++ b/python/sglang/bench_one_batch_server.py @@ -15,6 +15,7 @@ import itertools import json import multiprocessing +import os import time from typing import Tuple @@ -23,7 +24,7 @@ from sglang.srt.server import launch_server from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree @dataclasses.dataclass @@ -69,7 +70,7 @@ def launch_server_internal(server_args): except Exception as e: raise e finally: - kill_child_process() + kill_process_tree(os.getpid(), include_parent=False) def launch_server_process(server_args: ServerArgs): @@ -175,7 +176,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ) finally: if proc: - kill_child_process(proc.pid, include_self=True) + kill_process_tree(proc.pid) print(f"\nResults are saved to {bench_args.result_filename}") diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 3e2cd4a97f..b2ad1b3209 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -4,7 +4,7 @@ from sglang.srt.server import launch_server from sglang.srt.server_args import prepare_server_args -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree if __name__ == "__main__": server_args = prepare_server_args(sys.argv[1:]) @@ -12,4 +12,4 @@ try: launch_server(server_args) finally: - kill_child_process() + kill_process_tree(os.getpid(), include_parent=False) diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index d4730e3f7a..8edb79417e 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -15,9 +15,11 @@ import logging import multiprocessing as mp +import signal import threading from enum import Enum, auto +import psutil import zmq from sglang.srt.managers.io_struct import ( @@ -26,13 +28,7 @@ ) from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import ( - bind_port, - configure_logger, - get_zmq_socket, - kill_parent_process, - suppress_other_loggers, -) +from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -235,7 +231,7 @@ def run_data_parallel_controller_process( pipe_writer, ): configure_logger(server_args) - suppress_other_loggers() + parent_process = psutil.Process().parent() try: controller = DataParallelController(server_args, port_args) @@ -244,6 +240,6 @@ def run_data_parallel_controller_process( ) controller.event_loop() except Exception: - msg = get_exception_traceback() - logger.error(msg) - kill_parent_process() + traceback = get_exception_traceback() + logger.error(f"DataParallelController hit an exception: {traceback}") + parent_process.send_signal(signal.SIGQUIT) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 18f77424dc..e74ba5026c 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -15,9 +15,11 @@ import dataclasses import logging +import signal from collections import OrderedDict from typing import List, Union +import psutil import zmq from sglang.srt.hf_transformers_utils import get_tokenizer @@ -28,7 +30,7 @@ ) from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import configure_logger, get_zmq_socket, kill_parent_process +from sglang.srt.utils import configure_logger, get_zmq_socket from sglang.utils import find_printable_text, get_exception_traceback logger = logging.getLogger(__name__) @@ -193,11 +195,12 @@ def run_detokenizer_process( port_args: PortArgs, ): configure_logger(server_args) + parent_process = psutil.Process().parent() try: manager = DetokenizerManager(server_args, port_args) manager.event_loop() except Exception: - msg = get_exception_traceback() - logger.error(msg) - kill_parent_process() + traceback = get_exception_traceback() + logger.error(f"DetokenizerManager hit an exception: {traceback}") + parent_process.send_signal(signal.SIGQUIT) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2563bb5596..a327f37a2f 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -15,6 +15,7 @@ import logging import os +import signal import threading import time import warnings @@ -23,6 +24,7 @@ from types import SimpleNamespace from typing import List, Optional +import psutil import torch import zmq @@ -73,7 +75,6 @@ crash_on_warnings, get_bool_env_var, get_zmq_socket, - kill_parent_process, set_gpu_proc_affinity, set_random_seed, suppress_other_loggers, @@ -316,6 +317,7 @@ def __init__( self.watchdog_timeout = server_args.watchdog_timeout t = threading.Thread(target=self.watchdog_thread, daemon=True) t.start() + self.parent_process = psutil.Process().parent() # Init profiler if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "": @@ -359,7 +361,7 @@ def watchdog_thread(self): self.watchdog_last_time = time.time() time.sleep(self.watchdog_timeout / 2) - kill_parent_process() + self.parent_process.send_signal(signal.SIGQUIT) @torch.no_grad() def event_loop_normal(self): @@ -1423,6 +1425,7 @@ def run_scheduler_process( configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}") suppress_other_loggers() + parent_process = psutil.Process().parent() try: scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) @@ -1434,6 +1437,6 @@ def run_scheduler_process( else: scheduler.event_loop_normal() except Exception: - msg = get_exception_traceback() - logger.error(msg) - kill_parent_process() + traceback = get_exception_traceback() + logger.error(f"Scheduler hit an exception: {traceback}") + parent_process.send_signal(signal.SIGQUIT) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 001ecc1ebe..15518e9e5f 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -58,7 +58,7 @@ from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import get_zmq_socket, kill_child_process +from sglang.srt.utils import get_zmq_socket, kill_process_tree asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -532,7 +532,7 @@ async def sigterm_watchdog(self): else: break - kill_child_process(include_self=True) + kill_process_tree(os.getpid(), include_parent=True) sys.exit(0) async def handle_loop(self): diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 3b53759a75..a5412094c9 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -15,16 +15,19 @@ import dataclasses import logging +import signal import threading from queue import Queue from typing import Optional +import psutil import torch from sglang.srt.managers.io_struct import UpdateWeightReqInput from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.server_args import ServerArgs +from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -70,6 +73,7 @@ def __init__( target=self.forward_thread_func, ) self.forward_thread.start() + self.parent_process = psutil.Process().parent() def get_worker_info(self): return self.worker.get_worker_info() @@ -87,8 +91,13 @@ def get_memory_pool(self): ) def forward_thread_func(self): - with torch.cuda.stream(self.forward_stream): - self.forward_thread_func_() + try: + with torch.cuda.stream(self.forward_stream): + self.forward_thread_func_() + except Exception: + traceback = get_exception_traceback() + logger.error(f"TpModelWorkerClient hit an exception: {traceback}") + self.parent_process.send_signal(signal.SIGQUIT) @torch.no_grad() def forward_thread_func_(self): diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index a4753a1345..c958930671 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -23,6 +23,8 @@ import logging import multiprocessing as mp import os +import signal +import sys import threading import time from http import HTTPStatus @@ -79,7 +81,7 @@ configure_logger, delete_directory, is_port_available, - kill_child_process, + kill_process_tree, maybe_set_triton_cache_manager, prepare_model_and_tokenizer, set_prometheus_multiproc_dir, @@ -572,6 +574,15 @@ def _set_envs_and_config(server_args: ServerArgs): "at https://docs.flashinfer.ai/installation.html.", ) + # Register the signal handler. + # The child processes will send SIGQUIT to this process when any error happens + # This process then clean up the whole process tree + def sigquit_handler(signum, frame): + kill_process_tree(os.getpid()) + + signal.signal(signal.SIGQUIT, sigquit_handler) + + # Set mp start method mp.set_start_method("spawn", force=True) @@ -598,7 +609,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer): if pipe_finish_writer is not None: pipe_finish_writer.send(last_traceback) logger.error(f"Initialization failed. warmup error: {last_traceback}") - kill_child_process(include_self=True) + kill_process_tree(os.getpid()) return model_info = res.json() @@ -631,7 +642,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer): if pipe_finish_writer is not None: pipe_finish_writer.send(last_traceback) logger.error(f"Initialization failed. warmup error: {last_traceback}") - kill_child_process(include_self=True) + kill_process_tree(os.getpid()) return # logger.info(f"{res.json()=}") @@ -700,7 +711,7 @@ def __init__( def shutdown(self): if self.pid is not None: - kill_child_process(self.pid, include_self=True) + kill_process_tree(self.pid) self.pid = None def cache_prefix(self, prefix: str): @@ -924,7 +935,7 @@ async def generator_wrapper(): return ret def shutdown(self): - kill_child_process() + kill_process_tree(os.getpid(), include_parent=False) def get_tokenizer(self): global tokenizer_manager diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 4a974e2e75..46b4db8e88 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -443,26 +443,14 @@ def assert_pkg_version(pkg: str, min_version: str, message: str): ) -def kill_parent_process(): - """Kill the parent process and all children of the parent process.""" - current_process = psutil.Process() - parent_process = current_process.parent() - kill_child_process( - parent_process.pid, include_self=True, skip_pid=current_process.pid - ) - try: - current_process.kill() - except psutil.NoSuchProcess: - pass - - -def kill_child_process(pid=None, include_self=False, skip_pid=None): - """Kill the process and all its children process.""" - if pid is None: - pid = os.getpid() +def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None): + """Kill the process and all its child processes.""" + if parent_pid is None: + parent_pid = os.getpid() + include_parent = False try: - itself = psutil.Process(pid) + itself = psutil.Process(parent_pid) except psutil.NoSuchProcess: return @@ -475,13 +463,13 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None): except psutil.NoSuchProcess: pass - if include_self: + if include_parent: try: itself.kill() # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes), # so we send an additional signal to kill them. - itself.send_signal(signal.SIGINT) + itself.send_signal(signal.SIGQUIT) except psutil.NoSuchProcess: pass diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 3089668443..3f6cce23d5 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -22,7 +22,7 @@ from sglang.global_config import global_config from sglang.lang.backend.openai import OpenAI from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint -from sglang.srt.utils import get_bool_env_var, kill_child_process +from sglang.srt.utils import get_bool_env_var, kill_process_tree from sglang.test.run_eval import run_eval from sglang.utils import get_exception_traceback @@ -504,7 +504,7 @@ def run_one_file(filename): ) assert ret_code == 0 except TimeoutError: - kill_child_process(process.pid, include_self=True) + kill_process_tree(process.pid) time.sleep(5) print( f"\nTimeout after {timeout_per_file} seconds when running {filename}\n", @@ -578,7 +578,7 @@ def run_bench_serving( run_benchmark(warmup_args) res = run_benchmark(args) finally: - kill_child_process(process.pid, include_self=True) + kill_process_tree(process.pid) assert res["completed"] == num_prompts return res @@ -611,7 +611,7 @@ def run_bench_one_batch(model, other_args): lastline = output.split("\n")[-3] output_throughput = float(lastline.split(" ")[-2]) finally: - kill_child_process(process.pid, include_self=True) + kill_process_tree(process.pid) return output_throughput @@ -710,8 +710,8 @@ def run_and_check_memory_leak( workload_func(base_url, model) # Clean up everything - kill_child_process(process.pid, include_self=True) - kill_child_process(process.pid, include_self=True) + kill_process_tree(process.pid) + kill_process_tree(process.pid) stdout.close() stderr.close() if os.path.exists(STDOUT_FILENAME): diff --git a/python/sglang/utils.py b/python/sglang/utils.py index e694dc198d..c1bf62ef98 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -348,9 +348,9 @@ def wait_for_server(base_url: str, timeout: int = None) -> None: def terminate_process(process): - from sglang.srt.utils import kill_child_process + from sglang.srt.utils import kill_process_tree - kill_child_process(process.pid, include_self=True) + kill_process_tree(process.pid) def print_highlight(html_content: str): diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index 7fdaea6b1c..f39b341df2 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -5,7 +5,7 @@ import requests -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -79,7 +79,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py index 689d52a1c5..0eccb3407f 100644 --- a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py +++ b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py @@ -4,7 +4,7 @@ import requests -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -31,7 +31,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def run_decode( self, diff --git a/test/srt/test_cache_report.py b/test/srt/test_cache_report.py index 5d498ac3f4..f128aa147d 100644 --- a/test/srt/test_cache_report.py +++ b/test/srt/test_cache_report.py @@ -4,7 +4,7 @@ import openai import requests -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_URL_FOR_TEST, @@ -44,7 +44,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): response = requests.post( diff --git a/test/srt/test_data_parallelism.py b/test/srt/test_data_parallelism.py index f34313ea09..22d0006640 100644 --- a/test/srt/test_data_parallelism.py +++ b/test/srt/test_data_parallelism.py @@ -4,7 +4,7 @@ import requests -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -28,7 +28,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_double_sparsity.py b/test/srt/test_double_sparsity.py index 20896aff28..060a7926f6 100644 --- a/test/srt/test_double_sparsity.py +++ b/test/srt/test_double_sparsity.py @@ -2,7 +2,7 @@ import unittest from types import SimpleNamespace -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -45,7 +45,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_dp_attention.py b/test/srt/test_dp_attention.py index 32fe75a59b..31c9cc71ba 100644 --- a/test/srt/test_dp_attention.py +++ b/test/srt/test_dp_attention.py @@ -1,7 +1,7 @@ import unittest from types import SimpleNamespace -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -30,7 +30,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_embedding_openai_server.py b/test/srt/test_embedding_openai_server.py index 666297c650..8097bf42cd 100644 --- a/test/srt/test_embedding_openai_server.py +++ b/test/srt/test_embedding_openai_server.py @@ -3,7 +3,7 @@ import openai from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -28,7 +28,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def run_embedding(self, use_list_input, token_input): client = openai.Client(api_key=self.api_key, base_url=self.base_url) diff --git a/test/srt/test_eval_accuracy_large.py b/test/srt/test_eval_accuracy_large.py index 318390d100..f7fb3cec3e 100644 --- a/test/srt/test_eval_accuracy_large.py +++ b/test/srt/test_eval_accuracy_large.py @@ -6,7 +6,7 @@ import unittest from types import SimpleNamespace -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -30,7 +30,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_eval_accuracy_large_chunked_prefill.py b/test/srt/test_eval_accuracy_large_chunked_prefill.py index 2e9ff59cda..c8ce5cff2b 100644 --- a/test/srt/test_eval_accuracy_large_chunked_prefill.py +++ b/test/srt/test_eval_accuracy_large_chunked_prefill.py @@ -1,7 +1,7 @@ import unittest from types import SimpleNamespace -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -25,7 +25,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py b/test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py index 0fb08e64f4..3bc115874f 100644 --- a/test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py +++ b/test/srt/test_eval_accuracy_large_mixed_chunked_prefill.py @@ -1,7 +1,7 @@ import unittest from types import SimpleNamespace -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -31,7 +31,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_eval_accuracy_mini.py b/test/srt/test_eval_accuracy_mini.py index a718feff76..a008c3869e 100644 --- a/test/srt/test_eval_accuracy_mini.py +++ b/test/srt/test_eval_accuracy_mini.py @@ -1,7 +1,7 @@ import unittest from types import SimpleNamespace -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -22,7 +22,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_input_embeddings.py b/test/srt/test_input_embeddings.py index b57b61dad4..04d54c6bbd 100644 --- a/test/srt/test_input_embeddings.py +++ b/test/srt/test_input_embeddings.py @@ -4,7 +4,7 @@ import requests from transformers import AutoModelForCausalLM, AutoTokenizer -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -107,7 +107,7 @@ def test_compare_text_vs_embedding(self): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) if __name__ == "__main__": diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index ae27b036f9..28acdabd9d 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -9,7 +9,7 @@ import openai import requests -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_URL_FOR_TEST, @@ -46,7 +46,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1): response = requests.post( diff --git a/test/srt/test_large_max_new_tokens.py b/test/srt/test_large_max_new_tokens.py index 5ed2b06fc1..dcaeef5aa1 100644 --- a/test/srt/test_large_max_new_tokens.py +++ b/test/srt/test_large_max_new_tokens.py @@ -10,7 +10,7 @@ import openai from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -52,7 +52,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) cls.stdout.close() cls.stderr.close() os.remove(STDOUT_FILENAME) diff --git a/test/srt/test_matched_stop.py b/test/srt/test_matched_stop.py index 81d08b0913..7b09a6d35f 100644 --- a/test/srt/test_matched_stop.py +++ b/test/srt/test_matched_stop.py @@ -3,7 +3,7 @@ import requests -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_URL_FOR_TEST, @@ -32,7 +32,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def run_completions_generation( self, diff --git a/test/srt/test_metrics.py b/test/srt/test_metrics.py index 163a7cc0e0..3b73e500d7 100644 --- a/test/srt/test_metrics.py +++ b/test/srt/test_metrics.py @@ -2,7 +2,7 @@ import requests -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -75,7 +75,7 @@ def test_metrics_enabled(self): self.assertIn("_bucket{", metrics_content) finally: - kill_child_process(process.pid, include_self=True) + kill_process_tree(process.pid) if __name__ == "__main__": diff --git a/test/srt/test_mla.py b/test/srt/test_mla.py index a11be3950a..b8105a84af 100644 --- a/test/srt/test_mla.py +++ b/test/srt/test_mla.py @@ -1,7 +1,7 @@ import unittest from types import SimpleNamespace -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -25,7 +25,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_mla_fp8.py b/test/srt/test_mla_fp8.py index 5091759a9f..769bdf34da 100644 --- a/test/srt/test_mla_fp8.py +++ b/test/srt/test_mla_fp8.py @@ -1,7 +1,7 @@ import unittest from types import SimpleNamespace -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST, @@ -31,7 +31,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_mgsm_en(self): args = SimpleNamespace( diff --git a/test/srt/test_moe_eval_accuracy_large.py b/test/srt/test_moe_eval_accuracy_large.py index 9880a81626..6f3affbba4 100644 --- a/test/srt/test_moe_eval_accuracy_large.py +++ b/test/srt/test_moe_eval_accuracy_large.py @@ -6,7 +6,7 @@ import unittest from types import SimpleNamespace -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MOE_MODEL_NAME_FOR_TEST, @@ -35,7 +35,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_nightly_gsm8k_eval.py b/test/srt/test_nightly_gsm8k_eval.py index 7c208e84b9..8466c2c648 100644 --- a/test/srt/test_nightly_gsm8k_eval.py +++ b/test/srt/test_nightly_gsm8k_eval.py @@ -6,7 +6,7 @@ from datetime import datetime from types import SimpleNamespace -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1, @@ -132,7 +132,7 @@ def setUp(self): def tearDown(self): if self.process: - kill_child_process(self.process.pid, include_self=True) + kill_process_tree(self.process.pid) def test_mgsm_en_all_models(self): warnings.filterwarnings( diff --git a/test/srt/test_nightly_human_eval.py b/test/srt/test_nightly_human_eval.py index f69bbe1321..626e6fb153 100644 --- a/test/srt/test_nightly_human_eval.py +++ b/test/srt/test_nightly_human_eval.py @@ -6,7 +6,7 @@ from test_nightly_gsm8k_eval import launch_server, parse_models -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1, DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2, @@ -32,9 +32,9 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): if cls.process: - kill_child_process(cls.process.pid) + kill_process_tree(cls.process.pid) if cls.eval_process: - kill_child_process(cls.eval_process.pid) + kill_process_tree(cls.eval_process.pid) def run_evalplus(self, model): print("Delete evalplus results") diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 1e18e23ef3..d007bed31e 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -11,7 +11,7 @@ import openai from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -37,7 +37,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def run_completion( self, echo, logprobs, use_list_input, parallel_sample_num, token_input diff --git a/test/srt/test_pytorch_sampling_backend.py b/test/srt/test_pytorch_sampling_backend.py index 9aa6c33009..4f1403e0a8 100644 --- a/test/srt/test_pytorch_sampling_backend.py +++ b/test/srt/test_pytorch_sampling_backend.py @@ -3,7 +3,7 @@ import requests -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -27,7 +27,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_radix_attention.py b/test/srt/test_radix_attention.py index cdba7573d4..207303c8c0 100644 --- a/test/srt/test_radix_attention.py +++ b/test/srt/test_radix_attention.py @@ -8,7 +8,7 @@ DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, - kill_child_process, + kill_process_tree, popen_launch_server, ) @@ -80,7 +80,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_radix_attention(self): nodes = gen_radix_tree() diff --git a/test/srt/test_retract_decode.py b/test/srt/test_retract_decode.py index 834c51f9d5..5f169cdb68 100644 --- a/test/srt/test_retract_decode.py +++ b/test/srt/test_retract_decode.py @@ -1,7 +1,7 @@ import unittest from types import SimpleNamespace -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -22,7 +22,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_session_control.py b/test/srt/test_session_control.py index 7396779f64..8558b4249e 100644 --- a/test/srt/test_session_control.py +++ b/test/srt/test_session_control.py @@ -9,7 +9,7 @@ import requests from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -29,7 +29,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_session_control(self): chunks = [ @@ -191,7 +191,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_session_control(self): text_chunks = [ diff --git a/test/srt/test_skip_tokenizer_init.py b/test/srt/test_skip_tokenizer_init.py index 7ec73b15d6..bc99b23ad5 100644 --- a/test/srt/test_skip_tokenizer_init.py +++ b/test/srt/test_skip_tokenizer_init.py @@ -7,7 +7,7 @@ import requests -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -30,7 +30,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): max_new_tokens = 32 diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index fb50943f16..006059e035 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -9,7 +9,7 @@ import numpy as np import requests -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -29,7 +29,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def run_decode( self, diff --git a/test/srt/test_torch_compile.py b/test/srt/test_torch_compile.py index 76945f963d..6f3b344b3c 100644 --- a/test/srt/test_torch_compile.py +++ b/test/srt/test_torch_compile.py @@ -4,7 +4,7 @@ import requests -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -28,7 +28,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_torch_compile_moe.py b/test/srt/test_torch_compile_moe.py index e744e66867..89d4ed6bdf 100644 --- a/test/srt/test_torch_compile_moe.py +++ b/test/srt/test_torch_compile_moe.py @@ -4,7 +4,7 @@ import requests -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST, @@ -28,7 +28,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_torchao.py b/test/srt/test_torchao.py index 2a2fcb8dfc..a6414c60b8 100644 --- a/test/srt/test_torchao.py +++ b/test/srt/test_torchao.py @@ -3,7 +3,7 @@ import requests -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -27,7 +27,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_mmlu(self): args = SimpleNamespace( diff --git a/test/srt/test_triton_attention_backend.py b/test/srt/test_triton_attention_backend.py index a4d19bec0f..905590965d 100644 --- a/test/srt/test_triton_attention_backend.py +++ b/test/srt/test_triton_attention_backend.py @@ -6,7 +6,7 @@ import unittest from types import SimpleNamespace -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -54,7 +54,7 @@ def test_mmlu(self): metrics = run_eval(args) self.assertGreaterEqual(metrics["score"], 0.65) finally: - kill_child_process(process.pid, include_self=True) + kill_process_tree(process.pid) if __name__ == "__main__": diff --git a/test/srt/test_update_weights.py b/test/srt/test_update_weights.py index 327da729aa..ddb5a5e084 100644 --- a/test/srt/test_update_weights.py +++ b/test/srt/test_update_weights.py @@ -3,7 +3,7 @@ import requests -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -23,7 +23,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def run_decode(self): response = requests.post( diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 95a1624cf3..e19e6b01d5 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -17,7 +17,7 @@ from decord import VideoReader, cpu from PIL import Image -from sglang.srt.utils import kill_child_process +from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -46,7 +46,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_chat_completion(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url) @@ -387,7 +387,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - kill_child_process(cls.process.pid, include_self=True) + kill_process_tree(cls.process.pid) def test_chat_completion(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url) From b2ccf36d4d93d47b59399a93e7e00444b812a28c Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 28 Nov 2024 02:22:15 -0800 Subject: [PATCH 7/8] Fix memory leak during abort (#2238) --- .github/workflows/pr-test.yml | 8 +-- python/sglang/srt/managers/schedule_batch.py | 5 ++ python/sglang/srt/managers/scheduler.py | 8 ++- python/sglang/test/test_utils.py | 15 +++++- test/srt/run_suite.py | 1 + test/srt/test_abort.py | 54 ++++++++++++++++++++ 6 files changed, 84 insertions(+), 7 deletions(-) create mode 100644 test/srt/test_abort.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index a3d30324c3..0d7889a5f3 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -50,7 +50,7 @@ jobs: timeout-minutes: 25 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 0 --range-end 5 + python3 run_suite.py --suite minimal --range-begin 0 --range-end 6 unit-test-backend-part-2: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' @@ -67,7 +67,7 @@ jobs: timeout-minutes: 25 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 5 --range-end 14 + python3 run_suite.py --suite minimal --range-begin 6 --range-end 15 unit-test-backend-part-3: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' @@ -84,7 +84,7 @@ jobs: timeout-minutes: 25 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 14 --range-end 23 + python3 run_suite.py --suite minimal --range-begin 15 --range-end 24 unit-test-backend-part-4: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' @@ -101,7 +101,7 @@ jobs: timeout-minutes: 25 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 23 + python3 run_suite.py --suite minimal --range-begin 24 unit-test-backend-2-gpu-part-1: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 971809124b..f7d55ed9b1 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -231,6 +231,7 @@ def __init__( self.tokenizer = None self.finished_reason = None self.stream = False + self.to_abort = False # For incremental decoding # ----- | --------- read_ids -------| @@ -368,6 +369,10 @@ def check_finished(self): if self.finished(): return + if self.to_abort: + self.finished_reason = FINISH_ABORT() + return + if len(self.output_ids) >= self.sampling_params.max_new_tokens: self.finished_reason = FINISH_LENGTH( length=self.sampling_params.max_new_tokens diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a327f37a2f..c7e8318117 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -579,6 +579,8 @@ def handle_generate_request( "Image request length is longer than the KV cache pool size or " "the max context length aborting because you cannot truncate the image embeds" ) + req.image_inputs = None + req.origin_input_ids = [0] req.sampling_params.max_new_tokens = 0 self.waiting_queue.append(req) return @@ -1350,13 +1352,15 @@ def abort_request(self, recv_req: AbortReq): if to_del is not None: del self.waiting_queue[to_del] + logger.debug(f"Abort queued request. {req.rid=}") + return # Delete requests in the running batch if self.running_batch: for req in self.running_batch.reqs: if req.rid == recv_req.rid and not req.finished(): - req.finished_reason = FINISH_ABORT() - self.tree_cache.cache_finished_req(req) + logger.debug(f"Abort running request. {req.rid=}") + req.to_abort = True break def update_weights(self, recv_req: UpdateWeightReqInput): diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 3f6cce23d5..a1646fb5fd 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -677,8 +677,14 @@ def run_and_check_memory_leak( enable_mixed_chunk, disable_overlap, chunked_prefill_size, + assert_has_abort, ): - other_args = ["--chunked-prefill-size", str(chunked_prefill_size)] + other_args = [ + "--chunked-prefill-size", + str(chunked_prefill_size), + "--log-level", + "debug", + ] if disable_radix_cache: other_args += ["--disable-radix-cache"] if enable_mixed_chunk: @@ -723,14 +729,19 @@ def run_and_check_memory_leak( # Assert success has_new_server = False has_leak = False + has_abort = False for line in output_lines: if "The server is fired" in line: has_new_server = True if "leak" in line: has_leak = True + if "Abort" in line: + has_abort = True assert has_new_server assert not has_leak + if assert_has_abort: + assert has_abort def run_mmlu_test( @@ -761,6 +772,7 @@ def workload_func(base_url, model): enable_mixed_chunk, disable_overlap, chunked_prefill_size, + assert_has_abort=False, ) @@ -800,4 +812,5 @@ def run_one(_): enable_mixed_chunk, enable_overlap, chunked_prefill_size, + assert_has_abort=False, ) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 3f55eb25fd..c04a1671ed 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -10,6 +10,7 @@ "models/test_lora.py", "models/test_reward_models.py", "sampling/penaltylib", + "test_abort.py", "test_chunked_prefill.py", "test_double_sparsity.py", "test_embedding_openai_server.py", diff --git a/test/srt/test_abort.py b/test/srt/test_abort.py new file mode 100644 index 0000000000..ae27d83a85 --- /dev/null +++ b/test/srt/test_abort.py @@ -0,0 +1,54 @@ +import multiprocessing +import time +import unittest +from concurrent.futures import ThreadPoolExecutor + +import requests + +from sglang.test.test_utils import run_and_check_memory_leak + + +class TestAbort(unittest.TestCase): + def workload_func(self, base_url, model): + def process_func(): + def run_one(_): + prompt = """ + System: You are a helpful assistant. + User: What is the capital of France? + Assistant: The capital of France is + """ + + response = requests.post( + f"{base_url}/generate", + json={ + "text": prompt, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 2048, + }, + }, + ) + ret = response.json() + + with ThreadPoolExecutor(16) as executor: + list(executor.map(run_one, list(range(16)))) + + p = multiprocessing.Process(target=process_func) + p.start() + time.sleep(0.5) + p.terminate() + time.sleep(10) + + def test_memory_leak(self): + run_and_check_memory_leak( + self.workload_func, + disable_radix_cache=False, + enable_mixed_chunk=False, + disable_overlap=False, + chunked_prefill_size=8192, + assert_has_abort=True, + ) + + +if __name__ == "__main__": + unittest.main() From 65fdb289294f890c1814277ffc6160fa93b07750 Mon Sep 17 00:00:00 2001 From: Enrique Shockwave <33002121+qeternity@users.noreply.github.com> Date: Thu, 28 Nov 2024 13:24:47 +0000 Subject: [PATCH 8/8] fix missing launch server import (#2242) --- python/sglang/launch_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index b2ad1b3209..6b0c25711c 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -1,5 +1,6 @@ """Launch the inference server.""" +import os import sys from sglang.srt.server import launch_server