Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC Sign every compute task with a unique counter to correlated responses #7372

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,14 @@ class TaskState:
#: Task annotations
annotations: dict[str, Any]

#: A counter that counts how often a task was already assigned to a Worker.
#: This counter is used to sign a task such that the assigned Worker is
#: expected to return the same counter in the task-finished message. This is
#: used to correlate responses.
#: Only the most recently assigned worker is trusted. All other results
#: will be rejected
_attempt: int

#: Cached hash of :attr:`~TaskState.client_key`
_hash: int

Expand Down Expand Up @@ -1354,6 +1362,7 @@ def __init__(self, key: str, run_spec: object, state: TaskStateState):
self.metadata = {}
self.annotations = {}
self.erred_on = set()
self._attempt = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this means that if we re-create a task, e.g. by re-running a workload after it has been forgotten, its attempt count would be reset to 0. This would mean that the attempt counter is not unique. We should use a global counter instead that we also actively increment. transition_counter would be one possible candidate if we make sure to increment it every time we update _attempt.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure but I'd use a different name for this then since it's no longer an attempt

TaskState._instances.add(self)

def __hash__(self) -> int:
Expand Down Expand Up @@ -3257,9 +3266,10 @@ def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]:
# time to compute and submit this
if duration < 0:
duration = self.get_task_duration(ts)

ts._attempt += 1
msg: dict[str, Any] = {
"op": "compute-task",
"attempt": ts._attempt,
"key": ts.key,
"priority": ts.priority,
"duration": duration,
Expand Down Expand Up @@ -4620,7 +4630,7 @@ def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None:

self.transitions(recommendations, stimulus_id)

def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwargs):
def stimulus_task_finished(self, key, worker, stimulus_id, attempt, **kwargs):
"""Mark that a task has finished execution on a particular worker"""
logger.debug("Stimulus task finished %s, %s", key, worker)

Expand All @@ -4630,7 +4640,7 @@ def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwar

ws: WorkerState = self.workers[worker]
ts: TaskState = self.tasks.get(key)
if ts is None or ts.state in ("released", "queued"):
if ts is None or ts._attempt != attempt:
logger.debug(
"Received already computed task, worker: %s, state: %s"
", key: %s, who_has: %s",
Expand All @@ -4644,6 +4654,7 @@ def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwar
"op": "free-keys",
"keys": [key],
"stimulus_id": stimulus_id,
"attempts": [attempt],
}
]
elif ts.state == "memory":
Expand Down
3 changes: 3 additions & 0 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def test_computetask_to_dict():
function=b"blob",
args=b"blob",
kwargs=None,
attempt=5,
)
assert ev.run_spec == SerializedTask(function=b"blob", args=b"blob")
ev2 = ev.to_loggable(handled=11.22)
Expand All @@ -386,6 +387,7 @@ def test_computetask_to_dict():
"function": None,
"args": None,
"kwargs": None,
"attempt": 5,
}
ev3 = StateMachineEvent.from_dict(d)
assert isinstance(ev3, ComputeTaskEvent)
Expand All @@ -409,6 +411,7 @@ def test_computetask_dummy():
function=None,
args=None,
kwargs=None,
attempt=0,
)

# nbytes is generated from who_has if omitted
Expand Down
4 changes: 4 additions & 0 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
ExecuteFailureEvent,
ExecuteSuccessEvent,
FindMissingEvent,
FreeKeyByAttemptEvent,
FreeKeysEvent,
GatherDepBusyEvent,
GatherDepFailureEvent,
Expand Down Expand Up @@ -735,6 +736,7 @@ def __init__(
"acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent),
"compute-task": self._handle_remote_stimulus(ComputeTaskEvent),
"free-keys": self._handle_remote_stimulus(FreeKeysEvent),
"free_keys_attempt": self._handle_remote_stimulus(FreeKeyByAttemptEvent),
"remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent),
"steal-request": self._handle_remote_stimulus(StealRequestEvent),
"refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent),
Expand Down Expand Up @@ -2240,6 +2242,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent:

# The key *must* be in the worker state thanks to the cancelled state
ts = self.state.tasks[key]
execution_attempt = ts.attempt

try:
function, args, kwargs = await self._maybe_deserialize_task(ts)
Expand Down Expand Up @@ -2316,6 +2319,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent:
stop=result["stop"],
nbytes=result["nbytes"],
type=result["type"],
attempt=execution_attempt,
stimulus_id=f"task-finished-{time()}",
)

