Skip to content

Commit

Permalink
[aDAG] Allow custom NCCL group for aDAG
Browse files Browse the repository at this point in the history
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
  • Loading branch information
ruisearch42 committed Sep 3, 2024
1 parent d4a52ea commit 3f62bff
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 26 deletions.
23 changes: 22 additions & 1 deletion python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
import traceback
from typing import NamedTuple

<<<<<<< HEAD
from ray.experimental.channel.cached_channel import CachedChannel
=======
from ray.experimental.channel.gpu_communicator import GPUCommunicator
>>>>>>> 0d83e ([aDAG] Allow custom NCCL group for aDAG)
import ray
from ray.exceptions import RayTaskError, RayChannelError
from ray.util.annotations import PublicAPI
Expand All @@ -34,6 +38,7 @@
)

from ray.experimental.channel.torch_tensor_nccl_channel import (
_set_nccl_group,
_init_nccl_group,
_destroy_nccl_group,
)
Expand Down Expand Up @@ -529,6 +534,7 @@ def __init__(
enable_asyncio: bool = False,
asyncio_max_queue_size: Optional[int] = None,
max_buffered_results: Optional[int] = None,
nccl_group: Optional[GPUCommunicator] = None,
):
"""
Args:
Expand Down Expand Up @@ -557,6 +563,9 @@ def __init__(
executions is beyond the DAG capacity, the new execution would
be blocked in the first place; therefore, this limit is only
enforced when it is smaller than the DAG capacity.
nccl_group: The NCCL group to use for this DAG. If None, the DAG
will create a NCCL group internally if NCCL communication is
needed.
Returns:
Channel: A wrapper around ray.ObjectRef.
Expand Down Expand Up @@ -640,6 +649,7 @@ def __init__(
# Type hints specified by the user for DAG (intermediate) outputs.
self._type_hints = []

self._custom_nccl_group: Optional[GPUCommunicator] = nccl_group
# Uniquely identifies the NCCL communicator that will be used within
# this DAG, if any.
self._nccl_group_id: Optional[str] = None
Expand Down Expand Up @@ -916,7 +926,16 @@ def _preprocess(self) -> None:
if None in nccl_actors:
raise ValueError("Driver cannot participate in the NCCL group.")
if nccl_actors and self._nccl_group_id is None:
self._nccl_group_id = _init_nccl_group(nccl_actors)
if self._custom_nccl_group:
self._nccl_group_id = _set_nccl_group(
self._custom_nccl_group, nccl_actors
)
else:
self._nccl_group_id = _init_nccl_group(nccl_actors)
elif self._custom_nccl_group:
raise ValueError(
"The DAG does not use NCCL, but a custom NCCL group was provided."
)

if direct_input:
self._input_num_positional_args = 1
Expand Down Expand Up @@ -1932,13 +1951,15 @@ def build_compiled_dag_from_ray_dag(
enable_asyncio: bool = False,
asyncio_max_queue_size: Optional[int] = None,
max_buffered_results: Optional[int] = None,
nccl_group: Optional["GPUCommunicator"] = None,
) -> "CompiledDAG":
compiled_dag = CompiledDAG(
execution_timeout,
buffer_size_bytes,
enable_asyncio,
asyncio_max_queue_size,
max_buffered_results,
nccl_group,
)

def _build_compiled_dag(node):
Expand Down
8 changes: 8 additions & 0 deletions python/ray/dag/dag_node.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ray.experimental.channel.gpu_communicator import GPUCommunicator
import ray
from ray.dag.base import DAGNodeBase
from ray.dag.py_obj_scanner import _PyObjScanner
Expand Down Expand Up @@ -166,6 +167,7 @@ def experimental_compile(
enable_asyncio: bool = False,
_asyncio_max_queue_size: Optional[int] = None,
_max_buffered_results: Optional[int] = None,
_nccl_group: Optional["GPUCommunicator"] = None,
) -> "ray.dag.CompiledDAG":
"""Compile an accelerated execution path for this DAG.
Expand All @@ -186,6 +188,11 @@ def experimental_compile(
executions is beyond the DAG capacity, the new execution would
be blocked in the first place; therefore, this limit is only
enforced when it is smaller than the DAG capacity.
_nccl_group: The NCCL group to use for this DAG. If None, the DAG
will create a NCCL group internally if NCCL communication is
needed. Providing a nccl group here allows reusing, which
avoids extra memory overhead when initializing a new NCCL group
from aDAG.
Returns:
A compiled DAG.
Expand Down Expand Up @@ -218,6 +225,7 @@ def experimental_compile(
enable_asyncio,
_asyncio_max_queue_size,
_max_buffered_results,
_nccl_group,
)

def execute(
Expand Down
2 changes: 2 additions & 0 deletions python/ray/experimental/channel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
SynchronousWriter,
WriterInterface,
)
from ray.experimental.channel.gpu_communicator import GPUCommunicator
from ray.experimental.channel.intra_process_channel import IntraProcessChannel
from ray.experimental.channel.shared_memory_channel import Channel, CompositeChannel
from ray.experimental.channel.torch_tensor_nccl_channel import TorchTensorNcclChannel
Expand All @@ -19,6 +20,7 @@
"AwaitableBackgroundWriter",
"CachedChannel",
"Channel",
"GPUCommunicator",
"ReaderInterface",
"SynchronousReader",
"SynchronousWriter",
Expand Down
4 changes: 2 additions & 2 deletions python/ray/experimental/channel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

import ray
from ray.experimental.channel.nccl_group import _NcclGroup
from ray.experimental.channel.gpu_communicator import GPUCommunicator
from ray.experimental.channel.serialization_context import _SerializationContext
from ray.util.annotations import DeveloperAPI, PublicAPI

Expand Down Expand Up @@ -112,7 +112,7 @@ class ChannelContext:

def __init__(self):
# Used for the torch.Tensor NCCL transport.
self.nccl_groups: Dict[str, "_NcclGroup"] = {}
self.nccl_groups: Dict[str, "GPUCommunicator"] = {}

@staticmethod
def get_current() -> "ChannelContext":
Expand Down
12 changes: 11 additions & 1 deletion python/ray/experimental/channel/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import asyncio
from collections import defaultdict
from typing import Optional, Tuple
from unittest import mock

import torch

import ray
import ray.experimental.channel as ray_channel
from ray.experimental.channel.gpu_communicator import TorchTensorAllocator


@ray.remote(num_cpus=0)
Expand Down Expand Up @@ -74,13 +76,21 @@ def send(self, tensor: torch.Tensor, peer_rank: int):
ray.get(barrier.wait.remote(self.num_ops[barrier_key], tensor))
self.num_ops[barrier_key] += 1

def recv(self, buf: torch.Tensor, peer_rank: int):
def recv(
self,
shape: Tuple[int],
dtype: torch.dtype,
peer_rank: int,
allocator: Optional[TorchTensorAllocator] = None,
):
# "Receive" the tensor from the barrier actor.
barrier_key = f"barrier-{peer_rank}-{self.get_self_rank()}"
barrier = ray.get_actor(name=barrier_key)
received_tensor = ray.get(barrier.wait.remote(self.num_ops[barrier_key]))
buf = allocator(shape, dtype)
buf[:] = received_tensor[:]
self.num_ops[barrier_key] += 1
return buf


def start_nccl_mock():
Expand Down
95 changes: 95 additions & 0 deletions python/ray/experimental/channel/gpu_communicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable, Optional, Tuple

import ray
from ray.util.annotations import DeveloperAPI

if TYPE_CHECKING:
import torch


# Signature for a torch.Tensor allocator is:
# (shape: Tuple[int], dtype: torch.dtype) -> torch.Tensor.
TorchTensorAllocator = Callable[[Tuple[int], "torch.dtype"], "torch.Tensor"]


@DeveloperAPI
class GPUCommunicator(ABC):
"""
Communicator for a group of aDAG actors on Nvidia GPU.
The aDAG execution leverages this internally to support communication
between actors in the group.
"""

def register(self, group_id: str):
"""
Register the group in the Ray channel context.
This should be called once remotely on each actor
in the group before any other methods can be called,
with the same `group_id`.
"""
from ray.experimental.channel.common import ChannelContext

ctx = ChannelContext.get_current()
ctx.nccl_groups[group_id] = self

@abstractmethod
def get_rank(self, actor: ray.actor.ActorHandle) -> int:
"""
Return the given actor's rank in the group.
Args:
actor: The actor handle to look up.
"""
raise NotImplementedError

@abstractmethod
def get_self_rank(self) -> Optional[int]:
"""
Return this actor's rank.
"""
raise NotImplementedError

@abstractmethod
def send(self, value: "torch.Tensor", peer_rank: int) -> None:
"""
Send a torch.Tensor to a peer.
This returns when the send kernel has been queued, but the kernel may
not have completed. Therefore, the caller should ensure that there are
no concurrent writes to the sent `value` until the send has finished.
That is, either all writes should be submitted on the current stream
(self._cuda_stream) or, if on a different stream, that stream should
synchronize with the current stream.
Args:
value: The torch.Tensor to send. It should already be on this
actor's default device.
peer_rank: The rank of the actor to send to.
"""
raise NotImplementedError

@abstractmethod
def recv(
self,
shape: Tuple[int],
dtype: "torch.dtype",
peer_rank: int,
allocator: Optional[TorchTensorAllocator] = None,
) -> "torch.Tensor":
"""
Receive a torch.Tensor from a peer and synchronize.
After this call returns, the receive buffer is safe to read from from
any stream. An RayChannelError will be raised if an error occurred (e.g.,
remote actor died), and the buffer is not safe to read.
Args:
shape: The shape of the tensor to receive.
dtype: The dtype of the tensor to receive.
peer_rank: The rank of the actor to receive from.
allocator: A function to allocate the tensor to receive into.
"""
raise NotImplementedError
26 changes: 20 additions & 6 deletions python/ray/experimental/channel/nccl_group.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import logging
from types import ModuleType
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, List, Optional, Tuple

import ray
from ray.exceptions import RayChannelError
from ray.experimental.channel.gpu_communicator import (
GPUCommunicator,
TorchTensorAllocator,
)

if TYPE_CHECKING:
import cupy as cp
Expand All @@ -16,9 +20,10 @@
logger = logging.getLogger(__name__)


class _NcclGroup:
class _NcclGroup(GPUCommunicator):
"""
Represents an actor's NCCL communicator.
Represents an actor's NCCL communicator. This is the default NCCL communicator
to be used in aDAG if a custom communicator is not provided.
This class is not thread-safe.
"""
Expand Down Expand Up @@ -123,7 +128,7 @@ def get_self_rank(self) -> Optional[int]:
"""
return self._rank

def send(self, value: "torch.Tensor", peer_rank: int):
def send(self, value: "torch.Tensor", peer_rank: int) -> None:
"""
Send a torch.Tensor to a peer.
Expand Down Expand Up @@ -151,7 +156,13 @@ def send(self, value: "torch.Tensor", peer_rank: int):
self._cuda_stream.ptr,
)

def recv(self, buf: "torch.Tensor", peer_rank: int):
def recv(
self,
shape: Tuple[int],
dtype: "torch.dtype",
peer_rank: int,
allocator=Optional[TorchTensorAllocator],
) -> "torch.Tensor":
"""
Receive a torch.Tensor from a peer and synchronize the current stream.
Expand All @@ -165,6 +176,8 @@ def recv(self, buf: "torch.Tensor", peer_rank: int):
"""
if self._closed:
raise RayChannelError("NCCL group has been destroyed.")
assert allocator is not None, "NCCL group requires a tensor allocator"
buf = allocator(shape, dtype)
self._comm.recv(
self.nccl_util.get_tensor_ptr(buf),
buf.numel(),
Expand All @@ -180,8 +193,9 @@ def recv(self, buf: "torch.Tensor", peer_rank: int):
self._cuda_stream.synchronize()
if self._closed:
raise RayChannelError("NCCL group has been destroyed.")
return buf

def destroy(self):
def destroy(self) -> None:
"""
Destroy the NCCL group.
"""
Expand Down
Loading

0 comments on commit 3f62bff

Please sign in to comment.