diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index beca9d40734..f30c4440ff8 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -106,14 +106,16 @@ def test_WorkerState__to_dict(): "busy_workers": [], "constrained": [], "data": {"y": None}, - "data_needed": ["x"], - "data_needed_per_worker": {"127.0.0.1:1235": ["x"]}, + "data_needed": [], + "data_needed_per_worker": {"127.0.0.1:1235": []}, "executing": [], - "in_flight_tasks": [], - "in_flight_workers": {}, + "in_flight_tasks": ["x"], + "in_flight_workers": {"127.0.0.1:1235": ["x"]}, "log": [ ["x", "ensure-task-exists", "released", "s1"], ["x", "released", "fetch", "fetch", {}, "s1"], + ["gather-dependencies", "127.0.0.1:1235", ["x"], "s1"], + ["x", "fetch", "flight", "flight", {}, "s1"], ["y", "put-in-memory", "s2"], ["y", "receive-from-scatter", "s2"], ], @@ -137,10 +139,11 @@ def test_WorkerState__to_dict(): ], "tasks": { "x": { + "coming_from": "127.0.0.1:1235", "key": "x", "nbytes": 123, "priority": [1], - "state": "fetch", + "state": "flight", "who_has": ["127.0.0.1:1235"], }, "y": { @@ -149,7 +152,7 @@ def test_WorkerState__to_dict(): "state": "memory", }, }, - "transition_counter": 1, + "transition_counter": 2, } assert actual == expect @@ -819,3 +822,20 @@ async def test_deprecated_worker_attributes(s, a, b): ) with pytest.warns(FutureWarning, match=msg): assert a.in_flight_tasks == 0 + + +@pytest.mark.parametrize("nbytes,n_in_flight", [(1, 3), (2**30, 1)]) +def test_cluster_gather_deps(nbytes, n_in_flight): + ws = WorkerState(address="127.0.0.1:1234", transition_counter_max=10) + ws.handle_stimulus( + AcquireReplicasEvent( + who_has={ + "x1": ["127.0.0.1:1235"], + "x2": ["127.0.0.1:1235"], + "x3": ["127.0.0.1:1235"], + }, + nbytes={"x1": nbytes, "x2": nbytes, "x3": nbytes}, + stimulus_id="test", + ) + ) + assert len(ws.in_flight_tasks) == n_in_flight diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 570f7917c70..f2e60e16307 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -344,12 +344,6 @@ class RetryBusyWorkerLater(Instruction): worker: str -@dataclass -class EnsureCommunicatingAfterTransitions(Instruction): - __slots__ = () - - -@dataclass class SendMessageToScheduler(Instruction): #: Matches a key in Scheduler.stream_handlers op: ClassVar[str] @@ -1484,13 +1478,7 @@ def _transition_generic_fetch(self, ts: TaskState, stimulus_id: str) -> RecsInst self.data_needed.add(ts) for w in ts.who_has: self.data_needed_per_worker[w].add(ts) - - # This is the same as `return self._ensure_communicating()`, except that when - # many tasks transition to fetch at the same time, e.g. from a single - # compute-task or acquire-replicas command from the scheduler, it allows - # clustering the transfers into less GatherDep instructions; see - # _select_keys_for_gather(). - return {}, [EnsureCommunicatingAfterTransitions(stimulus_id=stimulus_id)] + return {}, [] def _transition_missing_waiting( self, ts: TaskState, *, stimulus_id: str @@ -2273,18 +2261,30 @@ def _transitions(self, recommendations: Recs, *, stimulus_id: str) -> Instructio reach a steady state """ instructions = [] - - remaining_recs = recommendations.copy() tasks = set() - while remaining_recs: - ts, finish = remaining_recs.popitem() - tasks.add(ts) - a_recs, a_instructions = self._transition( - ts, finish, stimulus_id=stimulus_id - ) - remaining_recs.update(a_recs) - instructions += a_instructions + def process_recs(recs: Recs) -> None: + while recs: + ts, finish = recs.popitem() + tasks.add(ts) + a_recs, a_instructions = self._transition( + ts, finish, stimulus_id=stimulus_id + ) + recs.update(a_recs) + instructions.extend(a_instructions) + + process_recs(recommendations.copy()) + + # We could call _ensure_communicating after we change something that could + # trigger a new call to gather_dep (e.g. on transitions to fetch, + # GatherDepDoneEvent, or RetryBusyWorkerEvent). However, doing so we'd + # potentially call it too early, before all tasks have transitioned to fetch. + # This in turn would hurt aggregation of multiple tasks into a single GatherDep + # instruction. + # Read: https://github.com/dask/distributed/issues/6497 + a_recs, a_instructions = self._ensure_communicating(stimulus_id=stimulus_id) + instructions += a_instructions + process_recs(a_recs) if self.validate: # Full state validation is very expensive @@ -2528,10 +2528,7 @@ def _handle_gather_dep_success(self, ev: GatherDepSuccessEvent) -> RecsInstrs: self.has_what[ev.worker].discard(ts.key) recommendations[ts] = "fetch" - return merge_recs_instructions( - (recommendations, []), - self._ensure_communicating(stimulus_id=ev.stimulus_id), - ) + return recommendations, [] @_handle_event.register def _handle_gather_dep_busy(self, ev: GatherDepBusyEvent) -> RecsInstrs: @@ -2560,10 +2557,7 @@ def _handle_gather_dep_busy(self, ev: GatherDepBusyEvent) -> RecsInstrs: ) ) - return merge_recs_instructions( - (recommendations, instructions), - self._ensure_communicating(stimulus_id=ev.stimulus_id), - ) + return recommendations, instructions @_handle_event.register def _handle_gather_dep_network_failure( @@ -2590,10 +2584,7 @@ def _handle_gather_dep_network_failure( self.log.append((ts.key, "missing-dep", ev.stimulus_id, time())) recommendations[ts] = "fetch" - return merge_recs_instructions( - (recommendations, []), - self._ensure_communicating(stimulus_id=ev.stimulus_id), - ) + return recommendations, [] @_handle_event.register def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs: @@ -2611,10 +2602,7 @@ def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs: for ts in self._gather_dep_done_common(ev) } - return merge_recs_instructions( - (recommendations, []), - self._ensure_communicating(stimulus_id=ev.stimulus_id), - ) + return recommendations, [] @_handle_event.register def _handle_secede(self, ev: SecedeEvent) -> RecsInstrs: @@ -2654,15 +2642,12 @@ def _handle_pause(self, ev: PauseEvent) -> RecsInstrs: def _handle_unpause(self, ev: UnpauseEvent) -> RecsInstrs: """Emerge from paused status""" self.running = True - return merge_recs_instructions( - self._ensure_computing(), - self._ensure_communicating(stimulus_id=ev.stimulus_id), - ) + return self._ensure_computing() @_handle_event.register def _handle_retry_busy_worker(self, ev: RetryBusyWorkerEvent) -> RecsInstrs: self.busy_workers.discard(ev.worker) - return self._ensure_communicating(stimulus_id=ev.stimulus_id) + return {}, [] @_handle_event.register def _handle_cancel_compute(self, ev: CancelComputeEvent) -> RecsInstrs: @@ -2764,17 +2749,13 @@ def _handle_refresh_who_has(self, ev: RefreshWhoHasEvent) -> RecsInstrs: if ts.who_has and ts.state == "missing": recommendations[ts] = "fetch" - elif ts.who_has and ts.state == "fetch": - # We potentially just acquired new replicas whereas all previously known - # workers are in flight or busy. We're deliberately not testing the - # minute use cases here for the sake of simplicity; instead we rely on - # _ensure_communicating to be a no-op when there's nothing to do. - recommendations, instructions = merge_recs_instructions( - (recommendations, instructions), - self._ensure_communicating(stimulus_id=ev.stimulus_id), - ) elif not ts.who_has and ts.state == "fetch": recommendations[ts] = "missing" + # Note: if ts.who_has and ts.state == "fetch", we may have just acquired new + # replicas whereas all previously known workers are in flight or busy. We + # rely on _transitions to call _ensure_communicating every time, even in + # absence of recommendations, to potentially kick off a new call to + # gather_dep. return recommendations, instructions @@ -3055,73 +3036,45 @@ def handle_stimulus(self, stim: StateMachineEvent) -> None: """ instructions = self.state.handle_stimulus(stim) - while instructions: - ensure_communicating: EnsureCommunicatingAfterTransitions | None = None - for inst in instructions: - task: asyncio.Task | None = None - - if isinstance(inst, SendMessageToScheduler): - self.batched_send(inst.to_dict()) - - elif isinstance(inst, EnsureCommunicatingAfterTransitions): - # A single compute-task or acquire-replicas command may cause - # multiple tasks to transition to fetch; this in turn means that we - # will receive multiple instances of this instruction. - # _ensure_communicating is a no-op if it runs twice in a row; we're - # not calling it inside the for loop to avoid a O(n^2) condition - # when - # 1. there are many fetches queued because all workers are in flight - # 2. a single compute-task or acquire-replicas command just sent - # many dependencies to fetch at once. - ensure_communicating = inst - - elif isinstance(inst, GatherDep): - assert inst.to_gather - keys_str = ", ".join(peekn(27, inst.to_gather)[0]) - if len(keys_str) > 80: - keys_str = keys_str[:77] + "..." - task = asyncio.create_task( - self.gather_dep( - inst.worker, - inst.to_gather, - total_nbytes=inst.total_nbytes, - stimulus_id=inst.stimulus_id, - ), - name=f"gather_dep({inst.worker}, {{{keys_str}}})", - ) - - elif isinstance(inst, Execute): - task = asyncio.create_task( - self.execute(inst.key, stimulus_id=inst.stimulus_id), - name=f"execute({inst.key})", - ) - - elif isinstance(inst, RetryBusyWorkerLater): - task = asyncio.create_task( - self.retry_busy_worker_later(inst.worker), - name=f"retry_busy_worker_later({inst.worker})", - ) + for inst in instructions: + task: asyncio.Task | None = None + + if isinstance(inst, SendMessageToScheduler): + self.batched_send(inst.to_dict()) + + elif isinstance(inst, GatherDep): + assert inst.to_gather + keys_str = ", ".join(peekn(27, inst.to_gather)[0]) + if len(keys_str) > 80: + keys_str = keys_str[:77] + "..." + task = asyncio.create_task( + self.gather_dep( + inst.worker, + inst.to_gather, + total_nbytes=inst.total_nbytes, + stimulus_id=inst.stimulus_id, + ), + name=f"gather_dep({inst.worker}, {{{keys_str}}})", + ) - else: - raise TypeError(inst) # pragma: nocover - - if task is not None: - self._async_instructions.add(task) - task.add_done_callback(self._handle_stimulus_from_task) - - if ensure_communicating: - # Potentially re-fill instructions, causing a second iteration of `while - # instructions` at the top of this method - # FIXME access to private methods - # https://github.com/dask/distributed/issues/6497 - recs, instructions = self.state._ensure_communicating( - stimulus_id=ensure_communicating.stimulus_id + elif isinstance(inst, Execute): + task = asyncio.create_task( + self.execute(inst.key, stimulus_id=inst.stimulus_id), + name=f"execute({inst.key})", ) - instructions += self.state._transitions( - recs, stimulus_id=ensure_communicating.stimulus_id + + elif isinstance(inst, RetryBusyWorkerLater): + task = asyncio.create_task( + self.retry_busy_worker_later(inst.worker), + name=f"retry_busy_worker_later({inst.worker})", ) + else: - return + raise TypeError(inst) # pragma: nocover + + if task is not None: + self._async_instructions.add(task) + task.add_done_callback(self._handle_stimulus_from_task) async def close(self, timeout: float = 30) -> None: """Cancel all asynchronous instructions"""