Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aDAG] Allow custom NCCL group for aDAG #47141

Merged
merged 8 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import NamedTuple

from ray.experimental.channel.cached_channel import CachedChannel
from ray.experimental.channel.gpu_communicator import GPUCommunicator
import ray
from ray.exceptions import RayTaskError, RayChannelError
from ray.util.annotations import PublicAPI
Expand Down Expand Up @@ -640,6 +641,7 @@ def __init__(
# Type hints specified by the user for DAG (intermediate) outputs.
self._type_hints = []

self._custom_nccl_group: Optional[GPUCommunicator] = None
# Uniquely identifies the NCCL communicator that will be used within
# this DAG, if any.
self._nccl_group_id: Optional[str] = None
Expand Down Expand Up @@ -806,6 +808,17 @@ def _preprocess(self) -> None:
if dag_node.type_hint.requires_nccl():
# Add all writers to the NCCL group.
nccl_actors.add(actor_handle)
custom_nccl_group = dag_node.type_hint.get_custom_nccl_group()
if custom_nccl_group is not None:
if self._custom_nccl_group is not None:
assert self._custom_nccl_group == custom_nccl_group, (
"Accelerated DAGs currently only support "
"a single custom NCCL group, but multiple "
"have been specified. Check all the "
"TorchTensor(transport=nccl_group) type hints "
"to make sure only one NCCL group is used."
)
self._custom_nccl_group = custom_nccl_group
elif isinstance(dag_node, InputNode):
if dag_node.type_hint.requires_nccl():
raise ValueError(
Expand Down Expand Up @@ -916,7 +929,7 @@ def _preprocess(self) -> None:
if None in nccl_actors:
raise ValueError("Driver cannot participate in the NCCL group.")
if nccl_actors and self._nccl_group_id is None:
self._nccl_group_id = _init_nccl_group(nccl_actors)
self._nccl_group_id = _init_nccl_group(nccl_actors, self._custom_nccl_group)

if direct_input:
self._input_num_positional_args = 1
Expand Down
206 changes: 206 additions & 0 deletions python/ray/dag/tests/experimental/test_torch_tensor_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
import os
import re
import sys
from typing import List, Optional, Tuple
from ray.experimental.channel.gpu_communicator import (
GPUCommunicator,
TorchTensorAllocator,
)
from ray.experimental.channel.nccl_group import _NcclGroup
import socket
import torch
import time

Expand Down Expand Up @@ -33,6 +40,11 @@ class TorchTensorWorker:
def __init__(self):
self.device = torch_utils.get_devices()[0]

def init_distributed(self, world_size, rank):
torch.distributed.init_process_group(
backend="nccl", world_size=world_size, rank=rank
)

def send(self, shape, dtype, value: int, send_tensor=True):
if not send_tensor:
return 1
Expand Down Expand Up @@ -291,6 +303,200 @@ def test_torch_tensor_nccl_dynamic(ray_start_regular):
compiled_dag.teardown()


@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_torch_tensor_custom_comm(ray_start_regular):
if not USE_GPU:
pytest.skip("NCCL tests require GPUs")

assert (
sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1
), "This test requires at least 2 GPUs"

actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1)

sender = actor_cls.remote()
receiver = actor_cls.remote()

class TestNcclGroup(GPUCommunicator):
ruisearch42 marked this conversation as resolved.
Show resolved Hide resolved
"""
A custom NCCL group for testing. This is a simple wrapper around `_NcclGroup`.
"""

def __init__(self, world_size, comm_id, actor_handles):
self._world_size = world_size
self._comm_id = comm_id
self._actor_handles = actor_handles
self._inner = None

def initialize(self, rank: int) -> None:
self._inner = _NcclGroup(
self._world_size,
self._comm_id,
rank,
self._actor_handles,
torch.cuda.current_stream().cuda_stream,
)

def get_rank(self, actor: ray.actor.ActorHandle) -> int:
# Implement this without forwarding to `_inner` to allow the method
# to be called before initialization.
actor_ids = [a._ray_actor_id for a in self._actor_handles]
try:
rank = actor_ids.index(actor._ray_actor_id)
except ValueError:
raise ValueError("Actor is not in the NCCL group.")
return rank

def get_world_size(self) -> int:
# Implement this without forwarding to `_inner` to allow the method
# to be called before initialization.
return self._world_size

def get_self_rank(self) -> Optional[int]:
if self._inner is None:
return None
return self._inner.get_self_rank()

def get_actor_handles(self) -> List["ray.actor.ActorHandle"]:
return self._actor_handles

def send(self, value: "torch.Tensor", peer_rank: int) -> None:
return self._inner.send(value, peer_rank)

def recv(
self,
shape: Tuple[int],
dtype: "torch.dtype",
peer_rank: int,
allocator: Optional[TorchTensorAllocator] = None,
) -> "torch.Tensor":
return self._inner.recv(shape, dtype, peer_rank, allocator=allocator)

def destroy(self) -> None:
return self._inner.destroy()

from cupy.cuda import nccl

comm_id = nccl.get_unique_id()
nccl_group = TestNcclGroup(2, comm_id, [sender, receiver])
with InputNode() as inp:
dag = sender.send_with_tuple_args.bind(inp)
dag = dag.with_type_hint(TorchTensorType(transport=nccl_group))
dag = receiver.recv.bind(dag)

compiled_dag = dag.experimental_compile()
ruisearch42 marked this conversation as resolved.
Show resolved Hide resolved
for i in range(3):
i += 1
shape = (i * 10,)
dtype = torch.float16
args = (shape, dtype, i)
ref = compiled_dag.execute(args)
result = ray.get(ref)
assert result == (i, shape, dtype)

compiled_dag.teardown()


@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
ruisearch42 marked this conversation as resolved.
Show resolved Hide resolved
def test_torch_tensor_custom_comm_inited(ray_start_regular):
if not USE_GPU:
pytest.skip("NCCL tests require GPUs")

assert (
sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1
), "This test requires at least 2 GPUs"
runtime_env = {
"env_vars": {
"MASTER_ADDR": socket.gethostbyname(socket.gethostname()),
"MASTER_PORT": "8888",
}
}
actor_cls = TorchTensorWorker.options(
num_cpus=0, num_gpus=1, runtime_env=runtime_env
)

sender = actor_cls.remote()
receiver = actor_cls.remote()

# Simulates that the distributed environment (e.g., torch.distributed)
# have already been set up
refs = [
sender.init_distributed.remote(2, 0),
receiver.init_distributed.remote(2, 1),
]
ray.wait(refs)

class InitedNcclGroup(GPUCommunicator):
"""
A custom NCCL group based on existing torch.distributed setup.
"""

def __init__(self, world_size, actor_handles):
self._world_size = world_size
self._actor_handles = actor_handles
self._rank = None

def initialize(self, rank: int) -> None:
expected_rank = self.get_rank(ray.get_runtime_context().current_actor)
assert (
rank == expected_rank
), f"NCCL actor's rank {rank} does not match expected rank {expected_rank}"
self._rank = rank
self._device = torch_utils.get_devices()[0]

def get_rank(self, actor: ray.actor.ActorHandle) -> int:
actor_ids = [a._ray_actor_id for a in self._actor_handles]
try:
rank = actor_ids.index(actor._ray_actor_id)
except ValueError:
raise ValueError("Actor is not in the NCCL group.")
return rank

def get_world_size(self) -> int:
return self._world_size

def get_self_rank(self) -> Optional[int]:
return self._rank

def get_actor_handles(self) -> List["ray.actor.ActorHandle"]:
return self._actor_handles

def send(self, value: "torch.Tensor", peer_rank: int) -> None:
torch.distributed.send(value, peer_rank)

def recv(
self,
shape: Tuple[int],
dtype: "torch.dtype",
peer_rank: int,
allocator: Optional[TorchTensorAllocator] = None,
) -> "torch.Tensor":
tensor = torch.empty(torch.Size(shape), dtype=dtype, device=self._device)
torch.distributed.recv(tensor, peer_rank)
return tensor

def destroy(self) -> None:
pass

nccl_group = InitedNcclGroup(2, [sender, receiver])
with InputNode() as inp:
dag = sender.send_with_tuple_args.bind(inp)
dag = dag.with_type_hint(TorchTensorType(transport=nccl_group))
dag = receiver.recv.bind(dag)

compiled_dag = dag.experimental_compile()
for i in range(3):
i += 1
shape = (i * 10,)
dtype = torch.float16
args = (shape, dtype, i)
ref = compiled_dag.execute(args)
result = ray.get(ref)
assert result == (i, shape, dtype)

compiled_dag.teardown()


@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_torch_tensor_nccl_wrong_shape(ray_start_regular):
if not USE_GPU:
Expand Down
2 changes: 2 additions & 0 deletions python/ray/experimental/channel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
SynchronousWriter,
WriterInterface,
)
from ray.experimental.channel.gpu_communicator import GPUCommunicator
from ray.experimental.channel.intra_process_channel import IntraProcessChannel
from ray.experimental.channel.shared_memory_channel import Channel, CompositeChannel
from ray.experimental.channel.torch_tensor_nccl_channel import TorchTensorNcclChannel
Expand All @@ -19,6 +20,7 @@
"AwaitableBackgroundWriter",
"CachedChannel",
"Channel",
"GPUCommunicator",
"ReaderInterface",
"SynchronousReader",
"SynchronousWriter",
Expand Down
12 changes: 10 additions & 2 deletions python/ray/experimental/channel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

