Skip to content

Commit

Permalink
Merge branch 'main' into WSMR/retry_busy_worker
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 30, 2022
2 parents 1741ed2 + 5feb171 commit f84e8f7
Show file tree
Hide file tree
Showing 25 changed files with 594 additions and 392 deletions.
3 changes: 0 additions & 3 deletions distributed/cli/dask_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import warnings

import click
from tornado.ioloop import IOLoop

from distributed import Scheduler
from distributed._signals import wait_for_signals
Expand Down Expand Up @@ -186,11 +185,9 @@ def del_pid_file():
resource.setrlimit(resource.RLIMIT_NOFILE, (limit, hard))

async def run():
loop = IOLoop.current()
logger.info("-" * 47)

scheduler = Scheduler(
loop=loop,
security=sec,
host=host,
port=port,
Expand Down
19 changes: 6 additions & 13 deletions distributed/cli/tests/test_dask_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
assert_can_connect_from_everywhere_4_6,
assert_can_connect_locally_4,
popen,
wait_for_log_line,
)


Expand Down Expand Up @@ -66,12 +67,8 @@ def test_dashboard(loop):
pytest.importorskip("bokeh")

with popen(["dask-scheduler"], flush_output=False) as proc:
for line in proc.stdout:
if b"dashboard at" in line:
dashboard_port = int(line.decode().split(":")[-1].strip())
break
else:
assert False # pragma: nocover
line = wait_for_log_line(b"dashboard at", proc.stdout)
dashboard_port = int(line.decode().split(":")[-1].strip())

with Client(f"127.0.0.1:{Scheduler.default_port}", loop=loop):
pass
Expand Down Expand Up @@ -223,13 +220,9 @@ def test_dashboard_port_zero(loop):
["dask-scheduler", "--dashboard-address", ":0"],
flush_output=False,
) as proc:
for line in proc.stdout:
if b"dashboard at" in line:
dashboard_port = int(line.decode().split(":")[-1].strip())
assert dashboard_port != 0
break
else:
assert False # pragma: nocover
line = wait_for_log_line(b"dashboard at", proc.stdout)
dashboard_port = int(line.decode().split(":")[-1].strip())
assert dashboard_port != 0


