Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
  • Loading branch information
ruisearch42 committed Sep 5, 2024
1 parent c19d7a5 commit d1aff64
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 77 deletions.
34 changes: 15 additions & 19 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
)

from ray.experimental.channel.torch_tensor_nccl_channel import (
_set_nccl_group,
_init_nccl_group,
_destroy_nccl_group,
)
Expand Down Expand Up @@ -531,7 +530,6 @@ def __init__(
enable_asyncio: bool = False,
asyncio_max_queue_size: Optional[int] = None,
max_buffered_results: Optional[int] = None,
nccl_group: Optional[GPUCommunicator] = None,
):
"""
Args:
Expand Down Expand Up @@ -560,9 +558,6 @@ def __init__(
executions is beyond the DAG capacity, the new execution would
be blocked in the first place; therefore, this limit is only
enforced when it is smaller than the DAG capacity.
nccl_group: The NCCL group to use for this DAG. If None, the DAG
will create a NCCL group internally if NCCL communication is
needed.
Returns:
Channel: A wrapper around ray.ObjectRef.
Expand Down Expand Up @@ -646,7 +641,7 @@ def __init__(
# Type hints specified by the user for DAG (intermediate) outputs.
self._type_hints = []

self._custom_nccl_group: Optional[GPUCommunicator] = nccl_group
self._custom_nccl_group: Optional[GPUCommunicator] = None
# Uniquely identifies the NCCL communicator that will be used within
# this DAG, if any.
self._nccl_group_id: Optional[str] = None
Expand Down Expand Up @@ -813,6 +808,15 @@ def _preprocess(self) -> None:
if dag_node.type_hint.requires_nccl():
# Add all writers to the NCCL group.
nccl_actors.add(actor_handle)
custom_nccl_group = dag_node.type_hint.get_custom_nccl_group()
if custom_nccl_group is not None:
if self._custom_nccl_group is not None:
assert self._custom_nccl_group == custom_nccl_group, (
"Accelerated DAGs currently only support "
"a single custom NCCL group, but multiple "
"have been specified."
)
self._custom_nccl_group = custom_nccl_group
elif isinstance(dag_node, InputNode):
if dag_node.type_hint.requires_nccl():
raise ValueError(
Expand Down Expand Up @@ -923,16 +927,7 @@ def _preprocess(self) -> None:
if None in nccl_actors:
raise ValueError("Driver cannot participate in the NCCL group.")
if nccl_actors and self._nccl_group_id is None:
if self._custom_nccl_group:
self._nccl_group_id = _set_nccl_group(
self._custom_nccl_group, nccl_actors
)
else:
self._nccl_group_id = _init_nccl_group(nccl_actors)
elif self._custom_nccl_group:
raise ValueError(
"The DAG does not use NCCL, but a custom NCCL group was provided."
)
self._nccl_group_id = _init_nccl_group(nccl_actors, self._custom_nccl_group)

if direct_input:
self._input_num_positional_args = 1
Expand Down Expand Up @@ -1747,7 +1742,10 @@ def teardown(self, wait: bool):
logger.exception("Error cancelling worker task")
pass

if outer._nccl_group_id is not None:
if (
outer._nccl_group_id is not None
and outer._custom_nccl_group is None
):
_destroy_nccl_group(outer._nccl_group_id)

if wait:
Expand Down Expand Up @@ -1948,15 +1946,13 @@ def build_compiled_dag_from_ray_dag(
enable_asyncio: bool = False,
asyncio_max_queue_size: Optional[int] = None,
max_buffered_results: Optional[int] = None,
nccl_group: Optional["GPUCommunicator"] = None,
) -> "CompiledDAG":
compiled_dag = CompiledDAG(
execution_timeout,
buffer_size_bytes,
enable_asyncio,
asyncio_max_queue_size,
max_buffered_results,
nccl_group,
)

def _build_compiled_dag(node):
Expand Down
8 changes: 0 additions & 8 deletions python/ray/dag/dag_node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from ray.experimental.channel.gpu_communicator import GPUCommunicator
import ray
from ray.dag.base import DAGNodeBase
from ray.dag.py_obj_scanner import _PyObjScanner
Expand Down Expand Up @@ -167,7 +166,6 @@ def experimental_compile(
enable_asyncio: bool = False,
_asyncio_max_queue_size: Optional[int] = None,
_max_buffered_results: Optional[int] = None,
_nccl_group: Optional["GPUCommunicator"] = None,
) -> "ray.dag.CompiledDAG":
"""Compile an accelerated execution path for this DAG.
Expand All @@ -188,11 +186,6 @@ def experimental_compile(
executions is beyond the DAG capacity, the new execution would
be blocked in the first place; therefore, this limit is only
enforced when it is smaller than the DAG capacity.
_nccl_group: The NCCL group to use for this DAG. If None, the DAG
will create a NCCL group internally if NCCL communication is
needed. Providing a nccl group here allows reusing, which
avoids extra memory overhead when initializing a new NCCL group
from aDAG.
Returns:
A compiled DAG.
Expand Down Expand Up @@ -225,7 +218,6 @@ def experimental_compile(
enable_asyncio,
_asyncio_max_queue_size,
_max_buffered_results,
_nccl_group,
)

def execute(
Expand Down
97 changes: 97 additions & 0 deletions python/ray/dag/tests/experimental/test_torch_tensor_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
import os
import re
import sys
from typing import Optional, Tuple
from ray.experimental.channel.gpu_communicator import (
GPUCommunicator,
TorchTensorAllocator,
)
from ray.experimental.channel.nccl_group import _NcclGroup
import torch
import time

Expand Down Expand Up @@ -291,6 +297,97 @@ def test_torch_tensor_nccl_dynamic(ray_start_regular):
compiled_dag.teardown()


@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_torch_tensor_custom_nccl(ray_start_regular):
if not USE_GPU:
pytest.skip("NCCL tests require GPUs")

assert (
sum(node["Resources"].get("GPU", 0) for node in ray.nodes()) > 1
), "This test requires at least 2 GPUs"

actor_cls = TorchTensorWorker.options(num_cpus=0, num_gpus=1)

sender = actor_cls.remote()
receiver = actor_cls.remote()

class TestNcclGroup(GPUCommunicator):
"""
A custom NCCL group for testing. This is a simple wrapper around `_NcclGroup`.
"""

def __init__(self, world_size, comm_id, actor_handles):
self._world_size = world_size
self._comm_id = comm_id
self._actor_handles = actor_handles
self._inner = None

def initialize(self, rank: int) -> None:
self._inner = _NcclGroup(
self._world_size,
self._comm_id,
rank,
self._actor_handles,
torch.cuda.current_stream().cuda_stream,
)

def get_rank(self, actor: ray.actor.ActorHandle) -> int:
# Implement this without forwarding to `_inner` to allow the method
# to be called before initialization.
actor_ids = [a._ray_actor_id for a in self._actor_handles]
try:
rank = actor_ids.index(actor._ray_actor_id)
except ValueError:
raise ValueError("Actor is not in the NCCL group.")
return rank

def get_world_size(self) -> int:
# Implement this without forwarding to `_inner` to allow the method
# to be called before initialization.
return self._world_size

def get_self_rank(self) -> Optional[int]:
if self._inner is None:
return None
return self._inner.get_self_rank()

def send(self, value: "torch.Tensor", peer_rank: int) -> None:
return self._inner.send(value, peer_rank)

def recv(
self,
shape: Tuple[int],
dtype: "torch.dtype",
peer_rank: int,
allocator: Optional[TorchTensorAllocator] = None,
) -> "torch.Tensor":
return self._inner.recv(shape, dtype, peer_rank, allocator=allocator)

def destroy(self) -> None:
return self._inner.destroy()

from cupy.cuda import nccl

comm_id = nccl.get_unique_id()
nccl_group = TestNcclGroup(2, comm_id, [sender, receiver])
with InputNode() as inp:
dag = sender.send_with_tuple_args.bind(inp)
dag = dag.with_type_hint(TorchTensorType(transport=nccl_group))
dag = receiver.recv.bind(dag)

compiled_dag = dag.experimental_compile()
for i in range(3):
i += 1
shape = (i * 10,)
dtype = torch.float16
args = (shape, dtype, i)
ref = compiled_dag.execute(args)
result = ray.get(ref)
assert result == (i, shape, dtype)

compiled_dag.teardown()


@pytest.mark.parametrize("ray_start_regular", [{"num_cpus": 4}], indirect=True)
def test_torch_tensor_nccl_wrong_shape(ray_start_regular):
if not USE_GPU:
Expand Down
8 changes: 8 additions & 0 deletions python/ray/experimental/channel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ def requires_nccl(self) -> bool:
# By default, channels do not require NCCL.
return False

def get_custom_nccl_group(self) -> Optional[GPUCommunicator]:
"""
Return the custom NCCL group if one is specified.
"""
if self._contains_type is not None:
return self._contains_type.get_custom_nccl_group()
return None

def set_nccl_group_id(self, group_id: str) -> None:
raise NotImplementedError

Expand Down
24 changes: 15 additions & 9 deletions python/ray/experimental/channel/gpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ class GPUCommunicator(ABC):
between actors in the group.
"""

def register(self, group_id: str):
@abstractmethod
def initialize(self, rank: int) -> None:
"""
Register the group in the Ray channel context.
Initialize the communicator from the actor.
This should be called once remotely on each actor
in the group before any other methods can be called,
with the same `group_id`.
"""
from ray.experimental.channel.common import ChannelContext
This is called once by aDAG on each actor to initialize the communicator,
before any other methods.
ctx = ChannelContext.get_current()
ctx.nccl_groups[group_id] = self
Args:
rank: The rank of this actor in the group.
"""
raise NotImplementedError

@abstractmethod
def get_rank(self, actor: ray.actor.ActorHandle) -> int:
Expand All @@ -52,6 +52,12 @@ def get_self_rank(self) -> Optional[int]:
"""
raise NotImplementedError

def get_world_size(self) -> int:
"""
Return the number of ranks in the group.
"""
raise NotImplementedError

@abstractmethod
def send(self, value: "torch.Tensor", peer_rank: int) -> None:
"""
Expand Down
11 changes: 11 additions & 0 deletions python/ray/experimental/channel/nccl_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
cuda_stream: A raw CUDA stream to dispatch NCCL ops to. If rank is
specified, then this must be specified too.
"""
self._world_size = world_size
self._rank: Optional[int] = rank
self.nccl_util: Optional[ModuleType] = None
self._actor_handles = actor_handles
Expand Down Expand Up @@ -105,6 +106,10 @@ def __init__(

self._closed = False

def initialize(self, rank: int) -> None:
# No additional initialization is needed.
pass

def _get_actor_handles(self) -> List["ray.actor.ActorHandle"]:
return self._actor_handles

Expand All @@ -128,6 +133,12 @@ def get_self_rank(self) -> Optional[int]:
"""
return self._rank

def get_world_size(self) -> int:
"""
Return the number of ranks in the NCCL communicator.
"""
return self._world_size

def send(self, value: "torch.Tensor", peer_rank: int) -> None:
"""
Send a torch.Tensor to a peer.
Expand Down
Loading

0 comments on commit d1aff64

Please sign in to comment.