Skip to content

Commit

Permalink
[Ray] Ray client channel get recv when first complied (#2740)
Browse files Browse the repository at this point in the history
  • Loading branch information
Catch-Bull authored Apr 8, 2022
1 parent 930b3f0 commit 79dd902
Showing 1 changed file with 33 additions and 22 deletions.
55 changes: 33 additions & 22 deletions mars/oscar/backends/ray/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class RayChannelBase(Channel, ABC):
Channel for communications between ray processes.
"""

__slots__ = "_channel_index", "_channel_id", "_in_queue", "_closed"
__slots__ = "_channel_index", "_channel_id", "_closed"

name = "ray"
_channel_index_gen = itertools.count()
Expand All @@ -142,7 +142,6 @@ def __init__(
self._channel_id = channel_id or ChannelID(
local_address, _gen_client_id(), self._channel_index, dest_address
)
self._in_queue = asyncio.Queue()
self._closed = asyncio.Event()

@property
Expand All @@ -169,7 +168,7 @@ class RayClientChannel(RayChannelBase):
A channel from ray driver/actor to ray actor. Use ray call reply for client channel recv.
"""

__slots__ = ("_peer_actor",)
__slots__ = "_peer_actor", "_done", "_todo"

def __init__(
self,
Expand All @@ -181,36 +180,47 @@ def __init__(
super().__init__(None, dest_address, channel_index, channel_id, compression)
# ray actor should be created with the address as the name.
self._peer_actor: "ray.actor.ActorHandle" = ray.get_actor(dest_address)
self._done = asyncio.Queue()
self._todo = set()

def _submit_task(self, message: Any, object_ref: "ray.ObjectRef"):
async def handle_task(message: Any, object_ref: "ray.ObjectRef"):
# use `%.500` to avoid print too long messages
with debug_async_timeout(
"ray_object_retrieval_timeout", "Client sent message is %.500s", message
):
result = await object_ref
if isinstance(result, RayChannelException):
raise result.exc_value.with_traceback(result.exc_traceback)
return result.message

def _on_completion(future):
self._todo.remove(future)
self._done.put_nowait(future)

future = asyncio.ensure_future(handle_task(message, object_ref))
future.add_done_callback(_on_completion)
self._todo.add(future)

@implements(Channel.send)
async def send(self, message: Any):
if self._closed.is_set(): # pragma: no cover
raise ChannelClosed("Channel already closed, cannot send message")
# Put ray object ref to queue
self._in_queue.put_nowait(
(
message,
self._peer_actor.__on_ray_recv__.remote(
self.channel_id, _ArgWrapper(message)
),
)
# Put ray object ref to todo queue
task = self._peer_actor.__on_ray_recv__.remote(
self.channel_id, _ArgWrapper(message)
)
self._submit_task(message, task)
await asyncio.sleep(0)

@implements(Channel.recv)
async def recv(self):
if self._closed.is_set(): # pragma: no cover
raise ChannelClosed("Channel already closed, cannot recv message")
try:
# Wait on ray object ref
message, object_ref = await self._in_queue.get()
# use `%.500` to avoid print too long messages
with debug_async_timeout(
"ray_object_retrieval_timeout", "Client sent message is %.500s", message
):
result = await object_ref
if isinstance(result, RayChannelException):
raise result.exc_value.with_traceback(result.exc_traceback)
return result.message
# Wait first done.
future = await self._done.get()
return future.result()
except ray.exceptions.RayActorError:
if not self._closed.is_set():
# raise a EOFError as the SocketChannel does
Expand All @@ -228,7 +238,7 @@ class RayServerChannel(RayChannelBase):
message's reply.
"""

__slots__ = "_out_queue", "_msg_recv_counter", "_msg_sent_counter"
__slots__ = "_in_queue", "_out_queue", "_msg_recv_counter", "_msg_sent_counter"

def __init__(
self,
Expand All @@ -238,6 +248,7 @@ def __init__(
compression=None,
):
super().__init__(local_address, None, channel_index, channel_id, compression)
self._in_queue = asyncio.Queue()
self._out_queue = asyncio.Queue()
self._msg_recv_counter = 0
self._msg_sent_counter = 0
Expand Down

0 comments on commit 79dd902

Please sign in to comment.