From 738f7c6b645f169e759a2d07a1563401de1bb89e Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 27 Apr 2022 15:43:37 +0100 Subject: [PATCH] Refactor ensure_communicating --- distributed/tests/test_worker.py | 10 +- distributed/worker.py | 142 ++++++++++++++++++---------- distributed/worker_state_machine.py | 26 ++--- 3 files changed, 109 insertions(+), 69 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 6bea4a3fdf..04073783e9 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1412,21 +1412,13 @@ def assert_amm_transfer_story(key: str, w_from: Worker, w_to: Worker) -> None: assert_story( w_to.story(key), [ - (key, "ensure-task-exists", "released"), - (key, "released", "fetch", "fetch", {}), - ("gather-dependencies", w_from.address, lambda set_: key in set_), (key, "fetch", "flight", "flight", {}), ("request-dep", w_from.address, lambda set_: key in set_), ("receive-dep", w_from.address, lambda set_: key in set_), (key, "put-in-memory"), (key, "flight", "memory", "memory", {}), ], - # There may be additional ('missing', 'fetch', 'fetch') events if transfers - # are slow enough that the Active Memory Manager ends up requesting them a - # second time. Here we're asserting that no matter how slow CI is, all - # transfers will be completed within 2 seconds (hardcoded interval in - # Scheduler.retire_worker when AMM is not enabled). - strict=True, + strict=False, ) assert key in w_to.data # The key may or may not still be in w_from.data, depending if the AMM had the diff --git a/distributed/worker.py b/distributed/worker.py index e5563ed09d..2ad9132581 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -30,7 +30,7 @@ from pickle import PicklingError from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast -from tlz import first, keymap, merge, pluck # noqa: F401 +from tlz import first, keymap, merge, peekn, pluck # noqa: F401 from tornado.ioloop import IOLoop, PeriodicCallback import dask @@ -114,6 +114,8 @@ Execute, ExecuteFailureEvent, ExecuteSuccessEvent, + GatherDep, + GatherDepDoneEvent, Instructions, InvalidTransition, LongRunningMsg, @@ -1201,7 +1203,7 @@ async def heartbeat(self): async def handle_scheduler(self, comm): try: - await self.handle_stream(comm, every_cycle=[self.ensure_communicating]) + await self.handle_stream(comm) except Exception as e: logger.exception(e) raise @@ -1937,6 +1939,12 @@ def handle_compute_task( for key, value in nbytes.items(): self.tasks[key].nbytes = value + def _add_to_data_needed(self, ts: TaskState) -> RecsInstrs: + self.data_needed.push(ts) + for w in ts.who_has: + self.data_needed_per_worker[w].push(ts) + return self._ensure_communicating() + def transition_missing_fetch( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: @@ -1947,10 +1955,7 @@ def transition_missing_fetch( self._missing_dep_flight.discard(ts) ts.state = "fetch" ts.done = False - self.data_needed.push(ts) - for w in ts.who_has: - self.data_needed_per_worker[w].push(ts) - return {}, [] + return self._add_to_data_needed(ts) def transition_missing_released( self, ts: TaskState, *, stimulus_id: str @@ -1987,10 +1992,7 @@ def transition_released_fetch( assert ts.priority is not None ts.state = "fetch" ts.done = False - self.data_needed.push(ts) - for w in ts.who_has: - self.data_needed_per_worker[w].push(ts) - return {}, [] + return self._add_to_data_needed(ts) def transition_generic_released( self, ts: TaskState, *, stimulus_id: str @@ -2426,17 +2428,13 @@ def transition_flight_fetch(self, ts: TaskState, *, stimulus_id: str) -> RecsIns if not ts.done: return {}, [] - recommendations: Recs = {} ts.state = "fetch" ts.coming_from = None ts.done = False - if not ts.who_has: - recommendations[ts] = "missing" + if ts.who_has: + return self._add_to_data_needed(ts) else: - self.data_needed.push(ts) - for w in ts.who_has: - self.data_needed_per_worker[w].push(ts) - return recommendations, [] + return {ts: "missing"}, [] def transition_flight_error( self, @@ -2699,6 +2697,7 @@ def _handle_instructions(self, instructions: Instructions) -> None: # TODO this method is temporary. # See final design: https://github.com/dask/distributed/issues/5894 for inst in instructions: + task = None if isinstance(inst, SendMessageToScheduler): self.batched_stream.send(inst.to_dict()) elif isinstance(inst, Execute): @@ -2706,11 +2705,27 @@ def _handle_instructions(self, instructions: Instructions) -> None: self.execute(inst.key, stimulus_id=inst.stimulus_id), name=f"execute({inst.key})", ) - self._async_instructions.add(task) - task.add_done_callback(self._handle_stimulus_from_task) + elif isinstance(inst, GatherDep): + assert inst.to_gather + keys_str = ", ".join(peekn(27, inst.to_gather)[0]) + if len(keys_str) > 80: + keys_str = keys_str[:77] + "..." + task = asyncio.create_task( + self.gather_dep( + inst.worker, + inst.to_gather, + total_nbytes=inst.total_nbytes, + stimulus_id=inst.stimulus_id, + ), + name=f"gather_dep({inst.worker}, {{{keys_str}}})", + ) else: raise TypeError(inst) # pragma: nocover + if task is not None: + self._async_instructions.add(task) + task.add_done_callback(self._handle_stimulus_from_task) + def maybe_transition_long_running( self, ts: TaskState, *, compute_duration: float, stimulus_id: str ): @@ -2747,13 +2762,17 @@ def stimulus_story( keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks} return [ev for ev in self.stimulus_log if getattr(ev, "key", None) in keys] - def ensure_communicating(self) -> None: + def _ensure_communicating(self) -> RecsInstrs: if self.status != Status.running: - return + return {}, [] stimulus_id = f"ensure-communicating-{time()}" skipped_worker_in_flight_or_busy = [] + recommendations: Recs = {} + instructions: Instructions = [] + all_keys_to_gather: set[str] = set() + while self.data_needed and ( len(self.in_flight_workers) < self.total_out_connections or self.comm_nbytes < self.comm_threshold_bytes @@ -2768,7 +2787,7 @@ def ensure_communicating(self) -> None: ts = self.data_needed.pop() - if ts.state != "fetch": + if ts.state != "fetch" or ts.key in all_keys_to_gather: continue if self.validate: @@ -2788,7 +2807,10 @@ def ensure_communicating(self) -> None: local = [w for w in workers if get_address_host(w) == host] worker = random.choice(local or workers) - to_gather, total_nbytes = self.select_keys_for_gather(worker, ts.key) + to_gather, total_nbytes = self._select_keys_for_gather( + worker, ts.key, all_keys_to_gather + ) + all_keys_to_gather |= to_gather self.log.append( ("gather-dependencies", worker, to_gather, stimulus_id, time()) @@ -2796,22 +2818,30 @@ def ensure_communicating(self) -> None: self.comm_nbytes += total_nbytes self.in_flight_workers[worker] = to_gather - recommendations: Recs = { - self.tasks[d]: ("flight", worker) for d in to_gather - } - self.transitions(recommendations, stimulus_id=stimulus_id) + for d_key in to_gather: + d_ts = self.tasks[d_key] + if self.validate: + assert d_ts.state == "fetch" + assert d_ts not in recommendations + recommendations[d_ts] = ("flight", worker) - self.loop.add_callback( - self.gather_dep, - worker=worker, - to_gather=to_gather, - total_nbytes=total_nbytes, - stimulus_id=stimulus_id, + # Note: given n tasks that must be fetched from the same worker, this method + # may generate anywhere between 1 and n GatherDep instructions, as multiple + # tasks may be clustered in the same instruction by _select_keys_for_gather + instructions.append( + GatherDep( + worker=worker, + to_gather=to_gather, + total_nbytes=total_nbytes, + stimulus_id=stimulus_id, + ) ) for ts in skipped_worker_in_flight_or_busy: self.data_needed.push(ts) + return recommendations, instructions + def _get_task_finished_msg( self, ts: TaskState, stimulus_id: str ) -> TaskFinishedMsg: @@ -2896,16 +2926,22 @@ def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs: self.log.append((ts.key, "put-in-memory", stimulus_id, time())) return recommendations - def select_keys_for_gather(self, worker, dep): - assert isinstance(dep, str) + def _select_keys_for_gather( + self, worker: str, dep: str, all_keys_to_gather: Container[str] + ) -> tuple[set[str], int]: + """``_ensure_communicating`` decided to fetch a single task from a worker, + following priority. In order to minimise overhead, request fetching other tasks + from the same worker within the message, following priority for the single + worker but ignoring higher priority tasks from other workers, up to + ``target_message_size``. + """ deps = {dep} - total_bytes = self.tasks[dep].get_nbytes() tasks = self.data_needed_per_worker[worker] while tasks: ts = tasks.peek() - if ts.state != "fetch": + if ts.state != "fetch" or ts.key in all_keys_to_gather: tasks.pop() continue if total_bytes + ts.get_nbytes() > self.target_message_size: @@ -3031,7 +3067,7 @@ async def gather_dep( total_nbytes: int, *, stimulus_id: str, - ) -> None: + ) -> StateMachineEvent | None: """Gather dependencies for a task from a worker who has them Parameters @@ -3046,7 +3082,7 @@ async def gather_dep( Total number of bytes for all the dependencies in to_gather combined """ if self.status not in Status.ANY_RUNNING: # type: ignore - return + return None recommendations: Recs = {} response = {} @@ -3061,7 +3097,7 @@ async def gather_dep( self.log.append( ("nothing-to-gather", worker, to_gather, stimulus_id, time()) ) - return + return GatherDepDoneEvent(stimulus_id=stimulus_id) assert cause # Keep namespace clean since this func is long and has many @@ -3084,7 +3120,7 @@ async def gather_dep( ) stop = time() if response["status"] == "busy": - return + return GatherDepDoneEvent(stimulus_id=stimulus_id) self._update_metrics_received_data( start=start, @@ -3096,6 +3132,7 @@ async def gather_dep( self.log.append( ("receive-dep", worker, set(response["data"]), stimulus_id, time()) ) + return GatherDepDoneEvent(stimulus_id=stimulus_id) except OSError: logger.exception("Worker stream died during communication: %s", worker) @@ -3112,6 +3149,8 @@ async def gather_dep( self.log.append( ("missing-who-has", worker, ts.key, stimulus_id, time()) ) + return GatherDepDoneEvent(stimulus_id=stimulus_id) + except Exception as e: logger.exception(e) if self.batched_stream and LOG_PDB: @@ -3122,7 +3161,8 @@ async def gather_dep( for k in self.in_flight_workers[worker]: ts = self.tasks[k] recommendations[ts] = tuple(msg.values()) - raise + return GatherDepDoneEvent(stimulus_id=stimulus_id) + finally: self.comm_nbytes -= total_nbytes busy = response.get("status", "") == "busy" @@ -3180,12 +3220,12 @@ async def gather_dep( ) self.update_who_has(who_has) - self.ensure_communicating() - @log_errors def _readd_busy_worker(self, worker: str) -> None: self.busy_workers.remove(worker) - self.ensure_communicating() + self.handle_stimulus( + GatherDepDoneEvent(stimulus_id=f"readd-busy-worker-{time()}") + ) @log_errors async def find_missing(self) -> None: @@ -3214,7 +3254,6 @@ async def find_missing(self) -> None: self.periodic_callbacks[ "find-missing" ].callback_time = self.periodic_callbacks["heartbeat"].callback_time - self.ensure_communicating() def update_who_has(self, who_has: dict[str, Collection[str]]) -> None: try: @@ -3686,8 +3725,15 @@ def _(self, ev: UnpauseEvent) -> RecsInstrs: Worker.status back to running. """ assert self.status == Status.running - self.ensure_communicating() - return self._ensure_computing() + return merge_recs_instructions( + self._ensure_computing(), + self._ensure_communicating(), + ) + + @handle_event.register + def _(self, ev: GatherDepDoneEvent) -> RecsInstrs: + """Temporary hack - to be removed""" + return self._ensure_communicating() @handle_event.register def _(self, ev: CancelComputeEvent) -> RecsInstrs: diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index abdbc12108..51c20657c6 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -259,18 +259,13 @@ class Instruction: __slots__ = () -# TODO https://github.com/dask/distributed/issues/5736 - -# @dataclass -# class GatherDep(Instruction): -# __slots__ = ("worker", "to_gather") -# worker: str -# to_gather: set[str] - - -# @dataclass -# class FindMissing(Instruction): -# __slots__ = () +@dataclass +class GatherDep(Instruction): + worker: str + to_gather: set[str] + total_nbytes: int + stimulus_id: str + __slots__ = tuple(__annotations__) # type: ignore @dataclass @@ -434,6 +429,13 @@ class UnpauseEvent(StateMachineEvent): __slots__ = () +@dataclass +class GatherDepDoneEvent(StateMachineEvent): + """Temporary hack - to be removed""" + + __slots__ = () + + @dataclass class ExecuteSuccessEvent(StateMachineEvent): key: str