Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register UCX close callback #5474

Merged
merged 4 commits into from
Nov 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

.. _UCX: https://github.com/openucx/ucx
"""
import functools
import logging
import os
import struct
Expand Down Expand Up @@ -166,6 +167,18 @@ def device_array(n):
ucx_create_listener = EndpointReuse.create_listener


def _close_comm(ref):
"""Callback to close Dask Comm when UCX Endpoint closes or errors

Parameters
----------
ref: weak reference to a Dask UCX comm
"""
comm = ref()
if comm is not None:
comm._closed = True


class UCX(Comm):
"""Comm object using UCP.

Expand Down Expand Up @@ -210,6 +223,17 @@ def __init__(self, ep, local_addr: str, peer_addr: str, deserialize=True):
self._peer_addr = peer_addr
self.deserialize = deserialize
self.comm_flag = None

# When the UCX endpoint closes or errors the registered callback
# is called.
if hasattr(self._ep, "set_close_callback"):
ref = weakref.ref(self)
self._ep.set_close_callback(functools.partial(_close_comm, ref))
self._closed = False
self._has_close_callback = True
else:
self._has_close_callback = False

logger.debug("UCX.__init__ %s", self)

@property
Expand Down Expand Up @@ -341,6 +365,7 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
return msg

async def close(self):
self._closed = True
if self._ep is not None:
try:
await self.ep.send(struct.pack("?Q", True, 0))
Expand All @@ -357,6 +382,7 @@ async def close(self):
self._ep = None

def abort(self):
self._closed = True
if self._ep is not None:
self._ep.abort()
self._ep = None
Expand All @@ -369,7 +395,13 @@ def ep(self):
raise CommClosedError("UCX Endpoint is closed")

def closed(self):
return self._ep is None
if self._has_close_callback is True:
# The self._closed flag is separate from the endpoint's lifetime, even when
# the endpoint has closed or errored, there may be messages on its buffer
# still to be received, even though sending is not possible anymore.
return self._closed
else:
return self._ep is None


class UCXConnector(Connector):
Expand Down
7 changes: 7 additions & 0 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1510,6 +1510,13 @@ async def close(

self.stop_services()

# Give some time for a UCX scheduler to complete closing endpoints
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed comments here

# before closing self.batched_stream, otherwise the local endpoint
# may be closed too early and errors be raised on the scheduler when
# trying to send closing message.
if self._protocol == "ucx":
await asyncio.sleep(0.2)

if (
self.batched_stream
and self.batched_stream.comm
Expand Down