Skip to content

Commit

Permalink
Use ipc instead of tcp in zmq (#1566)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Oct 4, 2024
1 parent 32eb6e9 commit 114bbc8
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=
model_path=args.model_path, # "liuhaotian/llava-v1.6-vicuna-7b",
tokenizer_path=tokenizer_path,
port=cur_port,
additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4],
json_model_override_args=json.dumps(model_override_args),
tp_size=1,
)
Expand Down
11 changes: 3 additions & 8 deletions python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,8 @@
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server import _set_envs_and_config
from sglang.srt.server_args import ServerArgs
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
allocate_init_ports,
configure_logger,
kill_child_process,
suppress_other_loggers,
Expand Down Expand Up @@ -127,11 +126,7 @@ def load_model(server_args, tp_rank):
suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None

server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port,
server_args.additional_ports,
server_args.dp_size,
)
port_args = PortArgs.init_new(server_args)
model_config = ModelConfig(
server_args.model_path,
server_args.trust_remote_code,
Expand All @@ -143,7 +138,7 @@ def load_model(server_args, tp_rank):
gpu_id=tp_rank,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
nccl_port=server_args.additional_ports[-1],
nccl_port=port_args.nccl_ports[0],
server_args=server_args,
)
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def __init__(
# Init inter-process communication
context = zmq.Context(2)
self.recv_from_scheduler = context.socket(zmq.PULL)
self.recv_from_scheduler.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
self.recv_from_scheduler.bind(f"ipc://{port_args.detokenizer_ipc_name}")

self.send_to_tokenizer = context.socket(zmq.PUSH)
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
self.send_to_tokenizer.connect(f"ipc://{port_args.tokenizer_ipc_name}")

if server_args.skip_tokenizer_init:
self.tokenizer = None
Expand Down
8 changes: 2 additions & 6 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,10 @@ def __init__(

if self.tp_rank == 0:
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(
f"tcp://127.0.0.1:{port_args.scheduler_input_port}"
)
self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")

self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
)
self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}")
else:
self.recv_from_tokenizer = self.send_to_detokenizer = None

Expand Down
6 changes: 2 additions & 4 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,10 @@ def __init__(
# Init inter-process communication
context = zmq.asyncio.Context(2)
self.recv_from_detokenizer = context.socket(zmq.PULL)
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
self.recv_from_detokenizer.bind(f"ipc://{port_args.tokenizer_ipc_name}")

self.send_to_scheduler = context.socket(zmq.PUSH)
self.send_to_scheduler.connect(
f"tcp://127.0.0.1:{port_args.scheduler_input_port}"
)
self.send_to_scheduler.connect(f"ipc://{port_args.scheduler_input_ipc_name}")

# Read model args
self.model_path = server_args.model_path
Expand Down
9 changes: 5 additions & 4 deletions python/sglang/srt/mem_cache/memory_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
"""Memory pool."""

import logging
from abc import ABC, abstractmethod
from typing import List, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -62,9 +61,11 @@ def __init__(
self,
size: int,
dtype: torch.dtype,
device: str,
):
self.size = size
self.dtype = dtype
self.device = device
if dtype == torch.float8_e5m2:
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
Expand All @@ -84,7 +85,7 @@ def alloc(self, need_size: int):
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]

return torch.tensor(select_index, dtype=torch.int32, device="cuda")
return torch.tensor(select_index, dtype=torch.int32, device=self.device)

def free(self, free_index: torch.Tensor):
self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy()))
Expand Down Expand Up @@ -123,7 +124,7 @@ def __init__(
layer_num: int,
device: str,
):
super().__init__(size, dtype)
super().__init__(size, dtype, device)

# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
Expand Down Expand Up @@ -187,7 +188,7 @@ def __init__(
layer_num: int,
device: str,
):
super().__init__(size, dtype)
super().__init__(size, dtype, device)

