diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index db57bc579e..fae1e5b6d2 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -563,7 +563,7 @@ async def client_communicate(key, delay=0): @pytest.mark.gpu @gen_test() -async def test_ucx_client_server(): +async def test_ucx_client_server(ucx_loop): pytest.importorskip("distributed.comm.ucx") ucp = pytest.importorskip("ucp") diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index f4a5729826..79fc298284 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -22,25 +22,7 @@ HOST = "127.0.0.1" -def handle_exception(loop, context): - msg = context.get("exception", context["message"]) - print(msg) - - -# Let's make sure that UCX gets time to cancel -# progress tasks before closing the event loop. -@pytest.fixture() -def event_loop(scope="function"): - loop = asyncio.new_event_loop() - loop.set_exception_handler(handle_exception) - ucp.reset() - yield loop - ucp.reset() - loop.run_until_complete(asyncio.sleep(0)) - loop.close() - - -def test_registered(): +def test_registered(ucx_loop): assert "ucx" in backends backend = get_backend("ucx") assert isinstance(backend, ucx.UCXBackend) @@ -62,7 +44,7 @@ async def handle_comm(comm): @gen_test() -async def test_ping_pong(): +async def test_ping_pong(ucx_loop): com, serv_com = await get_comm_pair() msg = {"op": "ping"} await com.write(msg) @@ -80,7 +62,7 @@ async def test_ping_pong(): @gen_test() -async def test_comm_objs(): +async def test_comm_objs(ucx_loop): comm, serv_comm = await get_comm_pair() scheme, loc = parse_address(comm.peer_address) @@ -93,7 +75,7 @@ async def test_comm_objs(): @gen_test() -async def test_ucx_specific(): +async def test_ucx_specific(ucx_loop): """ Test concrete UCX API. """ @@ -147,7 +129,7 @@ async def client_communicate(key, delay=0): @gen_test() -async def test_ping_pong_data(): +async def test_ping_pong_data(ucx_loop): np = pytest.importorskip("numpy") data = np.ones((10, 10)) @@ -170,7 +152,7 @@ async def test_ping_pong_data(): @gen_test() -async def test_ucx_deserialize(): +async def test_ucx_deserialize(ucx_loop): # Note we see this error on some systems with this test: # `socket.gaierror: [Errno -5] No address associated with hostname` # This may be due to a system configuration issue. @@ -196,7 +178,7 @@ async def test_ucx_deserialize(): ], ) @gen_test() -async def test_ping_pong_cudf(g): +async def test_ping_pong_cudf(ucx_loop, g): # if this test appears after cupy an import error arises # *** ImportError: /usr/lib/x86_64-linux-gnu/libstdc++.so.6: version `CXXABI_1.3.11' # not found (required by python3.7/site-packages/pyarrow/../../../libarrow.so.12) @@ -221,7 +203,7 @@ async def test_ping_pong_cudf(g): @pytest.mark.parametrize("shape", [(100,), (10, 10), (4947,)]) @gen_test() -async def test_ping_pong_cupy(shape): +async def test_ping_pong_cupy(ucx_loop, shape): cupy = pytest.importorskip("cupy") com, serv_com = await get_comm_pair() @@ -240,7 +222,7 @@ async def test_ping_pong_cupy(shape): @pytest.mark.slow @pytest.mark.parametrize("n", [int(1e9), int(2.5e9)]) @gen_test() -async def test_large_cupy(n, cleanup): +async def test_large_cupy(ucx_loop, n, cleanup): cupy = pytest.importorskip("cupy") com, serv_com = await get_comm_pair() @@ -257,7 +239,7 @@ async def test_large_cupy(n, cleanup): @gen_test() -async def test_ping_pong_numba(): +async def test_ping_pong_numba(ucx_loop): np = pytest.importorskip("numpy") numba = pytest.importorskip("numba") import numba.cuda @@ -276,7 +258,7 @@ async def test_ping_pong_numba(): @pytest.mark.parametrize("processes", [True, False]) @gen_test() -async def test_ucx_localcluster(processes, cleanup): +async def test_ucx_localcluster(ucx_loop, processes, cleanup): async with LocalCluster( protocol="ucx", host=HOST, @@ -297,7 +279,9 @@ async def test_ucx_localcluster(processes, cleanup): @pytest.mark.slow @gen_test(timeout=60) -async def test_stress(): +async def test_stress( + ucx_loop, +): da = pytest.importorskip("dask.array") chunksize = "10 MB" @@ -322,7 +306,9 @@ async def test_stress(): @gen_test() -async def test_simple(): +async def test_simple( + ucx_loop, +): async with LocalCluster(protocol="ucx", asynchronous=True) as cluster: async with Client(cluster, asynchronous=True) as client: assert cluster.scheduler_address.startswith("ucx://") @@ -330,7 +316,9 @@ async def test_simple(): @gen_test() -async def test_cuda_context(): +async def test_cuda_context( + ucx_loop, +): with dask.config.set({"distributed.comm.ucx.create-cuda-context": True}): async with LocalCluster( protocol="ucx", n_workers=1, asynchronous=True @@ -344,7 +332,9 @@ async def test_cuda_context(): @gen_test() -async def test_transpose(): +async def test_transpose( + ucx_loop, +): da = pytest.importorskip("dask.array") async with LocalCluster(protocol="ucx", asynchronous=True) as cluster: @@ -358,7 +348,7 @@ async def test_transpose(): @pytest.mark.parametrize("port", [0, 1234]) @gen_test() -async def test_ucx_protocol(cleanup, port): +async def test_ucx_protocol(ucx_loop, cleanup, port): async with Scheduler(protocol="ucx", port=port, dashboard_address=":0") as s: assert s.address.startswith("ucx://") @@ -367,10 +357,9 @@ async def test_ucx_protocol(cleanup, port): not hasattr(ucp.exceptions, "UCXUnreachable"), reason="Requires UCX-Py support for UCXUnreachable exception", ) -def test_ucx_unreachable(): - if ucp.get_ucx_version() > (1, 12, 0): - with pytest.raises(OSError, match="Timed out trying to connect to"): - Client("ucx://255.255.255.255:12345", timeout=1) - else: - with pytest.raises(ucp.exceptions.UCXError, match="Destination is unreachable"): - Client("ucx://255.255.255.255:12345", timeout=1) +@gen_test() +async def test_ucx_unreachable( + ucx_loop, +): + with pytest.raises(OSError, match="Timed out trying to connect to"): + await Client("ucx://255.255.255.255:12345", timeout=1, asynchronous=True) diff --git a/distributed/comm/tests/test_ucx_config.py b/distributed/comm/tests/test_ucx_config.py index f1f3f08a3a..baff89e611 100644 --- a/distributed/comm/tests/test_ucx_config.py +++ b/distributed/comm/tests/test_ucx_config.py @@ -22,7 +22,7 @@ @gen_test() -async def test_ucx_config(cleanup): +async def test_ucx_config(ucx_loop, cleanup): ucx = { "nvlink": True, "infiniband": True, @@ -79,7 +79,7 @@ async def test_ucx_config(cleanup): reruns=10, reruns_delay=5, ) -def test_ucx_config_w_env_var(cleanup, loop): +def test_ucx_config_w_env_var(ucx_loop, cleanup, loop): env = os.environ.copy() env["DASK_RMM__POOL_SIZE"] = "1000.00 MB" diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 739cfe39e7..b4742a7c65 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -397,11 +397,7 @@ async def connect(self, address: str, deserialize=True, **connection_args) -> UC init_once() try: ep = await ucp.create_endpoint(ip, port) - except (ucp.exceptions.UCXCloseError, ucp.exceptions.UCXCanceled,) + ( - getattr(ucp.exceptions, "UCXConnectionReset", ()), - getattr(ucp.exceptions, "UCXNotConnected", ()), - getattr(ucp.exceptions, "UCXUnreachable", ()), - ): # type: ignore + except ucp.exceptions.UCXBaseException: raise CommClosedError("Connection closed before handshake completed") return self.comm_class( ep, diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 06b4f95c8c..db4331c5ca 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -460,7 +460,7 @@ def raise_err(): @pytest.mark.parametrize("protocol", ["tcp", "ucx"]) @gen_test() -async def test_nanny_closed_by_keyboard_interrupt(protocol): +async def test_nanny_closed_by_keyboard_interrupt(ucx_loop, protocol): if protocol == "ucx": # Skip if UCX isn't available pytest.importorskip("ucp") diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 9e9330e45e..433a6b41d0 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1384,7 +1384,7 @@ async def test_interface_async(Worker): @pytest.mark.gpu @pytest.mark.parametrize("Worker", [Worker, Nanny]) @gen_test() -async def test_protocol_from_scheduler_address(Worker): +async def test_protocol_from_scheduler_address(ucx_loop, Worker): pytest.importorskip("ucp") async with Scheduler(protocol="ucx", dashboard_address=":0") as s: diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 87d5c83958..e9d0d91e68 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -2153,3 +2153,39 @@ def raises_with_cause( assert re.search( match_cause, str(exc.__cause__) ), f"Pattern ``{match_cause}`` not found in ``{exc.__cause__}``" + + +def ucx_exception_handler(loop, context): + """UCX exception handler for `ucx_loop` during test. + + Prints the exception and its message. + + Parameters + ---------- + loop: object + Reference to the running event loop + context: dict + Dictionary containing exception details. + """ + msg = context.get("exception", context["message"]) + print(msg) + + +# Let's make sure that UCX gets time to cancel +# progress tasks before closing the event loop. +@pytest.fixture(scope="function") +def ucx_loop(): + """Allows UCX to cancel progress tasks before closing event loop. + + When UCX tasks are not completed in time (e.g., by unexpected Endpoint + closure), clean up tasks before closing the event loop to prevent unwanted + errors from being raised. + """ + ucp = pytest.importorskip("ucp") + + loop = asyncio.new_event_loop() + loop.set_exception_handler(ucx_exception_handler) + ucp.reset() + yield loop + ucp.reset() + loop.close()