Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Jun 15, 2022
1 parent f8a0e9b commit ea0281d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 45 deletions.
68 changes: 23 additions & 45 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

import asyncio
import subprocess
import sys
import gc
from collections.abc import Iterator

import pytest
from tlz import first

from distributed import Worker, wait
import distributed.profile as profile
from distributed import Nanny, Worker, wait
from distributed.protocol.serialize import Serialize
from distributed.utils import recursive_to_dict
from distributed.utils_test import (
Expand Down Expand Up @@ -48,11 +48,12 @@ async def wait_for_state(key: str, state: TaskStateState, dask_worker: Worker) -

@clean()
def test_task_state_tracking():
x = TaskState("x")
assert len(TaskState._instances) == 1
assert first(TaskState._instances) == x
with clean():
x = TaskState("x")
assert len(TaskState._instances) == 1
assert first(TaskState._instances) == x

del x
del x
assert len(TaskState._instances) == 0


Expand Down Expand Up @@ -689,46 +690,23 @@ async def test_missing_to_waiting(c, s, w1, w2, w3):
await f1


client_script = """
from dask.distributed import Client
from dask.distributed.worker_state_machine import TaskState
def inc(x):
return x + 1
if __name__ == "__main__":
with Client(processes=%s, n_workers=1) as client:
futs = client.map(inc, range(10))
red = client.submit(sum, futs)
f1 = client.submit(inc, red, pure=False)
f2 = client.submit(inc, red, pure=False)
f2.result()
del futs, red, f1, f2
def check():
assert not TaskState._instances, len(TaskState._instances)
@gen_cluster(client=True, Worker=Nanny)
async def test_task_state_instance_are_garbage_collected(c, s, a, b):
futs = c.map(inc, range(10))
red = c.submit(sum, futs)
f1 = c.submit(inc, red, pure=False)
f2 = c.submit(inc, red, pure=False)

client.run(check)
"""


@pytest.mark.parametrize("processes", [True, False])
def test_task_state_instance_are_garbage_collected(processes, tmp_path):
with open(tmp_path / "script.py", mode="w") as f:
f.write(client_script % processes)

proc = subprocess.Popen(
[sys.executable, tmp_path / "script.py"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)

out, err = proc.communicate()
async def check(dask_worker):
while dask_worker.tasks:
await asyncio.sleep(0.01)
with profile.lock:
gc.collect()
assert not TaskState._instances

assert not out
assert not err
await c.gather([f2, f1])
del futs, red, f1, f2
await c.run(check)


@gen_cluster(client=True, nthreads=[("", 1)] * 3)
Expand Down
5 changes: 5 additions & 0 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3002,6 +3002,11 @@ def validate_state(self) -> None:
if self.transition_counter_max:
assert self.transition_counter < self.transition_counter_max

# Test that there aren't multiple TaskState objects with the same key in data_needed
assert len({ts.key for ts in self.data_needed}) == len(self.data_needed)
for tss in self.data_needed_per_worker.values():
assert len({ts.key for ts in tss}) == len(tss)


class BaseWorker(abc.ABC):
"""Wrapper around the :class:`WorkerState` that implements instructions handling.
Expand Down

0 comments on commit ea0281d

Please sign in to comment.