From addee5e50f411a1d3e5e4a8c3f2301d63f9d6210 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 28 Oct 2021 13:22:12 +0200 Subject: [PATCH] Add functionality for user to register close callback (#795) * Add functionality for user to register close callback When a user callback is registered, the Endpoint error callback or its finalizer will run the callback to inform the user's application that the Endpoint is terminating and will not accept sending new messages, although receiving may still be possible as UCP may still have some incoming messages in transit. * Add core API close callback test * Add async API close callback test * Fix core test_close_callback Since ep.close() calls worker.progress() it can't be called within the listener callback, therefore we have to progress outside until the callback is called. We must check the processes' exitcodes to be `0`, if they timeout the exitcode is `None`, which can't be checked correctly just with `not process.exitcode`. * Make UCXEndpointCloseCallback a regular Python class * Deregister UCXEndpointCloseCallback callback before calling it * Avoid partial function in test_close_callback --- tests/test_endpoint.py | 37 ++++++++++++++ ucp/_libs/tests/test_endpoint.py | 85 ++++++++++++++++++++++++++++++++ ucp/_libs/ucx_endpoint.pyx | 25 ++++++++++ ucp/core.py | 22 +++++++++ 4 files changed, 169 insertions(+) create mode 100644 tests/test_endpoint.py create mode 100644 ucp/_libs/tests/test_endpoint.py diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py new file mode 100644 index 000000000..332d15da1 --- /dev/null +++ b/tests/test_endpoint.py @@ -0,0 +1,37 @@ +import pytest + +import ucp + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server_close_callback", [True, False]) +async def test_close_callback(server_close_callback): + endpoint_error_handling = ucp.get_ucx_version() >= (1, 10, 0) + closed = [False] + + def _close_callback(): + closed[0] = True + + async def server_node(ep): + if server_close_callback is True: + ep.set_close_callback(_close_callback) + msg = bytearray(10) + await ep.recv(msg) + if server_close_callback is False: + await ep.close() + + async def client_node(port): + ep = await ucp.create_endpoint( + ucp.get_address(), port, endpoint_error_handling=endpoint_error_handling + ) + if server_close_callback is False: + ep.set_close_callback(_close_callback) + await ep.send(bytearray(b"0" * 10)) + if server_close_callback is True: + await ep.close() + + listener = ucp.create_listener( + server_node, endpoint_error_handling=endpoint_error_handling + ) + await client_node(listener.port) + assert closed[0] is True diff --git a/ucp/_libs/tests/test_endpoint.py b/ucp/_libs/tests/test_endpoint.py new file mode 100644 index 000000000..53d2b6d21 --- /dev/null +++ b/ucp/_libs/tests/test_endpoint.py @@ -0,0 +1,85 @@ +import functools +import multiprocessing as mp + +import pytest + +from ucp._libs import ucx_api + +mp = mp.get_context("spawn") + + +def _close_callback(closed): + closed[0] = True + + +def _server(queue, endpoint_error_handling, server_close_callback): + """Server that send received message back to the client + + Notice, since it is illegal to call progress() in call-back functions, + we use a "chain" of call-back functions. + """ + ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,)) + worker = ucx_api.UCXWorker(ctx) + + listener_finished = [False] + closed = [False] + + # A reference to listener's endpoint is stored to prevent it from going + # out of scope too early. + # ep = None + + def _listener_handler(conn_request): + global ep + ep = ucx_api.UCXEndpoint.create_from_conn_request( + worker, conn_request, endpoint_error_handling=endpoint_error_handling, + ) + if server_close_callback is True: + ep.set_close_callback(functools.partial(_close_callback, closed)) + listener_finished[0] = True + + listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) + queue.put(listener.port) + + if server_close_callback is True: + while closed[0] is False: + worker.progress() + assert closed[0] is True + else: + while listener_finished[0] is False: + worker.progress() + + +def _client(port, endpoint_error_handling, server_close_callback): + ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,)) + worker = ucx_api.UCXWorker(ctx) + ep = ucx_api.UCXEndpoint.create( + worker, "localhost", port, endpoint_error_handling=endpoint_error_handling, + ) + if server_close_callback is True: + ep.close() + worker.progress() + else: + closed = [False] + ep.set_close_callback(functools.partial(_close_callback, closed)) + while closed[0] is False: + worker.progress() + + +@pytest.mark.parametrize("server_close_callback", [True, False]) +def test_close_callback(server_close_callback): + endpoint_error_handling = ucx_api.get_ucx_version() >= (1, 10, 0) + + queue = mp.Queue() + server = mp.Process( + target=_server, args=(queue, endpoint_error_handling, server_close_callback), + ) + server.start() + port = queue.get() + client = mp.Process( + target=_client, args=(port, endpoint_error_handling, server_close_callback), + ) + client.start() + client.join(timeout=10) + server.join(timeout=10) + assert client.exitcode == 0 + assert server.exitcode == 0 diff --git a/ucp/_libs/ucx_endpoint.pyx b/ucp/_libs/ucx_endpoint.pyx index 64e8d56d4..83f237d11 100644 --- a/ucp/_libs/ucx_endpoint.pyx +++ b/ucp/_libs/ucx_endpoint.pyx @@ -17,6 +17,21 @@ from ..exceptions import UCXCanceled, UCXConnectionReset, UCXError logger = logging.getLogger("ucx") +class UCXEndpointCloseCallback(): + def __init__(self): + self._cb_func = None + + def run(self): + if self._cb_func is not None: + # Deregister callback to prevent calling from the endpoint error + # callback and again from the finalizer. + cb_func, self._cb_func = self._cb_func, None + cb_func() + + def set(self, cb_func): + self._cb_func = cb_func + + cdef void _err_cb(void *arg, ucp_ep_h ep, ucs_status_t status) with gil: cdef UCXEndpoint ucx_ep = arg assert ucx_ep.worker.initialized @@ -30,6 +45,7 @@ cdef void _err_cb(void *arg, ucp_ep_h ep, ucs_status_t status) with gil: hex(int(ep)), status, status_str ) ) + ucx_ep._endpoint_close_callback.run() logger.debug(msg) @@ -59,6 +75,7 @@ def _ucx_endpoint_finalizer( bint endpoint_error_handling, UCXWorker worker, set inflight_msgs, + object endpoint_close_callback, ): assert worker.initialized cdef ucp_ep_h handle = handle_as_int @@ -120,6 +137,8 @@ def _ucx_endpoint_finalizer( msg = ucs_status_string(UCS_PTR_STATUS(status)).decode("utf-8") raise UCXError("Error while closing endpoint: %s" % msg) + endpoint_close_callback.run() + cdef class UCXEndpoint(UCXObject): """Python representation of `ucp_ep_h`""" @@ -128,6 +147,7 @@ cdef class UCXEndpoint(UCXObject): uintptr_t _status bint _endpoint_error_handling set _inflight_msgs + object _endpoint_close_callback cdef readonly: UCXWorker worker @@ -143,6 +163,7 @@ cdef class UCXEndpoint(UCXObject): assert worker.initialized self.worker = worker self._inflight_msgs = set() + self._endpoint_close_callback = UCXEndpointCloseCallback() cdef ucp_err_handler_cb_t err_cb cdef uintptr_t ep_status @@ -172,6 +193,7 @@ cdef class UCXEndpoint(UCXObject): endpoint_error_handling, worker, self._inflight_msgs, + self._endpoint_close_callback, ) worker.add_child(self) @@ -305,3 +327,6 @@ cdef class UCXEndpoint(UCXObject): def unpack_rkey(self, rkey): return UCXRkey(self, rkey) + + def set_close_callback(self, cb_func): + self._endpoint_close_callback.set(cb_func) diff --git a/ucp/core.py b/ucp/core.py index f820ed377..521d2b103 100644 --- a/ucp/core.py +++ b/ucp/core.py @@ -862,6 +862,28 @@ async def flush(self): logger.debug("[Flush] ep: %s" % (hex(self.uid))) return await comm.flush_ep(self._ep) + def set_close_callback(self, callback_func): + """Register a user callback function to be called on Endpoint's closing. + + Allows the user to register a callback function to be called when the + Endpoint's error callback is called, or during its finalizer if the error + callback is never called. + + Once the callback is called, it's not possible to send any more messages. + However, receiving messages may still be possible, as UCP may still have + incoming messages in transit. + + Parameters + ---------- + callback_func: callable + The callback function to be called when the Endpoint's error callback + is called, otherwise called on its finalizer. + + Example + >>> ep.set_close_callback(lambda: print("Executing close callback")) + """ + self._ep.set_close_callback(callback_func) + # The following functions initialize and use a single ApplicationContext instance