diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 81a1ff6636f..142401eef1a 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3024,6 +3024,7 @@ def __init__( "keep-alive": lambda *args, **kwargs: None, "log-event": self.log_worker_event, "worker-status-change": self.handle_worker_status_change, + "request-refresh-who-has": self.handle_request_refresh_who_has, } client_handlers = { @@ -4769,6 +4770,21 @@ async def handle_worker_status_change( address=ws.address, stimulus_id=stimulus_id, close=False ) + async def handle_request_refresh_who_has( + self, keys: Iterable[str], worker: str, stimulus_id: str + ) -> None: + """Asynchronous request (through bulk comms) from a Worker to refresh the + who_has for some keys. Not to be confused with scheduler.who_has, which is a + synchronous RPC request from a Client. + """ + self.stream_comms[worker].send( + { + "op": "refresh-who-has", + "who_has": self.get_who_has(keys), + "stimulus_id": stimulus_id, + }, + ) + async def handle_worker(self, comm=None, worker=None, stimulus_id=None): """ Listen to responses from a single worker @@ -6216,13 +6232,13 @@ def get_processing(self, workers=None): w: [ts.key for ts in ws.processing] for w, ws in self.workers.items() } - def get_who_has(self, keys=None): + def get_who_has(self, keys: Iterable[str] | None = None) -> dict[str, list[str]]: if keys is not None: return { - k: [ws.address for ws in self.tasks[k].who_has] - if k in self.tasks + key: [ws.address for ws in self.tasks[key].who_has] + if key in self.tasks else [] - for k in keys + for key in keys } else: return { diff --git a/distributed/worker.py b/distributed/worker.py index f7746bf413b..6a31ce0eae2 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -115,6 +115,7 @@ Execute, ExecuteFailureEvent, ExecuteSuccessEvent, + FindMissingEvent, GatherDep, GatherDepDoneEvent, Instructions, @@ -123,7 +124,9 @@ MissingDataMsg, Recs, RecsInstrs, + RefreshWhoHasEvent, ReleaseWorkerDataMsg, + RequestRefreshWhoHasMsg, RescheduleEvent, RescheduleMsg, SendMessageToScheduler, @@ -813,6 +816,7 @@ def __init__( "compute-task": self.handle_compute_task, "free-keys": self.handle_free_keys, "remove-replicas": self.handle_remove_replicas, + "refresh-who-has": self.handle_refresh_who_has, "steal-request": self.handle_steal_request, "worker-status-change": self.handle_worker_status_change, } @@ -841,9 +845,7 @@ def __init__( ) self.periodic_callbacks["keep-alive"] = pc - # FIXME annotations: https://github.com/tornadoweb/tornado/issues/3117 - pc = PeriodicCallback(self.find_missing, 1000) # type: ignore - self._find_missing_running = False + pc = PeriodicCallback(self.find_missing, 1000) self.periodic_callbacks["find-missing"] = pc self._address = contact_address @@ -1829,6 +1831,13 @@ def handle_remove_replicas(self, keys: list[str], stimulus_id: str) -> str: return "OK" + def handle_refresh_who_has( + self, who_has: dict[str, list[str]], stimulus_id: str + ) -> None: + self.handle_stimulus( + RefreshWhoHasEvent(who_has=who_has, stimulus_id=stimulus_id) + ) + async def set_resources(self, **resources) -> None: for r, quantity in resources.items(): if r in self.total_resources: @@ -2845,7 +2854,8 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: @log_errors def handle_stimulus(self, stim: StateMachineEvent) -> None: - self.stimulus_log.append(stim.to_loggable(handled=time())) + if not isinstance(stim, FindMissingEvent): + self.stimulus_log.append(stim.to_loggable(handled=time())) recs, instructions = self.handle_event(stim) self.transitions(recs, stimulus_id=stim.stimulus_id) self._handle_instructions(instructions) @@ -3364,22 +3374,19 @@ def done_event(): ) ) recommendations[ts] = "fetch" - del data, response - self.transitions(recommendations, stimulus_id=stimulus_id) - self._handle_instructions(instructions) if refresh_who_has: # All workers that hold known replicas of our tasks are busy. # Try querying the scheduler for unknown ones. - who_has = await retry_operation( - self.scheduler.who_has, keys=refresh_who_has - ) - refresh_stimulus_id = f"refresh-who-has-{time()}" - recommendations, instructions = self._update_who_has( - who_has, stimulus_id=refresh_stimulus_id + instructions.append( + RequestRefreshWhoHasMsg( + keys=list(refresh_who_has), + stimulus_id=f"gather-dep-busy-{time()}", + ) ) - self.transitions(recommendations, stimulus_id=refresh_stimulus_id) - self._handle_instructions(instructions) + + self.transitions(recommendations, stimulus_id=stimulus_id) + self._handle_instructions(instructions) @log_errors def _readd_busy_worker(self, worker: str) -> None: @@ -3389,36 +3396,13 @@ def _readd_busy_worker(self, worker: str) -> None: ) @log_errors - async def find_missing(self) -> None: - if self._find_missing_running or not self._missing_dep_flight: - return - try: - self._find_missing_running = True - if self.validate: - for ts in self._missing_dep_flight: - assert not ts.who_has + def find_missing(self) -> None: + self.handle_stimulus(FindMissingEvent(stimulus_id=f"find-missing-{time()}")) - stimulus_id = f"find-missing-{time()}" - who_has = await retry_operation( - self.scheduler.who_has, - keys=[ts.key for ts in self._missing_dep_flight], - ) - recommendations, instructions = self._update_who_has( - who_has, stimulus_id=stimulus_id - ) - for ts in self._missing_dep_flight: - if ts.who_has: - assert ts not in recommendations - recommendations[ts] = "fetch" - self.transitions(recommendations, stimulus_id=stimulus_id) - self._handle_instructions(instructions) - - finally: - self._find_missing_running = False - # This is quite arbitrary but the heartbeat has scaling implemented - self.periodic_callbacks[ - "find-missing" - ].callback_time = self.periodic_callbacks["heartbeat"].callback_time + # This is quite arbitrary but the heartbeat has scaling implemented + self.periodic_callbacks["find-missing"].callback_time = self.periodic_callbacks[ + "heartbeat" + ].callback_time def _update_who_has( self, who_has: Mapping[str, Collection[str]], *, stimulus_id: str @@ -3487,12 +3471,15 @@ def _update_who_has( ts.who_has = workers # currently fetching -> can no longer be fetched -> transition to missing + # currently missing -> opportunity to be fetched -> transition to fetch # any other state -> eventually, possibly, the task may transition to fetch # or missing, at which point the relevant transitions will test who_has that # we just updated. e.g. see the various transitions to fetch, which # instead recommend transitioning to missing if who_has is empty. if not workers and ts.state == "fetch": recs[ts] = "missing" + elif workers and ts.state == "missing": + recs[ts] = "fetch" return recs, instructions @@ -4018,6 +4005,25 @@ def _(self, ev: RescheduleEvent) -> RecsInstrs: assert ts, self.story(ev.key) return {ts: "rescheduled"}, [] + @handle_event.register + def _(self, ev: FindMissingEvent) -> RecsInstrs: + if not self._missing_dep_flight: + return {}, [] + + if self.validate: + for ts in self._missing_dep_flight: + assert not ts.who_has + + smsg = RequestRefreshWhoHasMsg( + keys=[ts.key for ts in self._missing_dep_flight], + stimulus_id=ev.stimulus_id, + ) + return {}, [smsg] + + @handle_event.register + def _(self, ev: RefreshWhoHasEvent) -> RecsInstrs: + return self._update_who_has(ev.who_has, stimulus_id=ev.stimulus_id) + def _prepare_args_for_execution( self, ts: TaskState, args: tuple, kwargs: dict[str, Any] ) -> tuple[tuple, dict[str, Any]]: diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index f5fa39c0802..ec8a54fe6bc 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -379,6 +379,26 @@ class AddKeysMsg(SendMessageToScheduler): keys: list[str] +@dataclass +class RequestRefreshWhoHasMsg(SendMessageToScheduler): + """Worker -> Scheduler asynchronous request for updated who_has information. + Not to be confused with the scheduler.who_has synchronous RPC call, which is used + by the Client. + + See also + -------- + RefreshWhoHasEvent + distributed.scheduler.Scheduler.request_refresh_who_has + distributed.client.Client.who_has + distributed.scheduler.Scheduler.get_who_has + """ + + op = "request-refresh-who-has" + + __slots__ = ("keys",) + keys: list[str] + + @dataclass class StateMachineEvent: __slots__ = ("stimulus_id", "handled") @@ -508,6 +528,25 @@ class RescheduleEvent(StateMachineEvent): key: str +@dataclass +class FindMissingEvent(StateMachineEvent): + __slots__ = () + + +@dataclass +class RefreshWhoHasEvent(StateMachineEvent): + """Scheduler -> Worker message containing updated who_has information. + + See also + -------- + RequestRefreshWhoHasMsg + """ + + __slots__ = ("who_has",) + # {key: [worker address, ...]} + who_has: dict[str, list[str]] + + if TYPE_CHECKING: # TODO remove quotes (requires Python >=3.9) # TODO get out of TYPE_CHECKING (requires Python >=3.10)