diff --git a/.github/workflows/test-report.yaml b/.github/workflows/test-report.yaml index 008ef404e46..e5d925cea33 100644 --- a/.github/workflows/test-report.yaml +++ b/.github/workflows/test-report.yaml @@ -8,6 +8,8 @@ on: jobs: test-report: name: Test Report + # Do not run the report job on forks + if: github.repository == 'dask/distributed' || github.event_name == 'workflow_dispatch' runs-on: ubuntu-latest env: GITHUB_TOKEN: ${{ github.token }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0d1689ae23f..329103e7d02 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,7 +39,6 @@ repos: - types-docutils - types-requests - types-paramiko - - types-pkg_resources - types-PyYAML - types-setuptools - types-psutil diff --git a/continuous_integration/recipes/distributed/meta.yaml b/continuous_integration/recipes/distributed/meta.yaml index 4682d77cc5f..0353bdf4c79 100644 --- a/continuous_integration/recipes/distributed/meta.yaml +++ b/continuous_integration/recipes/distributed/meta.yaml @@ -73,7 +73,6 @@ outputs: - toolz >=0.8.2 - tornado >=6.0.3 - zict >=0.1.3 - - setuptools <60.0.0 run_constrained: - distributed-impl >={{ version }} *{{ build_ext }} - openssl !=1.1.1e diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index 4ca4b82601d..e61e90f30be 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -1,3 +1,9 @@ +"""Implementation of the Active Memory Manager. This is a scheduler extension which +sends drop/replicate suggestions to the worker. + +See also :mod:`distributed.worker_memory` and :mod:`distributed.spill`, which implement +spill/pause/terminate mechanics on the Worker side. +""" from __future__ import annotations import logging diff --git a/distributed/cli/tests/test_dask_scheduler.py b/distributed/cli/tests/test_dask_scheduler.py index f9e15607386..d73905401ff 100644 --- a/distributed/cli/tests/test_dask_scheduler.py +++ b/distributed/cli/tests/test_dask_scheduler.py @@ -63,13 +63,13 @@ def test_no_dashboard(loop): def test_dashboard(loop): pytest.importorskip("bokeh") - with popen(["dask-scheduler"]) as proc: - for line in proc.stderr: + with popen(["dask-scheduler"], flush_output=False) as proc: + for line in proc.stdout: if b"dashboard at" in line: dashboard_port = int(line.decode().split(":")[-1].strip()) break else: - raise Exception("dashboard not found") + assert False # pragma: nocover with Client(f"127.0.0.1:{Scheduler.default_port}", loop=loop): pass @@ -217,12 +217,17 @@ def test_scheduler_port_zero(loop): def test_dashboard_port_zero(loop): pytest.importorskip("bokeh") - with popen(["dask-scheduler", "--dashboard-address", ":0"]) as proc: - for line in proc.stderr: + with popen( + ["dask-scheduler", "--dashboard-address", ":0"], + flush_output=False, + ) as proc: + for line in proc.stdout: if b"dashboard at" in line: dashboard_port = int(line.decode().split(":")[-1].strip()) assert dashboard_port != 0 break + else: + assert False # pragma: nocover PRELOAD_TEXT = """ diff --git a/distributed/cli/tests/test_dask_spec.py b/distributed/cli/tests/test_dask_spec.py index 45f88c894b3..06a9de702ac 100644 --- a/distributed/cli/tests/test_dask_spec.py +++ b/distributed/cli/tests/test_dask_spec.py @@ -80,13 +80,17 @@ def test_errors(): '{"foo": "bar"}', "--spec-file", "foo.yaml", - ] + ], + flush_output=False, ) as proc: line = proc.stdout.readline().decode() assert "exactly one" in line assert "--spec" in line and "--spec-file" in line - with popen([sys.executable, "-m", "distributed.cli.dask_spec"]) as proc: + with popen( + [sys.executable, "-m", "distributed.cli.dask_spec"], + flush_output=False, + ) as proc: line = proc.stdout.readline().decode() assert "exactly one" in line assert "--spec" in line and "--spec-file" in line diff --git a/distributed/cli/tests/test_dask_ssh.py b/distributed/cli/tests/test_dask_ssh.py index 055eb754a81..826946d77f8 100644 --- a/distributed/cli/tests/test_dask_ssh.py +++ b/distributed/cli/tests/test_dask_ssh.py @@ -18,23 +18,27 @@ def test_version_option(): assert result.exit_code == 0 +@pytest.mark.slow def test_ssh_cli_nprocs_renamed_to_nworkers(loop): - n_workers = 2 with popen( - ["dask-ssh", f"--nprocs={n_workers}", "--nohost", "localhost"] - ) as cluster: + ["dask-ssh", "--nprocs=2", "--nohost", "localhost"], + flush_output=False, + ) as proc: with Client("tcp://127.0.0.1:8786", timeout="15 seconds", loop=loop) as c: - c.wait_for_workers(n_workers, timeout="15 seconds") + c.wait_for_workers(2, timeout="15 seconds") # This interrupt is necessary for the cluster to place output into the stdout # and stderr pipes - cluster.send_signal(2) - _, stderr = cluster.communicate() - - assert any(b"renamed to --nworkers" in l for l in stderr.splitlines()) + proc.send_signal(2) + assert any( + b"renamed to --nworkers" in proc.stdout.readline() for _ in range(15) + ) def test_ssh_cli_nworkers_with_nprocs_is_an_error(): - with popen(["dask-ssh", "localhost", "--nprocs=2", "--nworkers=2"]) as c: + with popen( + ["dask-ssh", "localhost", "--nprocs=2", "--nworkers=2"], + flush_output=False, + ) as proc: assert any( - b"Both --nprocs and --nworkers" in c.stderr.readline() for i in range(15) + b"Both --nprocs and --nworkers" in proc.stdout.readline() for _ in range(15) ) diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index 04e31847609..e38b4d70be0 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -1,15 +1,10 @@ import asyncio - -import pytest -from click.testing import CliRunner - -pytest.importorskip("requests") - import os from multiprocessing import cpu_count from time import sleep -import requests +import pytest +from click.testing import CliRunner from dask.utils import tmpfile @@ -18,127 +13,99 @@ from distributed.compatibility import LINUX, to_thread from distributed.deploy.utils import nprocesses_nthreads from distributed.metrics import time -from distributed.utils import parse_ports, sync from distributed.utils_test import gen_cluster, popen, requires_ipv6 -def test_nanny_worker_ports(loop): - with popen(["dask-scheduler", "--port", "9359", "--no-dashboard"]): - with popen( - [ - "dask-worker", - "127.0.0.1:9359", - "--host", - "127.0.0.1", - "--worker-port", - "9684", - "--nanny-port", - "5273", - "--no-dashboard", - ] - ): - with Client("127.0.0.1:9359", loop=loop) as c: - start = time() - while True: - d = sync(c.loop, c.scheduler.identity) - if d["workers"]: - break - else: - assert time() - start < 60 - sleep(0.1) - assert ( - d["workers"]["tcp://127.0.0.1:9684"]["nanny"] - == "tcp://127.0.0.1:5273" - ) +@pytest.mark.slow +@gen_cluster(client=True, nthreads=[]) +async def test_nanny_worker_ports(c, s): + with popen( + [ + "dask-worker", + s.address, + "--host", + "127.0.0.1", + "--worker-port", + "9684", + "--nanny-port", + "5273", + "--no-dashboard", + ] + ): + await c.wait_for_workers(1) + d = await c.scheduler.identity() + assert d["workers"]["tcp://127.0.0.1:9684"]["nanny"] == "tcp://127.0.0.1:5273" @pytest.mark.slow -def test_nanny_worker_port_range(loop): - with popen(["dask-scheduler", "--port", "9359", "--no-dashboard"]) as sched: - n_workers = 3 - worker_port = "9684:9686" - nanny_port = "9688:9690" - with popen( - [ - "dask-worker", - "127.0.0.1:9359", - "--nworkers", - f"{n_workers}", - "--host", - "127.0.0.1", - "--worker-port", - worker_port, - "--nanny-port", - nanny_port, - "--no-dashboard", - ] - ): - with Client("127.0.0.1:9359", loop=loop) as c: - start = time() - while len(c.scheduler_info()["workers"]) < n_workers: - sleep(0.1) - assert time() - start < 60 - - def get_port(dask_worker): - return dask_worker.port - - expected_worker_ports = set(parse_ports(worker_port)) - worker_ports = c.run(get_port) - assert set(worker_ports.values()) == expected_worker_ports - - expected_nanny_ports = set(parse_ports(nanny_port)) - nanny_ports = c.run(get_port, nanny=True) - assert set(nanny_ports.values()) == expected_nanny_ports - - -def test_nanny_worker_port_range_too_many_workers_raises(loop): - with popen(["dask-scheduler", "--port", "9359", "--no-dashboard"]): - with popen( - [ - "dask-worker", - "127.0.0.1:9359", - "--nworkers", - "3", - "--host", - "127.0.0.1", - "--worker-port", - "9684:9685", - "--nanny-port", - "9686:9687", - "--no-dashboard", - ] - ) as worker: - assert any( - b"Could not start" in worker.stderr.readline() for _ in range(100) - ) - - -def test_memory_limit(loop): - with popen(["dask-scheduler", "--no-dashboard"]): - with popen( - [ - "dask-worker", - "127.0.0.1:8786", - "--memory-limit", - "2e3MB", - "--no-dashboard", - ] - ): - with Client("127.0.0.1:8786", loop=loop) as c: - while not c.nthreads(): - sleep(0.1) - info = c.scheduler_info() - [d] = info["workers"].values() - assert isinstance(d["memory_limit"], int) - assert d["memory_limit"] == 2e9 +@gen_cluster(client=True, nthreads=[]) +async def test_nanny_worker_port_range(c, s): + with popen( + [ + "dask-worker", + s.address, + "--nworkers", + "3", + "--host", + "127.0.0.1", + "--worker-port", + "9684:9686", + "--nanny-port", + "9688:9690", + "--no-dashboard", + ] + ): + await c.wait_for_workers(3) + worker_ports = await c.run(lambda dask_worker: dask_worker.port) + assert set(worker_ports.values()) == {9684, 9685, 9686} + nanny_ports = await c.run(lambda dask_worker: dask_worker.port, nanny=True) + assert set(nanny_ports.values()) == {9688, 9689, 9690} -def test_no_nanny(loop): - with popen(["dask-scheduler", "--no-dashboard"]): - with popen( - ["dask-worker", "127.0.0.1:8786", "--no-nanny", "--no-dashboard"] - ) as worker: - assert any(b"Registered" in worker.stderr.readline() for i in range(15)) +@gen_cluster(nthreads=[]) +async def test_nanny_worker_port_range_too_many_workers_raises(s): + with popen( + [ + "dask-worker", + s.address, + "--nworkers", + "3", + "--host", + "127.0.0.1", + "--worker-port", + "9684:9685", + "--nanny-port", + "9686:9687", + "--no-dashboard", + ], + flush_output=False, + ) as worker: + assert any(b"Could not start" in worker.stdout.readline() for _ in range(100)) + + +@pytest.mark.slow +@gen_cluster(client=True, nthreads=[]) +async def test_memory_limit(c, s): + with popen( + [ + "dask-worker", + s.address, + "--memory-limit", + "2e3MB", + "--no-dashboard", + ] + ): + await c.wait_for_workers(1) + info = c.scheduler_info() + (d,) = info["workers"].values() + assert isinstance(d["memory_limit"], int) + assert d["memory_limit"] == 2e9 + + +@gen_cluster(client=True, nthreads=[]) +async def test_no_nanny(c, s): + with popen(["dask-worker", s.address, "--no-nanny", "--no-dashboard"]): + await c.wait_for_workers(1) @pytest.mark.slow @@ -161,7 +128,7 @@ async def test_no_reconnect(c, s, nanny): comm.abort() # worker terminates as soon as the connection is aborted - await to_thread(worker.communicate, timeout=5) + await to_thread(worker.wait, timeout=5) assert worker.returncode == 0 @@ -179,7 +146,7 @@ async def test_reconnect(c, s, nanny): ] ) as worker: # roundtrip works - await c.submit(lambda x: x + 1, 10) == 11 + assert await c.submit(lambda x: x + 1, 10) == 11 (comm,) = s.stream_comms.values() comm.abort() @@ -189,53 +156,49 @@ async def test_reconnect(c, s, nanny): # closing the scheduler cleanly does terminate the worker await s.close() - await to_thread(worker.communicate, timeout=5) + await to_thread(worker.wait, timeout=5) assert worker.returncode == 0 -def test_resources(loop): - with popen(["dask-scheduler", "--no-dashboard"]): - with popen( - [ - "dask-worker", - "tcp://127.0.0.1:8786", - "--no-dashboard", - "--resources", - "A=1 B=2,C=3", - ] - ): - with Client("127.0.0.1:8786", loop=loop) as c: - while not c.scheduler_info()["workers"]: - sleep(0.1) - info = c.scheduler_info() - worker = list(info["workers"].values())[0] - assert worker["resources"] == {"A": 1, "B": 2, "C": 3} +@pytest.mark.slow +@gen_cluster(client=True, nthreads=[]) +async def test_resources(c, s): + with popen( + [ + "dask-worker", + s.address, + "--no-dashboard", + "--resources", + "A=1 B=2,C=3", + ] + ): + await c.wait_for_workers(1) + info = c.scheduler_info() + (d,) = info["workers"].values() + assert d["resources"] == {"A": 1, "B": 2, "C": 3} +@pytest.mark.slow @pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) -def test_local_directory(loop, nanny): - with tmpfile() as fn: - with popen(["dask-scheduler", "--no-dashboard"]): - with popen( - [ - "dask-worker", - "127.0.0.1:8786", - nanny, - "--no-dashboard", - "--local-directory", - fn, - ] - ): - with Client("127.0.0.1:8786", loop=loop, timeout=10) as c: - start = time() - while not c.scheduler_info()["workers"]: - sleep(0.1) - assert time() < start + 8 - info = c.scheduler_info() - worker = list(info["workers"].values())[0] - assert worker["local_directory"].startswith(fn) +@gen_cluster(client=True, nthreads=[]) +async def test_local_directory(c, s, nanny, tmpdir): + with popen( + [ + "dask-worker", + s.address, + nanny, + "--no-dashboard", + "--local-directory", + str(tmpdir), + ] + ): + await c.wait_for_workers(1) + info = c.scheduler_info() + (d,) = info["workers"].values() + assert d["local_directory"].startswith(str(tmpdir)) +@pytest.mark.slow @pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) def test_scheduler_file(loop, nanny): with tmpfile() as fn: @@ -250,6 +213,7 @@ def test_scheduler_file(loop, nanny): assert time() < start + 10 +@pytest.mark.slow def test_scheduler_address_env(loop, monkeypatch): monkeypatch.setenv("DASK_SCHEDULER_ADDRESS", "tcp://127.0.0.1:8786") with popen(["dask-scheduler", "--no-dashboard"]): @@ -261,166 +225,165 @@ def test_scheduler_address_env(loop, monkeypatch): assert time() < start + 10 -def test_nworkers_requires_nanny(loop): - with popen(["dask-scheduler", "--no-dashboard"]): - with popen( - ["dask-worker", "127.0.0.1:8786", "--nworkers=2", "--no-nanny"] - ) as worker: - assert any( - b"Failed to launch worker" in worker.stderr.readline() - for i in range(15) - ) +@gen_cluster(nthreads=[]) +async def test_nworkers_requires_nanny(s): + with popen( + ["dask-worker", s.address, "--nworkers=2", "--no-nanny"], + flush_output=False, + ) as worker: + assert any( + b"Failed to launch worker" in worker.stdout.readline() for _ in range(15) + ) -def test_nworkers_negative(loop): - with popen(["dask-scheduler", "--no-dashboard"]): - with popen(["dask-worker", "127.0.0.1:8786", "--nworkers=-1"]): - with Client("tcp://127.0.0.1:8786", loop=loop) as c: - c.wait_for_workers(cpu_count(), timeout="10 seconds") +@pytest.mark.slow +@gen_cluster(client=True, nthreads=[]) +async def test_nworkers_negative(c, s): + with popen(["dask-worker", s.address, "--nworkers=-1"]): + await c.wait_for_workers(cpu_count()) -def test_nworkers_auto(loop): - with popen(["dask-scheduler", "--no-dashboard"]): - with popen(["dask-worker", "127.0.0.1:8786", "--nworkers=auto"]): - with Client("tcp://127.0.0.1:8786", loop=loop) as c: - procs, _ = nprocesses_nthreads() - c.wait_for_workers(procs, timeout="10 seconds") +@pytest.mark.slow +@gen_cluster(client=True, nthreads=[]) +async def test_nworkers_auto(c, s): + with popen(["dask-worker", s.address, "--nworkers=auto"]): + procs, _ = nprocesses_nthreads() + await c.wait_for_workers(procs) -def test_nworkers_expands_name(loop): - with popen(["dask-scheduler", "--no-dashboard"]): - with popen(["dask-worker", "127.0.0.1:8786", "--nworkers", "2", "--name", "0"]): - with popen(["dask-worker", "127.0.0.1:8786", "--nworkers", "2"]): - with Client("tcp://127.0.0.1:8786", loop=loop) as c: - start = time() - while len(c.scheduler_info()["workers"]) < 4: - sleep(0.2) - assert time() < start + 30 +@pytest.mark.slow +@gen_cluster(client=True, nthreads=[]) +async def test_nworkers_expands_name(c, s): + with popen(["dask-worker", s.address, "--nworkers", "2", "--name", "0"]): + with popen(["dask-worker", s.address, "--nworkers", "2"]): + await c.wait_for_workers(4) + info = c.scheduler_info() - info = c.scheduler_info() - names = [d["name"] for d in info["workers"].values()] - foos = [n for n in names if n.startswith("0-")] - assert len(foos) == 2 - assert len(set(names)) == 4 + names = [d["name"] for d in info["workers"].values()] + assert len(names) == len(set(names)) == 4 + zeros = [n for n in names if n.startswith("0-")] + assert len(zeros) == 2 -def test_worker_cli_nprocs_renamed_to_nworkers(loop): - n_workers = 2 - with popen(["dask-scheduler", "--no-dashboard"]): - with popen( - ["dask-worker", "127.0.0.1:8786", f"--nprocs={n_workers}"] - ) as worker: - assert any( - b"renamed to --nworkers" in worker.stderr.readline() for i in range(15) - ) - with Client("tcp://127.0.0.1:8786", loop=loop) as c: - c.wait_for_workers(n_workers, timeout="30 seconds") +@pytest.mark.slow +@gen_cluster(client=True, nthreads=[]) +async def test_worker_cli_nprocs_renamed_to_nworkers(c, s): + with popen( + ["dask-worker", s.address, "--nprocs=2"], + flush_output=False, + ) as worker: + await c.wait_for_workers(2) + assert any( + b"renamed to --nworkers" in worker.stdout.readline() for _ in range(15) + ) -def test_worker_cli_nworkers_with_nprocs_is_an_error(): - with popen(["dask-scheduler", "--no-dashboard"]): - with popen( - ["dask-worker", "127.0.0.1:8786", "--nprocs=2", "--nworkers=2"] - ) as worker: - assert any( - b"Both --nprocs and --nworkers" in worker.stderr.readline() - for i in range(15) - ) +@gen_cluster(nthreads=[]) +async def test_worker_cli_nworkers_with_nprocs_is_an_error(s): + with popen( + ["dask-worker", s.address, "--nprocs=2", "--nworkers=2"], + flush_output=False, + ) as worker: + assert any( + b"Both --nprocs and --nworkers" in worker.stdout.readline() + for _ in range(15) + ) +@pytest.mark.slow @pytest.mark.skipif(not LINUX, reason="Need 127.0.0.2 to mean localhost") @pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) @pytest.mark.parametrize( "listen_address", ["tcp://0.0.0.0:39837", "tcp://127.0.0.2:39837"] ) -def test_contact_listen_address(loop, nanny, listen_address): - with popen(["dask-scheduler", "--no-dashboard"]): - with popen( - [ - "dask-worker", - "127.0.0.1:8786", - nanny, - "--no-dashboard", - "--contact-address", - "tcp://127.0.0.2:39837", - "--listen-address", - listen_address, - ] - ): - with Client("127.0.0.1:8786") as client: - while not client.nthreads(): - sleep(0.1) - info = client.scheduler_info() - assert "tcp://127.0.0.2:39837" in info["workers"] +@gen_cluster(client=True, nthreads=[]) +async def test_contact_listen_address(c, s, nanny, listen_address): + with popen( + [ + "dask-worker", + s.address, + nanny, + "--no-dashboard", + "--contact-address", + "tcp://127.0.0.2:39837", + "--listen-address", + listen_address, + ] + ): + await c.wait_for_workers(1) + info = c.scheduler_info() + assert info["workers"].keys() == {"tcp://127.0.0.2:39837"} - # roundtrip works - assert client.submit(lambda x: x + 1, 10).result() == 11 + # roundtrip works + assert await c.submit(lambda x: x + 1, 10) == 11 - def func(dask_worker): - return dask_worker.listener.listen_address + def func(dask_worker): + return dask_worker.listener.listen_address - assert client.run(func) == {"tcp://127.0.0.2:39837": listen_address} + assert await c.run(func) == {"tcp://127.0.0.2:39837": listen_address} +@pytest.mark.slow @requires_ipv6 @pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) @pytest.mark.parametrize("listen_address", ["tcp://:39838", "tcp://[::1]:39838"]) -def test_listen_address_ipv6(loop, nanny, listen_address): - with popen(["dask-scheduler", "--no-dashboard"]): - with popen( - [ - "dask-worker", - "127.0.0.1:8786", - nanny, - "--no-dashboard", - "--listen-address", - listen_address, - ] - ): - # IPv4 used by default for name of global listener; IPv6 used by default when - # listening only on IPv6. - bind_all = "[::1]" not in listen_address - expected_ip = "127.0.0.1" if bind_all else "[::1]" - expected_name = f"tcp://{expected_ip}:39838" - expected_listen = "tcp://0.0.0.0:39838" if bind_all else listen_address - with Client("127.0.0.1:8786") as client: - while not client.nthreads(): - sleep(0.1) - info = client.scheduler_info() - assert expected_name in info["workers"] - assert client.submit(lambda x: x + 1, 10).result() == 11 +@gen_cluster(client=True, nthreads=[]) +async def test_listen_address_ipv6(c, s, nanny, listen_address): + with popen( + [ + "dask-worker", + s.address, + nanny, + "--no-dashboard", + "--listen-address", + listen_address, + ] + ): + # IPv4 used by default for name of global listener; IPv6 used by default when + # listening only on IPv6. + bind_all = "[::1]" not in listen_address + expected_ip = "127.0.0.1" if bind_all else "[::1]" + expected_name = f"tcp://{expected_ip}:39838" + expected_listen = "tcp://0.0.0.0:39838" if bind_all else listen_address - def func(dask_worker): - return dask_worker.listener.listen_address + await c.wait_for_workers(1) + info = c.scheduler_info() + assert info["workers"].keys() == {expected_name} + assert await c.submit(lambda x: x + 1, 10) == 11 + + def func(dask_worker): + return dask_worker.listener.listen_address - assert client.run(func) == {expected_name: expected_listen} + assert await c.run(func) == {expected_name: expected_listen} +@pytest.mark.slow @pytest.mark.skipif(not LINUX, reason="Need 127.0.0.2 to mean localhost") @pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) @pytest.mark.parametrize("host", ["127.0.0.2", "0.0.0.0"]) -def test_respect_host_listen_address(loop, nanny, host): - with popen(["dask-scheduler", "--no-dashboard"]): - with popen( - ["dask-worker", "127.0.0.1:8786", nanny, "--no-dashboard", "--host", host] - ) as worker: - with Client("127.0.0.1:8786") as client: - while not client.nthreads(): - sleep(0.1) - client.scheduler_info() +@gen_cluster(client=True, nthreads=[]) +async def test_respect_host_listen_address(c, s, nanny, host): + with popen(["dask-worker", s.address, nanny, "--no-dashboard", "--host", host]): + await c.wait_for_workers(1) - # roundtrip works - assert client.submit(lambda x: x + 1, 10).result() == 11 + # roundtrip works + assert await c.submit(lambda x: x + 1, 10) == 11 - def func(dask_worker): - return dask_worker.listener.listen_address + def func(dask_worker): + return dask_worker.listener.listen_address - listen_addresses = client.run(func) - assert all(host in v for v in listen_addresses.values()) + listen_addresses = await c.run(func) + assert all(host in v for v in listen_addresses.values()) -def test_dashboard_non_standard_ports(loop): +@pytest.mark.slow +@gen_cluster( + client=True, nthreads=[], scheduler_kwargs={"dashboard_address": "localhost:8787"} +) +async def test_dashboard_non_standard_ports(c, s): pytest.importorskip("bokeh") + requests = pytest.importorskip("requests") + try: import jupyter_server_proxy # noqa: F401 @@ -428,33 +391,27 @@ def test_dashboard_non_standard_ports(loop): except ImportError: proxy_exists = False - with popen(["dask-scheduler", "--port", "3449"]): - with popen( - [ - "dask-worker", - "tcp://127.0.0.1:3449", - "--dashboard-address", - ":4833", - "--host", - "127.0.0.1", - ] - ): - with Client("127.0.0.1:3449", loop=loop) as c: - c.wait_for_workers(1) - pass - - response = requests.get("http://127.0.0.1:4833/status") - assert response.ok - redirect_resp = requests.get("http://127.0.0.1:4833/main") - redirect_resp.ok - # TEST PROXYING WORKS - if proxy_exists: - url = "http://127.0.0.1:8787/proxy/4833/127.0.0.1/status" - response = requests.get(url) - assert response.ok - - with pytest.raises(Exception): - requests.get("http://localhost:4833/status/") + with popen( + [ + "dask-worker", + s.address, + "--dashboard-address", + ":4833", + "--host", + "127.0.0.1", + ] + ): + await c.wait_for_workers(1) + + response = requests.get("http://127.0.0.1:4833/status") + response.raise_for_status() + # TEST PROXYING WORKS + if proxy_exists: + response = requests.get("http://127.0.0.1:8787/proxy/4833/127.0.0.1/status") + response.raise_for_status() + + with pytest.raises(requests.ConnectionError): + requests.get("http://localhost:4833/status/") def test_version_option(): @@ -493,6 +450,7 @@ def test_bokeh_deprecation(): pass +@pytest.mark.slow @gen_cluster(nthreads=[]) async def test_integer_names(s): with popen(["dask-worker", s.address, "--name", "123"]): @@ -502,6 +460,7 @@ async def test_integer_names(s): assert ws.name == 123 +@pytest.mark.slow @pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"]) @gen_cluster(client=True, nthreads=[]) async def test_worker_class(c, s, tmp_path, nanny): @@ -543,6 +502,7 @@ def worker_type(dask_worker): assert all(name == "MyWorker" for name in worker_types.values()) +@pytest.mark.slow @gen_cluster(nthreads=[], client=True) async def test_preload_config(c, s): # Ensure dask-worker pulls the preload from the Dask config if @@ -553,13 +513,7 @@ def dask_setup(worker): """ env = os.environ.copy() env["DASK_DISTRIBUTED__WORKER__PRELOAD"] = preload_text - with popen( - [ - "dask-worker", - s.address, - ], - env=env, - ): + with popen(["dask-worker", s.address], env=env): await c.wait_for_workers(1) [foo] = (await c.run(lambda dask_worker: dask_worker.foo)).values() assert foo == "setup" diff --git a/distributed/comm/registry.py b/distributed/comm/registry.py index 00b10336a70..47ba730a7d9 100644 --- a/distributed/comm/registry.py +++ b/distributed/comm/registry.py @@ -1,6 +1,29 @@ from __future__ import annotations +import importlib.metadata +import sys from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import Protocol + + +class _EntryPoints(Protocol): + def __call__(self, **kwargs: str) -> Iterable[importlib.metadata.EntryPoint]: + ... + + +if sys.version_info >= (3, 10): + # py3.10 importlib.metadata type annotations are not in mypy yet + # https://github.com/python/typeshed/pull/7331 + _entry_points: _EntryPoints = importlib.metadata.entry_points # type: ignore[assignment] +else: + + def _entry_points( + *, group: str, name: str + ) -> Iterable[importlib.metadata.EntryPoint]: + for ep in importlib.metadata.entry_points().get(group, []): + if ep.name == name: + yield ep class Backend(ABC): @@ -59,40 +82,24 @@ def get_local_address_for(self, loc): backends: dict[str, Backend] = {} -def get_backend(scheme: str, require: bool = True) -> Backend: +def get_backend(scheme: str) -> Backend: """ Get the Backend instance for the given *scheme*. It looks for matching scheme in dask's internal cache, and falls-back to package metadata for the group name ``distributed.comm.backends`` - - Parameters - ---------- - - require : bool - Verify that the backends requirements are properly installed. See - https://setuptools.readthedocs.io/en/latest/pkg_resources.html for more - information. """ backend = backends.get(scheme) - if backend is None: - import pkg_resources - - backend = None - for backend_class_ep in pkg_resources.iter_entry_points( - "distributed.comm.backends", scheme - ): - # resolve and require are equivalent to load - backend_factory = backend_class_ep.resolve() - if require: - backend_class_ep.require() - backend = backend_factory() - - if backend is None: - raise ValueError( - "unknown address scheme %r (known schemes: %s)" - % (scheme, sorted(backends)) - ) - else: - backends[scheme] = backend - return backend + if backend is not None: + return backend + + for backend_class_ep in _entry_points( + name=scheme, group="distributed.comm.backends" + ): + backend = backend_class_ep.load()() + backends[scheme] = backend + return backend + + raise ValueError( + f"unknown address scheme {scheme!r} (known schemes: {sorted(backends)})" + ) diff --git a/distributed/comm/tests/test_comms.py b/distributed/comm/tests/test_comms.py index 46b5ab8d03b..f9ecf5e0726 100644 --- a/distributed/comm/tests/test_comms.py +++ b/distributed/comm/tests/test_comms.py @@ -2,18 +2,15 @@ import os import sys import threading -import types import warnings from functools import partial -import pkg_resources import pytest from tornado import ioloop from tornado.concurrent import Future import dask -import distributed from distributed.comm import ( CommClosedError, asyncio_tcp, @@ -30,7 +27,7 @@ from distributed.comm.registry import backends, get_backend from distributed.metrics import time from distributed.protocol import Serialized, deserialize, serialize, to_serialize -from distributed.utils import get_ip, get_ipv6 +from distributed.utils import get_ip, get_ipv6, mp_context from distributed.utils_test import ( get_cert, get_client_ssl_context, @@ -1313,30 +1310,18 @@ async def test_inproc_adresses(): await check_addresses(a, b) -def test_register_backend_entrypoint(): - # Code adapted from pandas backend entry point testing - # https://github.com/pandas-dev/pandas/blob/2470690b9f0826a8feb426927694fa3500c3e8d2/pandas/tests/plotting/test_backend.py#L50-L76 +def _get_backend_on_path(path): + sys.path.append(os.fsdecode(path)) + return get_backend("udp") - dist = pkg_resources.get_distribution("distributed") - if dist.module_path not in distributed.__file__: - # We are running from a non-installed distributed, and this test is invalid - pytest.skip("Testing a non-installed distributed") - mod = types.ModuleType("dask_udp") - mod.UDPBackend = lambda: 1 - sys.modules[mod.__name__] = mod - - entry_point_name = "distributed.comm.backends" - backends_entry_map = pkg_resources.get_entry_map("distributed") - if entry_point_name not in backends_entry_map: - backends_entry_map[entry_point_name] = dict() - backends_entry_map[entry_point_name]["udp"] = pkg_resources.EntryPoint( - "udp", mod.__name__, attrs=["UDPBackend"], dist=dist +def test_register_backend_entrypoint(tmp_path): + (tmp_path / "dask_udp.py").write_bytes(b"def udp_backend():\n return 1\n") + dist_info = tmp_path / "dask_udp-0.0.0.dist-info" + dist_info.mkdir() + (dist_info / "entry_points.txt").write_bytes( + b"[distributed.comm.backends]\nudp = dask_udp:udp_backend\n" ) - - # The require is disabled here since particularly unit tests may install - # dirty or dev versions which are conflicting with backend entrypoints if - # they are demanding for exact, stable versions. This should not fail the - # test - result = get_backend("udp", require=False) - assert result == 1 + with mp_context.Pool(1) as pool: + assert pool.apply(_get_backend_on_path, args=(tmp_path,)) == 1 + pool.join() diff --git a/distributed/core.py b/distributed/core.py index df57a337d59..6d043c62b1d 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -263,7 +263,10 @@ def set_thread_ident(): @property def status(self): - return self._status + try: + return self._status + except AttributeError: + return Status.undefined @status.setter def status(self, new_status): @@ -398,9 +401,7 @@ def port(self): def identity(self) -> dict[str, str]: return {"type": type(self).__name__, "id": self.id} - def _to_dict( - self, comm: Comm | None = None, *, exclude: Container[str] = () - ) -> dict: + def _to_dict(self, *, exclude: Container[str] = ()) -> dict: """Dictionary representation for debugging purposes. Not type stable and not intended for roundtrips. diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 1dc529b19f8..bab3cb2b1e5 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -49,7 +49,7 @@ import dask from dask import config -from dask.utils import format_bytes, format_time, key_split, parse_timedelta +from dask.utils import format_bytes, format_time, funcname, key_split, parse_timedelta from distributed.dashboard.components import add_periodic_callback from distributed.dashboard.components.shared import ( @@ -914,7 +914,7 @@ def __init__(self, scheduler, **kwargs): tools = "reset, xpan, xwheel_zoom" self.bandwidth = figure( - title="Workers Network Bandwidth", + title="Worker Network Bandwidth (average)", x_axis_type="datetime", tools=tools, x_range=x_range, @@ -946,7 +946,7 @@ def __init__(self, scheduler, **kwargs): self.bandwidth.xgrid.visible = False self.cpu = figure( - title="Workers CPU", + title="Worker CPU Utilization (average)", x_axis_type="datetime", tools=tools, x_range=x_range, @@ -966,7 +966,7 @@ def __init__(self, scheduler, **kwargs): self.cpu.xgrid.visible = False self.memory = figure( - title="Workers Memory", + title="Worker Memory Use (average)", x_axis_type="datetime", tools=tools, x_range=x_range, @@ -987,7 +987,7 @@ def __init__(self, scheduler, **kwargs): self.memory.xgrid.visible = False self.disk = figure( - title="Workers Disk", + title="Worker Disk Bandwidth (average)", x_axis_type="datetime", tools=tools, x_range=x_range, @@ -3499,6 +3499,7 @@ def individual_doc(cls, interval, scheduler, extra, doc, fig_attr="root", **kwar add_periodic_callback(doc, fig, interval) doc.add_root(getattr(fig, fig_attr)) doc.theme = BOKEH_THEME + doc.title = "Dask: " + funcname(cls) def individual_profile_doc(scheduler, extra, doc): diff --git a/distributed/dashboard/tests/test_scheduler_bokeh.py b/distributed/dashboard/tests/test_scheduler_bokeh.py index 2907482e078..5242aa1eb80 100644 --- a/distributed/dashboard/tests/test_scheduler_bokeh.py +++ b/distributed/dashboard/tests/test_scheduler_bokeh.py @@ -951,6 +951,12 @@ async def test_aggregate_action(c, s, a, b): assert ("transfer") in mbk.action_source.data["names"] assert ("compute") in mbk.action_source.data["names"] + [title_line] = [ + line for line in response.body.decode().split("\n") if "" in line + ] + assert "AggregateAction" in title_line + assert "Bokeh" not in title_line + @gen_cluster(client=True, scheduler_kwargs={"dashboard": True}) async def test_compute_per_key(c, s, a, b): diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index db74bf9201e..d3437671054 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -12,7 +12,8 @@ from distributed.nanny import Nanny from distributed.scheduler import Scheduler from distributed.security import Security -from distributed.worker import Worker, parse_memory_limit +from distributed.worker import Worker +from distributed.worker_memory import parse_memory_limit logger = logging.getLogger(__name__) diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 9fb5e5c3693..0da31e10199 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -127,7 +127,8 @@ class WorkerPlugin: ... exc_info=exc_info ... ) - >>> plugin = ErrorLogger() + >>> import logging + >>> plugin = ErrorLogger(logging) >>> client.register_worker_plugin(plugin) # doctest: +SKIP """ diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index d0fd06474e4..2c6c0d5b566 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -512,6 +512,11 @@ properties: description: >- Limit of number of bytes to be spilled on disk. + monitor-interval: + type: string + description: >- + Interval between checks for the spill, pause, and terminate thresholds + http: type: object description: Settings for Dask's embedded HTTP Server diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index e8a87348b40..27642579409 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -149,6 +149,10 @@ distributed: # Set to false for no maximum. max-spill: false + # Interval between checks for the spill, pause, and terminate thresholds. + # The target threshold is checked every time new data is inserted. + monitor-interval: 100ms + http: routes: - distributed.http.worker.prometheus diff --git a/distributed/nanny.py b/distributed/nanny.py index fd7a6882cda..e946655f501 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -13,11 +13,11 @@ from inspect import isawaitable from queue import Empty from time import sleep as sync_sleep -from typing import TYPE_CHECKING, ClassVar, Literal +from typing import TYPE_CHECKING, ClassVar import psutil from tornado import gen -from tornado.ioloop import IOLoop, PeriodicCallback +from tornado.ioloop import IOLoop import dask from dask.system import CPU_COUNT @@ -49,7 +49,12 @@ parse_ports, silence_logging, ) -from distributed.worker import Worker, parse_memory_limit, run +from distributed.worker import Worker, run +from distributed.worker_memory import ( + DeprecatedMemoryManagerAttribute, + DeprecatedMemoryMonitor, + NannyMemoryManager, +) if TYPE_CHECKING: from distributed.diagnostics.plugin import NannyPlugin @@ -89,6 +94,7 @@ class Nanny(ServerNode): _instances: ClassVar[weakref.WeakSet[Nanny]] = weakref.WeakSet() process = None status = Status.undefined + memory_manager: NannyMemoryManager def __init__( self, @@ -103,7 +109,6 @@ def __init__( services=None, name=None, memory_limit="auto", - memory_terminate_fraction: float | Literal[False] | None = None, reconnect=True, validate=False, quiet=False, @@ -192,7 +197,8 @@ def __init__( config_environ = dask.config.get("distributed.nanny.environ", {}) if not isinstance(config_environ, dict): raise TypeError( - f"distributed.nanny.environ configuration must be of type dict. Instead got {type(config_environ)}" + "distributed.nanny.environ configuration must be of type dict. " + f"Instead got {type(config_environ)}" ) self.env = config_environ.copy() for k in self.env: @@ -213,19 +219,12 @@ def __init__( self.worker_kwargs = worker_kwargs self.contact_address = contact_address - self.memory_terminate_fraction = ( - memory_terminate_fraction - if memory_terminate_fraction is not None - else dask.config.get("distributed.worker.memory.terminate") - ) self.services = services self.name = name self.quiet = quiet self.auto_restart = True - self.memory_limit = parse_memory_limit(memory_limit, self.nthreads) - if silence_logs: silence_logging(level=silence_logs) self.silence_logs = silence_logs @@ -250,10 +249,7 @@ def __init__( ) self.scheduler = self.rpc(self.scheduler_addr) - - if self.memory_limit: - pc = PeriodicCallback(self.memory_monitor, 100) - self.periodic_callbacks["memory"] = pc + self.memory_manager = NannyMemoryManager(self, memory_limit=memory_limit) if ( not host @@ -271,6 +267,11 @@ def __init__( Nanny._instances.add(self) self.status = Status.init + # Deprecated attributes; use Nanny.memory_manager.<name> instead + memory_limit = DeprecatedMemoryManagerAttribute() + memory_terminate_fraction = DeprecatedMemoryManagerAttribute() + memory_monitor = DeprecatedMemoryMonitor() + def __repr__(self): return "<Nanny: %s, threads: %d>" % (self.worker_address, self.nthreads) @@ -388,7 +389,7 @@ async def instantiate(self) -> Status: services=self.services, nanny=self.address, name=self.name, - memory_limit=self.memory_limit, + memory_limit=self.memory_manager.memory_limit, reconnect=self.reconnect, resources=self.resources, validate=self.validate, @@ -502,28 +503,6 @@ def _psutil_process(self): return self._psutil_process_obj - def memory_monitor(self): - """Track worker's memory. Restart if it goes above terminate fraction""" - if self.status != Status.running: - return - if self.process is None or self.process.process is None: - return None - process = self.process.process - - try: - proc = self._psutil_process - memory = proc.memory_info().rss - except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied): - return - frac = memory / self.memory_limit - - if self.memory_terminate_fraction and frac > self.memory_terminate_fraction: - logger.warning( - "Worker exceeded %d%% memory budget. Restarting", - 100 * self.memory_terminate_fraction, - ) - process.terminate() - def is_alive(self): return self.process is not None and self.process.is_alive() diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index b10bfcba183..0e0ae003b5f 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -2,10 +2,15 @@ import msgpack +import dask.config + +from distributed.protocol import pickle from distributed.protocol.compression import decompress, maybe_compress from distributed.protocol.serialize import ( + Pickled, Serialize, Serialized, + ToPickle, merge_and_deserialize, msgpack_decode_default, msgpack_encode_default, @@ -16,6 +21,15 @@ logger = logging.getLogger(__name__) +def ensure_memoryview(obj): + """Ensure `obj` is a memoryview of datatype bytes""" + ret = memoryview(obj) + if ret.nbytes: + return ret.cast("B") + else: + return ret + + def dumps( msg, serializers=None, on_error="message", context=None, frame_split_size=None ) -> list: @@ -45,31 +59,59 @@ def _inplace_compress_frames(header, frames): header["compression"] = tuple(compression) + def create_serialized_sub_frames(obj) -> list: + typ = type(obj) + if typ is Serialized: + sub_header, sub_frames = obj.header, obj.frames + else: + sub_header, sub_frames = serialize_and_split( + obj, + serializers=serializers, + on_error=on_error, + context=context, + size=frame_split_size, + ) + _inplace_compress_frames(sub_header, sub_frames) + sub_header["num-sub-frames"] = len(sub_frames) + sub_header = msgpack.dumps( + sub_header, default=msgpack_encode_default, use_bin_type=True + ) + return [sub_header] + sub_frames + + def create_pickled_sub_frames(obj) -> list: + typ = type(obj) + if typ is Pickled: + sub_header, sub_frames = obj.header, obj.frames + else: + sub_frames = [] + sub_header = { + "pickled-obj": pickle.dumps( + obj.data, + # In to support len() and slicing, we convert `PickleBuffer` + # objects to memoryviews of bytes. + buffer_callback=lambda x: sub_frames.append( + ensure_memoryview(x) + ), + ) + } + _inplace_compress_frames(sub_header, sub_frames) + + sub_header["num-sub-frames"] = len(sub_frames) + sub_header = msgpack.dumps(sub_header) + return [sub_header] + sub_frames + frames = [None] def _encode_default(obj): typ = type(obj) if typ is Serialize or typ is Serialized: offset = len(frames) - if typ is Serialized: - sub_header, sub_frames = obj.header, obj.frames - else: - sub_header, sub_frames = serialize_and_split( - obj, - serializers=serializers, - on_error=on_error, - context=context, - size=frame_split_size, - ) - _inplace_compress_frames(sub_header, sub_frames) - sub_header["num-sub-frames"] = len(sub_frames) - frames.append( - msgpack.dumps( - sub_header, default=msgpack_encode_default, use_bin_type=True - ) - ) - frames.extend(sub_frames) + frames.extend(create_serialized_sub_frames(obj)) return {"__Serialized__": offset} + elif typ is ToPickle or typ is Pickled: + offset = len(frames) + frames.extend(create_pickled_sub_frames(obj)) + return {"__Pickled__": offset} else: return msgpack_encode_default(obj) @@ -84,6 +126,8 @@ def _encode_default(obj): def loads(frames, deserialize=True, deserializers=None): """Transform bytestream back into Python value""" + allow_pickle = dask.config.get("distributed.scheduler.pickle") + try: def _decode_default(obj): @@ -105,8 +149,20 @@ def _decode_default(obj): ) else: return Serialized(sub_header, sub_frames) - else: - return msgpack_decode_default(obj) + + offset = obj.get("__Pickled__", 0) + if offset > 0: + sub_header = msgpack.loads(frames[offset]) + offset += 1 + sub_frames = frames[offset : offset + sub_header["num-sub-frames"]] + if allow_pickle: + return pickle.loads(sub_header["pickled-obj"], buffers=sub_frames) + else: + raise ValueError( + "Unpickle on the Scheduler isn't allowed, set `distributed.scheduler.pickle=true`" + ) + + return msgpack_decode_default(obj) return msgpack.loads( frames[0], object_hook=_decode_default, use_list=False, **msgpack_opts diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index c8bfc9dce7f..b4daf5bd658 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -522,8 +522,7 @@ def __hash__(self): class Serialized: - """ - An object that is already serialized into header and frames + """An object that is already serialized into header and frames Normal serialization operations pass these objects through. This is typically used within the scheduler which accepts messages that contain @@ -545,6 +544,54 @@ def __ne__(self, other): return not (self == other) +class ToPickle: + """Mark an object that should be pickled + + Both the scheduler and workers with automatically unpickle this + object on arrival. + + Notice, this requires that the scheduler is allowed to use pickle. + If the configuration option "distributed.scheduler.pickle" is set + to False, the scheduler will raise an exception instead. + """ + + def __init__(self, data): + self.data = data + + def __repr__(self): + return "<ToPickle: %s>" % str(self.data) + + def __eq__(self, other): + return isinstance(other, type(self)) and other.data == self.data + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return hash(self.data) + + +class Pickled: + """An object that is already pickled into header and frames + + Normal pickled objects are unpickled by the scheduler. + """ + + def __init__(self, header, frames): + self.header = header + self.frames = frames + + def __eq__(self, other): + return ( + isinstance(other, type(self)) + and other.header == self.header + and other.frames == self.frames + ) + + def __ne__(self, other): + return not (self == other) + + def nested_deserialize(x): """ Replace all Serialize and Serialized values nested in *x* diff --git a/distributed/protocol/tests/test_to_pickle.py b/distributed/protocol/tests/test_to_pickle.py new file mode 100644 index 00000000000..7db7a5d9738 --- /dev/null +++ b/distributed/protocol/tests/test_to_pickle.py @@ -0,0 +1,35 @@ +from typing import Dict + +import dask.config +from dask.highlevelgraph import HighLevelGraph, MaterializedLayer + +from distributed.client import Client +from distributed.protocol.serialize import ToPickle +from distributed.utils_test import gen_cluster + + +class NonMsgPackSerializableLayer(MaterializedLayer): + """Layer that uses non-msgpack-serializable data""" + + def __dask_distributed_pack__(self, *args, **kwargs): + ret = super().__dask_distributed_pack__(*args, **kwargs) + # Some info that contains a `list`, which msgpack will convert to + # a tuple if getting the chance. + ret["myinfo"] = ["myinfo"] + return ToPickle(ret) + + @classmethod + def __dask_distributed_unpack__(cls, state, *args, **kwargs): + assert state["myinfo"] == ["myinfo"] + return super().__dask_distributed_unpack__(state, *args, **kwargs) + + +@gen_cluster(client=True) +async def test_non_msgpack_serializable_layer(c: Client, s, w1, w2): + with dask.config.set({"distributed.scheduler.allowed-imports": "test_to_pickle"}): + a = NonMsgPackSerializableLayer({"x": 42}) + layers = {"a": a} + dependencies: Dict[str, set] = {"a": set()} + hg = HighLevelGraph(layers, dependencies) + res = await c.get(hg, "x", sync=False) + assert res == 42 diff --git a/distributed/scheduler.py b/distributed/scheduler.py index b33ff206b1e..85fe15b6542 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -56,7 +56,6 @@ from distributed.active_memory_manager import ActiveMemoryManagerExtension, RetireWorker from distributed.batched import BatchedSend from distributed.comm import ( - Comm, get_address_host, normalize_address, resolve_address, @@ -4061,9 +4060,7 @@ def identity(self): } return d - def _to_dict( - self, comm: "Comm | None" = None, *, exclude: "Container[str]" = () - ) -> dict: + def _to_dict(self, *, exclude: "Container[str]" = ()) -> dict: """Dictionary representation for debugging purposes. Not type stable and not intended for roundtrips. diff --git a/distributed/spill.py b/distributed/spill.py index 7cba8161a7d..86a43af5e00 100644 --- a/distributed/spill.py +++ b/distributed/spill.py @@ -2,19 +2,22 @@ import logging import time -from collections.abc import Mapping +from collections.abc import Mapping, MutableMapping, Sized from contextlib import contextmanager from functools import partial -from typing import Any, Literal, NamedTuple, cast +from typing import Any, Literal, NamedTuple, Protocol, cast -import zict from packaging.version import parse as parse_version +import zict + from distributed.protocol import deserialize_bytes, serialize_bytelist from distributed.sizeof import safe_sizeof logger = logging.getLogger(__name__) -has_zict_210 = parse_version(zict.__version__) > parse_version("2.0.0") +has_zict_210 = parse_version(zict.__version__) >= parse_version("2.1.0") +# At the moment of writing, zict 2.2.0 has not been released yet. Support git tip. +has_zict_220 = parse_version(zict.__version__) >= parse_version("2.2.0.dev2") class SpilledSize(NamedTuple): @@ -32,13 +35,43 @@ def __sub__(self, other: SpilledSize) -> SpilledSize: # type: ignore return SpilledSize(self.memory - other.memory, self.disk - other.disk) +class ManualEvictProto(Protocol): + """Duck-type API that a third-party alternative to SpillBuffer must respect (in + addition to MutableMapping) if it wishes to support spilling when the + ``distributed.worker.memory.spill`` threshold is surpassed. + + This is public API. At the moment of writing, Dask-CUDA implements this protocol in + the ProxifyHostFile class. + """ + + @property + def fast(self) -> Sized | bool: + """Access to fast memory. This is normally a MutableMapping, but for the purpose + of the manual eviction API it is just tested for emptiness to know if there is + anything to evict. + """ + ... # pragma: nocover + + def evict(self) -> int: + """Manually evict a key/value pair from fast to slow memory. + Return size of the evicted value in fast memory. + + If the eviction failed for whatever reason, return -1. This method must + guarantee that the key/value pair that caused the issue has been retained in + fast memory and that the problem has been logged internally. + + This method never raises. + """ + ... # pragma: nocover + + # zict.Buffer[str, Any] requires zict >= 2.2.0 class SpillBuffer(zict.Buffer): """MutableMapping that automatically spills out dask key/value pairs to disk when the total size of the stored data exceeds the target. If max_spill is provided the key/value pairs won't be spilled once this threshold has been reached. - Paramaters + Parameters ---------- spill_directory: str Location on disk to write the spill files to @@ -63,14 +96,15 @@ def __init__( ): if max_spill is not False and not has_zict_210: - raise ValueError("zict > 2.0.0 required to set max_weight") + raise ValueError("zict >= 2.1.0 required to set max-spill") - super().__init__( - fast={}, - slow=Slow(spill_directory, max_spill), - n=target, - weight=_in_memory_weight, - ) + slow: MutableMapping[str, Any] = Slow(spill_directory, max_spill) + if has_zict_220: + # If a value is still in use somewhere on the worker since the last time it + # was unspilled, don't duplicate it + slow = zict.Cache(slow, zict.WeakValueMapping()) + + super().__init__(fast={}, slow=slow, n=target, weight=_in_memory_weight) self.last_logged = 0 self.min_log_interval = min_log_interval self.logged_pickle_errors = set() # keys logged with pickle error @@ -163,11 +197,14 @@ def __setitem__(self, key: str, value: Any) -> None: assert key not in self.slow def evict(self) -> int: - """Manually evict the oldest key/value pair, even if target has not been reached. - Returns sizeof(value). + """Implementation of :meth:`ManualEvictProto.evict`. + + Manually evict the oldest key/value pair, even if target has not been + reached. Returns sizeof(value). If the eviction failed (value failed to pickle, disk full, or max_spill exceeded), return -1; the key/value pair that caused the issue will remain in - fast. This method never raises. + fast. The exception has been logged internally. + This method never raises. """ try: with self.handle_errors(None): @@ -204,7 +241,8 @@ def spilled_total(self) -> SpilledSize: The two may differ substantially, e.g. if sizeof() is inaccurate or in case of compression. """ - return cast(Slow, self.slow).total_weight + slow = cast(zict.Cache, self.slow).data if has_zict_220 else self.slow + return cast(Slow, slow).total_weight def _in_memory_weight(key: str, value: Any) -> int: @@ -224,6 +262,7 @@ class HandledError(Exception): pass +# zict.Func[str, Any] requires zict >= 2.2.0 class Slow(zict.Func): max_weight: int | Literal[False] weight_by_key: dict[str, SpilledSize] diff --git a/distributed/system.py b/distributed/system.py index 2b032a34024..ad981e8b1cf 100644 --- a/distributed/system.py +++ b/distributed/system.py @@ -5,7 +5,7 @@ __all__ = ("memory_limit", "MEMORY_LIMIT") -def memory_limit(): +def memory_limit() -> int: """Get the memory limit (in bytes) for this system. Takes the minimum value from the following locations: diff --git a/distributed/tests/test_active_memory_manager.py b/distributed/tests/test_active_memory_manager.py index 8a909701e2c..8fcc3f31ced 100644 --- a/distributed/tests/test_active_memory_manager.py +++ b/distributed/tests/test_active_memory_manager.py @@ -5,6 +5,7 @@ import random from contextlib import contextmanager from time import sleep +from typing import Literal import pytest @@ -43,7 +44,13 @@ def assert_amm_log(expect: list[str]): class DemoPolicy(ActiveMemoryManagerPolicy): """Drop or replicate a key n times""" - def __init__(self, action, key, n, candidates): + def __init__( + self, + action: Literal["drop", "replicate"], + key: str, + n: int, + candidates: list[int] | None, + ): self.action = action self.key = key self.n = n @@ -63,7 +70,14 @@ def run(self): yield self.action, ts, candidates -def demo_config(action, key="x", n=10, candidates=None, start=False, interval=0.1): +def demo_config( + action: Literal["drop", "replicate"], + key: str = "x", + n: int = 10, + candidates: list[int] | None = None, + start: bool = False, + interval: float = 0.1, +): """Create a dask config for AMM with DemoPolicy""" return { "distributed.scheduler.active-memory-manager.start": start, @@ -77,6 +91,8 @@ def demo_config(action, key="x", n=10, candidates=None, start=False, interval=0. "candidates": candidates, }, ], + # If pause is required, do it manually by setting Worker.status = Status.paused + "distributed.worker.memory.pause": False, } @@ -351,7 +367,7 @@ async def test_drop_from_worker_with_least_free_memory(c, s, *nannies): @gen_cluster( nthreads=[("", 1)] * 8, client=True, - config=demo_config("drop", n=1, candidates={5, 6}), + config=demo_config("drop", n=1, candidates=[5, 6]), ) async def test_drop_with_candidates(c, s, *workers): futures = await c.scatter({"x": 1}, broadcast=True) @@ -363,7 +379,7 @@ async def test_drop_with_candidates(c, s, *workers): await asyncio.sleep(0.01) -@gen_cluster(client=True, config=demo_config("drop", candidates=set())) +@gen_cluster(client=True, config=demo_config("drop", candidates=[])) async def test_drop_with_empty_candidates(c, s, a, b): """Key is not dropped as the plugin proposes an empty set of candidates, not to be confused with None @@ -375,7 +391,9 @@ async def test_drop_with_empty_candidates(c, s, a, b): @gen_cluster( - client=True, nthreads=[("", 1)] * 3, config=demo_config("drop", candidates={2}) + client=True, + nthreads=[("", 1)] * 3, + config=demo_config("drop", candidates=[2]), ) async def test_drop_from_candidates_without_key(c, s, *workers): """Key is not dropped as none of the candidates hold a replica""" @@ -390,7 +408,7 @@ async def test_drop_from_candidates_without_key(c, s, *workers): assert s.tasks["x"].who_has == {ws0, ws1} -@gen_cluster(client=True, config=demo_config("drop", candidates={0})) +@gen_cluster(client=True, config=demo_config("drop", candidates=[0])) async def test_drop_with_bad_candidates(c, s, a, b): """Key is not dropped as all candidates hold waiter tasks""" ws0, ws1 = s.workers.values() # Not necessarily a, b; it could be b, a! @@ -404,18 +422,13 @@ async def test_drop_with_bad_candidates(c, s, a, b): assert s.tasks["x"].who_has == {ws0, ws1} -@gen_cluster( - client=True, - nthreads=[("", 1)] * 10, - config=demo_config("drop", n=1), - worker_kwargs={"memory_monitor_interval": "20ms"}, -) +@gen_cluster(client=True, nthreads=[("", 1)] * 10, config=demo_config("drop", n=1)) async def test_drop_prefers_paused_workers(c, s, *workers): x = await c.scatter({"x": 1}, broadcast=True) ts = s.tasks["x"] assert len(ts.who_has) == 10 ws = s.workers[workers[3].address] - workers[3].memory_pause_fraction = 1e-15 + workers[3].status = Status.paused while ws.status != Status.paused: await asyncio.sleep(0.01) @@ -426,11 +439,7 @@ async def test_drop_prefers_paused_workers(c, s, *workers): @pytest.mark.slow -@gen_cluster( - client=True, - config=demo_config("drop"), - worker_kwargs={"memory_monitor_interval": "20ms"}, -) +@gen_cluster(client=True, config=demo_config("drop")) async def test_drop_with_paused_workers_with_running_tasks_1(c, s, a, b): """If there is exactly 1 worker that holds a replica of a task that isn't paused or retiring, and there are 1+ paused/retiring workers with the same task, don't drop @@ -445,7 +454,7 @@ async def test_drop_with_paused_workers_with_running_tasks_1(c, s, a, b): while "y" not in a.tasks or a.tasks["y"].state != "executing": await asyncio.sleep(0.01) - a.memory_pause_fraction = 1e-15 + a.status = Status.paused while s.workers[a.address].status != Status.paused: await asyncio.sleep(0.01) assert a.tasks["y"].state == "executing" @@ -455,11 +464,7 @@ async def test_drop_with_paused_workers_with_running_tasks_1(c, s, a, b): assert len(s.tasks["x"].who_has) == 2 -@gen_cluster( - client=True, - config=demo_config("drop"), - worker_kwargs={"memory_monitor_interval": "20ms"}, -) +@gen_cluster(client=True, config=demo_config("drop")) async def test_drop_with_paused_workers_with_running_tasks_2(c, s, a, b): """If there is exactly 1 worker that holds a replica of a task that isn't paused or retiring, and there are 1+ paused/retiring workers with the same task, don't drop @@ -470,7 +475,7 @@ async def test_drop_with_paused_workers_with_running_tasks_2(c, s, a, b): b is running and has no dependent tasks """ x = (await c.scatter({"x": 1}, broadcast=True))["x"] - a.memory_pause_fraction = 1e-15 + a.status = Status.paused while s.workers[a.address].status != Status.paused: await asyncio.sleep(0.01) @@ -481,11 +486,7 @@ async def test_drop_with_paused_workers_with_running_tasks_2(c, s, a, b): @pytest.mark.slow @pytest.mark.parametrize("pause", [True, False]) -@gen_cluster( - client=True, - config=demo_config("drop"), - worker_kwargs={"memory_monitor_interval": "20ms"}, -) +@gen_cluster(client=True, config=demo_config("drop")) async def test_drop_with_paused_workers_with_running_tasks_3_4(c, s, a, b, pause): """If there is exactly 1 worker that holds a replica of a task that isn't paused or retiring, and there are 1+ paused/retiring workers with the same task, don't drop @@ -505,8 +506,8 @@ async def test_drop_with_paused_workers_with_running_tasks_3_4(c, s, a, b, pause await asyncio.sleep(0.01) if pause: - a.memory_pause_fraction = 1e-15 - b.memory_pause_fraction = 1e-15 + a.status = Status.paused + b.status = Status.paused while any(ws.status != Status.paused for ws in s.workers.values()): await asyncio.sleep(0.01) @@ -519,12 +520,7 @@ async def test_drop_with_paused_workers_with_running_tasks_3_4(c, s, a, b, pause @pytest.mark.slow -@gen_cluster( - client=True, - nthreads=[("", 1)] * 3, - config=demo_config("drop"), - worker_kwargs={"memory_monitor_interval": "20ms"}, -) +@gen_cluster(client=True, nthreads=[("", 1)] * 3, config=demo_config("drop")) async def test_drop_with_paused_workers_with_running_tasks_5(c, s, w1, w2, w3): """If there is exactly 1 worker that holds a replica of a task that isn't paused or retiring, and there are 1+ paused/retiring workers with the same task, don't drop @@ -549,7 +545,7 @@ def executing() -> bool: while not executing(): await asyncio.sleep(0.01) - w1.memory_pause_fraction = 1e-15 + w1.status = Status.paused while s.workers[w1.address].status != Status.paused: await asyncio.sleep(0.01) assert executing() @@ -635,7 +631,7 @@ async def test_replicate_to_worker_with_most_free_memory(c, s, *nannies): @gen_cluster( nthreads=[("", 1)] * 8, client=True, - config=demo_config("replicate", n=1, candidates={5, 6}), + config=demo_config("replicate", n=1, candidates=[5, 6]), ) async def test_replicate_with_candidates(c, s, *workers): wss = list(s.workers.values()) @@ -647,7 +643,7 @@ async def test_replicate_with_candidates(c, s, *workers): await asyncio.sleep(0.01) -@gen_cluster(client=True, config=demo_config("replicate", candidates=set())) +@gen_cluster(client=True, config=demo_config("replicate", candidates=[])) async def test_replicate_with_empty_candidates(c, s, a, b): """Key is not replicated as the plugin proposes an empty set of candidates, not to be confused with None @@ -658,7 +654,7 @@ async def test_replicate_with_empty_candidates(c, s, a, b): assert len(s.tasks["x"].who_has) == 1 -@gen_cluster(client=True, config=demo_config("replicate", candidates={0})) +@gen_cluster(client=True, config=demo_config("replicate", candidates=[0])) async def test_replicate_to_candidates_with_key(c, s, a, b): """Key is not replicated as all candidates already hold replicas""" ws0, ws1 = s.workers.values() # Not necessarily a, b; it could be b, a! @@ -668,14 +664,9 @@ async def test_replicate_to_candidates_with_key(c, s, a, b): assert s.tasks["x"].who_has == {ws0} -@gen_cluster( - client=True, - nthreads=[("", 1)] * 3, - config=demo_config("replicate"), - worker_kwargs={"memory_monitor_interval": "20ms"}, -) +@gen_cluster(client=True, nthreads=[("", 1)] * 3, config=demo_config("replicate")) async def test_replicate_avoids_paused_workers_1(c, s, w0, w1, w2): - w1.memory_pause_fraction = 1e-15 + w1.status = Status.paused while s.workers[w1.address].status != Status.paused: await asyncio.sleep(0.01) @@ -687,13 +678,9 @@ async def test_replicate_avoids_paused_workers_1(c, s, w0, w1, w2): assert "x" not in w1.data -@gen_cluster( - client=True, - config=demo_config("replicate"), - worker_kwargs={"memory_monitor_interval": "20ms"}, -) +@gen_cluster(client=True, config=demo_config("replicate")) async def test_replicate_avoids_paused_workers_2(c, s, a, b): - b.memory_pause_fraction = 1e-15 + b.status = Status.paused while s.workers[b.address].status != Status.paused: await asyncio.sleep(0.01) @@ -892,13 +879,14 @@ async def test_RetireWorker_no_recipients(c, s, w1, w2, w3, w4): "distributed.scheduler.active-memory-manager.start": True, "distributed.scheduler.active-memory-manager.interval": 999, "distributed.scheduler.active-memory-manager.policies": [], + "distributed.worker.memory.pause": False, }, ) async def test_RetireWorker_all_recipients_are_paused(c, s, a, b): ws_a = s.workers[a.address] ws_b = s.workers[b.address] - b.memory_pause_fraction = 1e-15 + b.status = Status.paused while ws_b.status != Status.paused: await asyncio.sleep(0.01) diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index 31062a40039..ef53e9e1ccb 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -1,10 +1,7 @@ import asyncio from unittest import mock -import pytest - import distributed -from distributed import Worker from distributed.core import CommClosedError from distributed.utils_test import _LockedCommPool, gen_cluster, inc, slowinc @@ -208,58 +205,4 @@ async def wait_and_raise(*args, **kwargs): await asyncio.sleep(0.01) # Everything should still be executing as usual after this - await c.submit(sum, c.map(inc, range(10))) == sum(map(inc, range(10))) - - -class LargeButForbiddenSerialization: - def __reduce__(self): - raise RuntimeError("I will never serialize!") - - def __sizeof__(self) -> int: - """Ensure this is immediately tried to spill""" - return 1_000_000_000_000 - - -def test_ensure_spilled_immediately(tmpdir): - """See also test_value_raises_during_spilling""" - import sys - - from distributed.spill import SpillBuffer - - mem_target = 1000 - buf = SpillBuffer(tmpdir, target=mem_target) - buf["key"] = 1 - - obj = LargeButForbiddenSerialization() - assert sys.getsizeof(obj) > mem_target - with pytest.raises( - TypeError, - match=f"Could not serialize object of type {LargeButForbiddenSerialization.__name__}", - ): - buf["error"] = obj - - -@gen_cluster(client=True, nthreads=[]) -async def test_value_raises_during_spilling(c, s): - """See also test_ensure_spilled_immediately""" - - # Use a worker with a default memory limit - async with Worker( - s.address, - ) as w: - - def produce_evil_data(): - return LargeButForbiddenSerialization() - - fut = c.submit(produce_evil_data) - - await wait_for_state(fut.key, "error", w) - - with pytest.raises( - TypeError, - match=f"Could not serialize object of type {LargeButForbiddenSerialization.__name__}", - ): - await fut - - # Everything should still be executing as usual after this - await c.submit(sum, c.map(inc, range(10))) == sum(map(inc, range(10))) + assert await c.submit(sum, c.map(inc, range(10))) == sum(map(inc, range(10))) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 383558ca54c..f9558d2dd3f 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5878,14 +5878,14 @@ def bad_fn(x): @gen_cluster( client=True, nthreads=[("", 1)] * 10, - worker_kwargs={"memory_monitor_interval": "20ms"}, + config={"distributed.worker.memory.pause": False}, ) async def test_scatter_and_replicate_avoid_paused_workers( c, s, *workers, workers_arg, direct, broadcast ): paused_workers = [w for i, w in enumerate(workers) if i not in (3, 7)] for w in paused_workers: - w.memory_pause_fraction = 1e-15 + w.status = Status.paused while any(s.workers[w.address].status != Status.paused for w in paused_workers): await asyncio.sleep(0.01) diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index afc2dba7cd3..934e52c4fe7 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -6,7 +6,6 @@ import random import sys from contextlib import suppress -from time import sleep from unittest import mock import psutil @@ -27,7 +26,7 @@ from distributed.metrics import time from distributed.protocol.pickle import dumps from distributed.utils import TimeoutError, parse_ports -from distributed.utils_test import captured_logger, gen_cluster, gen_test, inc +from distributed.utils_test import captured_logger, gen_cluster, gen_test pytestmark = pytest.mark.ci1 @@ -265,55 +264,25 @@ async def test_nanny_timeout(c, s, a): @gen_cluster( - nthreads=[("127.0.0.1", 1)], - client=True, - Worker=Nanny, - worker_kwargs={"memory_limit": "400 MiB"}, -) -async def test_nanny_terminate(c, s, a): - def leak(): - L = [] - while True: - L.append(b"0" * 5_000_000) - sleep(0.01) - - before = a.process.pid - with captured_logger(logging.getLogger("distributed.nanny")) as logger: - future = c.submit(leak) - while a.process.pid == before: - await asyncio.sleep(0.01) - - out = logger.getvalue() - assert "restart" in out.lower() - assert "memory" in out.lower() - - -@gen_cluster( - nthreads=[("127.0.0.1", 1)] * 8, + nthreads=[("", 1)] * 8, client=True, - Worker=Worker, clean_kwargs={"threads": False}, + config={"distributed.worker.memory.pause": False}, ) -async def test_throttle_outgoing_connections(c, s, a, *workers): - # But a bunch of small data on worker a - await c.run(lambda: logging.getLogger("distributed.worker").setLevel(logging.DEBUG)) +async def test_throttle_outgoing_connections(c, s, a, *other_workers): + # Put a bunch of small data on worker a + logging.getLogger("distributed.worker").setLevel(logging.DEBUG) remote_data = c.map( lambda x: b"0" * 10000, range(10), pure=False, workers=[a.address] ) await wait(remote_data) - def pause(dask_worker): - # Patch paused and memory_monitor on the one worker - # This is is very fragile, since a refactor of memory_monitor to - # remove _memory_monitoring will break this test. - dask_worker._memory_monitoring = True - dask_worker.status = Status.paused - dask_worker.outgoing_current_count = 2 + a.status = Status.paused + a.outgoing_current_count = 2 - await c.run(pause, workers=[a.address]) requests = [ await a.get_data(await w.rpc.connect(w.address), keys=[f.key], who=w.address) - for w in workers + for w in other_workers for f in remote_data ] await wait(requests) @@ -322,36 +291,13 @@ def pause(dask_worker): assert "throttling" in wlogs.lower() -@gen_cluster(nthreads=[], client=True) -async def test_avoid_memory_monitor_if_zero_limit(c, s): - nanny = await Nanny(s.address, loop=s.loop, memory_limit=0) - typ = await c.run(lambda dask_worker: type(dask_worker.data)) - assert typ == {nanny.worker_address: dict} - pcs = await c.run(lambda dask_worker: list(dask_worker.periodic_callbacks)) - assert "memory" not in pcs - assert "memory" not in nanny.periodic_callbacks - - future = c.submit(inc, 1) - assert await future == 2 - await asyncio.sleep(0.02) - - await c.submit(inc, 2) # worker doesn't pause - - await nanny.close() - - -@gen_cluster(nthreads=[], client=True) -async def test_scheduler_address_config(c, s): +@gen_cluster(nthreads=[]) +async def test_scheduler_address_config(s): with dask.config.set({"scheduler-address": s.address}): - nanny = await Nanny(loop=s.loop) - assert nanny.scheduler.address == s.address - - start = time() - while not s.workers: - await asyncio.sleep(0.1) - assert time() < start + 10 - - await nanny.close() + async with Nanny() as nanny: + assert nanny.scheduler.address == s.address + while not s.workers: + await asyncio.sleep(0.01) @pytest.mark.slow @@ -421,14 +367,6 @@ async def test_environment_variable_config(c, s, monkeypatch): assert results[n.worker_address]["D"] == "123" -@gen_cluster(nthreads=[], client=True) -async def test_data_types(c, s): - w = await Nanny(s.address, data=dict) - r = await c.run(lambda dask_worker: type(dask_worker.data)) - assert r[w.worker_address] == dict - await w.close() - - @gen_cluster(nthreads=[]) async def test_local_directory(s): with tmpfile() as fn: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 2f8f8cb3a92..2750868dc01 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -20,7 +20,15 @@ from dask import delayed from dask.utils import apply, parse_timedelta, stringify, tmpfile, typename -from distributed import Client, Nanny, Worker, fire_and_forget, wait +from distributed import ( + Client, + Lock, + Nanny, + SchedulerPlugin, + Worker, + fire_and_forget, + wait, +) from distributed.compatibility import LINUX, WINDOWS from distributed.core import ConnectionPool, Status, clean_exception, connect, rpc from distributed.metrics import time @@ -2179,7 +2187,6 @@ async def test_gather_allow_worker_reconnect( """ # GH3246 if reschedule_different_worker: - from distributed.diagnostics.plugin import SchedulerPlugin class SwitchRestrictions(SchedulerPlugin): def __init__(self, scheduler): @@ -2192,8 +2199,6 @@ def transition(self, key, start, finish, **kwargs): plugin = SwitchRestrictions(s) s.add_plugin(plugin) - from distributed import Lock - b_address = b.address def inc_slow(x, lock): @@ -2215,8 +2220,9 @@ def reducer(*args): def finalizer(addr): if swap_data_insert_order: w = get_worker() - new_data = {k: w.data[k] for k in list(w.data.keys())[::-1]} - w.data = new_data + new_data = dict(reversed(list(w.data.items()))) + w.data.clear() + w.data.update(new_data) return addr z = c.submit(reducer, x, key="reducer", workers=[a.address]) @@ -3289,10 +3295,10 @@ async def test_set_restrictions(c, s, a, b): @gen_cluster( client=True, nthreads=[("", 1)] * 3, - worker_kwargs={"memory_monitor_interval": "20ms"}, + config={"distributed.worker.memory.pause": False}, ) async def test_avoid_paused_workers(c, s, w1, w2, w3): - w2.memory_pause_fraction = 1e-15 + w2.status = Status.paused while s.workers[w2.address].status != Status.paused: await asyncio.sleep(0.01) futures = c.map(slowinc, range(8), delay=0.1) @@ -3303,25 +3309,6 @@ async def test_avoid_paused_workers(c, s, w1, w2, w3): assert len(w1.data) + len(w3.data) == 8 -@gen_cluster( - client=True, - nthreads=[("", 1)], - worker_kwargs={"memory_monitor_interval": "20ms"}, -) -async def test_unpause_schedules_unrannable_tasks(c, s, a): - a.memory_pause_fraction = 1e-15 - while s.workers[a.address].status != Status.paused: - await asyncio.sleep(0.01) - - fut = c.submit(inc, 1, key="x") - while not s.unrunnable: - await asyncio.sleep(0.001) - assert next(iter(s.unrunnable)).key == "x" - - a.memory_pause_fraction = 0.8 - assert await fut == 2 - - @gen_cluster(client=True, nthreads=[("", 1)]) async def test_Scheduler__to_dict(c, s, a): futs = c.map(inc, range(2)) @@ -3489,9 +3476,6 @@ async def test_dump_cluster_state(s: Scheduler, *workers: Worker, format): @gen_cluster(nthreads=[]) async def test_idempotent_plugins(s): - - from distributed.diagnostics.plugin import SchedulerPlugin - class IdempotentPlugin(SchedulerPlugin): def __init__(self, instance=None): self.name = "idempotentplugin" @@ -3515,9 +3499,6 @@ def start(self, scheduler): @gen_cluster(nthreads=[]) async def test_non_idempotent_plugins(s): - - from distributed.diagnostics.plugin import SchedulerPlugin - class NonIdempotentPlugin(SchedulerPlugin): def __init__(self, instance=None): self.name = "nonidempotentplugin" diff --git a/distributed/tests/test_spill.py b/distributed/tests/test_spill.py index c30aa6cefc6..01f2369159d 100644 --- a/distributed/tests/test_spill.py +++ b/distributed/tests/test_spill.py @@ -1,20 +1,27 @@ from __future__ import annotations +import gc import logging import os import pytest -zict = pytest.importorskip("zict") -from packaging.version import parse as parse_version - from dask.sizeof import sizeof from distributed.compatibility import WINDOWS from distributed.protocol import serialize_bytelist -from distributed.spill import SpillBuffer +from distributed.spill import SpillBuffer, has_zict_210, has_zict_220 from distributed.utils_test import captured_logger +requires_zict_210 = pytest.mark.skipif( + not has_zict_210, + reason="requires zict version >= 2.1.0", +) +requires_zict_220 = pytest.mark.skipif( + not has_zict_220, + reason="requires zict version >= 2.2.0", +) + def psize(*objs) -> tuple[int, int]: return ( @@ -23,15 +30,30 @@ def psize(*objs) -> tuple[int, int]: ) +def assert_buf(buf: SpillBuffer, expect_fast: dict, expect_slow: dict) -> None: + # assertions on fast + assert dict(buf.fast) == expect_fast + assert buf.fast.weights == {k: sizeof(v) for k, v in expect_fast.items()} + assert buf.fast.total_weight == sum(sizeof(v) for v in expect_fast.values()) + for k, v in buf.fast.items(): + assert buf[k] is v + + # assertions on slow + assert set(buf.slow) == expect_slow.keys() + slow = buf.slow.data if has_zict_220 else buf.slow # type: ignore + assert slow.weight_by_key == {k: psize(v) for k, v in expect_slow.items()} + total_weight = psize(*expect_slow.values()) + assert slow.total_weight == total_weight + assert buf.spilled_total == total_weight + + def test_spillbuffer(tmpdir): buf = SpillBuffer(str(tmpdir), target=300) # Convenience aliases assert buf.memory is buf.fast assert buf.disk is buf.slow - assert not buf.slow.weight_by_key - assert buf.slow.total_weight == (0, 0) - assert buf.spilled_total == (0, 0) + assert_buf(buf, {}, {}) a, b, c, d = "a" * 100, "b" * 99, "c" * 98, "d" * 97 @@ -40,75 +62,48 @@ def test_spillbuffer(tmpdir): assert psize(a)[0] != psize(a)[1] buf["a"] = a - assert not buf.slow - assert buf.fast.weights == {"a": sizeof(a)} - assert buf.fast.total_weight == sizeof(a) - assert buf.slow.weight_by_key == {} - assert buf.slow.total_weight == (0, 0) + assert_buf(buf, {"a": a}, {}) assert buf["a"] == a buf["b"] = b - assert not buf.slow - assert not buf.slow.weight_by_key - assert buf.slow.total_weight == (0, 0) + assert_buf(buf, {"a": a, "b": b}, {}) buf["c"] = c - assert set(buf.slow) == {"a"} - assert buf.slow.weight_by_key == {"a": psize(a)} - assert buf.slow.total_weight == psize(a) + assert_buf(buf, {"b": b, "c": c}, {"a": a}) assert buf["a"] == a - assert set(buf.slow) == {"b"} - assert buf.slow.weight_by_key == {"b": psize(b)} - assert buf.slow.total_weight == psize(b) + assert_buf(buf, {"a": a, "c": c}, {"b": b}) buf["d"] = d - assert set(buf.slow) == {"b", "c"} - assert buf.slow.weight_by_key == {"b": psize(b), "c": psize(c)} - assert buf.slow.total_weight == psize(b, c) + assert_buf(buf, {"a": a, "d": d}, {"b": b, "c": c}) # Deleting an in-memory key does not automatically move spilled keys back to memory del buf["a"] - assert set(buf.slow) == {"b", "c"} - assert buf.slow.weight_by_key == {"b": psize(b), "c": psize(c)} - assert buf.slow.total_weight == psize(b, c) + assert_buf(buf, {"d": d}, {"b": b, "c": c}) with pytest.raises(KeyError): buf["a"] # Deleting a spilled key updates the metadata del buf["b"] - assert set(buf.slow) == {"c"} - assert buf.slow.weight_by_key == {"c": psize(c)} - assert buf.slow.total_weight == psize(c) + assert_buf(buf, {"d": d}, {"c": c}) with pytest.raises(KeyError): buf["b"] # Updating a spilled key moves it to the top of the LRU and to memory - buf["c"] = c * 2 - assert set(buf.slow) == {"d"} - assert buf.slow.weight_by_key == {"d": psize(d)} - assert buf.slow.total_weight == psize(d) + c2 = c * 2 + buf["c"] = c2 + assert_buf(buf, {"c": c2}, {"d": d}) # Single key is larger than target and goes directly into slow e = "e" * 500 buf["e"] = e - assert set(buf.slow) == {"d", "e"} - assert buf.slow.weight_by_key == {"d": psize(d), "e": psize(e)} - assert buf.slow.total_weight == psize(d, e) + assert_buf(buf, {"c": c2}, {"d": d, "e": e}) # Updating a spilled key with another larger than target updates slow directly d = "d" * 500 buf["d"] = d - assert set(buf.slow) == {"d", "e"} - assert buf.slow.weight_by_key == {"d": psize(d), "e": psize(e)} - assert buf.slow.total_weight == psize(d, e) - - -requires_zict_210 = pytest.mark.skipif( - parse_version(zict.__version__) <= parse_version("2.0.0"), - reason="requires zict version > 2.0.0", -) + assert_buf(buf, {"c": c2}, {"d": d, "e": e}) @requires_zict_210 @@ -120,32 +115,17 @@ def test_spillbuffer_maxlim(tmpdir): # size of a is bigger than target and is smaller than max_spill; # key should be in slow buf["a"] = a - assert not buf.fast - assert not buf.fast.weights - assert set(buf.slow) == {"a"} - assert buf.slow.weight_by_key == {"a": psize(a)} - assert buf.slow.total_weight == psize(a) + assert_buf(buf, {}, {"a": a}) assert buf["a"] == a # size of b is smaller than target key should be in fast buf["b"] = b - assert set(buf.fast) == {"b"} - assert buf.fast.weights == {"b": sizeof(b)} - assert buf["b"] == b - assert buf.fast.total_weight == sizeof(b) + assert_buf(buf, {"b": b}, {"a": a}) # size of c is smaller than target but b+c > target, c should stay in fast and b # move to slow since the max_spill limit has not been reached yet - buf["c"] = c - assert set(buf.fast) == {"c"} - assert buf.fast.weights == {"c": sizeof(c)} - assert buf["c"] == c - assert buf.fast.total_weight == sizeof(c) - - assert set(buf.slow) == {"a", "b"} - assert buf.slow.weight_by_key == {"a": psize(a), "b": psize(b)} - assert buf.slow.total_weight == psize(a, b) + assert_buf(buf, {"c": c}, {"a": a, "b": b}) # size of e < target but e+c > target, this will trigger movement of c to slow # but the max spill limit prevents it. Resulting in e remaining in fast @@ -154,15 +134,7 @@ def test_spillbuffer_maxlim(tmpdir): buf["e"] = e assert "disk reached capacity" in logs_e.getvalue() - - assert set(buf.fast) == {"c", "e"} - assert buf.fast.weights == {"c": sizeof(c), "e": sizeof(e)} - assert buf["e"] == e - assert buf.fast.total_weight == sizeof(c) + sizeof(e) - - assert set(buf.slow) == {"a", "b"} - assert buf.slow.weight_by_key == {"a": psize(a), "b": psize(b)} - assert buf.slow.total_weight == psize(a, b) + assert_buf(buf, {"c": c, "e": e}, {"a": a, "b": b}) # size of d > target, d should go to slow but slow reached the max_spill limit then # d will end up on fast with c (which can't be move to slow because it won't fit @@ -171,15 +143,7 @@ def test_spillbuffer_maxlim(tmpdir): buf["d"] = d assert "disk reached capacity" in logs_d.getvalue() - - assert set(buf.fast) == {"c", "d", "e"} - assert buf.fast.weights == {"c": sizeof(c), "d": sizeof(d), "e": sizeof(e)} - assert buf["d"] == d - assert buf.fast.total_weight == sizeof(c) + sizeof(d) + sizeof(e) - - assert set(buf.slow) == {"a", "b"} - assert buf.slow.weight_by_key == {"a": psize(a), "b": psize(b)} - assert buf.slow.total_weight == psize(a, b) + assert_buf(buf, {"c": c, "d": d, "e": e}, {"a": a, "b": b}) # Overwrite a key that was in slow, but the size of the new key is larger than # max_spill @@ -191,11 +155,7 @@ def test_spillbuffer_maxlim(tmpdir): buf["a"] = a_large assert "disk reached capacity" in logs_alarge.getvalue() - - assert set(buf.fast) == {"a", "d", "e"} - assert set(buf.slow) == {"b", "c"} - assert buf.fast.total_weight == sizeof(d) + sizeof(a_large) + sizeof(e) - assert buf.slow.total_weight == psize(b, c) + assert_buf(buf, {"a": a_large, "d": d, "e": e}, {"b": b, "c": c}) # Overwrite a key that was in fast, but the size of the new key is larger than # max_spill @@ -205,11 +165,7 @@ def test_spillbuffer_maxlim(tmpdir): buf["d"] = d_large assert "disk reached capacity" in logs_dlarge.getvalue() - - assert set(buf.fast) == {"a", "d", "e"} - assert set(buf.slow) == {"b", "c"} - assert buf.fast.total_weight == sizeof(a_large) + sizeof(d_large) + sizeof(e) - assert buf.slow.total_weight == psize(b, c) + assert_buf(buf, {"a": a_large, "d": d_large, "e": e}, {"b": b, "c": c}) class MyError(Exception): @@ -241,13 +197,12 @@ def test_spillbuffer_fail_to_serialize(tmpdir): # spill.py must remain silent because we're already logging in worker.py assert not logs_bad_key.getvalue() - assert not set(buf.fast) - assert not set(buf.slow) + assert_buf(buf, {}, {}) b = Bad(size=100) # this is small enough to fit in memory/fast buf["b"] = b - assert set(buf.fast) == {"b"} + assert_buf(buf, {"b": b}, {}) c = "c" * 100 with captured_logger(logging.getLogger("distributed.spill")) as logs_bad_key_mem: @@ -259,9 +214,7 @@ def test_spillbuffer_fail_to_serialize(tmpdir): logs_value = logs_bad_key_mem.getvalue() assert "Failed to pickle" in logs_value # from distributed.spill assert "Traceback" in logs_value # from distributed.spill - assert set(buf.fast) == {"b", "c"} - assert buf.fast.total_weight == sizeof(b) + sizeof(c) - assert not set(buf.slow) + assert_buf(buf, {"b": b, "c": c}, {}) @requires_zict_210 @@ -279,8 +232,7 @@ def test_spillbuffer_oserror(tmpdir): # let's have something in fast and something in slow buf["a"] = a buf["b"] = b - assert set(buf.fast) == {"b"} - assert set(buf.slow) == {"a"} + assert_buf(buf, {"b": b}, {"a": a}) # modify permissions of disk to be read only. # This causes writes to raise OSError, just like in case of disk full. @@ -291,15 +243,10 @@ def test_spillbuffer_oserror(tmpdir): buf["c"] = c assert "Spill to disk failed" in logs_oserror_slow.getvalue() - assert set(buf.fast) == {"b", "c"} - assert set(buf.slow) == {"a"} - - assert buf.slow.weight_by_key == {"a": psize(a)} - assert buf.fast.weights == {"b": sizeof(b), "c": sizeof(c)} + assert_buf(buf, {"b": b, "c": c}, {"a": a}) del buf["c"] - assert set(buf.fast) == {"b"} - assert set(buf.slow) == {"a"} + assert_buf(buf, {"b": b}, {"a": a}) # add key to fast which is smaller than target but when added it triggers spill, # which triggers OSError @@ -307,40 +254,26 @@ def test_spillbuffer_oserror(tmpdir): buf["d"] = d assert "Spill to disk failed" in logs_oserror_evict.getvalue() - assert set(buf.fast) == {"b", "d"} - assert set(buf.slow) == {"a"} - - assert buf.slow.weight_by_key == {"a": psize(a)} - assert buf.fast.weights == {"b": sizeof(b), "d": sizeof(d)} + assert_buf(buf, {"b": b, "d": d}, {"a": a}) @requires_zict_210 def test_spillbuffer_evict(tmpdir): buf = SpillBuffer(str(tmpdir), target=300, min_log_interval=0) - a_bad = Bad(size=100) + bad = Bad(size=100) a = "a" * 100 buf["a"] = a - - assert set(buf.fast) == {"a"} - assert not set(buf.slow) - assert buf.fast.weights == {"a": sizeof(a)} + assert_buf(buf, {"a": a}, {}) # successful eviction weight = buf.evict() assert weight == sizeof(a) + assert_buf(buf, {}, {"a": a}) - assert not buf.fast - assert set(buf.slow) == {"a"} - assert buf.slow.weight_by_key == {"a": psize(a)} - - buf["a_bad"] = a_bad - - assert set(buf.fast) == {"a_bad"} - assert buf.fast.weights == {"a_bad": sizeof(a_bad)} - assert set(buf.slow) == {"a"} - assert buf.slow.weight_by_key == {"a": psize(a)} + buf["bad"] = bad + assert_buf(buf, {"bad": bad}, {"a": a}) # unsuccessful eviction with captured_logger(logging.getLogger("distributed.spill")) as logs_evict_key: @@ -349,7 +282,63 @@ def test_spillbuffer_evict(tmpdir): assert "Failed to pickle" in logs_evict_key.getvalue() # bad keys stays in fast - assert set(buf.fast) == {"a_bad"} - assert buf.fast.weights == {"a_bad": sizeof(a_bad)} - assert set(buf.slow) == {"a"} - assert buf.slow.weight_by_key == {"a": psize(a)} + assert_buf(buf, {"bad": bad}, {"a": a}) + + +class SupportsWeakRef: + def __init__(self, n): + self.n = n + + def __sizeof__(self): + return self.n + + +class NoWeakRef: + __slots__ = ("n",) + + def __init__(self, n): + self.n = n + + def __sizeof__(self): + return self.n + + +@pytest.mark.parametrize( + "cls,expect_cached", + [ + (SupportsWeakRef, has_zict_220), + (NoWeakRef, False), + ], +) +@pytest.mark.parametrize("size", [60, 110]) +def test_weakref_cache(tmpdir, cls, expect_cached, size): + buf = SpillBuffer(str(tmpdir), target=100) + + # Run this test twice: + # - x is smaller than target and is evicted by y; + # - x is individually larger than target and it never touches fast + x = cls(size) + buf["x"] = x + if size < 100: + buf["y"] = cls(60) # spill x + assert "x" in buf.slow + + # Test that we update the weakref cache on setitem + assert (buf["x"] is x) == expect_cached + + id_x = id(x) + del x + gc.collect() # Only needed on pypy + + if size < 100: + buf["y"] + assert "x" in buf.slow + + x2 = buf["x"] + assert id(x2) != id_x + if size < 100: + buf["y"] + assert "x" in buf.slow + + # Test that we update the weakref cache on getitem + assert (buf["x"] is x2) == expect_cached diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 2a6a220bb97..e6469eae8b4 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -834,10 +834,10 @@ async def test_steal_twice(c, s, a, b): @gen_cluster( client=True, nthreads=[("", 1)] * 3, - worker_kwargs={"memory_monitor_interval": "20ms"}, + config={"distributed.worker.memory.pause": False}, ) async def test_paused_workers_must_not_steal(c, s, w1, w2, w3): - w2.memory_pause_fraction = 1e-15 + w2.status = Status.paused while s.workers[w2.address].status != Status.paused: await asyncio.sleep(0.01) diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index db5cc866b5c..dc481aae17f 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -202,9 +202,8 @@ async def test_gen_test_double_parametrized(foo, bar): @gen_test() -async def test_gen_test_pytest_fixture(tmp_path, c): +async def test_gen_test_pytest_fixture(tmp_path): assert isinstance(tmp_path, pathlib.Path) - assert isinstance(c, Client) @contextmanager diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 9fd29884edd..0caa128c02b 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -16,7 +16,6 @@ import psutil import pytest -from packaging.version import parse as parse_version from tlz import first, pluck, sliding_window import dask @@ -27,7 +26,6 @@ import distributed from distributed import ( Client, - Event, Nanny, Reschedule, default_client, @@ -59,21 +57,10 @@ slowinc, slowsum, ) -from distributed.worker import Worker, error_message, logger, parse_memory_limit +from distributed.worker import Worker, error_message, logger pytestmark = pytest.mark.ci1 -try: - import zict -except ImportError: - zict = None # type: ignore - -requires_zict = pytest.mark.skipif(not zict, reason="requires zict") -requires_zict_210 = pytest.mark.skipif( - not zict or parse_version(zict.__version__) <= parse_version("2.0.0"), - reason="requires zict version > 2.0.0", -) - @gen_cluster(nthreads=[]) async def test_worker_nthreads(s): @@ -906,109 +893,6 @@ def __sizeof__(self): assert result.data == 123 -class FailToPickle: - def __init__(self, *, reported_size=0, actual_size=0): - self.reported_size = int(reported_size) - self.data = "x" * int(actual_size) - - def __getstate__(self): - raise TypeError() - - def __sizeof__(self): - return self.reported_size - - -async def assert_basic_futures(c: Client) -> None: - futures = c.map(inc, range(10)) - results = await c.gather(futures) - assert results == list(map(inc, range(10))) - - -@requires_zict -@gen_cluster(client=True) -async def test_fail_write_to_disk_target_1(c, s, a, b): - """Test failure to spill triggered by key which is individually larger - than target. The data is lost and the task is marked as failed; - the worker remains in usable condition. - """ - future = c.submit(FailToPickle, reported_size=100e9) - await wait(future) - - assert future.status == "error" - - with pytest.raises(TypeError, match="Could not serialize"): - await future - - await assert_basic_futures(c) - - -@requires_zict -@gen_cluster( - client=True, - nthreads=[("", 1)], - worker_kwargs=dict( - memory_limit="1 kiB", - memory_target_fraction=0.5, - memory_spill_fraction=False, - memory_pause_fraction=False, - ), -) -async def test_fail_write_to_disk_target_2(c, s, a): - """Test failure to spill triggered by key which is individually smaller - than target, so it is not spilled immediately. The data is retained and - the task is NOT marked as failed; the worker remains in usable condition. - """ - x = c.submit(FailToPickle, reported_size=256, key="x") - await wait(x) - assert x.status == "finished" - assert set(a.data.memory) == {"x"} - - y = c.submit(lambda: "y" * 256, key="y") - await wait(y) - if parse_version(zict.__version__) <= parse_version("2.0.0"): - assert set(a.data.memory) == {"y"} - else: - assert set(a.data.memory) == {"x", "y"} - assert not a.data.disk - - await assert_basic_futures(c) - - -@requires_zict_210 -@gen_cluster( - client=True, - nthreads=[("", 1)], - worker_kwargs=dict( - memory_monitor_interval="10ms", - memory_limit="1 kiB", # Spill everything - memory_target_fraction=False, - memory_spill_fraction=0.7, - memory_pause_fraction=False, - ), -) -async def test_fail_write_to_disk_spill(c, s, a): - """Test failure to evict a key, triggered by the spill threshold""" - with captured_logger(logging.getLogger("distributed.spill")) as logs: - bad = c.submit(FailToPickle, actual_size=1_000_000, key="bad") - await wait(bad) - - # Must wait for memory monitor to kick in - while True: - logs_value = logs.getvalue() - if logs_value: - break - await asyncio.sleep(0.01) - - assert "Failed to pickle" in logs_value - assert "Traceback" in logs_value - - # key is in fast - assert bad.status == "finished" - assert bad.key in a.data.fast - - await assert_basic_futures(c) - - @gen_cluster() async def test_pid(s, a, b): assert s.workers[a.address].pid == os.getpid() @@ -1187,245 +1071,6 @@ async def test_statistical_profiling_2(c, s, a, b): break -@requires_zict -@gen_cluster( - client=True, - nthreads=[("", 1)], - worker_kwargs=dict( - memory_limit=1200 / 0.6, - memory_target_fraction=0.6, - memory_spill_fraction=False, - memory_pause_fraction=False, - ), -) -async def test_spill_target_threshold(c, s, a): - """Test distributed.worker.memory.target threshold. Note that in this test we - disabled spill and pause thresholds, which work on the process memory, and just left - the target threshold, which works on managed memory so it is unperturbed by the - several hundreds of MB of unmanaged memory that are typical of the test suite. - """ - x = c.submit(lambda: "x" * 500, key="x") - await wait(x) - y = c.submit(lambda: "y" * 500, key="y") - await wait(y) - - assert set(a.data) == {"x", "y"} - assert set(a.data.memory) == {"x", "y"} - - z = c.submit(lambda: "z" * 500, key="z") - await wait(z) - assert set(a.data) == {"x", "y", "z"} - assert set(a.data.memory) == {"y", "z"} - assert set(a.data.disk) == {"x"} - - await x - assert set(a.data.memory) == {"x", "z"} - assert set(a.data.disk) == {"y"} - - -@requires_zict_210 -@gen_cluster( - client=True, - nthreads=[("", 1)], - worker_kwargs=dict( - memory_limit=1600, - max_spill=600, - memory_target_fraction=0.6, - memory_spill_fraction=False, - memory_pause_fraction=False, - ), -) -async def test_spill_constrained(c, s, w): - """Test distributed.worker.memory.max-spill parameter""" - # spills starts at 1600*0.6=960 bytes of managed memory - - # size in memory ~200; size on disk ~400 - x = c.submit(lambda: "x" * 200, key="x") - await wait(x) - # size in memory ~500; size on disk ~700 - y = c.submit(lambda: "y" * 500, key="y") - await wait(y) - - assert set(w.data) == {x.key, y.key} - assert set(w.data.memory) == {x.key, y.key} - - z = c.submit(lambda: "z" * 500, key="z") - await wait(z) - - assert set(w.data) == {x.key, y.key, z.key} - - # max_spill has not been reached - assert set(w.data.memory) == {y.key, z.key} - assert set(w.data.disk) == {x.key} - - # zb is individually larger than max_spill - zb = c.submit(lambda: "z" * 1700, key="zb") - await wait(zb) - - assert set(w.data.memory) == {y.key, z.key, zb.key} - assert set(w.data.disk) == {x.key} - - del zb - while "zb" in w.data: - await asyncio.sleep(0.01) - - # zc is individually smaller than max_spill, but the evicted key together with - # x it exceeds max_spill - zc = c.submit(lambda: "z" * 500, key="zc") - await wait(zc) - assert set(w.data.memory) == {y.key, z.key, zc.key} - assert set(w.data.disk) == {x.key} - - -@requires_zict -@gen_cluster( - nthreads=[("", 1)], - client=True, - worker_kwargs=dict( - memory_limit="1000 MB", - memory_monitor_interval="10ms", - memory_target_fraction=False, - memory_spill_fraction=0.7, - memory_pause_fraction=False, - ), -) -async def test_spill_spill_threshold(c, s, a): - """Test distributed.worker.memory.spill threshold. - Test that the spill threshold uses the process memory and not the managed memory - reported by sizeof(), which may be inaccurate. - """ - a.monitor.get_process_memory = lambda: 800_000_000 if a.data.fast else 0 - x = c.submit(inc, 0, key="x") - while not a.data.disk: - await asyncio.sleep(0.01) - assert await x == 1 - - -@requires_zict -@pytest.mark.parametrize( - "memory_target_fraction,managed,expect_spilled", - [ - # no target -> no hysteresis - # Over-report managed memory to test that the automated LRU eviction based on - # target is never triggered - (False, int(10e9), 1), - # Under-report managed memory, so that we reach the spill threshold for process - # memory without first reaching the target threshold for managed memory - # target == spill -> no hysteresis - (0.7, 0, 1), - # target < spill -> hysteresis from spill to target - (0.4, 0, 7), - ], -) -@gen_cluster(nthreads=[], client=True) -async def test_spill_hysteresis(c, s, memory_target_fraction, managed, expect_spilled): - """ - 1. Test that you can enable the spill threshold while leaving the target threshold - to False - 2. Test the hysteresis system where, once you reach the spill threshold, the worker - won't stop spilling until the target threshold is reached - """ - - class C: - def __sizeof__(self): - return managed - - async with Worker( - s.address, - memory_limit="1000 MB", - memory_monitor_interval="10ms", - memory_target_fraction=memory_target_fraction, - memory_spill_fraction=0.7, - memory_pause_fraction=False, - ) as a: - a.monitor.get_process_memory = lambda: 50_000_000 * len(a.data.fast) - - # Add 500MB (reported) process memory. Spilling must not happen. - futures = [c.submit(C, pure=False) for _ in range(10)] - await wait(futures) - await asyncio.sleep(0.1) - assert not a.data.disk - - # Add another 250MB unmanaged memory. This must trigger the spilling. - futures += [c.submit(C, pure=False) for _ in range(5)] - await wait(futures) - - # Wait until spilling starts. Then, wait until it stops. - prev_n = 0 - while not a.data.disk or len(a.data.disk) > prev_n: - prev_n = len(a.data.disk) - await asyncio.sleep(0) - - assert len(a.data.disk) == expect_spilled - - -@gen_cluster( - nthreads=[("", 1)], - client=True, - worker_kwargs=dict( - memory_limit="1000 MB", - memory_monitor_interval="10ms", - memory_target_fraction=False, - memory_spill_fraction=False, - memory_pause_fraction=0.8, - ), -) -async def test_pause_executor(c, s, a): - mocked_rss = 0 - a.monitor.get_process_memory = lambda: mocked_rss - - # Task that is running when the worker pauses - ev_x = Event() - - def f(ev): - ev.wait() - return 1 - - x = c.submit(f, ev_x, key="x") - while a.executing_count != 1: - await asyncio.sleep(0.01) - - with captured_logger(logging.getLogger("distributed.worker")) as logger: - # Task that is queued on the worker when the worker pauses - y = c.submit(inc, 1, key="y") - while "y" not in a.tasks: - await asyncio.sleep(0.01) - - # Hog the worker with 900MB unmanaged memory - mocked_rss = 900_000_000 - while s.workers[a.address].status != Status.paused: - await asyncio.sleep(0.01) - - assert "Pausing worker" in logger.getvalue() - - # Task that is queued on the scheduler when the worker pauses. - # It is not sent to the worker. - z = c.submit(inc, 2, key="z") - while "z" not in s.tasks or s.tasks["z"].state != "no-worker": - await asyncio.sleep(0.01) - - # Test that a task that already started when the worker paused can complete - # and its output can be retrieved. Also test that the now free slot won't be - # used by other tasks. - await ev_x.set() - assert await x == 1 - await asyncio.sleep(0.05) - - assert a.executing_count == 0 - assert len(a.ready) == 1 - assert a.tasks["y"].state == "ready" - assert "z" not in a.tasks - - # Release the memory. Tasks that were queued on the worker are executed. - # Tasks that were stuck on the scheduler are sent to the worker and executed. - mocked_rss = 0 - assert await y == 2 - assert await z == 3 - - assert a.status == Status.running - assert "Resuming worker" in logger.getvalue() - - @gen_cluster(client=True, worker_kwargs={"profile_cycle_interval": "50 ms"}) async def test_statistical_profiling_cycle(c, s, a, b): futures = c.map(slowinc, range(20), delay=0.05) @@ -1486,31 +1131,6 @@ async def test_deque_handler(s): assert any(msg.msg == "foo456" for msg in deque_handler.deque) -@gen_cluster( - client=True, - nthreads=[("", 1)], - worker_kwargs={"memory_limit": 0, "memory_monitor_interval": "10ms"}, -) -async def test_avoid_memory_monitor_if_zero_limit(c, s, a): - assert type(a.data) is dict - assert "memory" not in a.periodic_callbacks - future = c.submit(inc, 1) - assert (await future) == 2 - await asyncio.sleep(0.05) - await c.submit(inc, 2) # worker doesn't pause - - -@gen_cluster( - nthreads=[("127.0.0.1", 1)], - config={ - "distributed.worker.memory.spill": False, - "distributed.worker.memory.target": False, - }, -) -async def test_dict_data_if_no_spill_to_disk(s, w): - assert type(w.data) is dict - - def test_get_worker_name(client): def f(): get_client().submit(inc, 1).result() @@ -1526,11 +1146,6 @@ def func(dask_scheduler): assert time() < start + 10 -@gen_cluster(nthreads=[("127.0.0.1", 1)], worker_kwargs={"memory_limit": "2e3 MB"}) -async def test_parse_memory_limit(s, w): - assert w.memory_limit == 2e9 - - @gen_cluster(nthreads=[], client=True) async def test_scheduler_address_config(c, s): with dask.config.set({"scheduler-address": s.address}): @@ -1670,28 +1285,6 @@ async def test_register_worker_callbacks_err(c, s, a, b): await c.register_worker_callbacks(setup=lambda: 1 / 0) -@gen_cluster(nthreads=[]) -async def test_data_types(s): - w = await Worker(s.address, data=dict) - assert isinstance(w.data, dict) - await w.close() - - data = dict() - w = await Worker(s.address, data=data) - assert w.data is data - await w.close() - - class Data(dict): - def __init__(self, x, y): - self.x = x - self.y = y - - w = await Worker(s.address, data=(Data, {"x": 123, "y": 456})) - assert w.data.x == 123 - assert w.data.y == 456 - await w.close() - - @gen_cluster(nthreads=[]) async def test_local_directory(s): with tmpfile() as fn: @@ -1723,16 +1316,6 @@ async def test_host_address(c, s): await n.close() -def test_resource_limit(monkeypatch): - assert parse_memory_limit("250MiB", 1, total_cores=1) == 1024 * 1024 * 250 - - new_limit = 1024 * 1024 * 200 - import distributed.worker - - monkeypatch.setattr(distributed.system, "MEMORY_LIMIT", new_limit) - assert parse_memory_limit("250MiB", 1, total_cores=1) == new_limit - - @pytest.mark.asyncio @pytest.mark.parametrize("Worker", [Worker, Nanny]) async def test_interface_async(cleanup, loop, Worker): @@ -3390,38 +2973,18 @@ async def test_missing_released_zombie_tasks_2(c, s, a, b): ) -@pytest.mark.slow -@gen_cluster( - client=True, - Worker=Nanny, - nthreads=[("", 1)], - config={"distributed.worker.memory.pause": 0.5}, - worker_kwargs={"memory_limit": 2**29}, # 500 MiB -) -async def test_worker_status_sync(c, s, a): - (ws,) = s.workers.values() - - while ws.status != Status.running: - await asyncio.sleep(0.01) - - def leak(): - distributed._test_leak = "x" * 2**28 # 250 MiB - - def clear_leak(): - del distributed._test_leak - - await c.run(leak) - +@gen_cluster(nthreads=[("", 1)], config={"distributed.worker.memory.pause": False}) +async def test_worker_status_sync(s, a): + ws = s.workers[a.address] + a.status = Status.paused while ws.status != Status.paused: await asyncio.sleep(0.01) - await c.run(clear_leak) - + a.status = Status.running while ws.status != Status.running: await asyncio.sleep(0.01) await s.retire_workers() - while ws.status != Status.closed: await asyncio.sleep(0.01) @@ -3710,12 +3273,11 @@ async def test_Worker__to_dict(c, s, a): x = c.submit(inc, 1, key="x") await wait(x) d = a._to_dict() - assert d.keys() == { + assert set(d) == { "type", "id", "scheduler", "nthreads", - "memory_limit", "address", "status", "thread_id", @@ -3727,48 +3289,20 @@ async def test_Worker__to_dict(c, s, a): "in_flight_workers", "log", "tasks", - "memory_target_fraction", - "memory_spill_fraction", - "memory_pause_fraction", "logs", "config", "incoming_transfer_log", "outgoing_transfer_log", "data_needed", "pending_data_per_worker", + # attributes of WorkerMemoryManager + "data", + "max_spill", + "memory_limit", + "memory_monitor_interval", + "memory_pause_fraction", + "memory_spill_fraction", + "memory_target_fraction", } assert d["tasks"]["x"]["key"] == "x" - - -@gen_cluster(nthreads=[]) -async def test_do_not_block_event_loop_during_shutdown(s): - loop = asyncio.get_running_loop() - called_handler = threading.Event() - block_handler = threading.Event() - - w = await Worker(s.address) - executor = w.executors["default"] - - # The block wait must be smaller than the test timeout and smaller than the - # default value for timeout in `Worker.close`` - async def block(): - def fn(): - called_handler.set() - assert block_handler.wait(20) - - await loop.run_in_executor(executor, fn) - - async def set_future(): - while True: - try: - await loop.run_in_executor(executor, sleep, 0.1) - except RuntimeError: # executor has started shutting down - block_handler.set() - return - - async def close(): - called_handler.wait() - # executor_wait is True by default but we want to be explicit here - await w.close(executor_wait=True) - - await asyncio.gather(block(), close(), set_future()) + assert d["data"] == ["x"] diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py new file mode 100644 index 00000000000..030daa914d6 --- /dev/null +++ b/distributed/tests/test_worker_memory.py @@ -0,0 +1,680 @@ +from __future__ import annotations + +import asyncio +import logging +from collections import UserDict +from time import sleep + +import pytest + +import dask.config + +import distributed.system +from distributed import Client, Event, Nanny, Worker, wait +from distributed.core import Status +from distributed.spill import has_zict_210 +from distributed.utils_test import captured_logger, gen_cluster, inc +from distributed.worker_memory import parse_memory_limit + +requires_zict_210 = pytest.mark.skipif( + not has_zict_210, + reason="requires zict version >= 2.1.0", +) + + +def memory_monitor_running(dask_worker: Worker | Nanny) -> bool: + return "memory_monitor" in dask_worker.periodic_callbacks + + +def test_parse_memory_limit_zero(): + assert parse_memory_limit(0, 1) is None + assert parse_memory_limit("0", 1) is None + assert parse_memory_limit(None, 1) is None + + +def test_resource_limit(monkeypatch): + assert parse_memory_limit("250MiB", 1, total_cores=1) == 1024 * 1024 * 250 + + new_limit = 1024 * 1024 * 200 + monkeypatch.setattr(distributed.system, "MEMORY_LIMIT", new_limit) + assert parse_memory_limit("250MiB", 1, total_cores=1) == new_limit + + +@gen_cluster(nthreads=[("", 1)], worker_kwargs={"memory_limit": "2e3 MB"}) +async def test_parse_memory_limit_worker(s, w): + assert w.memory_manager.memory_limit == 2e9 + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + Worker=Nanny, + worker_kwargs={"memory_limit": "2e3 MB"}, +) +async def test_parse_memory_limit_nanny(c, s, n): + assert n.memory_manager.memory_limit == 2e9 + out = await c.run(lambda dask_worker: dask_worker.memory_manager.memory_limit) + assert out[n.worker_address] == 2e9 + + +@gen_cluster( + nthreads=[("127.0.0.1", 1)], + config={ + "distributed.worker.memory.spill": False, + "distributed.worker.memory.target": False, + }, +) +async def test_dict_data_if_no_spill_to_disk(s, w): + assert type(w.data) is dict + + +class CustomError(Exception): + pass + + +class FailToPickle: + def __init__(self, *, reported_size=0): + self.reported_size = int(reported_size) + + def __getstate__(self): + raise CustomError() + + def __sizeof__(self): + return self.reported_size + + +async def assert_basic_futures(c: Client) -> None: + futures = c.map(inc, range(10)) + results = await c.gather(futures) + assert results == list(map(inc, range(10))) + + +@gen_cluster(client=True) +async def test_fail_to_pickle_target_1(c, s, a, b): + """Test failure to serialize triggered by key which is individually larger + than target. The data is lost and the task is marked as failed; + the worker remains in usable condition. + """ + x = c.submit(FailToPickle, reported_size=100e9, key="x") + await wait(x) + + assert x.status == "error" + + with pytest.raises(TypeError, match="Could not serialize"): + await x + + await assert_basic_futures(c) + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + worker_kwargs={"memory_limit": "1 kiB"}, + config={ + "distributed.worker.memory.target": 0.5, + "distributed.worker.memory.spill": False, + "distributed.worker.memory.pause": False, + }, +) +async def test_fail_to_pickle_target_2(c, s, a): + """Test failure to spill triggered by key which is individually smaller + than target, so it is not spilled immediately. The data is retained and + the task is NOT marked as failed; the worker remains in usable condition. + """ + x = c.submit(FailToPickle, reported_size=256, key="x") + await wait(x) + assert x.status == "finished" + assert set(a.data.memory) == {"x"} + + y = c.submit(lambda: "y" * 256, key="y") + await wait(y) + if has_zict_210: + assert set(a.data.memory) == {"x", "y"} + else: + assert set(a.data.memory) == {"y"} + + assert not a.data.disk + + await assert_basic_futures(c) + + +@requires_zict_210 +@gen_cluster( + client=True, + nthreads=[("", 1)], + worker_kwargs={"memory_limit": "1 kB"}, + config={ + "distributed.worker.memory.target": False, + "distributed.worker.memory.spill": 0.7, + "distributed.worker.memory.monitor-interval": "10ms", + }, +) +async def test_fail_to_pickle_spill(c, s, a): + """Test failure to evict a key, triggered by the spill threshold""" + a.monitor.get_process_memory = lambda: 701 if a.data.fast else 0 + + with captured_logger(logging.getLogger("distributed.spill")) as logs: + bad = c.submit(FailToPickle, key="bad") + await wait(bad) + + # Must wait for memory monitor to kick in + while True: + logs_value = logs.getvalue() + if logs_value: + break + await asyncio.sleep(0.01) + + assert "Failed to pickle" in logs_value + assert "Traceback" in logs_value + + # key is in fast + assert bad.status == "finished" + assert bad.key in a.data.fast + + await assert_basic_futures(c) + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + worker_kwargs={"memory_limit": 1200 / 0.6}, + config={ + "distributed.worker.memory.target": 0.6, + "distributed.worker.memory.spill": False, + "distributed.worker.memory.pause": False, + }, +) +async def test_spill_target_threshold(c, s, a): + """Test distributed.worker.memory.target threshold. Note that in this test we + disabled spill and pause thresholds, which work on the process memory, and just left + the target threshold, which works on managed memory so it is unperturbed by the + several hundreds of MB of unmanaged memory that are typical of the test suite. + """ + assert not memory_monitor_running(a) + + x = c.submit(lambda: "x" * 500, key="x") + await wait(x) + y = c.submit(lambda: "y" * 500, key="y") + await wait(y) + + assert set(a.data) == {"x", "y"} + assert set(a.data.memory) == {"x", "y"} + + z = c.submit(lambda: "z" * 500, key="z") + await wait(z) + assert set(a.data) == {"x", "y", "z"} + assert set(a.data.memory) == {"y", "z"} + assert set(a.data.disk) == {"x"} + + await x + assert set(a.data.memory) == {"x", "z"} + assert set(a.data.disk) == {"y"} + + +@requires_zict_210 +@gen_cluster( + client=True, + nthreads=[("", 1)], + worker_kwargs={"memory_limit": 1600}, + config={ + "distributed.worker.memory.target": 0.6, + "distributed.worker.memory.spill": False, + "distributed.worker.memory.pause": False, + "distributed.worker.memory.max-spill": 600, + }, +) +async def test_spill_constrained(c, s, w): + """Test distributed.worker.memory.max-spill parameter""" + # spills starts at 1600*0.6=960 bytes of managed memory + + # size in memory ~200; size on disk ~400 + x = c.submit(lambda: "x" * 200, key="x") + await wait(x) + # size in memory ~500; size on disk ~700 + y = c.submit(lambda: "y" * 500, key="y") + await wait(y) + + assert set(w.data) == {x.key, y.key} + assert set(w.data.memory) == {x.key, y.key} + + z = c.submit(lambda: "z" * 500, key="z") + await wait(z) + + assert set(w.data) == {x.key, y.key, z.key} + + # max_spill has not been reached + assert set(w.data.memory) == {y.key, z.key} + assert set(w.data.disk) == {x.key} + + # zb is individually larger than max_spill + zb = c.submit(lambda: "z" * 1700, key="zb") + await wait(zb) + + assert set(w.data.memory) == {y.key, z.key, zb.key} + assert set(w.data.disk) == {x.key} + + del zb + while "zb" in w.data: + await asyncio.sleep(0.01) + + # zc is individually smaller than max_spill, but the evicted key together with + # x it exceeds max_spill + zc = c.submit(lambda: "z" * 500, key="zc") + await wait(zc) + assert set(w.data.memory) == {y.key, z.key, zc.key} + assert set(w.data.disk) == {x.key} + + +@gen_cluster( + nthreads=[("", 1)], + client=True, + worker_kwargs={"memory_limit": "1000 MB"}, + config={ + "distributed.worker.memory.target": False, + "distributed.worker.memory.spill": 0.7, + "distributed.worker.memory.pause": False, + "distributed.worker.memory.monitor-interval": "10ms", + }, +) +async def test_spill_spill_threshold(c, s, a): + """Test distributed.worker.memory.spill threshold. + Test that the spill threshold uses the process memory and not the managed memory + reported by sizeof(), which may be inaccurate. + """ + assert memory_monitor_running(a) + a.monitor.get_process_memory = lambda: 800_000_000 if a.data.fast else 0 + x = c.submit(inc, 0, key="x") + while not a.data.disk: + await asyncio.sleep(0.01) + assert await x == 1 + + +@pytest.mark.parametrize( + "target,managed,expect_spilled", + [ + # no target -> no hysteresis + # Over-report managed memory to test that the automated LRU eviction based on + # target is never triggered + (False, int(10e9), 1), + # Under-report managed memory, so that we reach the spill threshold for process + # memory without first reaching the target threshold for managed memory + # target == spill -> no hysteresis + (0.7, 0, 1), + # target < spill -> hysteresis from spill to target + (0.4, 0, 7), + ], +) +@gen_cluster( + nthreads=[], + client=True, + config={ + "distributed.worker.memory.spill": 0.7, + "distributed.worker.memory.pause": False, + "distributed.worker.memory.monitor-interval": "10ms", + }, +) +async def test_spill_hysteresis(c, s, target, managed, expect_spilled): + """ + 1. Test that you can enable the spill threshold while leaving the target threshold + to False + 2. Test the hysteresis system where, once you reach the spill threshold, the worker + won't stop spilling until the target threshold is reached + """ + + class C: + def __sizeof__(self): + return managed + + with dask.config.set({"distributed.worker.memory.target": target}): + async with Worker(s.address, memory_limit="1000 MB") as a: + a.monitor.get_process_memory = lambda: 50_000_000 * len(a.data.fast) + + # Add 500MB (reported) process memory. Spilling must not happen. + futures = [c.submit(C, pure=False) for _ in range(10)] + await wait(futures) + await asyncio.sleep(0.1) + assert not a.data.disk + + # Add another 250MB unmanaged memory. This must trigger the spilling. + futures += [c.submit(C, pure=False) for _ in range(5)] + await wait(futures) + + # Wait until spilling starts. Then, wait until it stops. + prev_n = 0 + while not a.data.disk or len(a.data.disk) > prev_n: + prev_n = len(a.data.disk) + await asyncio.sleep(0) + + assert len(a.data.disk) == expect_spilled + + +@gen_cluster( + nthreads=[("", 1)], + client=True, + config={ + "distributed.worker.memory.target": False, + "distributed.worker.memory.spill": False, + "distributed.worker.memory.pause": False, + }, +) +async def test_pause_executor_manual(c, s, a): + assert not memory_monitor_running(a) + + # Task that is running when the worker pauses + ev_x = Event() + + def f(ev): + ev.wait() + return 1 + + # Task that is running on the worker when the worker pauses + x = c.submit(f, ev_x, key="x") + while a.executing_count != 1: + await asyncio.sleep(0.01) + + # Task that is queued on the worker when the worker pauses + y = c.submit(inc, 1, key="y") + while "y" not in a.tasks: + await asyncio.sleep(0.01) + + a.status = Status.paused + # Wait for sync to scheduler + while s.workers[a.address].status != Status.paused: + await asyncio.sleep(0.01) + + # Task that is queued on the scheduler when the worker pauses. + # It is not sent to the worker. + z = c.submit(inc, 2, key="z") + while "z" not in s.tasks or s.tasks["z"].state != "no-worker": + await asyncio.sleep(0.01) + assert s.unrunnable == {s.tasks["z"]} + + # Test that a task that already started when the worker paused can complete + # and its output can be retrieved. Also test that the now free slot won't be + # used by other tasks. + await ev_x.set() + assert await x == 1 + await asyncio.sleep(0.05) + + assert a.executing_count == 0 + assert len(a.ready) == 1 + assert a.tasks["y"].state == "ready" + assert "z" not in a.tasks + + # Unpause. Tasks that were queued on the worker are executed. + # Tasks that were stuck on the scheduler are sent to the worker and executed. + a.status = Status.running + assert await y == 2 + assert await z == 3 + + +@gen_cluster( + nthreads=[("", 1)], + client=True, + worker_kwargs={"memory_limit": "1000 MB"}, + config={ + "distributed.worker.memory.target": False, + "distributed.worker.memory.spill": False, + "distributed.worker.memory.pause": 0.8, + "distributed.worker.memory.monitor-interval": "10ms", + }, +) +async def test_pause_executor_with_memory_monitor(c, s, a): + assert memory_monitor_running(a) + mocked_rss = 0 + a.monitor.get_process_memory = lambda: mocked_rss + + # Task that is running when the worker pauses + ev_x = Event() + + def f(ev): + ev.wait() + return 1 + + # Task that is running on the worker when the worker pauses + x = c.submit(f, ev_x, key="x") + while a.executing_count != 1: + await asyncio.sleep(0.01) + + with captured_logger(logging.getLogger("distributed.worker_memory")) as logger: + # Task that is queued on the worker when the worker pauses + y = c.submit(inc, 1, key="y") + while "y" not in a.tasks: + await asyncio.sleep(0.01) + + # Hog the worker with 900MB unmanaged memory + mocked_rss = 900_000_000 + while s.workers[a.address].status != Status.paused: + await asyncio.sleep(0.01) + + assert "Pausing worker" in logger.getvalue() + + # Task that is queued on the scheduler when the worker pauses. + # It is not sent to the worker. + z = c.submit(inc, 2, key="z") + while "z" not in s.tasks or s.tasks["z"].state != "no-worker": + await asyncio.sleep(0.01) + assert s.unrunnable == {s.tasks["z"]} + + # Test that a task that already started when the worker paused can complete + # and its output can be retrieved. Also test that the now free slot won't be + # used by other tasks. + await ev_x.set() + assert await x == 1 + await asyncio.sleep(0.05) + + assert a.executing_count == 0 + assert len(a.ready) == 1 + assert a.tasks["y"].state == "ready" + assert "z" not in a.tasks + + # Release the memory. Tasks that were queued on the worker are executed. + # Tasks that were stuck on the scheduler are sent to the worker and executed. + mocked_rss = 0 + assert await y == 2 + assert await z == 3 + + assert a.status == Status.running + assert "Resuming worker" in logger.getvalue() + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + worker_kwargs={"memory_limit": 0}, + config={"distributed.worker.memory.monitor-interval": "10ms"}, +) +async def test_avoid_memory_monitor_if_zero_limit_worker(c, s, a): + assert type(a.data) is dict + assert not memory_monitor_running(a) + + future = c.submit(inc, 1) + assert await future == 2 + await asyncio.sleep(0.05) + assert await c.submit(inc, 2) == 3 # worker doesn't pause + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + Worker=Nanny, + worker_kwargs={"memory_limit": 0}, + config={"distributed.worker.memory.monitor-interval": "10ms"}, +) +async def test_avoid_memory_monitor_if_zero_limit_nanny(c, s, nanny): + typ = await c.run(lambda dask_worker: type(dask_worker.data)) + assert typ == {nanny.worker_address: dict} + assert not memory_monitor_running(nanny) + assert not (await c.run(memory_monitor_running))[nanny.worker_address] + + future = c.submit(inc, 1) + assert await future == 2 + await asyncio.sleep(0.02) + assert await c.submit(inc, 2) == 3 # worker doesn't pause + + +@gen_cluster(nthreads=[]) +async def test_override_data_worker(s): + # Use a UserDict to sidestep potential special case handling for dict + async with Worker(s.address, data=UserDict) as w: + assert type(w.data) is UserDict + + data = UserDict({"x": 1}) + async with Worker(s.address, data=data) as w: + assert w.data is data + assert w.data == {"x": 1} + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + Worker=Nanny, + worker_kwargs={"data": UserDict}, +) +async def test_override_data_nanny(c, s, n): + r = await c.run(lambda dask_worker: type(dask_worker.data)) + assert r[n.worker_address] is UserDict + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + worker_kwargs={"memory_limit": "1 GB", "data": UserDict}, + config={"distributed.worker.memory.monitor-interval": "10ms"}, +) +async def test_override_data_vs_memory_monitor(c, s, a): + a.monitor.get_process_memory = lambda: 801_000_000 if a.data else 0 + assert memory_monitor_running(a) + + # Push a key that would normally trip both the target and the spill thresholds + class C: + def __sizeof__(self): + return 801_000_000 + + # Capture output of log_errors() + with captured_logger(logging.getLogger("distributed.utils")) as logger: + x = c.submit(C) + await wait(x) + + # The pause subsystem of the memory monitor has been tripped. + # The spill subsystem hasn't. + while a.status != Status.paused: + await asyncio.sleep(0.01) + await asyncio.sleep(0.05) + + # This would happen if memory_monitor() tried to blindly call SpillBuffer.evict() + assert "Traceback" not in logger.getvalue() + + assert type(a.data) is UserDict + assert a.data.keys() == {x.key} + + +class ManualEvictDict(UserDict): + """A MutableMapping which implements distributed.spill.ManualEvictProto""" + + def __init__(self): + super().__init__() + self.evicted = set() + + @property + def fast(self): + # Any Sized of bool will do + return self.keys() - self.evicted + + def evict(self): + # Evict a random key + k = next(iter(self.fast)) + self.evicted.add(k) + return 1 + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + worker_kwargs={"memory_limit": "1 GB", "data": ManualEvictDict}, + config={ + "distributed.worker.memory.pause": False, + "distributed.worker.memory.monitor-interval": "10ms", + }, +) +async def test_manual_evict_proto(c, s, a): + """data is a third-party dict-like which respects the ManualEvictProto duck-type + API. spill threshold is respected. + """ + a.monitor.get_process_memory = lambda: 701_000_000 if a.data else 0 + assert memory_monitor_running(a) + assert isinstance(a.data, ManualEvictDict) + + futures = await c.scatter({"x": None, "y": None, "z": None}) + while a.data.evicted != {"x", "y", "z"}: + await asyncio.sleep(0.01) + + +@pytest.mark.slow +@gen_cluster( + nthreads=[("", 1)], + client=True, + Worker=Nanny, + worker_kwargs={"memory_limit": "400 MiB"}, + config={"distributed.worker.memory.monitor-interval": "10ms"}, +) +async def test_nanny_terminate(c, s, a): + def leak(): + L = [] + while True: + L.append(b"0" * 5_000_000) + sleep(0.01) + + before = a.process.pid + with captured_logger(logging.getLogger("distributed.worker_memory")) as logger: + future = c.submit(leak) + while a.process.pid == before: + await asyncio.sleep(0.01) + + out = logger.getvalue() + assert "restart" in out.lower() + assert "memory" in out.lower() + + +@pytest.mark.parametrize( + "cls,name,value", + [ + (Worker, "memory_limit", 123e9), + (Worker, "memory_target_fraction", 0.789), + (Worker, "memory_spill_fraction", 0.789), + (Worker, "memory_pause_fraction", 0.789), + (Nanny, "memory_limit", 123e9), + (Nanny, "memory_terminate_fraction", 0.789), + ], +) +@gen_cluster(nthreads=[]) +async def test_deprecated_attributes(s, cls, name, value): + async with cls(s.address) as a: + with pytest.warns(FutureWarning, match=name): + setattr(a, name, value) + with pytest.warns(FutureWarning, match=name): + assert getattr(a, name) == value + assert getattr(a.memory_manager, name) == value + + +@gen_cluster(nthreads=[("", 1)]) +async def test_deprecated_memory_monitor_method_worker(s, a): + with pytest.warns(FutureWarning, match="memory_monitor"): + await a.memory_monitor() + + +@gen_cluster(nthreads=[("", 1)], Worker=Nanny) +async def test_deprecated_memory_monitor_method_nanny(s, a): + with pytest.warns(FutureWarning, match="memory_monitor"): + a.memory_monitor() + + +@pytest.mark.parametrize( + "name", + ["memory_target_fraction", "memory_spill_fraction", "memory_pause_fraction"], +) +@gen_cluster(nthreads=[]) +async def test_deprecated_params(s, name): + with pytest.warns(FutureWarning, match=name): + async with Worker(s.address, **{name: 0.789}) as a: + assert getattr(a.memory_manager, name) == 0.789 diff --git a/distributed/utils.py b/distributed/utils.py index eecb6734731..afe048c2ac8 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -74,8 +74,6 @@ def _initialize_mp_context(): if method == "forkserver": # Makes the test suite much faster preload = ["distributed"] - if "pkg_resources" in sys.modules: - preload.append("pkg_resources") from distributed.versions import optional_packages, required_packages diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 3a4a31e7263..dc3010cab90 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1,6 +1,8 @@ from __future__ import annotations import asyncio +import concurrent.futures +import contextlib import copy import functools import gc @@ -11,19 +13,16 @@ import multiprocessing import os import re -import shutil import signal import socket import subprocess import sys import tempfile import threading -import uuid import weakref from collections import defaultdict from collections.abc import Callable from contextlib import contextmanager, nullcontext, suppress -from glob import glob from itertools import count from time import sleep from typing import Any, Literal @@ -490,12 +489,13 @@ def run_worker(q, scheduler_q, config, **kwargs): scheduler_addr = scheduler_q.get() async def _(): + pid = os.getpid() try: worker = await Worker(scheduler_addr, validate=True, **kwargs) except Exception as exc: - q.put(exc) + q.put((pid, exc)) else: - q.put(worker.address) + q.put((pid, worker.address)) await worker.finished() # Scheduler might've failed @@ -513,12 +513,13 @@ def run_nanny(q, scheduler_q, config, **kwargs): scheduler_addr = scheduler_q.get() async def _(): + pid = os.getpid() try: worker = await Nanny(scheduler_addr, validate=True, **kwargs) except Exception as exc: - q.put(exc) + q.put((pid, exc)) else: - q.put(worker.address) + q.put((pid, worker.address)) await worker.finished() # Scheduler might've failed @@ -629,6 +630,31 @@ def security(): return tls_only_security() +def _terminate_join(proc): + proc.terminate() + proc.join() + proc.close() + + +def _close_queue(q): + q.close() + q.join_thread() + q._writer.close() # https://bugs.python.org/issue42752 + + +class _SafeTemporaryDirectory(tempfile.TemporaryDirectory): + def __exit__(self, exc_type, exc_val, exc_tb): + try: + return super().__exit__(exc_type, exc_val, exc_tb) + except PermissionError: + # It appears that we either have a process still interacting with + # the tmpdirs of the workers or that win process are not releasing + # their lock in time. We are receiving PermissionErrors during + # teardown + # See also https://github.com/dask/distributed/pull/5825 + pass + + @contextmanager def cluster( nworkers=2, @@ -648,115 +674,104 @@ def cluster( else: _run_worker = run_worker - # The scheduler queue will receive the scheduler's address - scheduler_q = mp_context.Queue() - - # Launch scheduler - scheduler = mp_context.Process( - name="Dask cluster test: Scheduler", - target=run_scheduler, - args=(scheduler_q, nworkers + 1, config), - kwargs=scheduler_kwargs, - ) - ws.add(scheduler) - scheduler.daemon = True - scheduler.start() - - # Launch workers - workers = [] - for i in range(nworkers): - q = mp_context.Queue() - fn = "_test_worker-%s" % uuid.uuid4() - kwargs = merge( - { - "nthreads": 1, - "local_directory": fn, - "memory_limit": system.MEMORY_LIMIT, - }, - worker_kwargs, - ) - proc = mp_context.Process( - name="Dask cluster test: Worker", - target=_run_worker, - args=(q, scheduler_q, config), - kwargs=kwargs, + with contextlib.ExitStack() as stack: + # The scheduler queue will receive the scheduler's address + scheduler_q = mp_context.Queue() + stack.callback(_close_queue, scheduler_q) + + # Launch scheduler + scheduler = mp_context.Process( + name="Dask cluster test: Scheduler", + target=run_scheduler, + args=(scheduler_q, nworkers + 1, config), + kwargs=scheduler_kwargs, + daemon=True, ) - ws.add(proc) - workers.append({"proc": proc, "queue": q, "dir": fn}) - - for worker in workers: - worker["proc"].start() - saddr_or_exception = scheduler_q.get() - if isinstance(saddr_or_exception, Exception): - raise saddr_or_exception - saddr = saddr_or_exception - - for worker in workers: - addr_or_exception = worker["queue"].get() - if isinstance(addr_or_exception, Exception): - raise addr_or_exception - worker["address"] = addr_or_exception - - start = time() - try: - try: - security = scheduler_kwargs["security"] - rpc_kwargs = {"connection_args": security.get_connection_args("client")} - except KeyError: - rpc_kwargs = {} - - with rpc(saddr, **rpc_kwargs) as s: - while True: - nthreads = loop.run_sync(s.ncores) - if len(nthreads) == nworkers: - break - if time() - start > 5: - raise Exception("Timeout on cluster creation") - - # avoid sending processes down to function - yield {"address": saddr}, [ - {"address": w["address"], "proc": weakref.ref(w["proc"])} - for w in workers - ] - finally: - logger.debug("Closing out test cluster") + ws.add(scheduler) + scheduler.start() + stack.callback(_terminate_join, scheduler) - loop.run_sync( - lambda: disconnect_all( - [w["address"] for w in workers], - timeout=disconnect_timeout, - rpc_kwargs=rpc_kwargs, + # Launch workers + workers_by_pid = {} + q = mp_context.Queue() + stack.callback(_close_queue, q) + for _ in range(nworkers): + tmpdirname = stack.enter_context( + _SafeTemporaryDirectory(prefix="_dask_test_worker") ) - ) - loop.run_sync( - lambda: disconnect( - saddr, timeout=disconnect_timeout, rpc_kwargs=rpc_kwargs + kwargs = merge( + { + "nthreads": 1, + "local_directory": tmpdirname, + "memory_limit": system.MEMORY_LIMIT, + }, + worker_kwargs, ) - ) - - scheduler.terminate() - scheduler_q.close() - scheduler_q._reader.close() - scheduler_q._writer.close() - - for w in workers: - w["proc"].terminate() - w["queue"].close() - w["queue"]._reader.close() - w["queue"]._writer.close() - - scheduler.join(2) - del scheduler - for proc in [w["proc"] for w in workers]: - proc.join(timeout=30) - - with suppress(UnboundLocalError): - del worker, w, proc - del workers[:] - - for fn in glob("_test_worker-*"): - with suppress(OSError): - shutil.rmtree(fn) + proc = mp_context.Process( + name="Dask cluster test: Worker", + target=_run_worker, + args=(q, scheduler_q, config), + kwargs=kwargs, + ) + ws.add(proc) + proc.start() + stack.callback(_terminate_join, proc) + workers_by_pid[proc.pid] = {"proc": proc} + + saddr_or_exception = scheduler_q.get() + if isinstance(saddr_or_exception, Exception): + raise saddr_or_exception + saddr = saddr_or_exception + + for _ in range(nworkers): + pid, addr_or_exception = q.get() + if isinstance(addr_or_exception, Exception): + raise addr_or_exception + workers_by_pid[pid]["address"] = addr_or_exception + + start = time() + try: + try: + security = scheduler_kwargs["security"] + rpc_kwargs = { + "connection_args": security.get_connection_args("client") + } + except KeyError: + rpc_kwargs = {} + + with rpc(saddr, **rpc_kwargs) as s: + while True: + nthreads = loop.run_sync(s.ncores) + if len(nthreads) == nworkers: + break + if time() - start > 5: + raise Exception("Timeout on cluster creation") + + # avoid sending processes down to function + yield {"address": saddr}, [ + {"address": w["address"], "proc": weakref.ref(w["proc"])} + for w in workers_by_pid.values() + ] + finally: + logger.debug("Closing out test cluster") + alive_workers = [ + w["address"] + for w in workers_by_pid.values() + if w["proc"].is_alive() + ] + loop.run_sync( + lambda: disconnect_all( + alive_workers, + timeout=disconnect_timeout, + rpc_kwargs=rpc_kwargs, + ) + ) + if scheduler.is_alive(): + loop.run_sync( + lambda: disconnect( + saddr, timeout=disconnect_timeout, rpc_kwargs=rpc_kwargs + ) + ) try: client = default_client() @@ -765,12 +780,6 @@ def cluster( else: client.close() - start = time() - while any(proc.is_alive() for proc in ws): - text = str(list(ws)) - sleep(0.2) - assert time() < start + 5, ("Workers still around after five seconds", text) - async def disconnect(addr, timeout=3, rpc_kwargs=None): rpc_kwargs = rpc_kwargs or {} @@ -1091,7 +1100,6 @@ def get_unclosed(): # zict backends can fail if their storage directory # was already removed pass - del w.data return result @@ -1173,9 +1181,27 @@ def _terminate_process(proc): @contextmanager -def popen(args, **kwargs): +def popen(args: list[str], flush_output: bool = True, **kwargs): + """Start a shell command in a subprocess. + Yields a subprocess.Popen object. + + stderr is redirected to stdout. + stdout is redirected to a pipe. + + Parameters + ---------- + args: list[str] + Command line arguments + flush_output: bool, optional + If True (the default), the stdout/stderr pipe is emptied while it is being + filled. Set to False if you wish to read the output yourself. Note that setting + this to False and then failing to periodically read from the pipe may result in + a deadlock due to the pipe getting full. + kwargs: optional + optional arguments to subprocess.Popen + """ kwargs["stdout"] = subprocess.PIPE - kwargs["stderr"] = subprocess.PIPE + kwargs["stderr"] = subprocess.STDOUT if sys.platform.startswith("win"): # Allow using CTRL_C_EVENT / CTRL_BREAK_EVENT kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP @@ -1189,9 +1215,16 @@ def popen(args, **kwargs): os.environ.get("DESTDIR", "") + sys.prefix, "bin", args[0] ) proc = subprocess.Popen(args, **kwargs) + + if flush_output: + ex = concurrent.futures.ThreadPoolExecutor(1) + flush_future = ex.submit(proc.communicate) + try: yield proc - except Exception: + + # asyncio.CancelledError is raised by @gen_test/@gen_cluster timeout + except (Exception, asyncio.CancelledError): dump_stdout = True raise @@ -1200,13 +1233,17 @@ def popen(args, **kwargs): _terminate_process(proc) finally: # XXX Also dump stdout if return code != 0 ? - out, err = proc.communicate() - if dump_stdout: - print("\n\nPrint from stderr\n %s\n=================\n" % args[0][0]) - print(err.decode()) + if flush_output: + out, err = flush_future.result() + ex.shutdown() + else: + out, err = proc.communicate() + assert not err - print("\n\nPrint from stdout\n=================\n") - print(out.decode()) + if dump_stdout: + print("\n" + "-" * 27 + " Subprocess stdout/stderr" + "-" * 27) + print(out.decode().rstrip()) + print("-" * 80) def wait_for(predicate, timeout, fail_func=None, period=0.001): diff --git a/distributed/worker.py b/distributed/worker.py index a17d6b3222d..13e5adeff00 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -44,12 +44,11 @@ typename, ) -from distributed import comm, preloading, profile, system, utils +from distributed import comm, preloading, profile, utils from distributed.batched import BatchedSend -from distributed.comm import Comm, connect, get_address_host +from distributed.comm import connect, get_address_host from distributed.comm.addressing import address_from_user_args, parse_address from distributed.comm.utils import OFFLOAD_THRESHOLD -from distributed.compatibility import to_thread from distributed.core import ( CommClosedError, Status, @@ -93,12 +92,13 @@ warn_on_duration, ) from distributed.utils_comm import gather_from_workers, pack_data, retry_operation -from distributed.utils_perf import ( - ThrottledGC, - disable_gc_diagnosis, - enable_gc_diagnosis, -) +from distributed.utils_perf import disable_gc_diagnosis, enable_gc_diagnosis from distributed.versions import get_versions +from distributed.worker_memory import ( + DeprecatedMemoryManagerAttribute, + DeprecatedMemoryMonitor, + WorkerMemoryManager, +) from distributed.worker_state_machine import Instruction # noqa: F401 from distributed.worker_state_machine import ( PROCESSING, @@ -208,17 +208,6 @@ class Worker(ServerNode): * **tasks**: ``{key: TaskState}`` The tasks currently executing on this worker (and any dependencies of those tasks) - * **data:** ``{key: object}``: - Prefer using the **host** attribute instead of this, unless - memory_limit and at least one of memory_target_fraction or - memory_spill_fraction values are defined, in that case, this attribute - is a zict.Buffer, from which information on LRU cache can be queried. - * **data.memory:** ``{key: object}``: - Dictionary mapping keys to actual values stored in memory. Only - available if condition for **data** being a zict.Buffer is met. - * **data.disk:** ``{key: object}``: - Dictionary mapping keys to actual values stored on disk. Only - available if condition for **data** being a zict.Buffer is met. * **data_needed**: UniqueTaskHeap The tasks which still require data in order to execute, prioritized as a heap * **ready**: [keys] @@ -401,12 +390,6 @@ class Worker(ServerNode): extensions: dict security: Security connection_args: dict[str, Any] - memory_limit: int | None - memory_target_fraction: float | Literal[False] - memory_spill_fraction: float | Literal[False] - memory_pause_fraction: float | Literal[False] - max_spill: int | Literal[False] - data: MutableMapping[str, Any] # {task key: task payload} actors: dict[str, Actor | None] loop: IOLoop reconnect: bool @@ -425,9 +408,6 @@ class Worker(ServerNode): low_level_profiler: bool scheduler: Any execution_state: dict[str, Any] - memory_monitor_interval: float | None - _memory_monitoring: bool - _throttled_gc: ThrottledGC plugins: dict[str, WorkerPlugin] _pending_plugins: tuple[WorkerPlugin, ...] @@ -444,7 +424,6 @@ def __init__( services: dict | None = None, name: Any | None = None, reconnect: bool = True, - memory_limit: str | float = "auto", executor: Executor | dict[str, Executor] | Literal["offload"] | None = None, resources: dict[str, float] | None = None, silence_logs: int | None = None, @@ -454,24 +433,11 @@ def __init__( security: Security | dict[str, Any] | None = None, contact_address: str | None = None, heartbeat_interval: Any = "1s", - memory_monitor_interval: Any = "200ms", - memory_target_fraction: float | Literal[False] | None = None, - memory_spill_fraction: float | Literal[False] | None = None, - memory_pause_fraction: float | Literal[False] | None = None, - max_spill: float | str | Literal[False] | None = None, extensions: list[type] | None = None, metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS, startup_information: Mapping[ str, Callable[[Worker], Any] ] = DEFAULT_STARTUP_INFORMATION, - data: ( - MutableMapping[str, Any] # pre-initialised - | Callable[[], MutableMapping[str, Any]] # constructor - | tuple[ - Callable[..., MutableMapping[str, Any]], dict[str, Any] - ] # (constructor, kwargs to constructor) - | None # create internatlly - ) = None, interface: str | None = None, host: str | None = None, port: int | None = None, @@ -487,6 +453,18 @@ def __init__( lifetime: Any | None = None, lifetime_stagger: Any | None = None, lifetime_restart: bool | None = None, + ################################### + # Parameters to WorkerMemoryManager + memory_limit: str | float = "auto", + # Allow overriding the dict-like that stores the task outputs. + # This is meant for power users only. See WorkerMemoryManager for details. + data=None, + # Deprecated parameters; please use dask config instead. + memory_target_fraction: float | Literal[False] | None = None, + memory_spill_fraction: float | Literal[False] | None = None, + memory_pause_fraction: float | Literal[False] | None = None, + ################################### + # Parameters to Server **kwargs, ): self.tasks = {} @@ -686,54 +664,6 @@ def __init__( assert isinstance(self.security, Security) self.connection_args = self.security.get_connection_args("worker") - self.memory_limit = parse_memory_limit(memory_limit, self.nthreads) - - self.memory_target_fraction = ( - memory_target_fraction - if memory_target_fraction is not None - else dask.config.get("distributed.worker.memory.target") - ) - self.memory_spill_fraction = ( - memory_spill_fraction - if memory_spill_fraction is not None - else dask.config.get("distributed.worker.memory.spill") - ) - self.memory_pause_fraction = ( - memory_pause_fraction - if memory_pause_fraction is not None - else dask.config.get("distributed.worker.memory.pause") - ) - - if max_spill is None: - max_spill = dask.config.get("distributed.worker.memory.max-spill") - self.max_spill = False if max_spill is False else parse_bytes(max_spill) - - if isinstance(data, MutableMapping): - self.data = data - elif callable(data): - self.data = data() - elif isinstance(data, tuple): - self.data = data[0](**data[1]) - elif self.memory_limit and ( - self.memory_target_fraction or self.memory_spill_fraction - ): - from distributed.spill import SpillBuffer - - if self.memory_target_fraction: - target = int( - self.memory_limit - * (self.memory_target_fraction or self.memory_spill_fraction) - ) - else: - target = sys.maxsize - self.data = SpillBuffer( - os.path.join(self.local_directory, "storage"), - target=target, - max_spill=self.max_spill, - ) - else: - self.data = {} - self.actors = {} self.loop = loop or IOLoop.current() self.reconnect = reconnect @@ -851,24 +781,19 @@ def __init__( self._address = contact_address - self.memory_monitor_interval = parse_timedelta( - memory_monitor_interval, default="ms" - ) - self._memory_monitoring = False - if self.memory_limit: - assert self.memory_monitor_interval is not None - pc = PeriodicCallback( - self.memory_monitor, # type: ignore - self.memory_monitor_interval * 1000, - ) - self.periodic_callbacks["memory"] = pc - if extensions is None: extensions = DEFAULT_EXTENSIONS for ext in extensions: ext(self) - self._throttled_gc = ThrottledGC(logger=logger) + self.memory_manager = WorkerMemoryManager( + self, + data=data, + memory_limit=memory_limit, + memory_target_fraction=memory_target_fraction, + memory_spill_fraction=memory_spill_fraction, + memory_pause_fraction=memory_pause_fraction, + ) setproctitle("dask-worker [not started]") @@ -902,6 +827,33 @@ def __init__( Worker._instances.add(self) + ################ + # Memory manager + ################ + memory_manager: WorkerMemoryManager + + @property + def data(self) -> MutableMapping[str, Any]: + """{task key: task payload} of all completed tasks, whether they were computed on + this Worker or computed somewhere else and then transferred here over the + network. + + When using the default configuration, this is a zict buffer that automatically + spills to disk whenever the target threshold is exceeded. + If spilling is disabled, it is a plain dict instead. + It could also be a user-defined arbitrary dict-like passed when initialising + the Worker or the Nanny. + Worker logic should treat this opaquely and stick to the MutableMapping API. + """ + return self.memory_manager.data + + # Deprecated attributes moved to self.memory_manager.<name> + memory_limit = DeprecatedMemoryManagerAttribute() + memory_target_fraction = DeprecatedMemoryManagerAttribute() + memory_spill_fraction = DeprecatedMemoryManagerAttribute() + memory_pause_fraction = DeprecatedMemoryManagerAttribute() + memory_monitor = DeprecatedMemoryMonitor() + ################## # Administrative # ################## @@ -923,12 +875,13 @@ def logs(self): return self._deque_handler.deque def log_event(self, topic, msg): - self.batched_stream.send( + self.loop.add_callback( + self.batched_stream.send, { "op": "log-event", "topic": topic, "msg": msg, - } + }, ) @property @@ -944,23 +897,21 @@ def worker_address(self): """For API compatibility with Nanny""" return self.address - @property - def local_dir(self): - """For API compatibility with Nanny""" - warnings.warn( - "The local_dir attribute has moved to local_directory", stacklevel=2 - ) - return self.local_directory - @property def executor(self): return self.executors["default"] @ServerNode.status.setter # type: ignore def status(self, value): - """Override Server.status to notify the Scheduler of status changes""" + """Override Server.status to notify the Scheduler of status changes. + Also handles unpausing. + """ + prev_status = self.status ServerNode.status.__set__(self, value) self._send_worker_status_change() + if prev_status == Status.paused and value == Status.running: + self.ensure_computing() + self.ensure_communicating() def _send_worker_status_change(self) -> None: if ( @@ -1029,12 +980,10 @@ def identity(self): "id": self.id, "scheduler": self.scheduler.address, "nthreads": self.nthreads, - "memory_limit": self.memory_limit, + "memory_limit": self.memory_manager.memory_limit, } - def _to_dict( - self, comm: Comm | None = None, *, exclude: Container[str] = () - ) -> dict: + def _to_dict(self, *, exclude: Container[str] = ()) -> dict: """Dictionary representation for debugging purposes. Not type stable and not intended for roundtrips. @@ -1059,16 +1008,13 @@ def _to_dict( "in_flight_workers": self.in_flight_workers, "log": self.log, "tasks": self.tasks, - "memory_limit": self.memory_limit, - "memory_target_fraction": self.memory_target_fraction, - "memory_spill_fraction": self.memory_spill_fraction, - "memory_pause_fraction": self.memory_pause_fraction, "logs": self.get_logs(), "config": dask.config.config, "incoming_transfer_log": self.incoming_transfer_log, "outgoing_transfer_log": self.outgoing_transfer_log, } info.update(extra) + info.update(self.memory_manager._to_dict(exclude=exclude)) info = {k: v for k, v in info.items() if k not in exclude} return recursive_to_dict(info, exclude=exclude) @@ -1109,7 +1055,7 @@ async def _register_with_scheduler(self): types={k: typename(v) for k, v in self.data.items()}, now=time(), resources=self.total_resources, - memory_limit=self.memory_limit, + memory_limit=self.memory_manager.memory_limit, local_directory=self.local_directory, services=self.service_ports, nanny=self.nanny, @@ -1400,8 +1346,11 @@ async def start(self): logger.info("Waiting to connect to: %26s", self.scheduler.address) logger.info("-" * 49) logger.info(" Threads: %26d", self.nthreads) - if self.memory_limit: - logger.info(" Memory: %26s", format_bytes(self.memory_limit)) + if self.memory_manager.memory_limit: + logger.info( + " Memory: %26s", + format_bytes(self.memory_manager.memory_limit), + ) logger.info(" Local Directory: %26s", self.local_directory) setproctitle("dask-worker [%s]" % self.address) @@ -1529,19 +1478,11 @@ async def close( for executor in self.executors.values(): if executor is utils._offload_executor: continue # Never shutdown the offload executor - - def _close(): - if isinstance(executor, ThreadPoolExecutor): - executor._work_queue.queue.clear() - executor.shutdown(wait=executor_wait, timeout=timeout) - else: - executor.shutdown(wait=executor_wait) - - # Waiting for the shutdown can block the event loop causing - # weird deadlocks particularly if the task that is executing in - # the thread is waiting for a server reply, e.g. when using - # worker clients, semaphores, etc. - await to_thread(_close) + if isinstance(executor, ThreadPoolExecutor): + executor._work_queue.queue.clear() + executor.shutdown(wait=executor_wait, timeout=timeout) + else: + executor.shutdown(wait=executor_wait) self.stop() await self.rpc.close() @@ -3608,115 +3549,6 @@ def _prepare_args_for_execution( ################## # Administrative # ################## - - async def memory_monitor(self) -> None: - """Track this process's memory usage and act accordingly - - If we rise above 70% memory use, start dumping data to disk. - - If we rise above 80% memory use, stop execution of new tasks - """ - if self._memory_monitoring: - return - self._memory_monitoring = True - assert self.memory_limit - total = 0 - - memory = self.monitor.get_process_memory() - frac = memory / self.memory_limit - - def check_pause(memory): - frac = memory / self.memory_limit - # Pause worker threads if above 80% memory use - if self.memory_pause_fraction and frac > self.memory_pause_fraction: - # Try to free some memory while in paused state - self._throttled_gc.collect() - if self.status == Status.running: - logger.warning( - "Worker is at %d%% memory usage. Pausing worker. " - "Process memory: %s -- Worker memory limit: %s", - int(frac * 100), - format_bytes(memory), - format_bytes(self.memory_limit) - if self.memory_limit is not None - else "None", - ) - self.status = Status.paused - elif self.status == Status.paused: - logger.warning( - "Worker is at %d%% memory usage. Resuming worker. " - "Process memory: %s -- Worker memory limit: %s", - int(frac * 100), - format_bytes(memory), - format_bytes(self.memory_limit) - if self.memory_limit is not None - else "None", - ) - self.status = Status.running - self.ensure_computing() - self.ensure_communicating() - - check_pause(memory) - # Dump data to disk if above 70% - if self.memory_spill_fraction and frac > self.memory_spill_fraction: - from distributed.spill import SpillBuffer - - assert isinstance(self.data, SpillBuffer) - - logger.debug( - "Worker is at %.0f%% memory usage. Start spilling data to disk.", - frac * 100, - ) - # Implement hysteresis cycle where spilling starts at the spill threshold - # and stops at the target threshold. Normally that here the target threshold - # defines process memory, whereas normally it defines reported managed - # memory (e.g. output of sizeof() ). - # If target=False, disable hysteresis. - target = self.memory_limit * ( - self.memory_target_fraction or self.memory_spill_fraction - ) - count = 0 - need = memory - target - while memory > target: - if not self.data.fast: - logger.warning( - "Unmanaged memory use is high. This may indicate a memory leak " - "or the memory may not be released to the OS; see " - "https://distributed.dask.org/en/latest/worker.html#memtrim " - "for more information. " - "-- Unmanaged memory: %s -- Worker memory limit: %s", - format_bytes(memory), - format_bytes(self.memory_limit), - ) - break - weight = self.data.evict() - if weight == -1: - # Failed to evict: - # disk full, spill size limit exceeded, or pickle error - break - - total += weight - count += 1 - await asyncio.sleep(0) - - memory = self.monitor.get_process_memory() - if total > need and memory > target: - # Issue a GC to ensure that the evicted data is actually - # freed from memory and taken into account by the monitor - # before trying to evict even more data. - self._throttled_gc.collect() - memory = self.monitor.get_process_memory() - - check_pause(memory) - if count: - logger.debug( - "Moved %d tasks worth %s to disk", - count, - format_bytes(total), - ) - - self._memory_monitoring = False - def cycle_profile(self) -> None: now = time() + self.scheduler_delay prof, self.profile_recent = self.profile_recent, profile.create() @@ -4267,25 +4099,6 @@ class Reschedule(Exception): """ -def parse_memory_limit(memory_limit, nthreads, total_cores=CPU_COUNT) -> int | None: - if memory_limit is None: - return None - - if memory_limit == "auto": - memory_limit = int(system.MEMORY_LIMIT * min(1, nthreads / total_cores)) - with suppress(ValueError, TypeError): - memory_limit = float(memory_limit) - if isinstance(memory_limit, float) and memory_limit <= 1: - memory_limit = int(memory_limit * system.MEMORY_LIMIT) - - if isinstance(memory_limit, str): - memory_limit = parse_bytes(memory_limit) - else: - memory_limit = int(memory_limit) - - return min(memory_limit, system.MEMORY_LIMIT) - - async def get_data_from_worker( rpc, keys, diff --git a/distributed/worker_memory.py b/distributed/worker_memory.py new file mode 100644 index 00000000000..f08dacdb2c2 --- /dev/null +++ b/distributed/worker_memory.py @@ -0,0 +1,406 @@ +"""Encapsulated manager for in-memory tasks on a worker. + +This module covers: +- spill/unspill data depending on the 'distributed.worker.memory.target' threshold +- spill/unspill data depending on the 'distributed.worker.memory.spill' threshold +- pause/unpause the worker depending on the 'distributed.worker.memory.pause' threshold +- kill the worker depending on the 'distributed.worker.memory.terminate' threshold + +This module does *not* cover: +- Changes in behaviour in Worker, Scheduler, task stealing, Active Memory Manager, etc. + caused by the Worker being in paused status +- Worker restart after it's been killed +- Scheduler-side heuristics regarding memory usage, e.g. the Active Memory Manager + +See also: +- :mod:`distributed.spill`, which implements the spill-to-disk mechanism and is wrapped + by this module. Unlike this module, :mod:`distributed.spill` is agnostic to the + Worker. +- :mod:`distributed.active_memory_manager`, which runs on the scheduler side +""" +from __future__ import annotations + +import asyncio +import logging +import os +import sys +import warnings +from collections.abc import Callable, MutableMapping +from contextlib import suppress +from functools import partial +from typing import TYPE_CHECKING, Any, Container, Literal, cast + +import psutil +from tornado.ioloop import PeriodicCallback + +import dask.config +from dask.system import CPU_COUNT +from dask.utils import format_bytes, parse_bytes, parse_timedelta + +from distributed import system +from distributed.core import Status +from distributed.spill import ManualEvictProto, SpillBuffer +from distributed.utils import log_errors +from distributed.utils_perf import ThrottledGC + +if TYPE_CHECKING: + # Circular imports + from distributed.nanny import Nanny + from distributed.worker import Worker + +logger = logging.getLogger(__name__) + + +class WorkerMemoryManager: + data: MutableMapping[str, Any] # {task key: task payload} + memory_limit: int | None + memory_target_fraction: float | Literal[False] + memory_spill_fraction: float | Literal[False] + memory_pause_fraction: float | Literal[False] + max_spill: int | Literal[False] + memory_monitor_interval: float + _memory_monitoring: bool + _throttled_gc: ThrottledGC + + def __init__( + self, + worker: Worker, + *, + memory_limit: str | float = "auto", + # This should be None most of the times, short of a power user replacing the + # SpillBuffer with their own custom dict-like + data: ( + MutableMapping[str, Any] # pre-initialised + | Callable[[], MutableMapping[str, Any]] # constructor + | tuple[ + Callable[..., MutableMapping[str, Any]], dict[str, Any] + ] # (constructor, kwargs to constructor) + | None # create internally + ) = None, + # Deprecated parameters; use dask.config instead + memory_target_fraction: float | Literal[False] | None = None, + memory_spill_fraction: float | Literal[False] | None = None, + memory_pause_fraction: float | Literal[False] | None = None, + ): + self.memory_limit = parse_memory_limit(memory_limit, worker.nthreads) + + self.memory_target_fraction = _parse_threshold( + "distributed.worker.memory.target", + "memory_target_fraction", + memory_target_fraction, + ) + self.memory_spill_fraction = _parse_threshold( + "distributed.worker.memory.spill", + "memory_spill_fraction", + memory_spill_fraction, + ) + self.memory_pause_fraction = _parse_threshold( + "distributed.worker.memory.pause", + "memory_pause_fraction", + memory_pause_fraction, + ) + + max_spill = dask.config.get("distributed.worker.memory.max-spill") + self.max_spill = False if max_spill is False else parse_bytes(max_spill) + + if isinstance(data, MutableMapping): + self.data = data + elif callable(data): + self.data = data() + elif isinstance(data, tuple): + self.data = data[0](**data[1]) + elif self.memory_limit and ( + self.memory_target_fraction or self.memory_spill_fraction + ): + if self.memory_target_fraction: + target = int( + self.memory_limit + * (self.memory_target_fraction or self.memory_spill_fraction) + ) + else: + target = sys.maxsize + self.data = SpillBuffer( + os.path.join(worker.local_directory, "storage"), + target=target, + max_spill=self.max_spill, + ) + else: + self.data = {} + + self._memory_monitoring = False + + self.memory_monitor_interval = parse_timedelta( + dask.config.get("distributed.worker.memory.monitor-interval"), + default=None, + ) + assert isinstance(self.memory_monitor_interval, (int, float)) + + if self.memory_limit and ( + self.memory_spill_fraction is not False + or self.memory_pause_fraction is not False + ): + assert self.memory_monitor_interval is not None + pc = PeriodicCallback( + # Don't store worker as self.worker to avoid creating a circular + # dependency. We could have alternatively used a weakref. + # FIXME annotations: https://github.com/tornadoweb/tornado/issues/3117 + partial(self.memory_monitor, worker), # type: ignore + self.memory_monitor_interval * 1000, + ) + worker.periodic_callbacks["memory_monitor"] = pc + + self._throttled_gc = ThrottledGC(logger=logger) + + async def memory_monitor(self, worker: Worker) -> None: + """Track this process's memory usage and act accordingly. + If process memory rises above the spill threshold (70%), start dumping data to + disk until it goes below the target threshold (60%). + If process memory rises above the pause threshold (80%), stop execution of new + tasks. + """ + with log_errors(): + if self._memory_monitoring: + return + self._memory_monitoring = True + try: + # Don't use psutil directly; instead read from the same API that is used + # to send info to the Scheduler (e.g. for the benefit of Active Memory + # Manager) and which can be easily mocked in unit tests. + memory = worker.monitor.get_process_memory() + self._maybe_pause_or_unpause(worker, memory) + await self._maybe_spill(worker, memory) + finally: + self._memory_monitoring = False + + def _maybe_pause_or_unpause(self, worker: Worker, memory: int) -> None: + if self.memory_pause_fraction is False: + return + + assert self.memory_limit + frac = memory / self.memory_limit + # Pause worker threads if above 80% memory use + if frac > self.memory_pause_fraction: + # Try to free some memory while in paused state + self._throttled_gc.collect() + if worker.status == Status.running: + logger.warning( + "Worker is at %d%% memory usage. Pausing worker. " + "Process memory: %s -- Worker memory limit: %s", + int(frac * 100), + format_bytes(memory), + format_bytes(self.memory_limit) + if self.memory_limit is not None + else "None", + ) + worker.status = Status.paused + elif worker.status == Status.paused: + logger.warning( + "Worker is at %d%% memory usage. Resuming worker. " + "Process memory: %s -- Worker memory limit: %s", + int(frac * 100), + format_bytes(memory), + format_bytes(self.memory_limit) + if self.memory_limit is not None + else "None", + ) + worker.status = Status.running + + async def _maybe_spill(self, worker: Worker, memory: int) -> None: + if self.memory_spill_fraction is False: + return + + # SpillBuffer or a duct-type compatible MutableMapping which offers the + # fast property and evict() methods. Dask-CUDA uses this. + if not hasattr(self.data, "fast") or not hasattr(self.data, "evict"): + return + data = cast(ManualEvictProto, self.data) + + assert self.memory_limit + frac = memory / self.memory_limit + if frac <= self.memory_spill_fraction: + return + + total_spilled = 0 + logger.debug( + "Worker is at %.0f%% memory usage. Start spilling data to disk.", + frac * 100, + ) + # Implement hysteresis cycle where spilling starts at the spill threshold and + # stops at the target threshold. Normally that here the target threshold defines + # process memory, whereas normally it defines reported managed memory (e.g. + # output of sizeof() ). If target=False, disable hysteresis. + target = self.memory_limit * ( + self.memory_target_fraction or self.memory_spill_fraction + ) + count = 0 + need = memory - target + while memory > target: + if not data.fast: + logger.warning( + "Unmanaged memory use is high. This may indicate a memory leak " + "or the memory may not be released to the OS; see " + "https://distributed.dask.org/en/latest/worker.html#memtrim " + "for more information. " + "-- Unmanaged memory: %s -- Worker memory limit: %s", + format_bytes(memory), + format_bytes(self.memory_limit), + ) + break + + weight = data.evict() + if weight == -1: + # Failed to evict: + # disk full, spill size limit exceeded, or pickle error + break + + total_spilled += weight + count += 1 + await asyncio.sleep(0) + + memory = worker.monitor.get_process_memory() + if total_spilled > need and memory > target: + # Issue a GC to ensure that the evicted data is actually + # freed from memory and taken into account by the monitor + # before trying to evict even more data. + self._throttled_gc.collect() + memory = worker.monitor.get_process_memory() + + self._maybe_pause_or_unpause(worker, memory) + if count: + logger.debug( + "Moved %d tasks worth %s to disk", + count, + format_bytes(total_spilled), + ) + + def _to_dict(self, *, exclude: Container[str] = ()) -> dict: + info = { + k: v + for k, v in self.__dict__.items() + if not k.startswith("_") and k != "data" and k not in exclude + } + info["data"] = list(self.data) + return info + + +class NannyMemoryManager: + memory_limit: int | None + memory_terminate_fraction: float | Literal[False] + memory_monitor_interval: float | None + + def __init__( + self, + nanny: Nanny, + *, + memory_limit: str | float = "auto", + ): + self.memory_limit = parse_memory_limit(memory_limit, nanny.nthreads) + self.memory_terminate_fraction = dask.config.get( + "distributed.worker.memory.terminate" + ) + self.memory_monitor_interval = parse_timedelta( + dask.config.get("distributed.worker.memory.monitor-interval"), + default=None, + ) + assert isinstance(self.memory_monitor_interval, (int, float)) + if self.memory_limit and self.memory_terminate_fraction is not False: + pc = PeriodicCallback( + partial(self.memory_monitor, nanny), + self.memory_monitor_interval * 1000, + ) + nanny.periodic_callbacks["memory_monitor"] = pc + + def memory_monitor(self, nanny: Nanny) -> None: + """Track worker's memory. Restart if it goes above terminate fraction.""" + if nanny.status != Status.running: + return # pragma: nocover + if nanny.process is None or nanny.process.process is None: + return # pragma: nocover + process = nanny.process.process + try: + proc = nanny._psutil_process + memory = proc.memory_info().rss + except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied): + return # pragma: nocover + + if memory / self.memory_limit > self.memory_terminate_fraction: + logger.warning( + "Worker exceeded %d%% memory budget. Restarting", + 100 * self.memory_terminate_fraction, + ) + process.terminate() + + +def parse_memory_limit( + memory_limit: str | float, nthreads: int, total_cores: int = CPU_COUNT +) -> int | None: + if memory_limit is None: + return None + + if memory_limit == "auto": + memory_limit = int(system.MEMORY_LIMIT * min(1, nthreads / total_cores)) + with suppress(ValueError, TypeError): + memory_limit = float(memory_limit) + if isinstance(memory_limit, float) and memory_limit <= 1: + memory_limit = int(memory_limit * system.MEMORY_LIMIT) + + if isinstance(memory_limit, str): + memory_limit = parse_bytes(memory_limit) + else: + memory_limit = int(memory_limit) + + assert isinstance(memory_limit, int) + if memory_limit == 0: + return None + return min(memory_limit, system.MEMORY_LIMIT) + + +def _parse_threshold( + config_key: str, + deprecated_param_name: str, + deprecated_param_value: float | Literal[False] | None, +) -> float | Literal[False]: + if deprecated_param_value is not None: + warnings.warn( + f"Parameter {deprecated_param_name} has been deprecated and will be " + f"removed in a future version; please use dask config key {config_key} " + "instead", + FutureWarning, + ) + return deprecated_param_value + return dask.config.get(config_key) + + +def _warn_deprecated(w: Nanny | Worker, name: str) -> None: + warnings.warn( + f"The `{type(w).__name__}.{name}` attribute has been moved to " + f"`{type(w).__name__}.memory_manager.{name}", + FutureWarning, + ) + + +class DeprecatedMemoryManagerAttribute: + name: str + + def __set_name__(self, owner: type, name: str) -> None: + self.name = name + + def __get__(self, instance: Nanny | Worker | None, _): + if instance is None: + # This is triggered by Sphinx + return None # pragma: nocover + _warn_deprecated(instance, self.name) + return getattr(instance.memory_manager, self.name) + + def __set__(self, instance: Nanny | Worker, value) -> None: + _warn_deprecated(instance, self.name) + setattr(instance.memory_manager, self.name, value) + + +class DeprecatedMemoryMonitor: + def __get__(self, instance: Nanny | Worker | None, owner): + if instance is None: + # This is triggered by Sphinx + return None # pragma: nocover + _warn_deprecated(instance, "memory_monitor") + return partial(instance.memory_manager.memory_monitor, instance) diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 3bb87cff0d0..32b202f1d07 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,57 @@ Changelog ========= +.. _v2022.03.0: + +2022.03.0 +--------- + +Released on March 18, 2022 + +New Features +^^^^^^^^^^^^ +- Support dumping cluster state to URL (:pr:`5863`) `Gabe Joseph`_ + +Enhancements +^^^^^^^^^^^^ +- Prevent data duplication on unspill (:pr:`5936`) `crusaderky`_ +- Encapsulate spill buffer and memory_monitor (:pr:`5904`) `crusaderky`_ +- Drop ``pkg_resources`` in favour of ``importlib.metadata`` (:pr:`5923`) `Thomas Grainger`_ +- Worker State Machine refactor: redesign ``TaskState`` and scheduler messages (:pr:`5922`) `crusaderky`_ +- Tidying of OpenSSL 1.0.2/Python 3.9 (and earlier) handling (:pr:`5854`) `jakirkham`_ +- ``zict`` type annotations (:pr:`5905`) `crusaderky`_ +- Add key to compute failed message (:pr:`5928`) `Florian Jetter`_ +- Change default log format to include timestamp (:pr:`5897`) `Florian Jetter`_ +- Improve type annotations in worker.py (:pr:`5814`) `crusaderky`_ + +Bug Fixes +^^^^^^^^^ +- Fix ``progress_stream`` teardown (:pr:`5823`) `Thomas Grainger`_ +- Handle concurrent or failing handshakes in ``InProcListener`` (:pr:`5903`) `Thomas Grainger`_ +- Make ``log_event`` threadsafe (:pr:`5946`) `Gabe Joseph`_ + +Documentation +^^^^^^^^^^^^^ +- Fixes to documentation regarding plugins (:pr:`5940`) `crendoncoiled`_ +- Some updates to scheduling policies docs (:pr:`5911`) `Gabe Joseph`_ + +Maintenance +^^^^^^^^^^^ +- Fix ``test_nanny_worker_port_range`` hangs on Windows (:pr:`5956`) `crusaderky`_ +- (REVERTED) Unblock event loop while waiting for ThreadpoolExecutor to shut down (:pr:`5883`) `Florian Jetter`_ +- Revert :pr:`5883` (:pr:`5961`) `crusaderky`_ +- Invert ``event_name`` check in ``test-report`` job (:pr:`5959`) `jakirkham`_ +- Only run ``gh-pages`` workflow on ``dask/distributed`` (:pr:`5942`) `jakirkham`_ +- ``absolufy-imports`` - No relative imports - PEP8 (:pr:`5924`) `Florian Jetter`_ +- Fix ``track_features`` for distributed pre-releases (:pr:`5927`) `Charles Blackmon-Luca`_ +- Xfail ``test_submit_different_names`` (:pr:`5916`) `Florian Jetter`_ +- Fix ``distributed`` pre-release's ``distributed-impl`` constraint (:pr:`5867`) `Charles Blackmon-Luca`_ +- Mock process memory readings in test_worker.py (v2) (:pr:`5878`) `crusaderky`_ +- Drop unused ``_round_robin`` global variable (:pr:`5881`) `jakirkham`_ +- Add GitHub URL for PyPi (:pr:`5886`) `Andrii Oriekhov`_ +- Mark ``xfail`` COMPILED tests ``skipif`` instead (:pr:`5884`) `Florian Jetter`_ + + .. _v2022.02.1: 2022.02.1 @@ -3333,3 +3384,5 @@ significantly without many new features. .. _`Sarah Charlotte Johnson`: https://github.com/scharlottej13 .. _`Tim Harris`: https://github.com/tharris72 .. _`Bryan W. Weber`: https://github.com/bryanwweber +.. _`crendoncoiled`: https://github.com/crendoncoiled +.. _`Andrii Oriekhov`: https://github.com/andriyor \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 6535ac51034..d19ce0cfca7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -100,6 +100,7 @@ Contents scheduling-policies scheduling-state worker + worker-memory work-stealing killed diff --git a/docs/source/killed.rst b/docs/source/killed.rst index b63e9517d22..104a5e38068 100644 --- a/docs/source/killed.rst +++ b/docs/source/killed.rst @@ -102,7 +102,7 @@ interrupting any work in progress. The log will show a message like Worker exceeded X memory budget. Restarting Where X is the memory fraction. You can set this critical fraction using -the configuration, see :ref:`memman`. If you have an external system for +the configuration, see :doc:`worker-memory`. If you have an external system for watching memory usage provided by your cluster infrastructure (HPC, kubernetes, etc.), then it may be reasonable to turn off this memory limit. Indeed, in these cases, restarts might be handled for you too, so diff --git a/docs/source/memory.rst b/docs/source/memory.rst index 6b8e76d7d71..b207812af4e 100644 --- a/docs/source/memory.rst +++ b/docs/source/memory.rst @@ -210,4 +210,4 @@ usually necessary. Worker memory management ------------------------ -Memory usage can be optimized by configuring worker-side :ref:`memman`. +Memory usage can be optimized by configuring worker-side :doc:`worker-memory`. diff --git a/docs/source/scheduling-policies.rst b/docs/source/scheduling-policies.rst index 63fa9fa41f4..4375e4b38fe 100644 --- a/docs/source/scheduling-policies.rst +++ b/docs/source/scheduling-policies.rst @@ -50,7 +50,7 @@ would all belong to the TaskGroup ``random-a1b2c3``.) To identify the root(ish) tasks, we use this heuristic: 1. The TaskGroup has 2x more tasks than there are threads in the cluster -2. The TaskGroup has fewer than 5 dependencies across *all* tasks in the group. +2. The TaskGroup has fewer than 5 unique dependencies across *all* tasks in the group. We don't just say "The task has no dependencies", because real-world cases like :obj:`dask.array.from_zarr` and :obj:`dask.array.from_array` produce graphs like the one @@ -82,18 +82,36 @@ where can the task run the soonest, considering both data transfer and worker bu Tasks that don't meet the root-ish criteria described above are selected as follows: +First, we identify the pool of viable workers: + 1. If the task has no dependencies and no restrictions, then we find the least-occupied worker. 2. Otherwise, if a task has user-provided restrictions (for example it must run on a machine with a GPU) then we restrict the available pool of workers - to just that set, otherwise we consider all workers. -3. From among this pool of workers, we determine the workers to whom the least - amount of data would need to be transferred. -4. We break ties by choosing the worker that currently has the fewest tasks, - counting both those tasks in memory and those tasks processing currently. + to just that set. Otherwise, we consider all workers. +3. We restrict the above set to just workers that hold at least one dependency + of the task. + +From among this pool of workers, we then determine the worker where we think the task will +start running the soonest, using :meth:`Scheduler.worker_objective`. For each worker: + +1. We consider the estimated runtime of other tasks already queued on that worker. + Then, we add how long it will take to transfer any dependencies to that worker that + it doesn't already have, based on their size, in bytes, and the measured network + bandwith between workers. Note that this does *not* consider (de)serialization + time, time to retrieve the data from disk if it was spilled, or potential differences + between size in memory and serialized size. In practice, the + queue-wait-time (known as *occupancy*) usually dominates, so data will usually be + transferred to a different worker if it means the task can start any sooner. +2. It's possible for ties to occur with the "start soonest" metric, though uncommon + when all workers are busy. We break ties by choosing the worker that has the + fewest number of bytes of Dask data stored (including spilled data). Note that + this is the same as :ref:`managed <memtypes>` plus :ref:`spilled <memtypes>` + memory, not the :ref:`process <memtypes>` memory. This process is easy to change (and indeed this document may be outdated). We -encourage readers to inspect the ``decide_worker`` functions in ``scheduler.py``. +encourage readers to inspect the ``decide_worker`` and ``worker_objective`` +functions in ``scheduler.py``. .. currentmodule:: distributed.scheduler @@ -101,6 +119,8 @@ encourage readers to inspect the ``decide_worker`` functions in ``scheduler.py`` .. autosummary:: Scheduler.decide_worker +.. autosummary:: Scheduler.worker_objective + Choosing Tasks -------------- @@ -125,11 +145,11 @@ all of these (they all come up in important workloads) quickly. Last in, first out ~~~~~~~~~~~~~~~~~~ -When a worker finishes a task the immediate dependencies of that task get top +When a worker finishes a task, the immediate dependencies of that task get top priority. This encourages a behavior of finishing ongoing work immediately -before starting new work. This often conflicts with the -first-come-first-served objective but often results in shorter total runtimes -and significantly reduced memory footprints. +before starting new work (depth-first graph traversal). This often conflicts with +the first-come-first-served objective, but often results in significantly reduced +memory footprints and, due to avoiding data spillage to disk, better overall runtimes. .. _priority-break-ties: @@ -162,10 +182,11 @@ However, workers inevitably run out of tasks that were related to tasks they were just working on and the last-in-first-out policy eventually exhausts itself. In these cases workers often pull tasks from the common task pool. The tasks in this pool *are* ordered in a first-come-first-served basis and so -workers do behave in a fair scheduling manner at a *coarse* level if not a fine -grained one. +workers do behave in a scheduling manner that's fair to multiple submissions +at a *coarse* level, if not a fine-grained one. -Dask's scheduling policies are short-term-efficient and long-term-fair. +Dask's scheduling policies are short-term-efficient and long-term-fair +to multiple clients. Where these decisions are made @@ -187,9 +208,10 @@ scheduler, and workers at various points in the computation. policy between computations. All tasks from a previous call to compute have a higher priority than all tasks from a subsequent call to compute (or submit, persist, map, or any operation that generates futures). -3. Whenever a task is ready to run the scheduler assigns it to a worker. The - scheduler does not wait based on priority. -4. However when the worker receives these tasks it considers their priorities - when determining which tasks to prioritize for communication or for +3. Whenever a task is ready to run (its dependencies, if any, are complete), + the scheduler assigns it to a worker. When multiple tasks are ready at once, + they are all submitted to workers, in priority order. +4. However, when the worker receives these tasks, it considers their priorities + when determining which tasks to prioritize for fetching data or for computation. The worker maintains a heap of all ready-to-run tasks ordered by this priority. diff --git a/docs/source/worker-memory.rst b/docs/source/worker-memory.rst new file mode 100644 index 00000000000..bad40e3ed1d --- /dev/null +++ b/docs/source/worker-memory.rst @@ -0,0 +1,298 @@ +Worker Memory Management +======================== +For cluster-wide memory-management, see :doc:`memory`. + +Workers are given a target memory limit to stay under with the +command line ``--memory-limit`` keyword or the ``memory_limit=`` Python +keyword argument, which sets the memory limit per worker processes launched +by dask-worker :: + + $ dask-worker tcp://scheduler:port --memory-limit=auto # TOTAL_MEMORY * min(1, nthreads / total_nthreads) + $ dask-worker tcp://scheduler:port --memory-limit="4 GiB" # four gigabytes per worker process. + +Workers use a few different heuristics to keep memory use beneath this limit: + +Spilling based on managed memory +-------------------------------- +Every time the worker finishes a task, it estimates the size in bytes that the result +costs to keep in memory using the ``sizeof`` function. This function defaults to +:func:`sys.getsizeof` for arbitrary objects, which uses the standard Python +``__sizeof__`` protocol, but also has special-cased implementations for common data +types like NumPy arrays and Pandas dataframes. The sum of the ``sizeof`` of all data +tracked by Dask is called :ref:`managed memory <memtypes>`. + +When the managed memory exceeds 60% of the memory limit (*target threshold*), the worker +will begin to dump the least recently used data to disk. You can control this location +with the ``--local-directory`` keyword:: + + $ dask-worker tcp://scheduler:port --memory-limit="4 GiB" --local-directory /scratch + +That data is still available and will be read back from disk when necessary. On the +diagnostic dashboard status page, disk I/O will show up in the task stream plot as +orange blocks. Additionally, the memory plot in the upper left will show a section of +the bar colored in grey. + +Spilling based on process memory +-------------------------------- +The approach above can fail for a few reasons: + +1. Custom objects may not report their memory size accurately +2. User functions may take up more RAM than expected +3. Significant amounts of data may accumulate in network I/O buffers + +To address this, we periodically monitor the :ref:`process memory <memtypes>` of the +worker every 200 ms. If the system reported memory use is above 70% of the target memory +usage (*spill threshold*), then the worker will start dumping unused data to disk, even +if internal ``sizeof`` recording hasn't yet reached the normal 60% threshold. This +more aggressive spilling will continue until process memory falls below 60%. + +Pause worker +------------ +At 80% :ref:`process memory <memtypes>` load, the worker's thread pool will stop +starting computation on additional tasks in the worker's queue. This gives time for the +write-to-disk functionality to take effect even in the face of rapidly accumulating +data. Currently executing tasks continue to run. Additionally, data transfers to/from +other workers are throttled to a bare minimum. + +Kill Worker +----------- +At 95% :ref:`process memory <memtypes>` load (*terminate threshold*), a worker's nanny +process will terminate it. Tasks will be cancelled mid-execution and rescheduled +elsewhere; all unique data on the worker will be lost and will need to be recomputed. +This is to avoid having our worker job being terminated by an external watchdog (like +Kubernetes, YARN, Mesos, SGE, etc..). After termination, the nanny will restart the +worker in a fresh state. + +Thresholds configuration +------------------------ +These values can be configured by modifying the ``~/.config/dask/distributed.yaml`` +file: + +.. code-block:: yaml + + distributed: + worker: + # Fractions of worker process memory at which we take action to avoid memory + # blowup. Set any of the values to False to turn off the behavior entirely. + memory: + target: 0.60 # fraction of managed memory where we start spilling to disk + spill: 0.70 # fraction of process memory where we start spilling to disk + pause: 0.80 # fraction of process memory at which we pause worker threads + terminate: 0.95 # fraction of process memory at which we terminate the worker + +Using the dashboard to monitor memory usage +------------------------------------------- +The dashboard (typically available on port 8787) shows a summary of the overall memory +usage on the cluster, as well as the individual usage on each worker. It provides +different memory readings: + +.. _memtypes: + +process + Overall memory used by the worker process (RSS), as measured by the OS + +managed + Sum of the ``sizeof`` of all Dask data stored on the worker, excluding + spilled data. + +unmanaged + Memory usage that Dask is not directly aware of. It is estimated by subtracting + managed memory from the total process memory and typically includes: + + - The Python interpreter code, loaded modules, and global variables + - Memory temporarily used by running tasks + - Dereferenced Python objects that have not been garbage-collected yet + - Unused memory that the Python memory allocator did not return to libc through + `free`_ yet + - Unused memory that the user-space libc `free`_ function did not release to the OS + yet (see memory allocators below) + - Memory fragmentation + - Memory leaks + +unmanaged recent + Unmanaged memory that has appeared within the last 30 seconds. This is not included + in the 'unmanaged' memory measure above. Ideally, this memory should be for the most + part a temporary spike caused by tasks' heap use plus soon-to-be garbage collected + objects. + + The time it takes for unmanaged memory to transition away from its "recent" state + can be tweaked through the ``distributed.worker.memory.recent-to-old-time`` key in + the ``~/.config/dask/distributed.yaml`` file. If your tasks typically run for longer + than 30 seconds, it's recommended that you increase this setting accordingly. + + By default, :meth:`distributed.Client.rebalance` and + :meth:`distributed.scheduler.Scheduler.rebalance` ignore unmanaged recent memory. + This behaviour can also be tweaked using the Dask config - see the methods' + documentation. + +spilled + managed memory that has been spilled to disk. This is not included in the 'managed' + measure above. This measure reports the number of bytes actually spilled to disk, + which may differ from the output of ``sizeof`` particularly in case of compression. + +The sum of managed + unmanaged + unmanaged recent is equal by definition to the process +memory. + + +.. _memtrim: + +Memory not released back to the OS +---------------------------------- +In many cases, high unmanaged memory usage or "memory leak" warnings on workers can be +misleading: a worker may not actually be using its memory for anything, but simply +hasn't returned that unused memory back to the operating system, and is hoarding it just +in case it needs the memory capacity again. This is not a bug in your code, nor in +Dask — it's actually normal behavior for all processes on Linux and MacOS, and is a +consequence of how the low-level memory allocator works (see below for details). + +Because Dask makes decisions (spill-to-disk, pause, terminate, +:meth:`~distributed.Client.rebalance`) based on the worker's memory usage as reported by +the OS, and is unaware of how much of this memory is actually in use versus empty and +"hoarded", it can overestimate — sometimes significantly — how much memory the process +is using and think the worker is running out of memory when in fact it isn't. + +More in detail: both the Linux and MacOS memory allocators try to avoid performing a +`brk`_ kernel call every time the application calls `free`_ by implementing a user-space +memory management system. Upon `free`_, memory can remain allocated in user space and +potentially reusable by the next `malloc`_ - which in turn won't require a kernel call +either. This is generally very desirable for C/C++ applications which have no memory +allocator of their own, as it can drastically boost performance at the cost of a larger +memory footprint. CPython however adds its own memory allocator on top, which reduces +the need for this additional abstraction (with caveats). + +There are steps you can take to alleviate situations where worker memory is not released +back to the OS. These steps are discussed in the following sections. + +Manually trim memory +~~~~~~~~~~~~~~~~~~~~ +*Linux workers only* + +It is possible to forcefully release allocated but unutilized memory as follows: + +.. code-block:: python + + import ctypes + + def trim_memory() -> int: + libc = ctypes.CDLL("libc.so.6") + return libc.malloc_trim(0) + + client.run(trim_memory) + +This should be only used as a one-off debugging experiment. Watch the dashboard while +running the above code. If unmanaged worker memory (on the "Bytes stored" plot) +decreases significantly after calling ``client.run(trim_memory)``, then move on to the +next section. Otherwise, you likely do have a memory leak. + +Note that you should only run this `malloc_trim`_ if you are using the default glibc +memory allocator. When using a custom allocator such as `jemalloc`_ (see below), this +could cause unexpected behavior including segfaults. (If you don't know what this means, +you're probably using the default glibc allocator and are safe to run this). + +Automatically trim memory +~~~~~~~~~~~~~~~~~~~~~~~~~ +*Linux workers only* + +To aggressively and automatically trim the memory in a production environment, you +should instead set the environment variable ``MALLOC_TRIM_THRESHOLD_`` (note the final +underscore) to 0 or a low number; see the `mallopt`_ man page for details. Reducing +this value will increase the number of syscalls, and as a consequence may degrade +performance. + +.. note:: + The variable must be set before starting the ``dask-worker`` process. + +.. note:: + If using a :ref:`nanny`, the ``MALLOC_TRIM_THRESHOLD_`` environment variable + will automatically be set to ``65536`` for the worker process which the nanny is + monitoring. You can modify this behavior using the ``distributed.nanny.environ`` + configuration value. + +jemalloc +~~~~~~~~ +*Linux and MacOS workers* + +Alternatively to the above, you may experiment with the `jemalloc`_ memory allocator, as +follows: + +On Linux: + +.. code-block:: bash + + conda install jemalloc + LD_PRELOAD=$CONDA_PREFIX/lib/libjemalloc.so dask-worker <...> + +On macOS: + +.. code-block:: bash + + conda install jemalloc + DYLD_INSERT_LIBRARIES=$CONDA_PREFIX/lib/libjemalloc.dylib dask-worker <...> + +Alternatively on macOS, install globally with `homebrew`_: + +.. code-block:: bash + + brew install jemalloc + DYLD_INSERT_LIBRARIES=$(brew --prefix jemalloc)/lib/libjemalloc.dylib dask-worker <...> + +`jemalloc`_ offers a wealth of configuration settings; please refer to its +documentation. + +Ignore process memory +~~~~~~~~~~~~~~~~~~~~~ +If all else fails, you may want to stop Dask from using memory metrics from the OS (RSS) +in its decision-making: + +.. code-block:: yaml + + distributed: + worker: + memory: + rebalance: + measure: managed_in_memory + spill: false + pause: false + terminate: false + +This of course will be problematic if you have a genuine issue with unmanaged memory, +e.g. memory leaks and/or suffer from heavy fragmentation. + + +User-defined managed memory containers +-------------------------------------- +.. warning:: + This feature is intended for advanced users only; the built-in container for managed + memory should fit the needs of most. If you're looking to dynamically spill CUDA + device memory into host memory, you should use `dask-cuda`_. + +The design described in the sections above stores data in the worker's RAM, with +automatic spilling to disk when the ``target`` or ``spill`` thresholds are passed. +If one desires a different behaviour, a ``data=`` parameter can be passed when +initializing the :class:`~distributed.worker.Worker` or +:class:`~distributed.nanny.Nanny`. +This optional parameter accepts any of the following values: + +- an instance of ``MutableMapping[str, Any]`` +- a callable which returns a ``MutableMapping[str, Any]`` +- a tuple of + + - callable which returns a ``MutableMapping[str, Any]`` + - dict of keyword arguments to the callable + +Doing so causes the Worker to ignore both the ``target`` and the ``spill`` thresholds. +However, if the object also supports the following duck-type API in addition to the +MutableMapping API, the ``spill`` threshold will remain active: + +.. autoclass:: distributed.spill.ManualEvictProto + :members: + + +.. _malloc: https://www.man7.org/linux/man-pages/man3/malloc.3.html +.. _free: https://www.man7.org/linux/man-pages/man3/free.3.html +.. _mallopt: https://man7.org/linux/man-pages/man3/mallopt.3.html +.. _malloc_trim: https://man7.org/linux/man-pages/man3/malloc_trim.3.html +.. _brk: https://www.man7.org/linux/man-pages/man2/brk.2.html +.. _jemalloc: http://jemalloc.net +.. _homebrew: https://brew.sh/ +.. _dask-cuda: https://docs.rapids.ai/api/dask-cuda/stable/index.html diff --git a/docs/source/worker.rst b/docs/source/worker.rst index e6892220eb3..91cac09947a 100644 --- a/docs/source/worker.rst +++ b/docs/source/worker.rst @@ -48,6 +48,8 @@ This ``.data`` attribute is a ``MutableMapping`` that is typically a combination of in-memory and on-disk storage with an LRU policy to move data between them. +Read more: :doc:`worker-memory` + Thread Pool ----------- @@ -145,267 +147,6 @@ exceptions to this are when: occurs when a `worker dies <killed>`_ during computation. -.. _memman: - -Memory Management ------------------ -Workers are given a target memory limit to stay under with the -command line ``--memory-limit`` keyword or the ``memory_limit=`` Python -keyword argument, which sets the memory limit per worker processes launched -by dask-worker :: - - $ dask-worker tcp://scheduler:port --memory-limit=auto # TOTAL_MEMORY * min(1, nthreads / total_nthreads) - $ dask-worker tcp://scheduler:port --memory-limit="4 GiB" # four gigabytes per worker process. - -Workers use a few different heuristics to keep memory use beneath this limit: - -Spilling based on managed memory -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -Every time the worker finishes a task, it estimates the size in bytes that the result -costs to keep in memory using the ``sizeof`` function. This function defaults to -:func:`sys.getsizeof` for arbitrary objects, which uses the standard Python -``__sizeof__`` protocol, but also has special-cased implementations for common data -types like NumPy arrays and Pandas dataframes. The sum of the ``sizeof`` of all data -tracked by Dask is called :ref:`managed memory <memtypes>`. - -When the managed memory exceeds 60% of the memory limit (*target threshold*), the worker -will begin to dump the least recently used data to disk. You can control this location -with the ``--local-directory`` keyword:: - - $ dask-worker tcp://scheduler:port --memory-limit="4 GiB" --local-directory /scratch - -That data is still available and will be read back from disk when necessary. On the -diagnostic dashboard status page, disk I/O will show up in the task stream plot as -orange blocks. Additionally, the memory plot in the upper left will show a section of -the bar colored in grey. - -Spilling based on process memory -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The approach above can fail for a few reasons: - -1. Custom objects may not report their memory size accurately -2. User functions may take up more RAM than expected -3. Significant amounts of data may accumulate in network I/O buffers - -To address this, we periodically monitor the :ref:`process memory <memtypes>` of the -worker every 200 ms. If the system reported memory use is above 70% of the target memory -usage (*spill threshold*), then the worker will start dumping unused data to disk, even -if internal ``sizeof`` recording hasn't yet reached the normal 60% threshold. This -more aggressive spilling will continue until process memory falls below 60%. - -Pause worker -~~~~~~~~~~~~ -At 80% :ref:`process memory <memtypes>` load, the worker's thread pool will stop -starting computation on additional tasks in the worker's queue. This gives time for the -write-to-disk functionality to take effect even in the face of rapidly accumulating -data. Currently executing tasks continue to run. Additionally, data transfers to/from -other workers are throttled to a bare minimum. - -Kill Worker -~~~~~~~~~~~ -At 95% :ref:`process memory <memtypes>` load (*terminate threshold*), a worker's nanny -process will terminate it. Tasks will be cancelled mid-execution and rescheduled -elsewhere; all unique data on the worker will be lost and will need to be recomputed. -This is to avoid having our worker job being terminated by an external watchdog (like -Kubernetes, YARN, Mesos, SGE, etc..). After termination, the nanny will restart the -worker in a fresh state. - -Thresholds configuration -~~~~~~~~~~~~~~~~~~~~~~~~ -These values can be configured by modifying the ``~/.config/dask/distributed.yaml`` -file: - -.. code-block:: yaml - - distributed: - worker: - # Fractions of worker process memory at which we take action to avoid memory - # blowup. Set any of the values to False to turn off the behavior entirely. - memory: - target: 0.60 # fraction of managed memory where we start spilling to disk - spill: 0.70 # fraction of process memory where we start spilling to disk - pause: 0.80 # fraction of process memory at which we pause worker threads - terminate: 0.95 # fraction of process memory at which we terminate the worker - -Using the dashboard to monitor memory usage -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -The dashboard (typically available on port 8787) shows a summary of the overall memory -usage on the cluster, as well as the individual usage on each worker. It provides -different memory readings: - -.. _memtypes: - -process - Overall memory used by the worker process (RSS), as measured by the OS - -managed - Sum of the ``sizeof`` of all Dask data stored on the worker, excluding - spilled data. - -unmanaged - Memory usage that Dask is not directly aware of. It is estimated by subtracting - managed memory from the total process memory and typically includes: - - - The Python interpreter code, loaded modules, and global variables - - Memory temporarily used by running tasks - - Dereferenced Python objects that have not been garbage-collected yet - - Unused memory that the Python memory allocator did not return to libc through - `free`_ yet - - Unused memory that the user-space libc `free`_ function did not release to the OS - yet (see memory allocators below) - - Memory fragmentation - - Memory leaks - -unmanaged recent - Unmanaged memory that has appeared within the last 30 seconds. This is not included - in the 'unmanaged' memory measure above. Ideally, this memory should be for the most - part a temporary spike caused by tasks' heap use plus soon-to-be garbage collected - objects. - - The time it takes for unmanaged memory to transition away from its "recent" state - can be tweaked through the ``distributed.worker.memory.recent-to-old-time`` key in - the ``~/.config/dask/distributed.yaml`` file. If your tasks typically run for longer - than 30 seconds, it's recommended that you increase this setting accordingly. - - By default, :meth:`distributed.Client.rebalance` and - :meth:`distributed.scheduler.Scheduler.rebalance` ignore unmanaged recent memory. - This behaviour can also be tweaked using the Dask config - see the methods' - documentation. - -spilled - managed memory that has been spilled to disk. This is not included in the 'managed' - measure above. This measure reports the number of bytes actually spilled to disk, - which may differ from the output of ``sizeof`` particularly in case of compression. - -The sum of managed + unmanaged + unmanaged recent is equal by definition to the process -memory. - - -.. _memtrim: - -Memory not released back to the OS ----------------------------------- -In many cases, high unmanaged memory usage or "memory leak" warnings on workers can be -misleading: a worker may not actually be using its memory for anything, but simply -hasn't returned that unused memory back to the operating system, and is hoarding it just -in case it needs the memory capacity again. This is not a bug in your code, nor in -Dask — it's actually normal behavior for all processes on Linux and MacOS, and is a -consequence of how the low-level memory allocator works (see below for details). - -Because Dask makes decisions (spill-to-disk, pause, terminate, -:meth:`~distributed.Client.rebalance`) based on the worker's memory usage as reported by -the OS, and is unaware of how much of this memory is actually in use versus empty and -"hoarded", it can overestimate — sometimes significantly — how much memory the process -is using and think the worker is running out of memory when in fact it isn't. - -More in detail: both the Linux and MacOS memory allocators try to avoid performing a -`brk`_ kernel call every time the application calls `free`_ by implementing a user-space -memory management system. Upon `free`_, memory can remain allocated in user space and -potentially reusable by the next `malloc`_ - which in turn won't require a kernel call -either. This is generally very desirable for C/C++ applications which have no memory -allocator of their own, as it can drastically boost performance at the cost of a larger -memory footprint. CPython however adds its own memory allocator on top, which reduces -the need for this additional abstraction (with caveats). - -There are steps you can take to alleviate situations where worker memory is not released -back to the OS. These steps are discussed in the following sections. - -Manually trim memory -~~~~~~~~~~~~~~~~~~~~ -*Linux workers only* - -It is possible to forcefully release allocated but unutilized memory as follows: - -.. code-block:: python - - import ctypes - - def trim_memory() -> int: - libc = ctypes.CDLL("libc.so.6") - return libc.malloc_trim(0) - - client.run(trim_memory) - -This should be only used as a one-off debugging experiment. Watch the dashboard while -running the above code. If unmanaged worker memory (on the "Bytes stored" plot) -decreases significantly after calling ``client.run(trim_memory)``, then move on to the -next section. Otherwise, you likely do have a memory leak. - -Note that you should only run this `malloc_trim`_ if you are using the default glibc -memory allocator. When using a custom allocator such as `jemalloc`_ (see below), this -could cause unexpected behavior including segfaults. (If you don't know what this means, -you're probably using the default glibc allocator and are safe to run this). - -Automatically trim memory -~~~~~~~~~~~~~~~~~~~~~~~~~ -*Linux workers only* - -To aggressively and automatically trim the memory in a production environment, you -should instead set the environment variable ``MALLOC_TRIM_THRESHOLD_`` (note the final -underscore) to 0 or a low number; see the `mallopt`_ man page for details. Reducing -this value will increase the number of syscalls, and as a consequence may degrade -performance. - -.. note:: - The variable must be set before starting the ``dask-worker`` process. - -.. note:: - If using a :ref:`nanny`, the ``MALLOC_TRIM_THRESHOLD_`` environment variable - will automatically be set to ``65536`` for the worker process which the nanny is - monitoring. You can modify this behavior using the ``distributed.nanny.environ`` - configuration value. - -jemalloc -~~~~~~~~ -*Linux and MacOS workers* - -Alternatively to the above, you may experiment with the `jemalloc`_ memory allocator, as -follows: - -On Linux: - -.. code-block:: bash - - conda install jemalloc - LD_PRELOAD=$CONDA_PREFIX/lib/libjemalloc.so dask-worker <...> - -On macOS: - -.. code-block:: bash - - conda install jemalloc - DYLD_INSERT_LIBRARIES=$CONDA_PREFIX/lib/libjemalloc.dylib dask-worker <...> - -Alternatively on macOS, install globally with `homebrew`_: - -.. code-block:: bash - - brew install jemalloc - DYLD_INSERT_LIBRARIES=$(brew --prefix jemalloc)/lib/libjemalloc.dylib dask-worker <...> - -`jemalloc`_ offers a wealth of configuration settings; please refer to its -documentation. - -Ignore process memory -~~~~~~~~~~~~~~~~~~~~~ -If all else fails, you may want to stop Dask from using memory metrics from the OS (RSS) -in its decision-making: - -.. code-block:: yaml - - distributed: - worker: - memory: - rebalance: - measure: managed_in_memory - spill: false - pause: false - terminate: false - -This of course will be problematic if you have a genuine issue with unmanaged memory, -e.g. memory leaks and/or suffer from heavy fragmentation. - - .. _nanny: Nanny @@ -426,12 +167,3 @@ API Documentation .. autoclass:: distributed.worker.Worker :members: - - -.. _malloc: https://www.man7.org/linux/man-pages/man3/malloc.3.html -.. _free: https://www.man7.org/linux/man-pages/man3/free.3.html -.. _mallopt: https://man7.org/linux/man-pages/man3/mallopt.3.html -.. _malloc_trim: https://man7.org/linux/man-pages/man3/malloc_trim.3.html -.. _brk: https://www.man7.org/linux/man-pages/man2/brk.2.html -.. _jemalloc: http://jemalloc.net -.. _homebrew: https://brew.sh/ diff --git a/requirements.txt b/requirements.txt index 558b0570971..b6b7ed5d824 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ click >= 6.6 cloudpickle >= 1.5.0 -dask == 2022.02.1 +dask == 2022.03.0 jinja2 msgpack >= 0.6.0 packaging >= 20.0 @@ -11,4 +11,3 @@ toolz >= 0.8.2 tornado >= 6.0.3 zict >= 0.1.3 pyyaml -setuptools diff --git a/setup.cfg b/setup.cfg index aebd0a81dee..dd99eccfc7a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,7 +28,7 @@ skip_gitignore = true force_to_top = true default_section = THIRDPARTY known_first_party = distributed -known_distributed = dask +known_distributed = dask,zict [versioneer] VCS = git