From a6d9cc76b1c8e09d167f5798d14b08b01441c1fc Mon Sep 17 00:00:00 2001 From: fjetter Date: Tue, 6 Dec 2022 15:24:08 +0100 Subject: [PATCH 1/2] Sign every compute task with a unique counter to correlated responses --- distributed/scheduler.py | 16 +++++++++++--- .../tests/test_worker_state_machine.py | 3 +++ distributed/worker_state_machine.py | 21 +++++++++++++++---- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 43d9641928..dec0b279e8 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1310,6 +1310,14 @@ class TaskState: #: Task annotations annotations: dict[str, Any] + #: A counter that counts how often a task was already assigned to a Worker. + #: This counter is used to sign a task such that the assigned Worker is + #: expected to return the same counter in the task-finished message. This is + #: used to correlate responses. + #: Only the most recently assigned worker is trusted. All other results + #: will be rejected + _attempt: int + #: Cached hash of :attr:`~TaskState.client_key` _hash: int @@ -1354,6 +1362,7 @@ def __init__(self, key: str, run_spec: object, state: TaskStateState): self.metadata = {} self.annotations = {} self.erred_on = set() + self._attempt = 0 TaskState._instances.add(self) def __hash__(self) -> int: @@ -3257,9 +3266,10 @@ def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]: # time to compute and submit this if duration < 0: duration = self.get_task_duration(ts) - + ts._attempt += 1 msg: dict[str, Any] = { "op": "compute-task", + "attempt": ts._attempt, "key": ts.key, "priority": ts.priority, "duration": duration, @@ -4620,7 +4630,7 @@ def stimulus_queue_slots_maybe_opened(self, *, stimulus_id: str) -> None: self.transitions(recommendations, stimulus_id) - def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwargs): + def stimulus_task_finished(self, key, worker, stimulus_id, attempt, **kwargs): """Mark that a task has finished execution on a particular worker""" logger.debug("Stimulus task finished %s, %s", key, worker) @@ -4630,7 +4640,7 @@ def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwar ws: WorkerState = self.workers[worker] ts: TaskState = self.tasks.get(key) - if ts is None or ts.state in ("released", "queued"): + if ts is None or ts._attempt != attempt: logger.debug( "Received already computed task, worker: %s, state: %s" ", key: %s, who_has: %s", diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 6436a6ddc4..92e3f510c8 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -363,6 +363,7 @@ def test_computetask_to_dict(): function=b"blob", args=b"blob", kwargs=None, + attempt=5, ) assert ev.run_spec == SerializedTask(function=b"blob", args=b"blob") ev2 = ev.to_loggable(handled=11.22) @@ -386,6 +387,7 @@ def test_computetask_to_dict(): "function": None, "args": None, "kwargs": None, + "attempt": 5, } ev3 = StateMachineEvent.from_dict(d) assert isinstance(ev3, ComputeTaskEvent) @@ -409,6 +411,7 @@ def test_computetask_dummy(): function=None, args=None, kwargs=None, + attempt=0, ) # nbytes is generated from who_has if omitted diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 8a6237f429..1feb68dbde 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -285,6 +285,7 @@ class TaskState: #: the behaviour of transitions out of the ``executing``, ``flight`` etc. states. done: bool = False + attempt: int = 0 _instances: ClassVar[weakref.WeakSet[TaskState]] = weakref.WeakSet() # Support for weakrefs to a class with __slots__ @@ -459,6 +460,7 @@ class TaskFinishedMsg(SendMessageToScheduler): metadata: dict thread: int | None startstops: list[StartStop] + attempt: int __slots__ = tuple(__annotations__) def to_dict(self) -> dict[str, Any]: @@ -749,6 +751,8 @@ class ComputeTaskEvent(StateMachineEvent): resource_restrictions: dict[str, float] actor: bool annotations: dict + attempt: int + __slots__ = tuple(__annotations__) def __post_init__(self) -> None: @@ -794,6 +798,7 @@ def dummy( resource_restrictions: dict[str, float] | None = None, actor: bool = False, annotations: dict | None = None, + attempt: int = 0, stimulus_id: str, ) -> ComputeTaskEvent: """Build a dummy event, with most attributes set to a reasonable default. @@ -813,6 +818,7 @@ def dummy( actor=actor, annotations=annotations or {}, stimulus_id=stimulus_id, + attempt=attempt, ) @@ -1759,7 +1765,7 @@ def _next_ready_task(self) -> TaskState | None: return None def _get_task_finished_msg( - self, ts: TaskState, stimulus_id: str + self, ts: TaskState, stimulus_id: str, attempt: int ) -> TaskFinishedMsg: if ts.key not in self.data and ts.key not in self.actors: raise RuntimeError(f"Task {ts} not ready") @@ -1786,6 +1792,7 @@ def _get_task_finished_msg( metadata=ts.metadata, thread=self.threads.get(ts.key), startstops=ts.startstops, + attempt=attempt, stimulus_id=stimulus_id, ) @@ -2495,7 +2502,11 @@ def _transition_to_memory( instrs.append(AddKeysMsg(keys=[ts.key], stimulus_id=stimulus_id)) else: assert msg_type == "task-finished" - instrs.append(self._get_task_finished_msg(ts, stimulus_id=stimulus_id)) + instrs.append( + self._get_task_finished_msg( + ts, stimulus_id=stimulus_id, attempt=ts.attempt + ) + ) return recs, instrs def _transition_released_forgotten( @@ -2869,7 +2880,7 @@ def _handle_compute_task(self, ev: ComputeTaskEvent) -> RecsInstrs: except KeyError: self.tasks[ev.key] = ts = TaskState(ev.key) self.log.append((ev.key, "compute-task", ts.state, ev.stimulus_id, time())) - + ts.attempt = ev.attempt recommendations: Recs = {} instructions: Instructions = [] @@ -2881,7 +2892,9 @@ def _handle_compute_task(self, ev: ComputeTaskEvent) -> RecsInstrs: pass elif ts.state == "memory": instructions.append( - self._get_task_finished_msg(ts, stimulus_id=ev.stimulus_id) + self._get_task_finished_msg( + ts, stimulus_id=ev.stimulus_id, attempt=ev.attempt + ) ) elif ts.state == "error": instructions.append(TaskErredMsg.from_task(ts, stimulus_id=ev.stimulus_id)) From 4ef89754933e8106b9ab4d1f0ae1c9a97a0e2896 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 4 Jan 2023 16:37:28 +0100 Subject: [PATCH 2/2] Fix attempt before dispatching to threadpool --- distributed/scheduler.py | 1 + distributed/worker.py | 4 ++ distributed/worker_state_machine.py | 58 +++++++++++++++++++++++------ 3 files changed, 52 insertions(+), 11 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index dec0b279e8..fd150b3fe4 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4654,6 +4654,7 @@ def stimulus_task_finished(self, key, worker, stimulus_id, attempt, **kwargs): "op": "free-keys", "keys": [key], "stimulus_id": stimulus_id, + "attempts": [attempt], } ] elif ts.state == "memory": diff --git a/distributed/worker.py b/distributed/worker.py index b159e42c4f..863e4a8bf6 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -114,6 +114,7 @@ ExecuteFailureEvent, ExecuteSuccessEvent, FindMissingEvent, + FreeKeyByAttemptEvent, FreeKeysEvent, GatherDepBusyEvent, GatherDepFailureEvent, @@ -735,6 +736,7 @@ def __init__( "acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent), "compute-task": self._handle_remote_stimulus(ComputeTaskEvent), "free-keys": self._handle_remote_stimulus(FreeKeysEvent), + "free_keys_attempt": self._handle_remote_stimulus(FreeKeyByAttemptEvent), "remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent), "steal-request": self._handle_remote_stimulus(StealRequestEvent), "refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent), @@ -2240,6 +2242,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent: # The key *must* be in the worker state thanks to the cancelled state ts = self.state.tasks[key] + execution_attempt = ts.attempt try: function, args, kwargs = await self._maybe_deserialize_task(ts) @@ -2316,6 +2319,7 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent: stop=result["stop"], nbytes=result["nbytes"], type=result["type"], + attempt=execution_attempt, stimulus_id=f"task-finished-{time()}", ) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 1feb68dbde..14d6b4b7bf 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -23,7 +23,16 @@ from dataclasses import dataclass, field from functools import lru_cache, partial, singledispatchmethod from itertools import chain -from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypedDict, cast +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Literal, + NamedTuple, + Sequence, + TypedDict, + cast, +) from tlz import peekn @@ -838,6 +847,7 @@ class ExecuteSuccessEvent(ExecuteDoneEvent): start: float stop: float nbytes: int + attempt: int type: type | None __slots__ = tuple(__annotations__) @@ -862,6 +872,7 @@ def _after_from_dict(self) -> None: def dummy( key: str, value: object = None, + attempt: int | None = None, *, nbytes: int = 1, stimulus_id: str, @@ -875,6 +886,7 @@ def dummy( start=0.0, stop=1.0, nbytes=nbytes, + attempt=1, type=None, stimulus_id=stimulus_id, ) @@ -995,7 +1007,12 @@ class RemoveReplicasEvent(StateMachineEvent): @dataclass class FreeKeysEvent(StateMachineEvent): __slots__ = ("keys",) - keys: Collection[str] + keys: Sequence[str] + + +@dataclass +class FreeKeyByAttemptEvent(FreeKeysEvent): + attempts: Sequence[int] @dataclass @@ -2429,13 +2446,13 @@ def _transition_cancelled_or_resumed_long_running( return self._ensure_computing() def _transition_executing_memory( - self, ts: TaskState, value: object, *, stimulus_id: str + self, ts: TaskState, value: object, attempt: int, *, stimulus_id: str ) -> RecsInstrs: """This transition is *normally* triggered by ExecuteSuccessEvent. However, beware that it can also be triggered by scatter(). """ return self._transition_to_memory( - ts, value, "task-finished", stimulus_id=stimulus_id + ts, value, "task-finished", attempt=attempt, stimulus_id=stimulus_id ) def _transition_released_memory( @@ -2443,7 +2460,7 @@ def _transition_released_memory( ) -> RecsInstrs: """This transition is triggered by scatter()""" return self._transition_to_memory( - ts, value, "add-keys", stimulus_id=stimulus_id + ts, value, "add-keys", attempt=-1, stimulus_id=stimulus_id ) def _transition_flight_memory( @@ -2453,11 +2470,11 @@ def _transition_flight_memory( However, beware that it can also be triggered by scatter(). """ return self._transition_to_memory( - ts, value, "add-keys", stimulus_id=stimulus_id + ts, value, "add-keys", attempt=-1, stimulus_id=stimulus_id ) def _transition_resumed_memory( - self, ts: TaskState, value: object, *, stimulus_id: str + self, ts: TaskState, value: object, attempt: int, *, stimulus_id: str ) -> RecsInstrs: """Normally, we send to the scheduler a 'task-finished' message for a completed execution and 'add-data' for a completed replication from another worker. The @@ -2479,13 +2496,16 @@ def _transition_resumed_memory( ts.previous = None ts.next = None - return self._transition_to_memory(ts, value, msg_type, stimulus_id=stimulus_id) + return self._transition_to_memory( + ts, value, msg_type, attempt=attempt, stimulus_id=stimulus_id + ) def _transition_to_memory( self, ts: TaskState, value: object, msg_type: Literal["add-keys", "task-finished"], + attempt: int, *, stimulus_id: str, ) -> RecsInstrs: @@ -2504,7 +2524,7 @@ def _transition_to_memory( assert msg_type == "task-finished" instrs.append( self._get_task_finished_msg( - ts, stimulus_id=stimulus_id, attempt=ts.attempt + ts, stimulus_id=stimulus_id, attempt=attempt ) ) return recs, instrs @@ -2801,6 +2821,22 @@ def _handle_free_keys(self, ev: FreeKeysEvent) -> RecsInstrs: recommendations[ts] = "released" return recommendations, [] + @_handle_event.register + def _handle_free_keys_attempt(self, ev: FreeKeyByAttemptEvent) -> RecsInstrs: + """Handler to be called by the scheduler. + + Similar to _handle_free_keys but will only act if the provided attempt counter matches the known one + """ + self.log.append( + ("free-keys-by-attempt", ev.keys, ev.attempts, ev.stimulus_id, time()) + ) + recommendations: Recs = {} + for key, attempt in zip(ev.keys, ev.attempts): + ts = self.tasks.get(key) + if ts and ts.attempt == attempt: + recommendations[ts] = "released" + return recommendations, [] + @_handle_event.register def _handle_remove_replicas(self, ev: RemoveReplicasEvent) -> RecsInstrs: """Stream handler notifying the worker that it might be holding unreferenced, @@ -2990,7 +3026,7 @@ def _handle_gather_dep_success(self, ev: GatherDepSuccessEvent) -> RecsInstrs: recommendations: Recs = {} for ts in self._gather_dep_done_common(ev): if ts.key in ev.data: - recommendations[ts] = ("memory", ev.data[ts.key]) + recommendations[ts] = ("memory", ev.data[ts.key], ts.attempt) else: self.log.append((ts.key, "missing-dep", ev.stimulus_id, time())) if self.validate: @@ -3185,7 +3221,7 @@ def _handle_execute_success(self, ev: ExecuteSuccessEvent) -> RecsInstrs: ) ts.nbytes = ev.nbytes ts.type = ev.type - recs[ts] = ("memory", ev.value) + recs[ts] = ("memory", ev.value, ev.attempt) return recs, instr @_handle_event.register