Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functionality for user to register close callback #795

Merged
merged 9 commits into from
Oct 28, 2021
37 changes: 37 additions & 0 deletions tests/test_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest

import ucp


@pytest.mark.asyncio
@pytest.mark.parametrize("server_close_callback", [True, False])
async def test_close_callback(server_close_callback):
endpoint_error_handling = ucp.get_ucx_version() >= (1, 10, 0)
closed = [False]

def _close_callback():
closed[0] = True

async def server_node(ep):
if server_close_callback is True:
ep.set_close_callback(_close_callback)
msg = bytearray(10)
await ep.recv(msg)
if server_close_callback is False:
await ep.close()

async def client_node(port):
ep = await ucp.create_endpoint(
ucp.get_address(), port, endpoint_error_handling=endpoint_error_handling
)
if server_close_callback is False:
ep.set_close_callback(_close_callback)
await ep.send(bytearray(b"0" * 10))
if server_close_callback is True:
await ep.close()

listener = ucp.create_listener(
server_node, endpoint_error_handling=endpoint_error_handling
)
await client_node(listener.port)
assert closed[0] is True
85 changes: 85 additions & 0 deletions ucp/_libs/tests/test_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import functools
import multiprocessing as mp

import pytest

from ucp._libs import ucx_api

mp = mp.get_context("spawn")


def _close_callback(closed):
closed[0] = True


def _server(queue, endpoint_error_handling, server_close_callback):
"""Server that send received message back to the client

Notice, since it is illegal to call progress() in call-back functions,
we use a "chain" of call-back functions.
"""
ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,))
worker = ucx_api.UCXWorker(ctx)

listener_finished = [False]
closed = [False]

# A reference to listener's endpoint is stored to prevent it from going
# out of scope too early.
# ep = None

def _listener_handler(conn_request):
global ep
ep = ucx_api.UCXEndpoint.create_from_conn_request(
worker, conn_request, endpoint_error_handling=endpoint_error_handling,
)
if server_close_callback is True:
ep.set_close_callback(functools.partial(_close_callback, closed))
listener_finished[0] = True

listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler)
queue.put(listener.port)

if server_close_callback is True:
while closed[0] is False:
worker.progress()
assert closed[0] is True
else:
while listener_finished[0] is False:
worker.progress()


def _client(port, endpoint_error_handling, server_close_callback):
ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,))
worker = ucx_api.UCXWorker(ctx)
ep = ucx_api.UCXEndpoint.create(
worker, "localhost", port, endpoint_error_handling=endpoint_error_handling,
)
if server_close_callback is True:
ep.close()
worker.progress()
else:
closed = [False]
ep.set_close_callback(functools.partial(_close_callback, closed))
while closed[0] is False:
worker.progress()


@pytest.mark.parametrize("server_close_callback", [True, False])
def test_close_callback(server_close_callback):
endpoint_error_handling = ucx_api.get_ucx_version() >= (1, 10, 0)

queue = mp.Queue()
server = mp.Process(
target=_server, args=(queue, endpoint_error_handling, server_close_callback),
)
server.start()
port = queue.get()
client = mp.Process(
target=_client, args=(port, endpoint_error_handling, server_close_callback),
)
client.start()
client.join(timeout=10)
server.join(timeout=10)
assert client.exitcode == 0
assert server.exitcode == 0
25 changes: 25 additions & 0 deletions ucp/_libs/ucx_endpoint.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@ from ..exceptions import UCXCanceled, UCXConnectionReset, UCXError
logger = logging.getLogger("ucx")


class UCXEndpointCloseCallback():
def __init__(self):
self._cb_func = None

def run(self):
if self._cb_func is not None:
# Deregister callback to prevent calling from the endpoint error
# callback and again from the finalizer.
cb_func, self._cb_func = self._cb_func, None
cb_func()

def set(self, cb_func):
self._cb_func = cb_func


cdef void _err_cb(void *arg, ucp_ep_h ep, ucs_status_t status) with gil:
cdef UCXEndpoint ucx_ep = <UCXEndpoint> arg
assert ucx_ep.worker.initialized
Expand All @@ -30,6 +45,7 @@ cdef void _err_cb(void *arg, ucp_ep_h ep, ucs_status_t status) with gil:
hex(int(<uintptr_t>ep)), status, status_str
)
)
ucx_ep._endpoint_close_callback.run()
logger.debug(msg)


Expand Down Expand Up @@ -59,6 +75,7 @@ def _ucx_endpoint_finalizer(
bint endpoint_error_handling,
UCXWorker worker,
set inflight_msgs,
object endpoint_close_callback,
):
assert worker.initialized
cdef ucp_ep_h handle = <ucp_ep_h>handle_as_int
Expand Down Expand Up @@ -120,6 +137,8 @@ def _ucx_endpoint_finalizer(
msg = ucs_status_string(UCS_PTR_STATUS(status)).decode("utf-8")
raise UCXError("Error while closing endpoint: %s" % msg)

endpoint_close_callback.run()


cdef class UCXEndpoint(UCXObject):
"""Python representation of `ucp_ep_h`"""
Expand All @@ -128,6 +147,7 @@ cdef class UCXEndpoint(UCXObject):
uintptr_t _status
bint _endpoint_error_handling
set _inflight_msgs
object _endpoint_close_callback

cdef readonly:
UCXWorker worker
Expand All @@ -143,6 +163,7 @@ cdef class UCXEndpoint(UCXObject):
assert worker.initialized
self.worker = worker
self._inflight_msgs = set()
self._endpoint_close_callback = UCXEndpointCloseCallback()

cdef ucp_err_handler_cb_t err_cb
cdef uintptr_t ep_status
Expand Down Expand Up @@ -172,6 +193,7 @@ cdef class UCXEndpoint(UCXObject):
endpoint_error_handling,
worker,
self._inflight_msgs,
self._endpoint_close_callback,
)
worker.add_child(self)

Expand Down Expand Up @@ -305,3 +327,6 @@ cdef class UCXEndpoint(UCXObject):

def unpack_rkey(self, rkey):
return UCXRkey(self, rkey)

def set_close_callback(self, cb_func):
self._endpoint_close_callback.set(cb_func)
22 changes: 22 additions & 0 deletions ucp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,28 @@ async def flush(self):
logger.debug("[Flush] ep: %s" % (hex(self.uid)))
return await comm.flush_ep(self._ep)

def set_close_callback(self, callback_func):
"""Register a user callback function to be called on Endpoint's closing.

Allows the user to register a callback function to be called when the
Endpoint's error callback is called, or during its finalizer if the error
callback is never called.

Once the callback is called, it's not possible to send any more messages.
However, receiving messages may still be possible, as UCP may still have
incoming messages in transit.

Parameters
----------
callback_func: callable
madsbk marked this conversation as resolved.
Show resolved Hide resolved
The callback function to be called when the Endpoint's error callback
is called, otherwise called on its finalizer.

Example
>>> ep.set_close_callback(lambda: print("Executing close callback"))
"""
self._ep.set_close_callback(callback_func)


# The following functions initialize and use a single ApplicationContext instance

Expand Down