-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add functionality for user to register close callback (#795)
* Add functionality for user to register close callback When a user callback is registered, the Endpoint error callback or its finalizer will run the callback to inform the user's application that the Endpoint is terminating and will not accept sending new messages, although receiving may still be possible as UCP may still have some incoming messages in transit. * Add core API close callback test * Add async API close callback test * Fix core test_close_callback Since ep.close() calls worker.progress() it can't be called within the listener callback, therefore we have to progress outside until the callback is called. We must check the processes' exitcodes to be `0`, if they timeout the exitcode is `None`, which can't be checked correctly just with `not process.exitcode`. * Make UCXEndpointCloseCallback a regular Python class * Deregister UCXEndpointCloseCallback callback before calling it * Avoid partial function in test_close_callback
- Loading branch information
Showing
4 changed files
with
169 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import pytest | ||
|
||
import ucp | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("server_close_callback", [True, False]) | ||
async def test_close_callback(server_close_callback): | ||
endpoint_error_handling = ucp.get_ucx_version() >= (1, 10, 0) | ||
closed = [False] | ||
|
||
def _close_callback(): | ||
closed[0] = True | ||
|
||
async def server_node(ep): | ||
if server_close_callback is True: | ||
ep.set_close_callback(_close_callback) | ||
msg = bytearray(10) | ||
await ep.recv(msg) | ||
if server_close_callback is False: | ||
await ep.close() | ||
|
||
async def client_node(port): | ||
ep = await ucp.create_endpoint( | ||
ucp.get_address(), port, endpoint_error_handling=endpoint_error_handling | ||
) | ||
if server_close_callback is False: | ||
ep.set_close_callback(_close_callback) | ||
await ep.send(bytearray(b"0" * 10)) | ||
if server_close_callback is True: | ||
await ep.close() | ||
|
||
listener = ucp.create_listener( | ||
server_node, endpoint_error_handling=endpoint_error_handling | ||
) | ||
await client_node(listener.port) | ||
assert closed[0] is True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import functools | ||
import multiprocessing as mp | ||
|
||
import pytest | ||
|
||
from ucp._libs import ucx_api | ||
|
||
mp = mp.get_context("spawn") | ||
|
||
|
||
def _close_callback(closed): | ||
closed[0] = True | ||
|
||
|
||
def _server(queue, endpoint_error_handling, server_close_callback): | ||
"""Server that send received message back to the client | ||
Notice, since it is illegal to call progress() in call-back functions, | ||
we use a "chain" of call-back functions. | ||
""" | ||
ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,)) | ||
worker = ucx_api.UCXWorker(ctx) | ||
|
||
listener_finished = [False] | ||
closed = [False] | ||
|
||
# A reference to listener's endpoint is stored to prevent it from going | ||
# out of scope too early. | ||
# ep = None | ||
|
||
def _listener_handler(conn_request): | ||
global ep | ||
ep = ucx_api.UCXEndpoint.create_from_conn_request( | ||
worker, conn_request, endpoint_error_handling=endpoint_error_handling, | ||
) | ||
if server_close_callback is True: | ||
ep.set_close_callback(functools.partial(_close_callback, closed)) | ||
listener_finished[0] = True | ||
|
||
listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) | ||
queue.put(listener.port) | ||
|
||
if server_close_callback is True: | ||
while closed[0] is False: | ||
worker.progress() | ||
assert closed[0] is True | ||
else: | ||
while listener_finished[0] is False: | ||
worker.progress() | ||
|
||
|
||
def _client(port, endpoint_error_handling, server_close_callback): | ||
ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,)) | ||
worker = ucx_api.UCXWorker(ctx) | ||
ep = ucx_api.UCXEndpoint.create( | ||
worker, "localhost", port, endpoint_error_handling=endpoint_error_handling, | ||
) | ||
if server_close_callback is True: | ||
ep.close() | ||
worker.progress() | ||
else: | ||
closed = [False] | ||
ep.set_close_callback(functools.partial(_close_callback, closed)) | ||
while closed[0] is False: | ||
worker.progress() | ||
|
||
|
||
@pytest.mark.parametrize("server_close_callback", [True, False]) | ||
def test_close_callback(server_close_callback): | ||
endpoint_error_handling = ucx_api.get_ucx_version() >= (1, 10, 0) | ||
|
||
queue = mp.Queue() | ||
server = mp.Process( | ||
target=_server, args=(queue, endpoint_error_handling, server_close_callback), | ||
) | ||
server.start() | ||
port = queue.get() | ||
client = mp.Process( | ||
target=_client, args=(port, endpoint_error_handling, server_close_callback), | ||
) | ||
client.start() | ||
client.join(timeout=10) | ||
server.join(timeout=10) | ||
assert client.exitcode == 0 | ||
assert server.exitcode == 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters