Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
  • Loading branch information
ruisearch42 committed Aug 19, 2024
1 parent 548344c commit a212a01
Showing 1 changed file with 94 additions and 0 deletions.
94 changes: 94 additions & 0 deletions python/ray/experimental/channel/actor_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable, Optional, Tuple

from ray.experimental.channel.common import ChannelContext
import ray
from ray.util.annotations import DeveloperAPI

if TYPE_CHECKING:
import torch


# Signature for a torch.Tensor allocator is:
# (shape: Tuple[int], dtype: torch.dtype) -> torch.Tensor.
TorchTensorAllocator = Callable[[Tuple[int], "torch.dtype"], "torch.Tensor"]


@DeveloperAPI
class ActorGroup(ABC):
"""
Communicator for a group of aDAG actors.
The aDAG execution leverages this internally to support communication
between actors in the group.
"""

def register(self, group_id: str):
"""
Register the group in the Ray channel context.
This should be called once remotely on each actor
in the group before any other methods can be called,
with the same `group_id`.
"""
ctx = ChannelContext.get_current()
ctx.nccl_groups[group_id] = self

@abstractmethod
def get_rank(self, actor: ray.actor.ActorHandle) -> int:
"""
Return the given actor's rank in the group.
Args:
actor: The actor handle to look up.
"""
raise NotImplementedError

@abstractmethod
def get_self_rank(self) -> Optional[int]:
"""
Return this actor's rank.
"""
raise NotImplementedError

@abstractmethod
def send(self, value: "torch.Tensor", peer_rank: int):
"""
Send a torch.Tensor to a peer.
This returns when the send kernel has been queued, but the kernel may
not have completed. Therefore, the caller should ensure that there are
no concurrent writes to the sent `value` until the send has finished.
That is, either all writes should be submitted on the current stream
(self._cuda_stream) or, if on a different stream, that stream should
synchronize with the current stream.
Args:
value: The torch.Tensor to send. It should already be on this
actor's default device.
peer_rank: The rank of the actor to send to.
"""
raise NotImplementedError

@abstractmethod
def recv(
self,
shape: Tuple[int],
dtype: "torch.dtype",
peer_rank: int,
allocator: Optional[TorchTensorAllocator] = None,
):
"""
Receive a torch.Tensor from a peer and synchronize.
After this call returns, the receive buffer is safe to read from from
any stream. An RayChannelError will be raised if an error occurred (e.g.,
remote actor died), and the buffer is not safe to read.
Args:
shape: The shape of the tensor to receive.
dtype: The dtype of the tensor to receive.
peer_rank: The rank of the actor to receive from.
allocator: A function to allocate the tensor to receive into.
"""
raise NotImplementedError

0 comments on commit a212a01

Please sign in to comment.