PRELOAD_TEXT = """
Expand Down
10 changes: 3 additions & 7 deletions distributed/cli/tests/test_dask_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from distributed import Client
from distributed.cli.dask_ssh import main
from distributed.compatibility import MACOS, WINDOWS
from distributed.utils_test import popen
from distributed.utils_test import popen, wait_for_log_line

pytest.importorskip("paramiko")
pytestmark = [
Expand All @@ -30,16 +30,12 @@ def test_ssh_cli_nprocs_renamed_to_nworkers(loop):
# This interrupt is necessary for the cluster to place output into the stdout
# and stderr pipes
proc.send_signal(2)
assert any(
b"renamed to --nworkers" in proc.stdout.readline() for _ in range(15)
)
wait_for_log_line(b"renamed to --nworkers", proc.stdout, max_lines=15)


def test_ssh_cli_nworkers_with_nprocs_is_an_error():
with popen(
["dask-ssh", "localhost", "--nprocs=2", "--nworkers=2"],
flush_output=False,
) as proc:
assert any(
b"Both --nprocs and --nworkers" in proc.stdout.readline() for _ in range(15)
)
wait_for_log_line(b"Both --nprocs and --nworkers", proc.stdout, max_lines=15)
45 changes: 11 additions & 34 deletions distributed/cli/tests/test_dask_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from distributed.deploy.utils import nprocesses_nthreads
from distributed.metrics import time
from distributed.utils import open_port
from distributed.utils_test import gen_cluster, popen, requires_ipv6
from distributed.utils_test import gen_cluster, popen, requires_ipv6, wait_for_log_line


@pytest.mark.parametrize(
Expand Down Expand Up @@ -246,9 +246,7 @@ async def test_nanny_worker_port_range_too_many_workers_raises(s):
],
flush_output=False,
) as worker:
assert any(
b"Not enough ports in range" in worker.stdout.readline() for _ in range(100)
)
wait_for_log_line(b"Not enough ports in range", worker.stdout, max_lines=100)


@pytest.mark.slow
Expand Down Expand Up @@ -282,26 +280,14 @@ async def test_reconnect_deprecated(c, s):
["dask-worker", s.address, "--reconnect"],
flush_output=False,
) as worker:
for _ in range(10):
line = worker.stdout.readline()
print(line)
if b"`--reconnect` option has been removed" in line:
break
else:
raise AssertionError("Message not printed, see stdout")
wait_for_log_line(b"`--reconnect` option has been removed", worker.stdout)
assert worker.wait() == 1

with popen(
["dask-worker", s.address, "--no-reconnect"],
flush_output=False,
) as worker:
for _ in range(10):
line = worker.stdout.readline()
print(line)
if b"flag is deprecated, and will be removed" in line:
break
else:
raise AssertionError("Message not printed, see stdout")
wait_for_log_line(b"flag is deprecated, and will be removed", worker.stdout)
await c.wait_for_workers(1)
await c.shutdown()

Expand Down Expand Up @@ -377,9 +363,7 @@ async def test_nworkers_requires_nanny(s):
["dask-worker", s.address, "--nworkers=2", "--no-nanny"],
flush_output=False,
) as worker:
assert any(
b"Failed to launch worker" in worker.stdout.readline() for _ in range(15)
)
wait_for_log_line(b"Failed to launch worker", worker.stdout, max_lines=15)


@pytest.mark.slow
Expand Down Expand Up @@ -419,9 +403,7 @@ async def test_worker_cli_nprocs_renamed_to_nworkers(c, s):
flush_output=False,
) as worker:
await c.wait_for_workers(2)
assert any(
b"renamed to --nworkers" in worker.stdout.readline() for _ in range(15)
)
wait_for_log_line(b"renamed to --nworkers", worker.stdout, max_lines=15)


@gen_cluster(nthreads=[])
Expand All @@ -430,10 +412,7 @@ async def test_worker_cli_nworkers_with_nprocs_is_an_error(s):
["dask-worker", s.address, "--nprocs=2", "--nworkers=2"],
flush_output=False,
) as worker:
assert any(
b"Both --nprocs and --nworkers" in worker.stdout.readline()
for _ in range(15)
)
wait_for_log_line(b"Both --nprocs and --nworkers", worker.stdout, max_lines=15)


@pytest.mark.slow
Expand Down Expand Up @@ -733,12 +712,10 @@ def test_error_during_startup(monkeypatch, nanny):
) as scheduler:
start = time()
# Wait for the scheduler to be up
while line := scheduler.stdout.readline():
if b"Scheduler at" in line:
break
# Ensure this is not killed by pytest-timeout
if time() - start > 5:
raise TimeoutError("Scheduler failed to start in time.")
wait_for_log_line(b"Scheduler at", scheduler.stdout)
# Ensure this is not killed by pytest-timeout
if time() - start > 5:
raise TimeoutError("Scheduler failed to start in time.")

with popen(
[
Expand Down
2 changes: 1 addition & 1 deletion distributed/comm/tests/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ async def client_communicate(key, delay=0):

@pytest.mark.gpu
@gen_test()
async def test_ucx_client_server():
async def test_ucx_client_server(ucx_loop):
pytest.importorskip("distributed.comm.ucx")
ucp = pytest.importorskip("ucp")

Expand Down
71 changes: 30 additions & 41 deletions distributed/comm/tests/test_ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,7 @@
HOST = "127.0.0.1"


def handle_exception(loop, context):
msg = context.get("exception", context["message"])
print(msg)


# Let's make sure that UCX gets time to cancel
# progress tasks before closing the event loop.
@pytest.fixture()
def event_loop(scope="function"):
loop = asyncio.new_event_loop()
loop.set_exception_handler(handle_exception)
ucp.reset()
yield loop
ucp.reset()
loop.run_until_complete(asyncio.sleep(0))
loop.close()


def test_registered():
def test_registered(ucx_loop):
assert "ucx" in backends
backend = get_backend("ucx")
assert isinstance(backend, ucx.UCXBackend)
Expand All @@ -62,7 +44,7 @@ async def handle_comm(comm):


@gen_test()
async def test_ping_pong():
async def test_ping_pong(ucx_loop):
com, serv_com = await get_comm_pair()
msg = {"op": "ping"}
await com.write(msg)
Expand All @@ -80,7 +62,7 @@ async def test_ping_pong():


@gen_test()
async def test_comm_objs():
async def test_comm_objs(ucx_loop):
comm, serv_comm = await get_comm_pair()

scheme, loc = parse_address(comm.peer_address)
Expand All @@ -93,7 +75,7 @@ async def test_comm_objs():


@gen_test()
async def test_ucx_specific():
async def test_ucx_specific(ucx_loop):
"""
Test concrete UCX API.
"""
Expand Down Expand Up @@ -147,7 +129,7 @@ async def client_communicate(key, delay=0):


@gen_test()
async def test_ping_pong_data():
async def test_ping_pong_data(ucx_loop):
np = pytest.importorskip("numpy")

data = np.ones((10, 10))
Expand All @@ -170,7 +152,7 @@ async def test_ping_pong_data():


@gen_test()
async def test_ucx_deserialize():
async def test_ucx_deserialize(ucx_loop):
# Note we see this error on some systems with this test:
# `socket.gaierror: [Errno -5] No address associated with hostname`
# This may be due to a system configuration issue.
Expand All @@ -196,7 +178,7 @@ async def test_ucx_deserialize():
],
)
@gen_test()
async def test_ping_pong_cudf(g):
async def test_ping_pong_cudf(ucx_loop, g):
# if this test appears after cupy an import error arises
# *** ImportError: /usr/lib/x86_64-linux-gnu/libstdc++.so.6: version `CXXABI_1.3.11'
# not found (required by python3.7/site-packages/pyarrow/../../../libarrow.so.12)
Expand All @@ -221,7 +203,7 @@ async def test_ping_pong_cudf(g):

@pytest.mark.parametrize("shape", [(100,), (10, 10), (4947,)])
@gen_test()
async def test_ping_pong_cupy(shape):
async def test_ping_pong_cupy(ucx_loop, shape):
cupy = pytest.importorskip("cupy")
com, serv_com = await get_comm_pair()

Expand All @@ -240,7 +222,7 @@ async def test_ping_pong_cupy(shape):
@pytest.mark.slow
@pytest.mark.parametrize("n", [int(1e9), int(2.5e9)])
@gen_test()
async def test_large_cupy(n, cleanup):
async def test_large_cupy(ucx_loop, n, cleanup):
cupy = pytest.importorskip("cupy")
com, serv_com = await get_comm_pair()

Expand All @@ -257,7 +239,7 @@ async def test_large_cupy(n, cleanup):


@gen_test()
async def test_ping_pong_numba():
async def test_ping_pong_numba(ucx_loop):
np = pytest.importorskip("numpy")
numba = pytest.importorskip("numba")
import numba.cuda
Expand All @@ -276,7 +258,7 @@ async def test_ping_pong_numba():

@pytest.mark.parametrize("processes", [True, False])
@gen_test()
async def test_ucx_localcluster(processes, cleanup):
async def test_ucx_localcluster(ucx_loop, processes, cleanup):
async with LocalCluster(
protocol="ucx",
host=HOST,
Expand All @@ -297,7 +279,9 @@ async def test_ucx_localcluster(processes, cleanup):

@pytest.mark.slow
@gen_test(timeout=60)
async def test_stress():
async def test_stress(
ucx_loop,
):
da = pytest.importorskip("dask.array")

chunksize = "10 MB"
Expand All @@ -322,15 +306,19 @@ async def test_stress():


@gen_test()
async def test_simple():
async def test_simple(
ucx_loop,
):
async with LocalCluster(protocol="ucx", asynchronous=True) as cluster:
async with Client(cluster, asynchronous=True) as client:
assert cluster.scheduler_address.startswith("ucx://")
assert await client.submit(lambda x: x + 1, 10) == 11


@gen_test()
async def test_cuda_context():
async def test_cuda_context(
ucx_loop,
):
with dask.config.set({"distributed.comm.ucx.create-cuda-context": True}):
async with LocalCluster(
protocol="ucx", n_workers=1, asynchronous=True
Expand All @@ -344,7 +332,9 @@ async def test_cuda_context():


@gen_test()
async def test_transpose():
async def test_transpose(
ucx_loop,
):
da = pytest.importorskip("dask.array")

async with LocalCluster(protocol="ucx", asynchronous=True) as cluster:
Expand All @@ -358,7 +348,7 @@ async def test_transpose():

@pytest.mark.parametrize("port", [0, 1234])
@gen_test()
async def test_ucx_protocol(cleanup, port):
async def test_ucx_protocol(ucx_loop, cleanup, port):
async with Scheduler(protocol="ucx", port=port, dashboard_address=":0") as s:
assert s.address.startswith("ucx://")

Expand All @@ -367,10 +357,9 @@ async def test_ucx_protocol(cleanup, port):
not hasattr(ucp.exceptions, "UCXUnreachable"),
reason="Requires UCX-Py support for UCXUnreachable exception",
)
def test_ucx_unreachable():
if ucp.get_ucx_version() > (1, 12, 0):
with pytest.raises(OSError, match="Timed out trying to connect to"):
Client("ucx://255.255.255.255:12345", timeout=1)
else:
with pytest.raises(ucp.exceptions.UCXError, match="Destination is unreachable"):
Client("ucx://255.255.255.255:12345", timeout=1)
@gen_test()
async def test_ucx_unreachable(
ucx_loop,
):
with pytest.raises(OSError, match="Timed out trying to connect to"):
await Client("ucx://255.255.255.255:12345", timeout=1, asynchronous=True)
Loading

0 comments on commit f84e8f7

Please sign in to comment.