Skip to content

Commit

Permalink
Add functionality for user to register close callback (#795)
Browse files Browse the repository at this point in the history
* Add functionality for user to register close callback

When a user callback is registered, the Endpoint error callback or its
finalizer will run the callback to inform the user's application that
the Endpoint is terminating and will not accept sending new messages,
although receiving may still be possible as UCP may still have some
incoming messages in transit.

* Add core API close callback test

* Add async API close callback test

* Fix core test_close_callback

Since ep.close() calls worker.progress() it can't be called within the
listener callback, therefore we have to progress outside until the
callback is called. We must check the processes' exitcodes to be `0`, if
they timeout the exitcode is `None`, which can't be checked correctly
just with `not process.exitcode`.

* Make UCXEndpointCloseCallback a regular Python class

* Deregister UCXEndpointCloseCallback callback before calling it

* Avoid partial function in test_close_callback
  • Loading branch information
pentschev authored Oct 28, 2021
1 parent e28d770 commit addee5e
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 0 deletions.
37 changes: 37 additions & 0 deletions tests/test_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest

import ucp


@pytest.mark.asyncio
@pytest.mark.parametrize("server_close_callback", [True, False])
async def test_close_callback(server_close_callback):
endpoint_error_handling = ucp.get_ucx_version() >= (1, 10, 0)
closed = [False]

def _close_callback():
closed[0] = True

async def server_node(ep):
if server_close_callback is True:
ep.set_close_callback(_close_callback)
msg = bytearray(10)
await ep.recv(msg)
if server_close_callback is False:
await ep.close()

async def client_node(port):
ep = await ucp.create_endpoint(
ucp.get_address(), port, endpoint_error_handling=endpoint_error_handling
)
if server_close_callback is False:
ep.set_close_callback(_close_callback)
await ep.send(bytearray(b"0" * 10))
if server_close_callback is True:
await ep.close()

listener = ucp.create_listener(
server_node, endpoint_error_handling=endpoint_error_handling
)
await client_node(listener.port)
assert closed[0] is True
85 changes: 85 additions & 0 deletions ucp/_libs/tests/test_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import functools
import multiprocessing as mp

import pytest

from ucp._libs import ucx_api

mp = mp.get_context("spawn")


def _close_callback(closed):
closed[0] = True


def _server(queue, endpoint_error_handling, server_close_callback):
"""Server that send received message back to the client
Notice, since it is illegal to call progress() in call-back functions,
we use a "chain" of call-back functions.
"""
ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,))
worker = ucx_api.UCXWorker(ctx)

listener_finished = [False]
closed = [False]

# A reference to listener's endpoint is stored to prevent it from going
# out of scope too early.
# ep = None

def _listener_handler(conn_request):
global ep
ep = ucx_api.UCXEndpoint.create_from_conn_request(
worker, conn_request, endpoint_error_handling=endpoint_error_handling,
)
if server_close_callback is True:
ep.set_close_callback(functools.partial(_close_callback, closed))
listener_finished[0] = True

listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler)
queue.put(listener.port)

if server_close_callback is True:
while closed[0] is False:
worker.progress()
assert closed[0] is True
else:
while listener_finished[0] is False:
worker.progress()


def _client(port, endpoint_error_handling, server_close_callback):
ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,))
worker = ucx_api.UCXWorker(ctx)
ep = ucx_api.UCXEndpoint.create(
worker, "localhost", port, endpoint_error_handling=endpoint_error_handling,
)
if server_close_callback is True:
ep.close()
worker.progress()
else:
closed = [False]
ep.set_close_callback(functools.partial(_close_callback, closed))
while closed[0] is False:
worker.progress()


@pytest.mark.parametrize("server_close_callback", [True, False])
def test_close_callback(server_close_callback):
endpoint_error_handling = ucx_api.get_ucx_version() >= (1, 10, 0)

