diff --git a/examples/frontend_language/usage/llava_video/srt_example_llava_v.py b/examples/frontend_language/usage/llava_video/srt_example_llava_v.py index c3b8da7d6a..266d4c2ba0 100644 --- a/examples/frontend_language/usage/llava_video/srt_example_llava_v.py +++ b/examples/frontend_language/usage/llava_video/srt_example_llava_v.py @@ -223,7 +223,6 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= model_path=args.model_path, # "liuhaotian/llava-v1.6-vicuna-7b", tokenizer_path=tokenizer_path, port=cur_port, - additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4], json_model_override_args=json.dumps(model_override_args), tp_size=1, ) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 47aca50593..f6511e3408 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -66,9 +66,8 @@ from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server import _set_envs_and_config -from sglang.srt.server_args import ServerArgs +from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( - allocate_init_ports, configure_logger, kill_child_process, suppress_other_loggers, @@ -127,11 +126,7 @@ def load_model(server_args, tp_rank): suppress_other_loggers() rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None - server_args.port, server_args.additional_ports = allocate_init_ports( - server_args.port, - server_args.additional_ports, - server_args.dp_size, - ) + port_args = PortArgs.init_new(server_args) model_config = ModelConfig( server_args.model_path, server_args.trust_remote_code, @@ -143,7 +138,7 @@ def load_model(server_args, tp_rank): gpu_id=tp_rank, tp_rank=tp_rank, tp_size=server_args.tp_size, - nccl_port=server_args.additional_ports[-1], + nccl_port=port_args.nccl_ports[0], server_args=server_args, ) rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}") diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 2c3cd1dbef..49c9e6fdb5 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -59,10 +59,10 @@ def __init__( # Init inter-process communication context = zmq.Context(2) self.recv_from_scheduler = context.socket(zmq.PULL) - self.recv_from_scheduler.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}") + self.recv_from_scheduler.bind(f"ipc://{port_args.detokenizer_ipc_name}") self.send_to_tokenizer = context.socket(zmq.PUSH) - self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}") + self.send_to_tokenizer.connect(f"ipc://{port_args.tokenizer_ipc_name}") if server_args.skip_tokenizer_init: self.tokenizer = None diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 9888440048..5577c7fa40 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -96,14 +96,10 @@ def __init__( if self.tp_rank == 0: self.recv_from_tokenizer = context.socket(zmq.PULL) - self.recv_from_tokenizer.bind( - f"tcp://127.0.0.1:{port_args.scheduler_input_port}" - ) + self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}") self.send_to_detokenizer = context.socket(zmq.PUSH) - self.send_to_detokenizer.connect( - f"tcp://127.0.0.1:{port_args.detokenizer_port}" - ) + self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}") else: self.recv_from_tokenizer = self.send_to_detokenizer = None diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index c0a0ff34ce..27cac65c3d 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -84,12 +84,10 @@ def __init__( # Init inter-process communication context = zmq.asyncio.Context(2) self.recv_from_detokenizer = context.socket(zmq.PULL) - self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}") + self.recv_from_detokenizer.bind(f"ipc://{port_args.tokenizer_ipc_name}") self.send_to_scheduler = context.socket(zmq.PUSH) - self.send_to_scheduler.connect( - f"tcp://127.0.0.1:{port_args.scheduler_input_port}" - ) + self.send_to_scheduler.connect(f"ipc://{port_args.scheduler_input_ipc_name}") # Read model args self.model_path = server_args.model_path diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 7ce80b57cb..3bf96d381a 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -16,7 +16,6 @@ """Memory pool.""" import logging -from abc import ABC, abstractmethod from typing import List, Tuple, Union import numpy as np @@ -62,9 +61,11 @@ def __init__( self, size: int, dtype: torch.dtype, + device: str, ): self.size = size self.dtype = dtype + self.device = device if dtype == torch.float8_e5m2: # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2 self.store_dtype = torch.uint8 @@ -84,7 +85,7 @@ def alloc(self, need_size: int): select_index = self.free_slots[:need_size] self.free_slots = self.free_slots[need_size:] - return torch.tensor(select_index, dtype=torch.int32, device="cuda") + return torch.tensor(select_index, dtype=torch.int32, device=self.device) def free(self, free_index: torch.Tensor): self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy())) @@ -123,7 +124,7 @@ def __init__( layer_num: int, device: str, ): - super().__init__(size, dtype) + super().__init__(size, dtype, device) # [size, head_num, head_dim] for each layer # The padded slot 0 is used for writing dummy outputs from padded tokens. @@ -187,7 +188,7 @@ def __init__( layer_num: int, device: str, ): - super().__init__(size, dtype) + super().__init__(size, dtype, device) self.kv_lora_rank = kv_lora_rank # The padded slot 0 is used for writing dummy outputs from padded tokens. diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 583e609895..0772816c9e 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -24,6 +24,7 @@ import logging import multiprocessing as mp import os +import random import threading import time from http import HTTPStatus @@ -68,9 +69,9 @@ from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( add_api_key_middleware, - allocate_init_ports, assert_pkg_version, configure_logger, + is_port_available, kill_child_process, maybe_set_triton_cache_manager, prepare_model_and_tokenizer, @@ -302,18 +303,7 @@ def launch_server( _set_envs_and_config(server_args) # Allocate ports for inter-process communications - server_args.port, server_args.additional_ports = allocate_init_ports( - server_args.port, - server_args.additional_ports, - server_args.dp_size, - ) - ports = server_args.additional_ports - port_args = PortArgs( - tokenizer_port=ports[0], - scheduler_input_port=ports[1], - detokenizer_port=ports[2], - nccl_ports=ports[3:], - ) + port_args = PortArgs.init_new(server_args) logger.info(f"{server_args=}") # If using model from www.modelscope.cn, first download the model. @@ -499,17 +489,16 @@ def __init__( self.server_args = ServerArgs(*args, log_level=log_level, **kwargs) # Pre-allocate ports - self.server_args.port, self.server_args.additional_ports = allocate_init_ports( - self.server_args.port, - self.server_args.additional_ports, - self.server_args.dp_size, - ) + for port in range(10000, 40000): + if is_port_available(port): + break + port += 1 + self.server_args.port = port self.url = self.server_args.url() - self.generate_url = ( - f"http://{self.server_args.host}:{self.server_args.port}/generate" - ) + self.generate_url = self.url + "/generate" + # NOTE: We store pid instead of proc to fix some issues during __delete__ self.pid = None pipe_reader, pipe_writer = mp.Pipe(duplex=False) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ceacd93648..12f1303522 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -19,9 +19,10 @@ import dataclasses import logging import random -from typing import List, Optional, Union +import tempfile +from typing import List, Optional -from sglang.srt.utils import is_hip, is_ipv6 +from sglang.srt.utils import is_hip, is_ipv6, is_port_available logger = logging.getLogger(__name__) @@ -46,7 +47,6 @@ class ServerArgs: # Port host: str = "127.0.0.1" port: int = 30000 - additional_ports: Optional[Union[List[int], int]] = None # Memory and scheduling mem_fraction_static: Optional[float] = None @@ -134,11 +134,6 @@ def __post_init__(self): else: self.mem_fraction_static = 0.88 - if isinstance(self.additional_ports, int): - self.additional_ports = [self.additional_ports] - elif self.additional_ports is None: - self.additional_ports = [] - if self.random_seed is None: self.random_seed = random.randint(0, 1 << 30) @@ -199,13 +194,6 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--port", type=int, default=ServerArgs.port, help="The port of the server." ) - parser.add_argument( - "--additional-ports", - type=int, - nargs="*", - default=[], - help="The additional ports specified for the server.", - ) parser.add_argument( "--tokenizer-mode", type=str, @@ -625,16 +613,31 @@ def prepare_server_args(argv: List[str]) -> ServerArgs: @dataclasses.dataclass class PortArgs: - # The port for tokenizer to receive inputs from detokenizer (zmq) - tokenizer_port: int - # The port for scheduler (rank 0) to receive inputs from tokenizer (zmq) - scheduler_input_port: int - # The port for detokenizer to receive inputs from scheduler (zmq) - detokenizer_port: int + # The ipc filename for tokenizer to receive inputs from detokenizer (zmq) + tokenizer_ipc_name: str + # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq) + scheduler_input_ipc_name: str + # The ipc filename for detokenizer to receive inputs from scheduler (zmq) + detokenizer_ipc_name: str # The port for nccl initialization for multiple TP groups (torch.dist) nccl_ports: List[int] + @classmethod + def init_new(self, server_args): + port = server_args.port + 1 + while True: + if is_port_available(port): + break + port += 1 + + return PortArgs( + tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, + scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, + detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, + nccl_ports=[port], + ) + class LoRAPathAction(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 1a08463b5f..dedcb9dfcb 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -177,35 +177,6 @@ def is_port_available(port): return False -def allocate_init_ports( - port: Optional[int] = None, - additional_ports: Optional[List[int]] = None, - dp_size: int = 1, -): - """Allocate ports for all connections.""" - if additional_ports: - ret_ports = [port] + additional_ports - else: - ret_ports = [port] - - ret_ports = list(set(x for x in ret_ports if is_port_available(x))) - cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000 - - # HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl) - num_ports_needed = 4 + dp_size - while len(ret_ports) < num_ports_needed: - if cur_port not in ret_ports and is_port_available(cur_port): - ret_ports.append(cur_port) - cur_port += 1 - - if port is not None and ret_ports[0] != port: - logger.warning( - f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead." - ) - - return ret_ports[0], ret_ports[1:num_ports_needed] - - def is_multimodal_model(model_architectures): if ( "LlavaLlamaForCausalLM" in model_architectures