Skip to content

Commit

Permalink
Sign every compute task with a unique counter to correlated responses
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Dec 6, 2022
1 parent 8c81d03 commit 79d9e65
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 7 deletions.
16 changes: 13 additions & 3 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,6 +1319,14 @@ class TaskState:
#: Task annotations
annotations: dict[str, Any]

#: A counter that counts how often a task was already assigned to a Worker.
#: This counter is used to sign a task such that the assigned Worker is
#: expected to return the same counter in the task-finished message. This is
#: used to correlate responses.
#: Only the most recently assigned worker is trusted. All other results
#: will be rejected
_attempt: int

#: Cached hash of :attr:`~TaskState.client_key`
_hash: int

Expand Down Expand Up @@ -1363,6 +1371,7 @@ def __init__(self, key: str, run_spec: object, state: TaskStateState):
self.metadata = {}
self.annotations = {}
self.erred_on = set()
self._attempt = 0
TaskState._instances.add(self)

def __hash__(self) -> int:
Expand Down Expand Up @@ -3297,9 +3306,10 @@ def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]:
# time to compute and submit this
if duration < 0:
duration = self.get_task_duration(ts)

ts._attempt += 1
msg: dict[str, Any] = {
"op": "compute-task",
"attempt": ts._attempt,
"key": ts.key,
"priority": ts.priority,
"duration": duration,
Expand Down Expand Up @@ -4623,7 +4633,7 @@ def update_graph(

# TODO: balance workers

def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwargs):
def stimulus_task_finished(self, key, worker, stimulus_id, attempt, **kwargs):
"""Mark that a task has finished execution on a particular worker"""
logger.debug("Stimulus task finished %s, %s", key, worker)

Expand All @@ -4633,7 +4643,7 @@ def stimulus_task_finished(self, key=None, worker=None, stimulus_id=None, **kwar

ws: WorkerState = self.workers[worker]
ts: TaskState = self.tasks.get(key)
if ts is None or ts.state in ("released", "queued"):
if ts is None or ts._attempt != attempt:
logger.debug(
"Received already computed task, worker: %s, state: %s"
", key: %s, who_has: %s",
Expand Down
3 changes: 3 additions & 0 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def test_computetask_to_dict():
function=b"blob",
args=b"blob",
kwargs=None,
attempt=5,
)
assert ev.run_spec == SerializedTask(function=b"blob", args=b"blob")
ev2 = ev.to_loggable(handled=11.22)
Expand All @@ -386,6 +387,7 @@ def test_computetask_to_dict():
"function": None,
"args": None,
"kwargs": None,
"attempt": 5,
}
ev3 = StateMachineEvent.from_dict(d)
assert isinstance(ev3, ComputeTaskEvent)
Expand All @@ -409,6 +411,7 @@ def test_computetask_dummy():
function=None,
args=None,
kwargs=None,
attempt=0,
)

# nbytes is generated from who_has if omitted
Expand Down
19 changes: 15 additions & 4 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ class TaskState:
#: the behaviour of transitions out of the ``executing``, ``flight`` etc. states.
done: bool = False

attempt: int = 0
_instances: ClassVar[weakref.WeakSet[TaskState]] = weakref.WeakSet()

# Support for weakrefs to a class with __slots__
Expand Down Expand Up @@ -452,6 +453,7 @@ class TaskFinishedMsg(SendMessageToScheduler):
metadata: dict
thread: int | None
startstops: list[StartStop]
attempt: int
__slots__ = tuple(__annotations__)

def to_dict(self) -> dict[str, Any]:
Expand Down Expand Up @@ -742,6 +744,8 @@ class ComputeTaskEvent(StateMachineEvent):
resource_restrictions: dict[str, float]
actor: bool
annotations: dict
attempt: int

__slots__ = tuple(__annotations__)

def __post_init__(self) -> None:
Expand Down Expand Up @@ -787,6 +791,7 @@ def dummy(
resource_restrictions: dict[str, float] | None = None,
actor: bool = False,
annotations: dict | None = None,
attempt: int = 0,
stimulus_id: str,
) -> ComputeTaskEvent:
"""Build a dummy event, with most attributes set to a reasonable default.
Expand All @@ -806,6 +811,7 @@ def dummy(
actor=actor,
annotations=annotations or {},
stimulus_id=stimulus_id,
attempt=attempt,
)


Expand Down Expand Up @@ -1748,7 +1754,7 @@ def _next_ready_task(self) -> TaskState | None:
return None

def _get_task_finished_msg(
self, ts: TaskState, stimulus_id: str
self, ts: TaskState, stimulus_id: str, attempt: int
) -> TaskFinishedMsg:
if ts.key not in self.data and ts.key not in self.actors:
raise RuntimeError(f"Task {ts} not ready")
Expand All @@ -1775,6 +1781,7 @@ def _get_task_finished_msg(
metadata=ts.metadata,
thread=self.threads.get(ts.key),
startstops=ts.startstops,
attempt=attempt,
stimulus_id=stimulus_id,
)

Expand Down Expand Up @@ -2471,7 +2478,9 @@ def _transition_to_memory(
smsg: Instruction = AddKeysMsg(keys=[ts.key], stimulus_id=stimulus_id)
else:
assert msg_type == "task-finished"
smsg = self._get_task_finished_msg(ts, stimulus_id=stimulus_id)
smsg = self._get_task_finished_msg(
ts, stimulus_id=stimulus_id, attempt=ts.attempt
)
return recs, [smsg]

def _transition_released_forgotten(
Expand Down Expand Up @@ -2844,7 +2853,7 @@ def _handle_compute_task(self, ev: ComputeTaskEvent) -> RecsInstrs:
except KeyError:
self.tasks[ev.key] = ts = TaskState(ev.key)
self.log.append((ev.key, "compute-task", ts.state, ev.stimulus_id, time()))

ts.attempt = ev.attempt
recommendations: Recs = {}
instructions: Instructions = []

Expand All @@ -2856,7 +2865,9 @@ def _handle_compute_task(self, ev: ComputeTaskEvent) -> RecsInstrs:
pass
elif ts.state == "memory":
instructions.append(
self._get_task_finished_msg(ts, stimulus_id=ev.stimulus_id)
self._get_task_finished_msg(
ts, stimulus_id=ev.stimulus_id, attempt=ev.attempt
)
)
elif ts.state == "error":
instructions.append(TaskErredMsg.from_task(ts, stimulus_id=ev.stimulus_id))
Expand Down

0 comments on commit 79d9e65

Please sign in to comment.