From 23cc66f7b65f885969d4608fd4964e0ba98fb7f5 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Fri, 11 Oct 2024 07:22:48 -0700 Subject: [PATCH] Add back data parallelism (#1635) --- .github/workflows/pr-test.yml | 11 +- python/sglang/bench_latency.py | 2 +- .../srt/managers/data_parallel_controller.py | 177 ++++++++++++++++++ python/sglang/srt/managers/scheduler.py | 9 +- .../sglang/srt/model_executor/model_runner.py | 2 +- python/sglang/srt/server.py | 53 ++++-- python/sglang/srt/server_args.py | 13 +- 7 files changed, 228 insertions(+), 39 deletions(-) create mode 100644 python/sglang/srt/managers/data_parallel_controller.py diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 67551f09c0..5d84526ab6 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -255,12 +255,11 @@ jobs: python3 test_mla.py python3 test_mla_fp8.py - # Temporarily disabled - #- name: Evaluate Data Parallelism Accuracy (TP=2) - # timeout-minutes: 10 - # run: | - # cd test/srt - # python3 test_data_parallelism.py + - name: Evaluate Data Parallelism Accuracy (DP=2) + timeout-minutes: 10 + run: | + cd test/srt + python3 test_data_parallelism.py finish: needs: [ diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 87dad3ed0c..9540e2266e 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -139,7 +139,7 @@ def load_model(server_args, port_args, tp_rank): gpu_id=tp_rank, tp_rank=tp_rank, tp_size=server_args.tp_size, - nccl_port=port_args.nccl_ports[0], + nccl_port=port_args.nccl_port, server_args=server_args, ) rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}") diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py new file mode 100644 index 0000000000..1b7da747f1 --- /dev/null +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -0,0 +1,177 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""A controller that dispatches requests to multiple data parallel workers.""" + +import logging +import multiprocessing as mp +from enum import Enum, auto + +import zmq + +from sglang.srt.managers.io_struct import ( + TokenizedEmbeddingReqInput, + TokenizedGenerateReqInput, + TokenizedRewardReqInput, +) +from sglang.srt.managers.scheduler import run_scheduler_process +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.utils import ( + configure_logger, + kill_parent_process, + suppress_other_loggers, +) +from sglang.utils import get_exception_traceback + +logger = logging.getLogger(__name__) + + +class LoadBalanceMethod(Enum): + """Load balance method.""" + + ROUND_ROBIN = auto() + SHORTEST_QUEUE = auto() + + @classmethod + def from_str(cls, method: str): + method = method.upper() + try: + return cls[method] + except KeyError as exc: + raise ValueError(f"Invalid load balance method: {method}") from exc + + +class DataParallelController: + """A controller that dispatches requests to multiple data parallel workers.""" + + def __init__(self, server_args, port_args) -> None: + # Parse args + self.server_args = server_args + self.port_args = port_args + self.load_balance_method = LoadBalanceMethod.from_str( + server_args.load_balance_method + ) + + # Init inter-process communication + self.context = zmq.Context(1 + server_args.dp_size) + self.recv_from_tokenizer = self.context.socket(zmq.PULL) + self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}") + + # Dispatch method + self.round_robin_counter = 0 + dispatch_lookup = { + LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler, + LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler, + } + self.dispatching = dispatch_lookup[self.load_balance_method] + + # Start data parallel workers + base_gpu_id = 0 + self.workers = [] + for dp_rank in range(server_args.dp_size): + tmp_port_args = PortArgs.init_new(server_args) + tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name + + send_to = self.launch_tensor_parallel_group( + server_args, + tmp_port_args, + base_gpu_id, + dp_rank, + ) + + self.workers.append(send_to) + base_gpu_id += server_args.tp_size + + def launch_tensor_parallel_group( + self, + server_args: ServerArgs, + port_args: PortArgs, + base_gpu_id: int, + dp_rank: int, + ): + # Launch tensor parallel scheduler processes + scheduler_procs = [] + scheduler_pipe_readers = [] + tp_size_per_node = server_args.tp_size // server_args.nnodes + tp_rank_range = range( + tp_size_per_node * server_args.node_rank, + tp_size_per_node * (server_args.node_rank + 1), + ) + for tp_rank in tp_rank_range: + reader, writer = mp.Pipe(duplex=False) + gpu_id = base_gpu_id + tp_rank % tp_size_per_node + proc = mp.Process( + target=run_scheduler_process, + args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer), + ) + proc.start() + scheduler_procs.append(proc) + scheduler_pipe_readers.append(reader) + + send_to = self.context.socket(zmq.PUSH) + send_to.connect(f"ipc://{port_args.scheduler_input_ipc_name}") + + # Wait for model to finish loading + for i in range(len(scheduler_pipe_readers)): + scheduler_pipe_readers[i].recv() + + return send_to + + def round_robin_scheduler(self, req): + self.workers[self.round_robin_counter].send_pyobj(req) + self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers) + + def shortest_queue_scheduler(self, input_requests): + raise NotImplementedError() + + def event_loop(self): + while True: + while True: + try: + recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) + except zmq.ZMQError: + break + + if isinstance( + recv_req, + ( + TokenizedGenerateReqInput, + TokenizedEmbeddingReqInput, + TokenizedRewardReqInput, + ), + ): + self.dispatching(recv_req) + else: + # Send other control messages to all workers + for worker in self.workers: + worker.queue.put(recv_req) + + +def run_data_parallel_controller_process( + server_args: ServerArgs, + port_args: PortArgs, + pipe_writer, +): + configure_logger(server_args) + suppress_other_loggers() + + try: + controller = DataParallelController(server_args, port_args) + pipe_writer.send("ready") + controller.event_loop() + except Exception: + msg = get_exception_traceback() + logger.error(msg) + kill_parent_process() diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 10411cd3e5..c6df4a2e81 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -142,7 +142,7 @@ def __init__( gpu_id=gpu_id, tp_rank=tp_rank, server_args=server_args, - nccl_port=port_args.nccl_ports[0], + nccl_port=port_args.nccl_port, ) self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group @@ -1042,9 +1042,14 @@ def run_scheduler_process( port_args: PortArgs, gpu_id: int, tp_rank: int, + dp_rank: Optional[int], pipe_writer, ): - configure_logger(server_args, prefix=f" TP{tp_rank}") + if dp_rank is None: + configure_logger(server_args, prefix=f" TP{tp_rank}") + else: + configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}") + suppress_other_loggers() try: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index c2f2368e4f..5f0675de51 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -141,7 +141,7 @@ def __init__( self.init_attention_backend() def init_torch_distributed(self): - logger.info("Init torch distributed begin.") + logger.info("Init torch distributed begin.") # Init torch distributed if self.device == "cuda": torch.cuda.set_device(self.gpu_id) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index ff33640129..233c6d29ce 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -44,6 +44,9 @@ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.managers.data_parallel_controller import ( + run_data_parallel_controller_process, +) from sglang.srt.managers.detokenizer_manager import run_detokenizer_process from sglang.srt.managers.io_struct import ( EmbeddingReqInput, @@ -337,30 +340,40 @@ def launch_engine( server_args.model_path, server_args.tokenizer_path ) - # Launch tensor parallel scheduler processes - scheduler_procs = [] - scheduler_pipe_readers = [] - tp_size_per_node = server_args.tp_size // server_args.nnodes - tp_rank_range = range( - tp_size_per_node * server_args.node_rank, - tp_size_per_node * (server_args.node_rank + 1), - ) - for tp_rank in tp_rank_range: + if server_args.dp_size == 1: + # Launch tensor parallel scheduler processes + scheduler_procs = [] + scheduler_pipe_readers = [] + tp_size_per_node = server_args.tp_size // server_args.nnodes + tp_rank_range = range( + tp_size_per_node * server_args.node_rank, + tp_size_per_node * (server_args.node_rank + 1), + ) + for tp_rank in tp_rank_range: + reader, writer = mp.Pipe(duplex=False) + gpu_id = tp_rank % tp_size_per_node + proc = mp.Process( + target=run_scheduler_process, + args=(server_args, port_args, gpu_id, tp_rank, None, writer), + ) + proc.start() + scheduler_procs.append(proc) + scheduler_pipe_readers.append(reader) + + if server_args.node_rank >= 1: + # For other nodes, they do not need to run tokenizer or detokenizer, + # so they can just wait here. + while True: + pass + else: + # Launch the data parallel controller reader, writer = mp.Pipe(duplex=False) - gpu_id = tp_rank % tp_size_per_node + scheduler_pipe_readers = [reader] proc = mp.Process( - target=run_scheduler_process, - args=(server_args, port_args, gpu_id, tp_rank, writer), + target=run_data_parallel_controller_process, + args=(server_args, port_args, writer), ) proc.start() - scheduler_procs.append(proc) - scheduler_pipe_readers.append(reader) - - if server_args.node_rank >= 1: - # For other nodes, they do not need to run tokenizer or detokenizer, - # so they can just wait here. - while True: - pass # Launch detokenizer process detoken_proc = mp.Process( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 757f2bcb74..4b70b393ec 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -574,7 +574,7 @@ def check_server_args(self): self.tp_size % self.nnodes == 0 ), "tp_size must be divisible by number of nodes" assert not ( - self.dp_size > 1 and self.node_rank is not None + self.dp_size > 1 and self.nnodes != 1 ), "multi-node data parallel is not supported" assert ( self.max_loras_per_batch > 0 @@ -583,11 +583,6 @@ def check_server_args(self): and (self.lora_paths is None or self.disable_radix_cache) ), "compatibility of lora and cuda graph and radix attention is in progress" - assert self.dp_size == 1, ( - "The support for data parallelism is temporarily disabled during refactor. " - "Please use sglang<=0.3.2 or wait for later updates." - ) - if isinstance(self.lora_paths, list): lora_paths = self.lora_paths self.lora_paths = {} @@ -626,8 +621,8 @@ class PortArgs: # The ipc filename for detokenizer to receive inputs from scheduler (zmq) detokenizer_ipc_name: str - # The port for nccl initialization for multiple TP groups (torch.dist) - nccl_ports: List[int] + # The port for nccl initialization (torch.dist) + nccl_port: int @staticmethod def init_new(server_args) -> "PortArgs": @@ -641,7 +636,7 @@ def init_new(server_args) -> "PortArgs": tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, - nccl_ports=[port], + nccl_port=port, )