From a0178da84534e4b84828cce7bdc55f9aae72c394 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 18 Oct 2021 09:52:58 -0700 Subject: [PATCH 1/7] Add functionality for user to register close callback When a user callback is registered, the Endpoint error callback or its finalizer will run the callback to inform the user's application that the Endpoint is terminating and will not accept sending new messages, although receiving may still be possible as UCP may still have some incoming messages in transit. --- ucp/_libs/ucx_endpoint.pyx | 29 +++++++++++++++++++++++++++++ ucp/core.py | 22 ++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/ucp/_libs/ucx_endpoint.pyx b/ucp/_libs/ucx_endpoint.pyx index 373002226..9511c4ce9 100644 --- a/ucp/_libs/ucx_endpoint.pyx +++ b/ucp/_libs/ucx_endpoint.pyx @@ -17,6 +17,25 @@ from ..exceptions import UCXCanceled, UCXConnectionReset, UCXError logger = logging.getLogger("ucx") +cdef class UCXEndpointCloseCallback(): + cdef: + object _cb_func + + def __init__(self): + self._cb_func = None + + def run(self): + if self._cb_func is not None: + self._cb_func() + + # Deregister callback to prevent calling from the endpoint error + # callback and again from the finalizer. + self._cb_func = None + + def set(self, cb_func): + self._cb_func = cb_func + + cdef void _err_cb(void *arg, ucp_ep_h ep, ucs_status_t status): cdef UCXEndpoint ucx_ep = arg assert ucx_ep.worker.initialized @@ -30,6 +49,7 @@ cdef void _err_cb(void *arg, ucp_ep_h ep, ucs_status_t status): hex(int(ep)), status, status_str ) ) + ucx_ep._endpoint_close_callback.run() logger.debug(msg) @@ -59,6 +79,7 @@ def _ucx_endpoint_finalizer( bint endpoint_error_handling, UCXWorker worker, set inflight_msgs, + UCXEndpointCloseCallback endpoint_close_callback, ): assert worker.initialized cdef ucp_ep_h handle = handle_as_int @@ -120,6 +141,8 @@ def _ucx_endpoint_finalizer( msg = ucs_status_string(UCS_PTR_STATUS(status)).decode("utf-8") raise UCXError("Error while closing endpoint: %s" % msg) + endpoint_close_callback.run() + cdef class UCXEndpoint(UCXObject): """Python representation of `ucp_ep_h`""" @@ -128,6 +151,7 @@ cdef class UCXEndpoint(UCXObject): uintptr_t _status bint _endpoint_error_handling set _inflight_msgs + UCXEndpointCloseCallback _endpoint_close_callback cdef readonly: UCXWorker worker @@ -143,6 +167,7 @@ cdef class UCXEndpoint(UCXObject): assert worker.initialized self.worker = worker self._inflight_msgs = set() + self._endpoint_close_callback = UCXEndpointCloseCallback() cdef ucp_err_handler_cb_t err_cb cdef uintptr_t ep_status @@ -172,6 +197,7 @@ cdef class UCXEndpoint(UCXObject): endpoint_error_handling, worker, self._inflight_msgs, + self._endpoint_close_callback, ) worker.add_child(self) @@ -305,3 +331,6 @@ cdef class UCXEndpoint(UCXObject): def unpack_rkey(self, rkey): return UCXRkey(self, rkey) + + def set_close_callback(self, cb_func): + self._endpoint_close_callback.set(cb_func) diff --git a/ucp/core.py b/ucp/core.py index f820ed377..521d2b103 100644 --- a/ucp/core.py +++ b/ucp/core.py @@ -862,6 +862,28 @@ async def flush(self): logger.debug("[Flush] ep: %s" % (hex(self.uid))) return await comm.flush_ep(self._ep) + def set_close_callback(self, callback_func): + """Register a user callback function to be called on Endpoint's closing. + + Allows the user to register a callback function to be called when the + Endpoint's error callback is called, or during its finalizer if the error + callback is never called. + + Once the callback is called, it's not possible to send any more messages. + However, receiving messages may still be possible, as UCP may still have + incoming messages in transit. + + Parameters + ---------- + callback_func: callable + The callback function to be called when the Endpoint's error callback + is called, otherwise called on its finalizer. + + Example + >>> ep.set_close_callback(lambda: print("Executing close callback")) + """ + self._ep.set_close_callback(callback_func) + # The following functions initialize and use a single ApplicationContext instance From c99fdceef09d89cdfdb475da60941da8023bef1c Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 18 Oct 2021 14:19:59 -0700 Subject: [PATCH 2/7] Add core API close callback test --- ucp/_libs/tests/test_endpoint.py | 85 ++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 ucp/_libs/tests/test_endpoint.py diff --git a/ucp/_libs/tests/test_endpoint.py b/ucp/_libs/tests/test_endpoint.py new file mode 100644 index 000000000..9cacf4031 --- /dev/null +++ b/ucp/_libs/tests/test_endpoint.py @@ -0,0 +1,85 @@ +import functools +import multiprocessing as mp + +import pytest + +from ucp._libs import ucx_api + +mp = mp.get_context("spawn") + + +def _close_callback(closed): + closed[0] = True + + +def _echo_server(queue, endpoint_error_handling, server_close_callback): + """Server that send received message back to the client + + Notice, since it is illegal to call progress() in call-back functions, + we use a "chain" of call-back functions. + """ + ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,)) + worker = ucx_api.UCXWorker(ctx) + + listener_finished = [False] + closed = [False] + + # A reference to listener's endpoint is stored to prevent it from going + # out of scope too early. + # ep = None + + def _listener_handler(conn_request): + global ep + ep = ucx_api.UCXEndpoint.create_from_conn_request( + worker, conn_request, endpoint_error_handling=endpoint_error_handling, + ) + if server_close_callback is True: + ep.set_close_callback(functools.partial(_close_callback, closed)) + ep.close() + listener_finished[0] = True + + listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) + queue.put(listener.port) + + while listener_finished[0] is False: + worker.progress() + if server_close_callback is True: + assert closed[0] is True + + +def _echo_client(port, endpoint_error_handling, server_close_callback): + ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,)) + worker = ucx_api.UCXWorker(ctx) + ep = ucx_api.UCXEndpoint.create( + worker, "localhost", port, endpoint_error_handling=endpoint_error_handling, + ) + if server_close_callback is True: + ep.close() + worker.progress() + else: + closed = [False] + ep.set_close_callback(functools.partial(_close_callback, closed)) + while closed[0] is False: + worker.progress() + + +@pytest.mark.parametrize("server_close_callback", [True, False]) +def test_close_callback(server_close_callback): + endpoint_error_handling = ucx_api.get_ucx_version() >= (1, 10, 0) + + queue = mp.Queue() + server = mp.Process( + target=_echo_server, + args=(queue, endpoint_error_handling, server_close_callback), + ) + server.start() + port = queue.get() + client = mp.Process( + target=_echo_client, + args=(port, endpoint_error_handling, server_close_callback), + ) + client.start() + client.join(timeout=10) + assert not client.exitcode + server.join(timeout=10) + assert not server.exitcode From c287e238368a0847f45558eabb8107f97d65f786 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 18 Oct 2021 15:05:50 -0700 Subject: [PATCH 3/7] Add async API close callback test --- tests/test_endpoint.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 tests/test_endpoint.py diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py new file mode 100644 index 000000000..9db8826f5 --- /dev/null +++ b/tests/test_endpoint.py @@ -0,0 +1,40 @@ +import functools + +import pytest + +import ucp + + +def _close_callback(closed): + closed[0] = True + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server_close_callback", [True, False]) +async def test_close_callback(server_close_callback): + endpoint_error_handling = ucp.get_ucx_version() >= (1, 10, 0) + closed = [False] + + async def server_node(ep): + if server_close_callback is True: + ep.set_close_callback(functools.partial(_close_callback, closed)) + msg = bytearray(10) + await ep.recv(msg) + if server_close_callback is False: + await ep.close() + + async def client_node(port): + ep = await ucp.create_endpoint( + ucp.get_address(), port, endpoint_error_handling=endpoint_error_handling + ) + if server_close_callback is False: + ep.set_close_callback(functools.partial(_close_callback, closed)) + await ep.send(bytearray(b"0" * 10)) + if server_close_callback is True: + await ep.close() + + listener = ucp.create_listener( + server_node, endpoint_error_handling=endpoint_error_handling + ) + await client_node(listener.port) + assert closed[0] is True From aad69b90c7d34060c2055b2d7b959959afcc39c8 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 27 Oct 2021 09:34:07 -0700 Subject: [PATCH 4/7] Fix core test_close_callback Since ep.close() calls worker.progress() it can't be called within the listener callback, therefore we have to progress outside until the callback is called. We must check the processes' exitcodes to be `0`, if they timeout the exitcode is `None`, which can't be checked correctly just with `not process.exitcode`. --- ucp/_libs/tests/test_endpoint.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/ucp/_libs/tests/test_endpoint.py b/ucp/_libs/tests/test_endpoint.py index 9cacf4031..53d2b6d21 100644 --- a/ucp/_libs/tests/test_endpoint.py +++ b/ucp/_libs/tests/test_endpoint.py @@ -12,7 +12,7 @@ def _close_callback(closed): closed[0] = True -def _echo_server(queue, endpoint_error_handling, server_close_callback): +def _server(queue, endpoint_error_handling, server_close_callback): """Server that send received message back to the client Notice, since it is illegal to call progress() in call-back functions, @@ -35,19 +35,21 @@ def _listener_handler(conn_request): ) if server_close_callback is True: ep.set_close_callback(functools.partial(_close_callback, closed)) - ep.close() listener_finished[0] = True listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) queue.put(listener.port) - while listener_finished[0] is False: - worker.progress() if server_close_callback is True: + while closed[0] is False: + worker.progress() assert closed[0] is True + else: + while listener_finished[0] is False: + worker.progress() -def _echo_client(port, endpoint_error_handling, server_close_callback): +def _client(port, endpoint_error_handling, server_close_callback): ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,)) worker = ucx_api.UCXWorker(ctx) ep = ucx_api.UCXEndpoint.create( @@ -69,17 +71,15 @@ def test_close_callback(server_close_callback): queue = mp.Queue() server = mp.Process( - target=_echo_server, - args=(queue, endpoint_error_handling, server_close_callback), + target=_server, args=(queue, endpoint_error_handling, server_close_callback), ) server.start() port = queue.get() client = mp.Process( - target=_echo_client, - args=(port, endpoint_error_handling, server_close_callback), + target=_client, args=(port, endpoint_error_handling, server_close_callback), ) client.start() client.join(timeout=10) - assert not client.exitcode server.join(timeout=10) - assert not server.exitcode + assert client.exitcode == 0 + assert server.exitcode == 0 From 29433482677e4a0c835a64c372f41b94990284e8 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 28 Oct 2021 00:48:16 -0700 Subject: [PATCH 5/7] Make UCXEndpointCloseCallback a regular Python class --- ucp/_libs/ucx_endpoint.pyx | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/ucp/_libs/ucx_endpoint.pyx b/ucp/_libs/ucx_endpoint.pyx index ad3d8c6cd..f67ba5af4 100644 --- a/ucp/_libs/ucx_endpoint.pyx +++ b/ucp/_libs/ucx_endpoint.pyx @@ -17,10 +17,7 @@ from ..exceptions import UCXCanceled, UCXConnectionReset, UCXError logger = logging.getLogger("ucx") -cdef class UCXEndpointCloseCallback(): - cdef: - object _cb_func - +class UCXEndpointCloseCallback(): def __init__(self): self._cb_func = None @@ -79,7 +76,7 @@ def _ucx_endpoint_finalizer( bint endpoint_error_handling, UCXWorker worker, set inflight_msgs, - UCXEndpointCloseCallback endpoint_close_callback, + object endpoint_close_callback, ): assert worker.initialized cdef ucp_ep_h handle = handle_as_int @@ -151,7 +148,7 @@ cdef class UCXEndpoint(UCXObject): uintptr_t _status bint _endpoint_error_handling set _inflight_msgs - UCXEndpointCloseCallback _endpoint_close_callback + object _endpoint_close_callback cdef readonly: UCXWorker worker From 09fa9d96427ec5582314367f87c74a54d6173925 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 28 Oct 2021 00:53:13 -0700 Subject: [PATCH 6/7] Deregister UCXEndpointCloseCallback callback before calling it --- ucp/_libs/ucx_endpoint.pyx | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ucp/_libs/ucx_endpoint.pyx b/ucp/_libs/ucx_endpoint.pyx index f67ba5af4..83f237d11 100644 --- a/ucp/_libs/ucx_endpoint.pyx +++ b/ucp/_libs/ucx_endpoint.pyx @@ -23,11 +23,10 @@ class UCXEndpointCloseCallback(): def run(self): if self._cb_func is not None: - self._cb_func() - # Deregister callback to prevent calling from the endpoint error # callback and again from the finalizer. - self._cb_func = None + cb_func, self._cb_func = self._cb_func, None + cb_func() def set(self, cb_func): self._cb_func = cb_func From 4ca22e4a60389204fd3f3a3868a0b052d26de07e Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 28 Oct 2021 00:56:09 -0700 Subject: [PATCH 7/7] Avoid partial function in test_close_callback --- tests/test_endpoint.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py index 9db8826f5..332d15da1 100644 --- a/tests/test_endpoint.py +++ b/tests/test_endpoint.py @@ -1,23 +1,20 @@ -import functools - import pytest import ucp -def _close_callback(closed): - closed[0] = True - - @pytest.mark.asyncio @pytest.mark.parametrize("server_close_callback", [True, False]) async def test_close_callback(server_close_callback): endpoint_error_handling = ucp.get_ucx_version() >= (1, 10, 0) closed = [False] + def _close_callback(): + closed[0] = True + async def server_node(ep): if server_close_callback is True: - ep.set_close_callback(functools.partial(_close_callback, closed)) + ep.set_close_callback(_close_callback) msg = bytearray(10) await ep.recv(msg) if server_close_callback is False: @@ -28,7 +25,7 @@ async def client_node(port): ucp.get_address(), port, endpoint_error_handling=endpoint_error_handling ) if server_close_callback is False: - ep.set_close_callback(functools.partial(_close_callback, closed)) + ep.set_close_callback(_close_callback) await ep.send(bytearray(b"0" * 10)) if server_close_callback is True: await ep.close()