Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Harden vs. TaskState collisions #6593

Merged
merged 2 commits into from
Jun 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ def test_TaskState_get_nbytes():
assert TaskState("y").get_nbytes() == 1024


def test_TaskState_eq():
"""Test that TaskState objects are hashable and that two identical objects compare
as different. See comment in TaskState.__hash__ for why.
"""
a = TaskState("x")
b = TaskState("x")
assert a != b
s = {a, b}
assert len(s) == 2


def test_TaskState__to_dict():
"""Tasks that are listed as dependencies or dependents of other tasks are dumped as
a short repr and always appear in full directly under Worker.state.tasks.
Expand Down
37 changes: 27 additions & 10 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from copy import copy
from dataclasses import dataclass, field
from functools import lru_cache
from itertools import chain
from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypedDict, cast

from tlz import peekn, pluck
Expand Down Expand Up @@ -274,15 +275,22 @@ def __post_init__(self) -> None:
def __repr__(self) -> str:
return f"<TaskState {self.key!r} {self.state}>"

def __eq__(self, other: object) -> bool:
# A task may be forgotten and a new TaskState object with the same key may be
# created in its place later on. In the Worker state, you should never have
# multiple TaskState objects with the same key. We can't assert it here however,
# as this comparison is also used in WeakSets for instance tracking purposes.
return other is self

def __hash__(self) -> int:
return hash(self.key)
"""Override dataclass __hash__, reverting to the default behaviour
hash(o) == id(o).

Note that we also defined @dataclass(eq=False), which reverts to the default
behaviour (a == b) == (a is b).

On first thought, it would make sense to use TaskState.key for equality and
hashing. However, a task may be forgotten and a new TaskState object with the
same key may be created in its place later on. In the Worker state, you should
never have multiple TaskState objects with the same key; see
WorkerState.validate_state for relevant checks. We can't assert the same thing
in __eq__ though, as multiple objects with the same key may appear in
TaskState._instances for a brief period of time.
"""
return id(self)

def get_nbytes(self) -> int:
nbytes = self.nbytes
Expand Down Expand Up @@ -3015,14 +3023,23 @@ def validate_state(self) -> None:

for ts in self.data_needed:
assert ts.state == "fetch", self.story(ts)
assert self.tasks[ts.key] is ts
for worker, tss in self.data_needed_per_worker.items():
for ts in tss:
assert ts.state == "fetch"
assert self.tasks[ts.key] is ts
assert ts in self.data_needed
assert worker in ts.who_has

# Test that there aren't multiple TaskState objects with the same key in any
# Set[TaskState]. See note in TaskState.__hash__.
for ts in chain(
self.data_needed,
*self.data_needed_per_worker.values(),
self.missing_dep_flight,
self.in_flight_tasks,
self.executing,
):
assert self.tasks[ts.key] is ts

for ts in self.tasks.values():
self.validate_task(ts)

Expand Down