diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 98223e87b2..c232e2422b 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -1,14 +1,14 @@ from __future__ import annotations import asyncio -import subprocess -import sys +import gc from collections.abc import Iterator import pytest from tlz import first -from distributed import Worker, wait +import distributed.profile as profile +from distributed import Nanny, Worker, wait from distributed.protocol.serialize import Serialize from distributed.utils import recursive_to_dict from distributed.utils_test import ( @@ -48,11 +48,12 @@ async def wait_for_state(key: str, state: TaskStateState, dask_worker: Worker) - @clean() def test_task_state_tracking(): - x = TaskState("x") - assert len(TaskState._instances) == 1 - assert first(TaskState._instances) == x + with clean(): + x = TaskState("x") + assert len(TaskState._instances) == 1 + assert first(TaskState._instances) == x - del x + del x assert len(TaskState._instances) == 0 @@ -689,46 +690,23 @@ async def test_missing_to_waiting(c, s, w1, w2, w3): await f1 -client_script = """ -from dask.distributed import Client -from dask.distributed.worker_state_machine import TaskState - - -def inc(x): - return x + 1 - - -if __name__ == "__main__": - with Client(processes=%s, n_workers=1) as client: - futs = client.map(inc, range(10)) - red = client.submit(sum, futs) - f1 = client.submit(inc, red, pure=False) - f2 = client.submit(inc, red, pure=False) - f2.result() - del futs, red, f1, f2 - - def check(): - assert not TaskState._instances, len(TaskState._instances) +@gen_cluster(client=True, Worker=Nanny) +async def test_task_state_instance_are_garbage_collected(c, s, a, b): + futs = c.map(inc, range(10)) + red = c.submit(sum, futs) + f1 = c.submit(inc, red, pure=False) + f2 = c.submit(inc, red, pure=False) - client.run(check) -""" - - -@pytest.mark.parametrize("processes", [True, False]) -def test_task_state_instance_are_garbage_collected(processes, tmp_path): - with open(tmp_path / "script.py", mode="w") as f: - f.write(client_script % processes) - - proc = subprocess.Popen( - [sys.executable, tmp_path / "script.py"], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - - out, err = proc.communicate() + async def check(dask_worker): + while dask_worker.tasks: + await asyncio.sleep(0.01) + with profile.lock: + gc.collect() + assert not TaskState._instances - assert not out - assert not err + await c.gather([f2, f1]) + del futs, red, f1, f2 + await c.run(check) @gen_cluster(client=True, nthreads=[("", 1)] * 3) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index c76c5ef6bd..c4ee6d4c1f 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -3002,6 +3002,11 @@ def validate_state(self) -> None: if self.transition_counter_max: assert self.transition_counter < self.transition_counter_max + # Test that there aren't multiple TaskState objects with the same key in data_needed + assert len({ts.key for ts in self.data_needed}) == len(self.data_needed) + for tss in self.data_needed_per_worker.values(): + assert len({ts.key for ts in tss}) == len(tss) + class BaseWorker(abc.ABC): """Wrapper around the :class:`WorkerState` that implements instructions handling.