diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index b507cd2e1cddb..7d526b25ed193 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -9,11 +9,12 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup +from zmq import IPV6 # type: ignore from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import get_ip, get_open_port +from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL @@ -214,6 +215,8 @@ def __init__( self.remote_socket = context.socket(XPUB) self.remote_socket.setsockopt(XPUB_VERBOSE, True) remote_subscribe_port = get_open_port() + if is_valid_ipv6_address(connect_ip): + self.remote_socket.setsockopt(IPV6, 1) socket_addr = f"tcp://*:{remote_subscribe_port}" self.remote_socket.bind(socket_addr) @@ -274,6 +277,8 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self.remote_socket = context.socket(SUB) self.remote_socket.setsockopt_string(SUBSCRIBE, "") + if is_valid_ipv6_address(handle.connect_ip): + self.remote_socket.setsockopt(IPV6, 1) socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}" logger.debug("Connecting to %s", socket_addr) self.remote_socket.connect(socket_addr) diff --git a/vllm/utils.py b/vllm/utils.py index db2ef146e38ea..b73e3b9bbf68e 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -5,6 +5,7 @@ import enum import gc import inspect +import ipaddress import os import random import socket @@ -533,6 +534,14 @@ def get_ip() -> str: return "0.0.0.0" +def is_valid_ipv6_address(address: str) -> bool: + try: + ipaddress.IPv6Address(address) + return True + except ValueError: + return False + + def get_distributed_init_method(ip: str, port: int) -> str: # Brackets are not permitted in ipv4 addresses, # see https://github.com/python/cpython/issues/103848