queue = mp.Queue()
server = mp.Process(
target=_server, args=(queue, endpoint_error_handling, server_close_callback),
)
server.start()
port = queue.get()
client = mp.Process(
target=_client, args=(port, endpoint_error_handling, server_close_callback),
)
client.start()
client.join(timeout=10)
server.join(timeout=10)
assert client.exitcode == 0
assert server.exitcode == 0
25 changes: 25 additions & 0 deletions ucp/_libs/ucx_endpoint.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@ from ..exceptions import UCXCanceled, UCXConnectionReset, UCXError
logger = logging.getLogger("ucx")


class UCXEndpointCloseCallback():
def __init__(self):
self._cb_func = None

def run(self):
if self._cb_func is not None:
# Deregister callback to prevent calling from the endpoint error
# callback and again from the finalizer.
cb_func, self._cb_func = self._cb_func, None
cb_func()

def set(self, cb_func):
self._cb_func = cb_func


cdef void _err_cb(void *arg, ucp_ep_h ep, ucs_status_t status) with gil:
cdef UCXEndpoint ucx_ep = <UCXEndpoint> arg
assert ucx_ep.worker.initialized
Expand All @@ -30,6 +45,7 @@ cdef void _err_cb(void *arg, ucp_ep_h ep, ucs_status_t status) with gil:
hex(int(<uintptr_t>ep)), status, status_str
)
)
ucx_ep._endpoint_close_callback.run()
logger.debug(msg)


Expand Down Expand Up @@ -59,6 +75,7 @@ def _ucx_endpoint_finalizer(
bint endpoint_error_handling,
UCXWorker worker,
set inflight_msgs,
object endpoint_close_callback,
):
assert worker.initialized
cdef ucp_ep_h handle = <ucp_ep_h>handle_as_int
Expand Down Expand Up @@ -120,6 +137,8 @@ def _ucx_endpoint_finalizer(
msg = ucs_status_string(UCS_PTR_STATUS(status)).decode("utf-8")
raise UCXError("Error while closing endpoint: %s" % msg)

endpoint_close_callback.run()


cdef class UCXEndpoint(UCXObject):
"""Python representation of `ucp_ep_h`"""
Expand All @@ -128,6 +147,7 @@ cdef class UCXEndpoint(UCXObject):
uintptr_t _status
bint _endpoint_error_handling
set _inflight_msgs
object _endpoint_close_callback

cdef readonly:
UCXWorker worker
Expand All @@ -143,6 +163,7 @@ cdef class UCXEndpoint(UCXObject):
assert worker.initialized
self.worker = worker
self._inflight_msgs = set()
self._endpoint_close_callback = UCXEndpointCloseCallback()

cdef ucp_err_handler_cb_t err_cb
cdef uintptr_t ep_status
Expand Down Expand Up @@ -172,6 +193,7 @@ cdef class UCXEndpoint(UCXObject):
endpoint_error_handling,
worker,
self._inflight_msgs,
self._endpoint_close_callback,
)
worker.add_child(self)

Expand Down Expand Up @@ -305,3 +327,6 @@ cdef class UCXEndpoint(UCXObject):

def unpack_rkey(self, rkey):
return UCXRkey(self, rkey)

def set_close_callback(self, cb_func):
self._endpoint_close_callback.set(cb_func)
22 changes: 22 additions & 0 deletions ucp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,28 @@ async def flush(self):
logger.debug("[Flush] ep: %s" % (hex(self.uid)))
return await comm.flush_ep(self._ep)

def set_close_callback(self, callback_func):
"""Register a user callback function to be called on Endpoint's closing.
Allows the user to register a callback function to be called when the
Endpoint's error callback is called, or during its finalizer if the error
callback is never called.
Once the callback is called, it's not possible to send any more messages.
However, receiving messages may still be possible, as UCP may still have
incoming messages in transit.
Parameters
----------
callback_func: callable
The callback function to be called when the Endpoint's error callback
is called, otherwise called on its finalizer.
Example
>>> ep.set_close_callback(lambda: print("Executing close callback"))
"""
self._ep.set_close_callback(callback_func)


# The following functions initialize and use a single ApplicationContext instance

Expand Down

0 comments on commit addee5e

Please sign in to comment.