Skip to content

Commit

Permalink
Register UCX close callback (#5474)
Browse files Browse the repository at this point in the history
Register UCX close callback
  • Loading branch information
pentschev authored Nov 1, 2021
1 parent 11c41b5 commit 8cc4284
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
34 changes: 33 additions & 1 deletion distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
.. _UCX: https://github.com/openucx/ucx
"""
import functools
import logging
import os
import struct
Expand Down Expand Up @@ -166,6 +167,18 @@ def device_array(n):
ucx_create_listener = EndpointReuse.create_listener


def _close_comm(ref):
"""Callback to close Dask Comm when UCX Endpoint closes or errors
Parameters
----------
ref: weak reference to a Dask UCX comm
"""
comm = ref()
if comm is not None:
comm._closed = True


class UCX(Comm):
"""Comm object using UCP.
Expand Down Expand Up @@ -210,6 +223,17 @@ def __init__(self, ep, local_addr: str, peer_addr: str, deserialize=True):
self._peer_addr = peer_addr
self.deserialize = deserialize
self.comm_flag = None

# When the UCX endpoint closes or errors the registered callback
# is called.
if hasattr(self._ep, "set_close_callback"):
ref = weakref.ref(self)
self._ep.set_close_callback(functools.partial(_close_comm, ref))
self._closed = False
self._has_close_callback = True
else:
self._has_close_callback = False

logger.debug("UCX.__init__ %s", self)

@property
Expand Down Expand Up @@ -341,6 +365,7 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
return msg

async def close(self):
self._closed = True
if self._ep is not None:
try:
await self.ep.send(struct.pack("?Q", True, 0))
Expand All @@ -357,6 +382,7 @@ async def close(self):
self._ep = None

def abort(self):
self._closed = True
if self._ep is not None:
self._ep.abort()
self._ep = None
Expand All @@ -369,7 +395,13 @@ def ep(self):
raise CommClosedError("UCX Endpoint is closed")

def closed(self):
return self._ep is None
if self._has_close_callback is True:
# The self._closed flag is separate from the endpoint's lifetime, even when
# the endpoint has closed or errored, there may be messages on its buffer
# still to be received, even though sending is not possible anymore.
return self._closed
else:
return self._ep is None


class UCXConnector(Connector):
Expand Down
7 changes: 7 additions & 0 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1510,6 +1510,13 @@ async def close(

self.stop_services()

# Give some time for a UCX scheduler to complete closing endpoints
# before closing self.batched_stream, otherwise the local endpoint
# may be closed too early and errors be raised on the scheduler when
# trying to send closing message.
if self._protocol == "ucx":
await asyncio.sleep(0.2)

if (
self.batched_stream
and self.batched_stream.comm
Expand Down

0 comments on commit 8cc4284

Please sign in to comment.