Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ray] Ray client channel get recv when first complied #2740

Merged
merged 8 commits into from
Apr 8, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 33 additions & 22 deletions mars/oscar/backends/ray/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class RayChannelBase(Channel, ABC):
Channel for communications between ray processes.
"""

__slots__ = "_channel_index", "_channel_id", "_in_queue", "_closed"
__slots__ = "_channel_index", "_channel_id", "_closed"

name = "ray"
_channel_index_gen = itertools.count()
Expand All @@ -142,7 +142,6 @@ def __init__(
self._channel_id = channel_id or ChannelID(
local_address, _gen_client_id(), self._channel_index, dest_address
)
self._in_queue = asyncio.Queue()
self._closed = asyncio.Event()

@property
Expand All @@ -169,7 +168,7 @@ class RayClientChannel(RayChannelBase):
A channel from ray driver/actor to ray actor. Use ray call reply for client channel recv.
"""

__slots__ = ("_peer_actor",)
__slots__ = "_peer_actor", "_done", "_todo"

def __init__(
self,
Expand All @@ -181,36 +180,47 @@ def __init__(
super().__init__(None, dest_address, channel_index, channel_id, compression)
# ray actor should be created with the address as the name.
self._peer_actor: "ray.actor.ActorHandle" = ray.get_actor(dest_address)
self._done = asyncio.Queue()
self._todo = set()

def _submit_task(self, message: Any, object_ref: "ray.ObjectRef"):
async def handle_task(message: Any, object_ref: "ray.ObjectRef"):
# use `%.500` to avoid print too long messages
with debug_async_timeout(
"ray_object_retrieval_timeout", "Client sent message is %.500s", message
):
result = await object_ref
if isinstance(result, RayChannelException):
raise result.exc_value.with_traceback(result.exc_traceback)
return result.message

def _on_completion(future):
self._todo.remove(future)
self._done.put_nowait(future)

future = asyncio.ensure_future(handle_task(message, object_ref))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ensure_future will create an asyncio.Task, which may have some cost. Wonder whether we should use direct ray call for ray channel, like we did in #2690

future.add_done_callback(_on_completion)
self._todo.add(future)

@implements(Channel.send)
async def send(self, message: Any):
if self._closed.is_set(): # pragma: no cover
raise ChannelClosed("Channel already closed, cannot send message")
# Put ray object ref to queue
self._in_queue.put_nowait(
(
message,
self._peer_actor.__on_ray_recv__.remote(
self.channel_id, _ArgWrapper(message)
),
)
# Put ray object ref to todo queue
task = self._peer_actor.__on_ray_recv__.remote(
self.channel_id, _ArgWrapper(message)
)
self._submit_task(message, task)
await asyncio.sleep(0)

@implements(Channel.recv)
async def recv(self):
if self._closed.is_set(): # pragma: no cover
raise ChannelClosed("Channel already closed, cannot recv message")
try:
# Wait on ray object ref
message, object_ref = await self._in_queue.get()
# use `%.500` to avoid print too long messages
with debug_async_timeout(
"ray_object_retrieval_timeout", "Client sent message is %.500s", message
):
result = await object_ref
if isinstance(result, RayChannelException):
raise result.exc_value.with_traceback(result.exc_traceback)
return result.message
# Wait first done.
future = await self._done.get()
return future.result()
except ray.exceptions.RayActorError:
if not self._closed.is_set():
# raise a EOFError as the SocketChannel does
Expand All @@ -228,7 +238,7 @@ class RayServerChannel(RayChannelBase):
message's reply.
"""

__slots__ = "_out_queue", "_msg_recv_counter", "_msg_sent_counter"
__slots__ = "_in_queue", "_out_queue", "_msg_recv_counter", "_msg_sent_counter"

def __init__(
self,
Expand All @@ -238,6 +248,7 @@ def __init__(
compression=None,
):
super().__init__(local_address, None, channel_index, channel_id, compression)
self._in_queue = asyncio.Queue()
self._out_queue = asyncio.Queue()
self._msg_recv_counter = 0
self._msg_sent_counter = 0
Expand Down