diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 1252b302525..667032f6e50 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -57,6 +57,17 @@ def test_TaskState_get_nbytes(): assert TaskState("y").get_nbytes() == 1024 +def test_TaskState_eq(): + """Test that TaskState objects are hashable and that two identical objects compare + as different. See comment in TaskState.__hash__ for why. + """ + a = TaskState("x") + b = TaskState("x") + assert a != b + s = {a, b} + assert len(s) == 2 + + def test_TaskState__to_dict(): """Tasks that are listed as dependencies or dependents of other tasks are dumped as a short repr and always appear in full directly under Worker.state.tasks. diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index b740b42e28d..7dd3431ef14 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -22,6 +22,7 @@ from copy import copy from dataclasses import dataclass, field from functools import lru_cache +from itertools import chain from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypedDict, cast from tlz import peekn, pluck @@ -274,15 +275,22 @@ def __post_init__(self) -> None: def __repr__(self) -> str: return f"" - def __eq__(self, other: object) -> bool: - # A task may be forgotten and a new TaskState object with the same key may be - # created in its place later on. In the Worker state, you should never have - # multiple TaskState objects with the same key. We can't assert it here however, - # as this comparison is also used in WeakSets for instance tracking purposes. - return other is self - def __hash__(self) -> int: - return hash(self.key) + """Override dataclass __hash__, reverting to the default behaviour + hash(o) == id(o). + + Note that we also defined @dataclass(eq=False), which reverts to the default + behaviour (a == b) == (a is b). + + On first thought, it would make sense to use TaskState.key for equality and + hashing. However, a task may be forgotten and a new TaskState object with the + same key may be created in its place later on. In the Worker state, you should + never have multiple TaskState objects with the same key; see + WorkerState.validate_state for relevant checks. We can't assert the same thing + in __eq__ though, as multiple objects with the same key may appear in + TaskState._instances for a brief period of time. + """ + return id(self) def get_nbytes(self) -> int: nbytes = self.nbytes @@ -3015,14 +3023,23 @@ def validate_state(self) -> None: for ts in self.data_needed: assert ts.state == "fetch", self.story(ts) - assert self.tasks[ts.key] is ts for worker, tss in self.data_needed_per_worker.items(): for ts in tss: assert ts.state == "fetch" - assert self.tasks[ts.key] is ts assert ts in self.data_needed assert worker in ts.who_has + # Test that there aren't multiple TaskState objects with the same key in any + # Set[TaskState]. See note in TaskState.__hash__. + for ts in chain( + self.data_needed, + *self.data_needed_per_worker.values(), + self.missing_dep_flight, + self.in_flight_tasks, + self.executing, + ): + assert self.tasks[ts.key] is ts + for ts in self.tasks.values(): self.validate_task(ts)