Skip to content

Commit

Permalink
Add back data parallelism (#1635)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Oct 11, 2024
1 parent 5d09ca5 commit 23cc66f
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 39 deletions.
11 changes: 5 additions & 6 deletions .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,11 @@ jobs:
python3 test_mla.py
python3 test_mla_fp8.py
# Temporarily disabled
#- name: Evaluate Data Parallelism Accuracy (TP=2)
# timeout-minutes: 10
# run: |
# cd test/srt
# python3 test_data_parallelism.py
- name: Evaluate Data Parallelism Accuracy (DP=2)
timeout-minutes: 10
run: |
cd test/srt
python3 test_data_parallelism.py
finish:
needs: [
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def load_model(server_args, port_args, tp_rank):
gpu_id=tp_rank,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
nccl_port=port_args.nccl_ports[0],
nccl_port=port_args.nccl_port,
server_args=server_args,
)
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
Expand Down
177 changes: 177 additions & 0 deletions python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""A controller that dispatches requests to multiple data parallel workers."""

import logging
import multiprocessing as mp
from enum import Enum, auto

import zmq

from sglang.srt.managers.io_struct import (
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
TokenizedRewardReqInput,
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import (
configure_logger,
kill_parent_process,
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback

logger = logging.getLogger(__name__)


class LoadBalanceMethod(Enum):
"""Load balance method."""

ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto()

@classmethod
def from_str(cls, method: str):
method = method.upper()
try:
return cls[method]
except KeyError as exc:
raise ValueError(f"Invalid load balance method: {method}") from exc


class DataParallelController:
"""A controller that dispatches requests to multiple data parallel workers."""

def __init__(self, server_args, port_args) -> None:
# Parse args
self.server_args = server_args
self.port_args = port_args
self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method
)

# Init inter-process communication
self.context = zmq.Context(1 + server_args.dp_size)
self.recv_from_tokenizer = self.context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"ipc://{port_args.scheduler_input_ipc_name}")

# Dispatch method
self.round_robin_counter = 0
dispatch_lookup = {
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
}
self.dispatching = dispatch_lookup[self.load_balance_method]

# Start data parallel workers
base_gpu_id = 0
self.workers = []
for dp_rank in range(server_args.dp_size):
tmp_port_args = PortArgs.init_new(server_args)
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name

send_to = self.launch_tensor_parallel_group(
server_args,
tmp_port_args,
base_gpu_id,
dp_rank,
)

self.workers.append(send_to)
base_gpu_id += server_args.tp_size

def launch_tensor_parallel_group(
self,
server_args: ServerArgs,
port_args: PortArgs,
base_gpu_id: int,
dp_rank: int,
):
# Launch tensor parallel scheduler processes
scheduler_procs = []
scheduler_pipe_readers = []
tp_size_per_node = server_args.tp_size // server_args.nnodes
tp_rank_range = range(
tp_size_per_node * server_args.node_rank,
tp_size_per_node * (server_args.node_rank + 1),
)
for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False)
gpu_id = base_gpu_id + tp_rank % tp_size_per_node
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
)
proc.start()
scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)

send_to = self.context.socket(zmq.PUSH)
send_to.connect(f"ipc://{port_args.scheduler_input_ipc_name}")

# Wait for model to finish loading
for i in range(len(scheduler_pipe_readers)):
scheduler_pipe_readers[i].recv()

return send_to

def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)

def shortest_queue_scheduler(self, input_requests):
raise NotImplementedError()

def event_loop(self):
while True:
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break

if isinstance(
recv_req,
(
TokenizedGenerateReqInput,
TokenizedEmbeddingReqInput,
TokenizedRewardReqInput,
),
):
self.dispatching(recv_req)
else:
# Send other control messages to all workers
for worker in self.workers:
worker.queue.put(recv_req)


def run_data_parallel_controller_process(
server_args: ServerArgs,
port_args: PortArgs,
pipe_writer,
):
configure_logger(server_args)
suppress_other_loggers()

try:
controller = DataParallelController(server_args, port_args)
pipe_writer.send("ready")
controller.event_loop()
except Exception:
msg = get_exception_traceback()
logger.error(msg)
kill_parent_process()
9 changes: 7 additions & 2 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __init__(
gpu_id=gpu_id,
tp_rank=tp_rank,
server_args=server_args,
nccl_port=port_args.nccl_ports[0],
nccl_port=port_args.nccl_port,
)
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group

Expand Down Expand Up @@ -1042,9 +1042,14 @@ def run_scheduler_process(
port_args: PortArgs,
gpu_id: int,
tp_rank: int,
dp_rank: Optional[int],
pipe_writer,
):
configure_logger(server_args, prefix=f" TP{tp_rank}")
if dp_rank is None:
configure_logger(server_args, prefix=f" TP{tp_rank}")
else:
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")

suppress_other_loggers()

try:
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(
self.init_attention_backend()

def init_torch_distributed(self):
logger.info("Init torch distributed begin.")
logger.info("Init torch distributed begin.")
# Init torch distributed
if self.device == "cuda":
torch.cuda.set_device(self.gpu_id)
Expand Down
53 changes: 33 additions & 20 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@

from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.data_parallel_controller import (
run_data_parallel_controller_process,
)
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
from sglang.srt.managers.io_struct import (
EmbeddingReqInput,
Expand Down Expand Up @@ -337,30 +340,40 @@ def launch_engine(
server_args.model_path, server_args.tokenizer_path
)

# Launch tensor parallel scheduler processes
scheduler_procs = []
scheduler_pipe_readers = []
tp_size_per_node = server_args.tp_size // server_args.nnodes
tp_rank_range = range(
tp_size_per_node * server_args.node_rank,
tp_size_per_node * (server_args.node_rank + 1),
)
for tp_rank in tp_rank_range:
if server_args.dp_size == 1:
# Launch tensor parallel scheduler processes
scheduler_procs = []
scheduler_pipe_readers = []
tp_size_per_node = server_args.tp_size // server_args.nnodes
tp_rank_range = range(
tp_size_per_node * server_args.node_rank,
tp_size_per_node * (server_args.node_rank + 1),
)
for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False)
gpu_id = tp_rank % tp_size_per_node
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
)
proc.start()
scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)

if server_args.node_rank >= 1:
# For other nodes, they do not need to run tokenizer or detokenizer,
# so they can just wait here.
while True:
pass
else:
# Launch the data parallel controller
reader, writer = mp.Pipe(duplex=False)
gpu_id = tp_rank % tp_size_per_node
scheduler_pipe_readers = [reader]
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, writer),
target=run_data_parallel_controller_process,
args=(server_args, port_args, writer),
)
proc.start()
scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)

if server_args.node_rank >= 1:
# For other nodes, they do not need to run tokenizer or detokenizer,
# so they can just wait here.
while True:
pass

# Launch detokenizer process
detoken_proc = mp.Process(
Expand Down
13 changes: 4 additions & 9 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def check_server_args(self):
self.tp_size % self.nnodes == 0
), "tp_size must be divisible by number of nodes"
assert not (
self.dp_size > 1 and self.node_rank is not None
self.dp_size > 1 and self.nnodes != 1
), "multi-node data parallel is not supported"
assert (
self.max_loras_per_batch > 0
Expand All @@ -583,11 +583,6 @@ def check_server_args(self):
and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and cuda graph and radix attention is in progress"

assert self.dp_size == 1, (
"The support for data parallelism is temporarily disabled during refactor. "
"Please use sglang<=0.3.2 or wait for later updates."
)

if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths
self.lora_paths = {}
Expand Down Expand Up @@ -626,8 +621,8 @@ class PortArgs:
# The ipc filename for detokenizer to receive inputs from scheduler (zmq)
detokenizer_ipc_name: str

# The port for nccl initialization for multiple TP groups (torch.dist)
nccl_ports: List[int]
# The port for nccl initialization (torch.dist)
nccl_port: int

@staticmethod
def init_new(server_args) -> "PortArgs":
Expand All @@ -641,7 +636,7 @@ def init_new(server_args) -> "PortArgs":
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
nccl_ports=[port],
nccl_port=port,
)


Expand Down

0 comments on commit 23cc66f

Please sign in to comment.