Expand Down
77 changes: 63 additions & 14 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,16 @@
from dataclasses import dataclass, field
from functools import lru_cache, partial, singledispatchmethod
from itertools import chain
from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypedDict, cast
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Literal,
NamedTuple,
Sequence,
TypedDict,
cast,
)

from tlz import peekn

Expand Down Expand Up @@ -285,6 +294,7 @@ class TaskState:
#: the behaviour of transitions out of the ``executing``, ``flight`` etc. states.
done: bool = False

attempt: int = 0
_instances: ClassVar[weakref.WeakSet[TaskState]] = weakref.WeakSet()

# Support for weakrefs to a class with __slots__
Expand Down Expand Up @@ -459,6 +469,7 @@ class TaskFinishedMsg(SendMessageToScheduler):
metadata: dict
thread: int | None
startstops: list[StartStop]
attempt: int
__slots__ = tuple(__annotations__)

def to_dict(self) -> dict[str, Any]:
Expand Down Expand Up @@ -749,6 +760,8 @@ class ComputeTaskEvent(StateMachineEvent):
resource_restrictions: dict[str, float]
actor: bool
annotations: dict
attempt: int

__slots__ = tuple(__annotations__)

def __post_init__(self) -> None:
Expand Down Expand Up @@ -794,6 +807,7 @@ def dummy(
resource_restrictions: dict[str, float] | None = None,
actor: bool = False,
annotations: dict | None = None,
attempt: int = 0,
stimulus_id: str,
) -> ComputeTaskEvent:
"""Build a dummy event, with most attributes set to a reasonable default.
Expand All @@ -813,6 +827,7 @@ def dummy(
actor=actor,
annotations=annotations or {},
stimulus_id=stimulus_id,
attempt=attempt,
)


Expand All @@ -832,6 +847,7 @@ class ExecuteSuccessEvent(ExecuteDoneEvent):
start: float
stop: float
nbytes: int
attempt: int
type: type | None
__slots__ = tuple(__annotations__)

Expand All @@ -856,6 +872,7 @@ def _after_from_dict(self) -> None:
def dummy(
key: str,
value: object = None,
attempt: int | None = None,
*,
nbytes: int = 1,
stimulus_id: str,
Expand All @@ -869,6 +886,7 @@ def dummy(
start=0.0,
stop=1.0,
nbytes=nbytes,
attempt=1,
type=None,
stimulus_id=stimulus_id,
)
Expand Down Expand Up @@ -989,7 +1007,12 @@ class RemoveReplicasEvent(StateMachineEvent):
@dataclass
class FreeKeysEvent(StateMachineEvent):
__slots__ = ("keys",)
keys: Collection[str]
keys: Sequence[str]


@dataclass
class FreeKeyByAttemptEvent(FreeKeysEvent):
attempts: Sequence[int]


@dataclass
Expand Down Expand Up @@ -1759,7 +1782,7 @@ def _next_ready_task(self) -> TaskState | None:
return None

def _get_task_finished_msg(
self, ts: TaskState, stimulus_id: str
self, ts: TaskState, stimulus_id: str, attempt: int
) -> TaskFinishedMsg:
if ts.key not in self.data and ts.key not in self.actors:
raise RuntimeError(f"Task {ts} not ready")
Expand All @@ -1786,6 +1809,7 @@ def _get_task_finished_msg(
metadata=ts.metadata,
thread=self.threads.get(ts.key),
startstops=ts.startstops,
attempt=attempt,
stimulus_id=stimulus_id,
)

Expand Down Expand Up @@ -2422,21 +2446,21 @@ def _transition_cancelled_or_resumed_long_running(
return self._ensure_computing()

def _transition_executing_memory(
self, ts: TaskState, value: object, *, stimulus_id: str
self, ts: TaskState, value: object, attempt: int, *, stimulus_id: str
) -> RecsInstrs:
"""This transition is *normally* triggered by ExecuteSuccessEvent.
However, beware that it can also be triggered by scatter().
"""
return self._transition_to_memory(
ts, value, "task-finished", stimulus_id=stimulus_id
ts, value, "task-finished", attempt=attempt, stimulus_id=stimulus_id
)

def _transition_released_memory(
self, ts: TaskState, value: object, *, stimulus_id: str
) -> RecsInstrs:
"""This transition is triggered by scatter()"""
return self._transition_to_memory(
ts, value, "add-keys", stimulus_id=stimulus_id
ts, value, "add-keys", attempt=-1, stimulus_id=stimulus_id
)

def _transition_flight_memory(
Expand All @@ -2446,11 +2470,11 @@ def _transition_flight_memory(
However, beware that it can also be triggered by scatter().
"""
return self._transition_to_memory(
ts, value, "add-keys", stimulus_id=stimulus_id
ts, value, "add-keys", attempt=-1, stimulus_id=stimulus_id
)