self.kv_lora_rank = kv_lora_rank
# The padded slot 0 is used for writing dummy outputs from padded tokens.
Expand Down
31 changes: 10 additions & 21 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import logging
import multiprocessing as mp
import os
import random
import threading
import time
from http import HTTPStatus
Expand Down Expand Up @@ -68,9 +69,9 @@
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
add_api_key_middleware,
allocate_init_ports,
assert_pkg_version,
configure_logger,
is_port_available,
kill_child_process,
maybe_set_triton_cache_manager,
prepare_model_and_tokenizer,
Expand Down Expand Up @@ -302,18 +303,7 @@ def launch_server(
_set_envs_and_config(server_args)

# Allocate ports for inter-process communications
server_args.port, server_args.additional_ports = allocate_init_ports(
server_args.port,
server_args.additional_ports,
server_args.dp_size,
)
ports = server_args.additional_ports
port_args = PortArgs(
tokenizer_port=ports[0],
scheduler_input_port=ports[1],
detokenizer_port=ports[2],
nccl_ports=ports[3:],
)
port_args = PortArgs.init_new(server_args)
logger.info(f"{server_args=}")

# If using model from www.modelscope.cn, first download the model.
Expand Down Expand Up @@ -499,17 +489,16 @@ def __init__(
self.server_args = ServerArgs(*args, log_level=log_level, **kwargs)

# Pre-allocate ports
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
self.server_args.port,
self.server_args.additional_ports,
self.server_args.dp_size,
)
for port in range(10000, 40000):
if is_port_available(port):
break
port += 1
self.server_args.port = port

self.url = self.server_args.url()
self.generate_url = (
f"http://{self.server_args.host}:{self.server_args.port}/generate"
)
self.generate_url = self.url + "/generate"

# NOTE: We store pid instead of proc to fix some issues during __delete__
self.pid = None
pipe_reader, pipe_writer = mp.Pipe(duplex=False)

Expand Down
45 changes: 24 additions & 21 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
import dataclasses
import logging
import random
from typing import List, Optional, Union
import tempfile
from typing import List, Optional

from sglang.srt.utils import is_hip, is_ipv6
from sglang.srt.utils import is_hip, is_ipv6, is_port_available

logger = logging.getLogger(__name__)

Expand All @@ -46,7 +47,6 @@ class ServerArgs:
# Port
host: str = "127.0.0.1"
port: int = 30000
additional_ports: Optional[Union[List[int], int]] = None

# Memory and scheduling
mem_fraction_static: Optional[float] = None
Expand Down Expand Up @@ -134,11 +134,6 @@ def __post_init__(self):
else:
self.mem_fraction_static = 0.88

if isinstance(self.additional_ports, int):
self.additional_ports = [self.additional_ports]
elif self.additional_ports is None:
self.additional_ports = []

if self.random_seed is None:
self.random_seed = random.randint(0, 1 << 30)

Expand Down Expand Up @@ -199,13 +194,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--port", type=int, default=ServerArgs.port, help="The port of the server."
)
parser.add_argument(
"--additional-ports",
type=int,
nargs="*",
default=[],
help="The additional ports specified for the server.",
)
parser.add_argument(
"--tokenizer-mode",
type=str,
Expand Down Expand Up @@ -625,16 +613,31 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:

@dataclasses.dataclass
class PortArgs:
# The port for tokenizer to receive inputs from detokenizer (zmq)
tokenizer_port: int
# The port for scheduler (rank 0) to receive inputs from tokenizer (zmq)
scheduler_input_port: int
# The port for detokenizer to receive inputs from scheduler (zmq)
detokenizer_port: int
# The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
tokenizer_ipc_name: str
# The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
scheduler_input_ipc_name: str
# The ipc filename for detokenizer to receive inputs from scheduler (zmq)
detokenizer_ipc_name: str

# The port for nccl initialization for multiple TP groups (torch.dist)
nccl_ports: List[int]

@classmethod
def init_new(self, server_args):
port = server_args.port + 1
while True:
if is_port_available(port):
break
port += 1

return PortArgs(
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
nccl_ports=[port],
)


class LoRAPathAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
Expand Down
29 changes: 0 additions & 29 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,35 +177,6 @@ def is_port_available(port):
return False


def allocate_init_ports(
port: Optional[int] = None,
additional_ports: Optional[List[int]] = None,
dp_size: int = 1,
):
"""Allocate ports for all connections."""
if additional_ports:
ret_ports = [port] + additional_ports
else:
ret_ports = [port]

ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000

# HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl)
num_ports_needed = 4 + dp_size
while len(ret_ports) < num_ports_needed:
if cur_port not in ret_ports and is_port_available(cur_port):
ret_ports.append(cur_port)
cur_port += 1

if port is not None and ret_ports[0] != port:
logger.warning(
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
)

return ret_ports[0], ret_ports[1:num_ports_needed]


def is_multimodal_model(model_architectures):
if (
"LlavaLlamaForCausalLM" in model_architectures
Expand Down

0 comments on commit 114bbc8

Please sign in to comment.