Skip to content

Commit

Permalink
[core][distributed] zmq fallback for broadcasting large objects (vllm…
Browse files Browse the repository at this point in the history
…-project#6183)

[core][distributed] add zmq fallback for broadcasting large objects (vllm-project#6183)

(cherry picked from commit da78cae)
  • Loading branch information
youkaichao authored and adityagoel14 committed Jul 10, 2024
1 parent 8822361 commit f7b7547
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 80 deletions.
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ lm-format-enforcer == 0.10.1
outlines >= 0.0.43 # Requires torch >= 2.1.0
typing_extensions
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
5 changes: 3 additions & 2 deletions tests/distributed/test_same_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import torch

from vllm.distributed.parallel_state import is_in_the_same_node
from vllm.distributed.parallel_state import in_the_same_node_as

torch.distributed.init_process_group(backend="gloo")
test_result = is_in_the_same_node(torch.distributed.group.WORLD)
test_result = all(
in_the_same_node_as(torch.distributed.group.WORLD, source_rank=0))

expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
assert test_result == expected, f"Expected {expected}, got {test_result}"
17 changes: 3 additions & 14 deletions tests/distributed/test_shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import numpy as np
import torch.distributed as dist

from vllm.distributed.device_communicators.shm_broadcast import (
ShmRingBuffer, ShmRingBufferIO)
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
from vllm.utils import update_environment_variables


Expand Down Expand Up @@ -56,8 +55,8 @@ def wrapped_fn(env):
@worker_fn_wrapper
def worker_fn():
writer_rank = 2
broadcaster = ShmRingBufferIO.create_from_process_group(
dist.group.WORLD, 1024 * 1024, 2, writer_rank)
broadcaster = MessageQueue.create_from_process_group(
dist.group.WORLD, 40 * 1024, 2, writer_rank)
if dist.get_rank() == writer_rank:
seed = random.randint(0, 1000)
dist.broadcast_object_list([seed], writer_rank)
Expand Down Expand Up @@ -87,13 +86,3 @@ def worker_fn():

def test_shm_broadcast():
distributed_run(worker_fn, 4)


def test_singe_process():
buffer = ShmRingBuffer(1, 1024, 4)
reader = ShmRingBufferIO(buffer, reader_rank=0)
writer = ShmRingBufferIO(buffer, reader_rank=-1)
writer.enqueue([0])
writer.enqueue([1])
assert reader.dequeue() == [0]
assert reader.dequeue() == [1]
4 changes: 2 additions & 2 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check)
from vllm.distributed.parallel_state import is_in_the_same_node
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless, is_full_nvlink

Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(self,
assert dist.get_backend(group) != dist.Backend.NCCL, (
"CustomAllreduce should be attached to a non-NCCL group.")

if not is_in_the_same_node(group):
if not all(in_the_same_node_as(group, source_rank=0)):
# No need to initialize custom allreduce for multi-node case.
logger.warning(
"Custom allreduce is disabled because this process group"
Expand Down
Loading

0 comments on commit f7b7547

Please sign in to comment.