def _transition_resumed_memory(
self, ts: TaskState, value: object, *, stimulus_id: str
self, ts: TaskState, value: object, attempt: int, *, stimulus_id: str
) -> RecsInstrs:
"""Normally, we send to the scheduler a 'task-finished' message for a completed
execution and 'add-data' for a completed replication from another worker. The
Expand All @@ -2472,13 +2496,16 @@ def _transition_resumed_memory(

ts.previous = None
ts.next = None
return self._transition_to_memory(ts, value, msg_type, stimulus_id=stimulus_id)
return self._transition_to_memory(
ts, value, msg_type, attempt=attempt, stimulus_id=stimulus_id
)

def _transition_to_memory(
self,
ts: TaskState,
value: object,
msg_type: Literal["add-keys", "task-finished"],
attempt: int,
*,
stimulus_id: str,
) -> RecsInstrs:
Expand All @@ -2495,7 +2522,11 @@ def _transition_to_memory(
instrs.append(AddKeysMsg(keys=[ts.key], stimulus_id=stimulus_id))
else:
assert msg_type == "task-finished"
instrs.append(self._get_task_finished_msg(ts, stimulus_id=stimulus_id))
instrs.append(
self._get_task_finished_msg(
ts, stimulus_id=stimulus_id, attempt=attempt
)
)
return recs, instrs

def _transition_released_forgotten(
Expand Down Expand Up @@ -2790,6 +2821,22 @@ def _handle_free_keys(self, ev: FreeKeysEvent) -> RecsInstrs:
recommendations[ts] = "released"
return recommendations, []

@_handle_event.register
def _handle_free_keys_attempt(self, ev: FreeKeyByAttemptEvent) -> RecsInstrs:
"""Handler to be called by the scheduler.

Similar to _handle_free_keys but will only act if the provided attempt counter matches the known one
"""
self.log.append(
("free-keys-by-attempt", ev.keys, ev.attempts, ev.stimulus_id, time())
)
recommendations: Recs = {}
for key, attempt in zip(ev.keys, ev.attempts):
ts = self.tasks.get(key)
if ts and ts.attempt == attempt:
recommendations[ts] = "released"
return recommendations, []

@_handle_event.register
def _handle_remove_replicas(self, ev: RemoveReplicasEvent) -> RecsInstrs:
"""Stream handler notifying the worker that it might be holding unreferenced,
Expand Down Expand Up @@ -2869,7 +2916,7 @@ def _handle_compute_task(self, ev: ComputeTaskEvent) -> RecsInstrs:
except KeyError:
self.tasks[ev.key] = ts = TaskState(ev.key)
self.log.append((ev.key, "compute-task", ts.state, ev.stimulus_id, time()))

ts.attempt = ev.attempt
recommendations: Recs = {}
instructions: Instructions = []

Expand All @@ -2881,7 +2928,9 @@ def _handle_compute_task(self, ev: ComputeTaskEvent) -> RecsInstrs:
pass
elif ts.state == "memory":
instructions.append(
self._get_task_finished_msg(ts, stimulus_id=ev.stimulus_id)
self._get_task_finished_msg(
ts, stimulus_id=ev.stimulus_id, attempt=ev.attempt
)
)
elif ts.state == "error":
instructions.append(TaskErredMsg.from_task(ts, stimulus_id=ev.stimulus_id))
Expand Down Expand Up @@ -2977,7 +3026,7 @@ def _handle_gather_dep_success(self, ev: GatherDepSuccessEvent) -> RecsInstrs:
recommendations: Recs = {}
for ts in self._gather_dep_done_common(ev):
if ts.key in ev.data:
recommendations[ts] = ("memory", ev.data[ts.key])
recommendations[ts] = ("memory", ev.data[ts.key], ts.attempt)
else:
self.log.append((ts.key, "missing-dep", ev.stimulus_id, time()))
if self.validate:
Expand Down Expand Up @@ -3172,7 +3221,7 @@ def _handle_execute_success(self, ev: ExecuteSuccessEvent) -> RecsInstrs:
)
ts.nbytes = ev.nbytes
ts.type = ev.type
recs[ts] = ("memory", ev.value)
recs[ts] = ("memory", ev.value, ev.attempt)
return recs, instr

@_handle_event.register
Expand Down