import ray
from ray.experimental.channel.nccl_group import _NcclGroup
from ray.experimental.channel.gpu_communicator import GPUCommunicator
from ray.experimental.channel.serialization_context import _SerializationContext
from ray.util.annotations import DeveloperAPI, PublicAPI

Expand Down Expand Up @@ -100,6 +100,14 @@ def requires_nccl(self) -> bool:
# By default, channels do not require NCCL.
return False

def get_custom_nccl_group(self) -> Optional[GPUCommunicator]:
"""
Return the custom NCCL group if one is specified.
"""
if self._contains_type is not None:
return self._contains_type.get_custom_nccl_group()
return None

def set_nccl_group_id(self, group_id: str) -> None:
raise NotImplementedError

Expand All @@ -112,7 +120,7 @@ class ChannelContext:

def __init__(self):
# Used for the torch.Tensor NCCL transport.
self.nccl_groups: Dict[str, "_NcclGroup"] = {}
self.nccl_groups: Dict[str, "GPUCommunicator"] = {}

@staticmethod
def get_current() -> "ChannelContext":
Expand Down
15 changes: 14 additions & 1 deletion python/ray/experimental/channel/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
from collections import defaultdict
from typing import Optional, Tuple
from unittest import mock

import torch

import ray
import ray.experimental.channel as ray_channel
from ray.experimental.channel.gpu_communicator import TorchTensorAllocator


@ray.remote(num_cpus=0)
Expand Down Expand Up @@ -74,13 +76,24 @@ def send(self, tensor: torch.Tensor, peer_rank: int):
ray.get(barrier.wait.remote(self.num_ops[barrier_key], tensor))
self.num_ops[barrier_key] += 1

def recv(self, buf: torch.Tensor, peer_rank: int):
def recv(
self,
shape: Tuple[int],
dtype: torch.dtype,
peer_rank: int,
allocator: Optional[TorchTensorAllocator] = None,
):
# "Receive" the tensor from the barrier actor.
barrier_key = f"barrier-{peer_rank}-{self.get_self_rank()}"
barrier = ray.get_actor(name=barrier_key)
received_tensor = ray.get(barrier.wait.remote(self.num_ops[barrier_key]))
assert (
allocator is not None
), "torch tensor allocator is required for MockNcclGroup"
buf = allocator(shape, dtype)
ruisearch42 marked this conversation as resolved.
Show resolved Hide resolved
buf[:] = received_tensor[:]
self.num_ops[barrier_key] += 1
return buf


def start_nccl_mock():
Expand Down
Loading