diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 048661ac2f2..d4a84c184b5 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -3031,6 +3031,7 @@ def __init__( "keep-alive": lambda *args, **kwargs: None, "log-event": self.log_worker_event, "worker-status-change": self.handle_worker_status_change, + "request-refresh-who-has": self.handle_request_refresh_who_has, } client_handlers = { @@ -4782,6 +4783,21 @@ def handle_worker_status_change( else: self.running.discard(ws) + async def handle_request_refresh_who_has( + self, keys: Iterable[str], worker: str, stimulus_id: str + ) -> None: + """Asynchronous request (through bulk comms) from a Worker to refresh the + who_has for some keys. Not to be confused with scheduler.who_has, which is a + synchronous RPC request from a Client. + """ + self.stream_comms[worker].send( + { + "op": "refresh-who-has", + "who_has": self.get_who_has(keys), + "stimulus_id": stimulus_id, + }, + ) + async def handle_worker(self, comm=None, worker=None, stimulus_id=None): """ Listen to responses from a single worker @@ -6230,13 +6246,13 @@ def get_processing(self, workers=None): w: [ts.key for ts in ws.processing] for w, ws in self.workers.items() } - def get_who_has(self, keys=None): + def get_who_has(self, keys: Iterable[str] | None = None) -> dict[str, list[str]]: if keys is not None: return { - k: [ws.address for ws in self.tasks[k].who_has] - if k in self.tasks + key: [ws.address for ws in self.tasks[key].who_has] + if key in self.tasks else [] - for k in keys + for key in keys } else: return { diff --git a/distributed/worker.py b/distributed/worker.py index 0ad15cd3fb0..2ad5d9c53a5 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -115,6 +115,7 @@ Execute, ExecuteFailureEvent, ExecuteSuccessEvent, + FindMissingEvent, GatherDep, GatherDepDoneEvent, Instructions, @@ -123,7 +124,9 @@ MissingDataMsg, Recs, RecsInstrs, + RefreshWhoHasEvent, ReleaseWorkerDataMsg, + RequestRefreshWhoHasMsg, RescheduleEvent, RescheduleMsg, SendMessageToScheduler, @@ -813,6 +816,7 @@ def __init__( "free-keys": self.handle_free_keys, "remove-replicas": self.handle_remove_replicas, "steal-request": self.handle_steal_request, + "refresh-who-has": self.handle_refresh_who_has, "worker-status-change": self.handle_worker_status_change, } @@ -840,9 +844,7 @@ def __init__( ) self.periodic_callbacks["keep-alive"] = pc - # FIXME annotations: https://github.com/tornadoweb/tornado/issues/3117 - pc = PeriodicCallback(self.find_missing, 1000) # type: ignore - self._find_missing_running = False + pc = PeriodicCallback(self.find_missing, 1000) self.periodic_callbacks["find-missing"] = pc self._address = contact_address @@ -1839,6 +1841,13 @@ def handle_remove_replicas(self, keys: list[str], stimulus_id: str) -> str: return "OK" + def handle_refresh_who_has( + self, who_has: dict[str, list[str]], stimulus_id: str + ) -> None: + self.handle_stimulus( + RefreshWhoHasEvent(who_has=who_has, stimulus_id=stimulus_id) + ) + async def set_resources(self, **resources) -> None: for r, quantity in resources.items(): if r in self.total_resources: @@ -2849,7 +2858,8 @@ def transitions(self, recommendations: Recs, *, stimulus_id: str) -> None: @log_errors def handle_stimulus(self, stim: StateMachineEvent) -> None: - self.stimulus_log.append(stim.to_loggable(handled=time())) + if not isinstance(stim, FindMissingEvent): + self.stimulus_log.append(stim.to_loggable(handled=time())) recs, instructions = self.handle_event(stim) self.transitions(recs, stimulus_id=stim.stimulus_id) self._handle_instructions(instructions) @@ -2991,11 +3001,8 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: if ts.state != "fetch" or ts.key in all_keys_to_gather: continue - if not ts.who_has: - recommendations[ts] = "missing" - continue - if self.validate: + assert ts.who_has assert self.address not in ts.who_has workers = [ @@ -3348,7 +3355,7 @@ def done_event(): self.busy_workers.add(worker) self.io_loop.call_later(0.15, self._readd_busy_worker, worker) - refresh_who_has = set() + refresh_who_has = [] for d in self.in_flight_workers.pop(worker): ts = self.tasks[d] @@ -3358,7 +3365,7 @@ def done_event(): elif busy: recommendations[ts] = "fetch" if not ts.who_has - self.busy_workers: - refresh_who_has.add(ts.key) + refresh_who_has.append(d) elif ts not in recommendations: ts.who_has.discard(worker) self.has_what[worker].discard(ts.key) @@ -3371,17 +3378,19 @@ def done_event(): ) ) recommendations[ts] = "fetch" - del data, response - self.transitions(recommendations, stimulus_id=stimulus_id) - self._handle_instructions(instructions) if refresh_who_has: # All workers that hold known replicas of our tasks are busy. # Try querying the scheduler for unknown ones. - who_has = await retry_operation( - self.scheduler.who_has, keys=refresh_who_has + instructions.append( + RequestRefreshWhoHasMsg( + keys=refresh_who_has, + stimulus_id=f"gather-dep-busy-{time()}", + ) ) - self._update_who_has(who_has) + + self.transitions(recommendations, stimulus_id=stimulus_id) + self._handle_instructions(instructions) @log_errors def _readd_busy_worker(self, worker: str) -> None: @@ -3391,33 +3400,13 @@ def _readd_busy_worker(self, worker: str) -> None: ) @log_errors - async def find_missing(self) -> None: - if self._find_missing_running or not self._missing_dep_flight: - return - try: - self._find_missing_running = True - if self.validate: - for ts in self._missing_dep_flight: - assert not ts.who_has + def find_missing(self) -> None: + self.handle_stimulus(FindMissingEvent(stimulus_id=f"find-missing-{time()}")) - stimulus_id = f"find-missing-{time()}" - who_has = await retry_operation( - self.scheduler.who_has, - keys=[ts.key for ts in self._missing_dep_flight], - ) - self._update_who_has(who_has) - recommendations: Recs = {} - for ts in self._missing_dep_flight: - if ts.who_has: - recommendations[ts] = "fetch" - self.transitions(recommendations, stimulus_id=stimulus_id) - - finally: - self._find_missing_running = False - # This is quite arbitrary but the heartbeat has scaling implemented - self.periodic_callbacks[ - "find-missing" - ].callback_time = self.periodic_callbacks["heartbeat"].callback_time + # This is quite arbitrary but the heartbeat has scaling implemented + self.periodic_callbacks["find-missing"].callback_time = self.periodic_callbacks[ + "heartbeat" + ].callback_time def _update_who_has(self, who_has: Mapping[str, Collection[str]]) -> None: for key, workers in who_has.items(): @@ -3965,6 +3954,47 @@ def _(self, ev: RescheduleEvent) -> RecsInstrs: assert ts, self.story(ev.key) return {ts: "rescheduled"}, [] + @handle_event.register + def _(self, ev: FindMissingEvent) -> RecsInstrs: + if not self._missing_dep_flight: + return {}, [] + + if self.validate: + assert not any(ts.who_has for ts in self._missing_dep_flight) + + smsg = RequestRefreshWhoHasMsg( + keys=[ts.key for ts in self._missing_dep_flight], + stimulus_id=ev.stimulus_id, + ) + return {}, [smsg] + + @handle_event.register + def _(self, ev: RefreshWhoHasEvent) -> RecsInstrs: + self._update_who_has(ev.who_has) + recommendations: Recs = {} + instructions: Instructions = [] + + for key in ev.who_has: + ts = self.tasks.get(key) + if not ts: + continue + + if ts.who_has and ts.state == "missing": + recommendations[ts] = "fetch" + elif ts.who_has and ts.state == "fetch": + # We potentially just acquired new replicas whereas all previously known + # workers are in flight or busy. We're deliberately not testing the + # minute use cases here for the sake of simplicity; instead we rely on + # _ensure_communicating to be a no-op when there's nothing to do. + recommendations, instructions = merge_recs_instructions( + (recommendations, instructions), + self._ensure_communicating(stimulus_id=ev.stimulus_id), + ) + elif not ts.who_has and ts.state == "fetch": + recommendations[ts] = "missing" + + return recommendations, instructions + def _prepare_args_for_execution( self, ts: TaskState, args: tuple, kwargs: dict[str, Any] ) -> tuple[tuple, dict[str, Any]]: @@ -4190,8 +4220,8 @@ def validate_task_fetch(self, ts): assert self.address not in ts.who_has assert not ts.done assert ts in self.data_needed - # Note: ts.who_has may be have been emptied by _update_who_has, but the task - # won't transition to missing until it reaches the top of the data_needed heap. + assert ts.who_has + for w in ts.who_has: assert ts.key in self.has_what[w] assert ts in self.data_needed_per_worker[w] diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 060028aa24b..2f311cdb482 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -380,6 +380,26 @@ class AddKeysMsg(SendMessageToScheduler): keys: list[str] +@dataclass +class RequestRefreshWhoHasMsg(SendMessageToScheduler): + """Worker -> Scheduler asynchronous request for updated who_has information. + Not to be confused with the scheduler.who_has synchronous RPC call, which is used + by the Client. + + See also + -------- + RefreshWhoHasEvent + distributed.scheduler.Scheduler.request_refresh_who_has + distributed.client.Client.who_has + distributed.scheduler.Scheduler.get_who_has + """ + + op = "request-refresh-who-has" + + __slots__ = ("keys",) + keys: list[str] + + @dataclass class StateMachineEvent: __slots__ = ("stimulus_id", "handled") @@ -533,6 +553,25 @@ class RescheduleEvent(StateMachineEvent): key: str +@dataclass +class FindMissingEvent(StateMachineEvent): + __slots__ = () + + +@dataclass +class RefreshWhoHasEvent(StateMachineEvent): + """Scheduler -> Worker message containing updated who_has information. + + See also + -------- + RequestRefreshWhoHasMsg + """ + + __slots__ = ("who_has",) + # {key: [worker address, ...]} + who_has: dict[str, list[str]] + + if TYPE_CHECKING: # TODO remove quotes (requires Python >=3.9) # TODO get out of TYPE_CHECKING (requires Python >=3.10)