diff --git a/distributed/core.py b/distributed/core.py index 9641911691..d54592ef5a 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -623,7 +623,7 @@ async def handle_comm(self, comm): "Failed while closing connection to %r: %s", address, e ) - async def handle_stream(self, comm, extra=None, every_cycle=()): + async def handle_stream(self, comm, extra=None): extra = extra or {} logger.info("Starting established connection") @@ -653,12 +653,6 @@ async def handle_stream(self, comm, extra=None, every_cycle=()): logger.error("odd message %s", msg) await asyncio.sleep(0) - for func in every_cycle: - if is_coroutine_function(func): - self.loop.add_callback(func) - else: - func() - except OSError: pass except Exception as e: diff --git a/distributed/tests/test_stories.py b/distributed/tests/test_stories.py index 1616422f25..2a89d92222 100644 --- a/distributed/tests/test_stories.py +++ b/distributed/tests/test_stories.py @@ -134,14 +134,11 @@ async def test_worker_story_with_deps(c, s, a, b): story = a.story("res") assert story == [] - story = b.story("res") # Story now includes randomized stimulus_ids and timestamps. - stimulus_ids = {ev[-2] for ev in story} - # Compute dep - # Success dep - # Compute res - assert len(stimulus_ids) == 3 + story = b.story("res") + stimulus_ids = {ev[-2].rsplit("-", 1)[0] for ev in story} + assert stimulus_ids == {"compute-task", "task-finished"} # This is a simple transition log expected = [ @@ -155,8 +152,8 @@ async def test_worker_story_with_deps(c, s, a, b): assert_story(story, expected, strict=True) story = b.story("dep") - stimulus_ids = {ev[-2] for ev in story} - assert len(stimulus_ids) == 2, stimulus_ids + stimulus_ids = {ev[-2].rsplit("-", 1)[0] for ev in story} + assert stimulus_ids == {"compute-task"} expected = [ ("dep", "ensure-task-exists", "released"), ("dep", "released", "fetch", "fetch", {}), diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index adc0c16bd0..27375c7c71 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -631,11 +631,13 @@ async def test_clean(c, s, a, b): @gen_cluster(client=True) async def test_message_breakup(c, s, a, b): - n = 100000 + n = 100_000 a.target_message_size = 10 * n b.target_message_size = 10 * n - xs = [c.submit(mul, b"%d" % i, n, workers=a.address) for i in range(30)] - y = c.submit(lambda *args: None, xs, workers=b.address) + xs = [ + c.submit(mul, b"%d" % i, n, key=f"x{i}", workers=[a.address]) for i in range(30) + ] + y = c.submit(lambda _: None, xs, key="y", workers=[b.address]) await y assert 2 <= len(b.incoming_transfer_log) <= 20 @@ -714,27 +716,32 @@ async def test_clean_nbytes(c, s, a, b): ) -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 20) -async def test_gather_many_small(c, s, a, *workers): +@pytest.mark.parametrize("as_deps", [True, False]) +@gen_cluster(client=True, nthreads=[("", 1)] * 20) +async def test_gather_many_small(c, s, a, *workers, as_deps): """If the dependencies of a given task are very small, do not limit the number of concurrent outgoing connections """ a.total_out_connections = 2 - futures = await c._scatter(list(range(100))) - + futures = await c.scatter( + {f"x{i}": i for i in range(100)}, + workers=[w.address for w in workers], + ) assert all(w.data for w in workers) - def f(*args): - return 10 - - future = c.submit(f, *futures, workers=a.address) - await wait(future) + if as_deps: + future = c.submit(lambda _: None, futures, key="y", workers=[a.address]) + await wait(future) + else: + s.request_acquire_replicas(a.address, list(futures), stimulus_id="test") + while len(a.data) < 100: + await asyncio.sleep(0.01) types = list(pluck(0, a.log)) req = [i for i, t in enumerate(types) if t == "request-dep"] recv = [i for i, t in enumerate(types) if t == "receive-dep"] + assert len(req) == len(recv) == 19 assert min(recv) > max(req) - assert a.comm_nbytes == 0 @@ -1424,21 +1431,13 @@ def assert_amm_transfer_story(key: str, w_from: Worker, w_to: Worker) -> None: assert_story( w_to.story(key), [ - (key, "ensure-task-exists", "released"), - (key, "released", "fetch", "fetch", {}), - ("gather-dependencies", w_from.address, lambda set_: key in set_), (key, "fetch", "flight", "flight", {}), ("request-dep", w_from.address, lambda set_: key in set_), ("receive-dep", w_from.address, lambda set_: key in set_), (key, "put-in-memory"), (key, "flight", "memory", "memory", {}), ], - # There may be additional ('missing', 'fetch', 'fetch') events if transfers - # are slow enough that the Active Memory Manager ends up requesting them a - # second time. Here we're asserting that no matter how slow CI is, all - # transfers will be completed within 2 seconds (hardcoded interval in - # Scheduler.retire_worker when AMM is not enabled). - strict=True, + strict=False, ) assert key in w_to.data # The key may or may not still be in w_from.data, depending if the AMM had the @@ -3054,7 +3053,7 @@ async def test_missing_released_zombie_tasks_2(c, s, b): await asyncio.sleep(0) ts = b.tasks[f1.key] - assert ts.state == "fetch" + assert ts.state == "flight" while ts.state != "missing": # If we sleep for a longer time, the worker will spin into an diff --git a/distributed/worker.py b/distributed/worker.py index 601b053cad..674299e90b 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -30,7 +30,7 @@ from pickle import PicklingError from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast -from tlz import first, keymap, merge, pluck # noqa: F401 +from tlz import first, keymap, merge, peekn, pluck # noqa: F401 from tornado.ioloop import IOLoop, PeriodicCallback import dask @@ -111,9 +111,12 @@ AddKeysMsg, AlreadyCancelledEvent, CancelComputeEvent, + EnsureCommunicatingAfterTransitions, Execute, ExecuteFailureEvent, ExecuteSuccessEvent, + GatherDep, + GatherDepDoneEvent, Instructions, InvalidTransition, LongRunningMsg, @@ -1275,7 +1278,7 @@ async def heartbeat(self): @fail_hard async def handle_scheduler(self, comm): - await self.handle_stream(comm, every_cycle=[self.ensure_communicating]) + await self.handle_stream(comm) if self.reconnect and self.status in WORKER_ANY_RUNNING: logger.info("Connection to scheduler broken. Reconnecting...") @@ -1989,6 +1992,10 @@ def handle_compute_task( ts.dependencies.add(dep_ts) dep_ts.dependents.add(ts) + if nbytes is not None: + for key, value in nbytes.items(): + self.tasks[key].nbytes = value + if ts.state in READY | {"executing", "waiting", "resumed"}: pass elif ts.state == "memory": @@ -2012,9 +2019,16 @@ def handle_compute_task( self.update_who_has(who_has) self.transitions(recommendations, stimulus_id=stimulus_id) - if nbytes is not None: - for key, value in nbytes.items(): - self.tasks[key].nbytes = value + def _add_to_data_needed(self, ts: TaskState, stimulus_id: str) -> RecsInstrs: + self.data_needed.push(ts) + for w in ts.who_has: + self.data_needed_per_worker[w].push(ts) + # This is the same as `return self._ensure_communicating()`, except that when + # many tasks transition to fetch at the same time, e.g. from a single + # compute-task or acquire-replicas command from the scheduler, it allows + # clustering the transfers into less GatherDep instructions; see + # _select_keys_for_gather(). + return {}, [EnsureCommunicatingAfterTransitions(stimulus_id=stimulus_id)] def transition_missing_fetch( self, ts: TaskState, *, stimulus_id: str @@ -2027,10 +2041,7 @@ def transition_missing_fetch( self._missing_dep_flight.discard(ts) ts.state = "fetch" ts.done = False - self.data_needed.push(ts) - for w in ts.who_has: - self.data_needed_per_worker[w].push(ts) - return {}, [] + return self._add_to_data_needed(ts, stimulus_id=stimulus_id) def transition_missing_released( self, ts: TaskState, *, stimulus_id: str @@ -2069,10 +2080,7 @@ def transition_released_fetch( return {ts: "missing"}, [] ts.state = "fetch" ts.done = False - self.data_needed.push(ts) - for w in ts.who_has: - self.data_needed_per_worker[w].push(ts) - return {}, [] + return self._add_to_data_needed(ts, stimulus_id=stimulus_id) def transition_generic_released( self, ts: TaskState, *, stimulus_id: str @@ -2523,17 +2531,13 @@ def transition_flight_fetch(self, ts: TaskState, *, stimulus_id: str) -> RecsIns if not ts.done: return {}, [] - recommendations: Recs = {} ts.state = "fetch" ts.coming_from = None ts.done = False - if not ts.who_has: - recommendations[ts] = "missing" + if ts.who_has: + return self._add_to_data_needed(ts, stimulus_id=stimulus_id) else: - self.data_needed.push(ts) - for w in ts.who_has: - self.data_needed_per_worker[w].push(ts) - return recommendations, [] + return {ts: "missing"}, [] def transition_flight_error( self, @@ -2804,20 +2808,63 @@ def _handle_stimulus_from_task( @fail_hard def _handle_instructions(self, instructions: Instructions) -> None: - # TODO this method is temporary. - # See final design: https://github.com/dask/distributed/issues/5894 - for inst in instructions: - if isinstance(inst, SendMessageToScheduler): - self.batched_stream.send(inst.to_dict()) - elif isinstance(inst, Execute): - task = asyncio.create_task( - self.execute(inst.key, stimulus_id=inst.stimulus_id), - name=f"execute({inst.key})", + while instructions: + ensure_communicating: EnsureCommunicatingAfterTransitions | None = None + for inst in instructions: + task: asyncio.Task | None = None + + if isinstance(inst, SendMessageToScheduler): + self.batched_stream.send(inst.to_dict()) + + elif isinstance(inst, EnsureCommunicatingAfterTransitions): + # A single compute-task or acquire-replicas command may cause + # multiple tasks to transition to fetch; this in turn means that we + # will receive multiple instances of this instruction. + # _ensure_communicating is a no-op if it runs twice in a row; we're + # not calling it inside the for loop to avoid a O(n^2) condition + # when + # 1. there are many fetches queued because all workers are in flight + # 2. a single compute-task or acquire-replicas command just sent + # many dependencies to fetch at once. + ensure_communicating = inst + + elif isinstance(inst, GatherDep): + assert inst.to_gather + keys_str = ", ".join(peekn(27, inst.to_gather)[0]) + if len(keys_str) > 80: + keys_str = keys_str[:77] + "..." + task = asyncio.create_task( + self.gather_dep( + inst.worker, + inst.to_gather, + total_nbytes=inst.total_nbytes, + stimulus_id=inst.stimulus_id, + ), + name=f"gather_dep({inst.worker}, {{{keys_str}}})", + ) + + elif isinstance(inst, Execute): + task = asyncio.create_task( + self.execute(inst.key, stimulus_id=inst.stimulus_id), + name=f"execute({inst.key})", + ) + + else: + raise TypeError(inst) # pragma: nocover + + if task is not None: + self._async_instructions.add(task) + task.add_done_callback(self._handle_stimulus_from_task) + + if ensure_communicating: + # Potentially re-fill instructions, causing a second iteration of `while + # instructions` at the top of this method + recs, instructions = self._ensure_communicating( + stimulus_id=ensure_communicating.stimulus_id ) - self._async_instructions.add(task) - task.add_done_callback(self._handle_stimulus_from_task) + self.transitions(recs, stimulus_id=ensure_communicating.stimulus_id) else: - raise TypeError(inst) # pragma: nocover + instructions = [] def maybe_transition_long_running( self, ts: TaskState, *, compute_duration: float, stimulus_id: str @@ -2855,13 +2902,16 @@ def stimulus_story( keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks} return [ev for ev in self.stimulus_log if getattr(ev, "key", None) in keys] - def ensure_communicating(self) -> None: + def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: if self.status != Status.running: - return + return {}, [] - stimulus_id = f"ensure-communicating-{time()}" skipped_worker_in_flight_or_busy = [] + recommendations: Recs = {} + instructions: Instructions = [] + all_keys_to_gather: set[str] = set() + while self.data_needed and ( len(self.in_flight_workers) < self.total_out_connections or self.comm_nbytes < self.comm_threshold_bytes @@ -2876,7 +2926,7 @@ def ensure_communicating(self) -> None: ts = self.data_needed.pop() - if ts.state != "fetch": + if ts.state != "fetch" or ts.key in all_keys_to_gather: continue if self.validate: @@ -2896,7 +2946,10 @@ def ensure_communicating(self) -> None: local = [w for w in workers if get_address_host(w) == host] worker = random.choice(local or workers) - to_gather, total_nbytes = self.select_keys_for_gather(worker, ts.key) + to_gather, total_nbytes = self._select_keys_for_gather( + worker, ts.key, all_keys_to_gather + ) + all_keys_to_gather |= to_gather self.log.append( ("gather-dependencies", worker, to_gather, stimulus_id, time()) @@ -2904,22 +2957,32 @@ def ensure_communicating(self) -> None: self.comm_nbytes += total_nbytes self.in_flight_workers[worker] = to_gather - recommendations: Recs = { - self.tasks[d]: ("flight", worker) for d in to_gather - } - self.transitions(recommendations, stimulus_id=stimulus_id) - - self.loop.add_callback( - self.gather_dep, - worker=worker, - to_gather=to_gather, - total_nbytes=total_nbytes, - stimulus_id=stimulus_id, + for d_key in to_gather: + d_ts = self.tasks[d_key] + if self.validate: + assert d_ts.state == "fetch" + assert d_ts not in recommendations + recommendations[d_ts] = ("flight", worker) + + # A single invocation of _ensure_communicating may generate up to one + # GatherDep instruction per worker. Multiple tasks from the same worker may + # be clustered in the same instruction by _select_keys_for_gather. But once + # a worker has been selected for a GatherDep and added to in_flight_workers, + # it won't be selected again until the gather completes. + instructions.append( + GatherDep( + worker=worker, + to_gather=to_gather, + total_nbytes=total_nbytes, + stimulus_id=stimulus_id, + ) ) for ts in skipped_worker_in_flight_or_busy: self.data_needed.push(ts) + return recommendations, instructions + def _get_task_finished_msg( self, ts: TaskState, stimulus_id: str ) -> TaskFinishedMsg: @@ -3004,8 +3067,15 @@ def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs: self.log.append((ts.key, "put-in-memory", stimulus_id, time())) return recommendations - def select_keys_for_gather(self, worker, dep): - assert isinstance(dep, str) + def _select_keys_for_gather( + self, worker: str, dep: str, all_keys_to_gather: Container[str] + ) -> tuple[set[str], int]: + """``_ensure_communicating`` decided to fetch a single task from a worker, + following priority. In order to minimise overhead, request fetching other tasks + from the same worker within the message, following priority for the single + worker but ignoring higher priority tasks from other workers, up to + ``target_message_size``. + """ deps = {dep} total_bytes = self.tasks[dep].get_nbytes() @@ -3013,7 +3083,8 @@ def select_keys_for_gather(self, worker, dep): while tasks: ts = tasks.peek() - if ts.state != "fetch": + if ts.state != "fetch" or ts.key in all_keys_to_gather: + # Do not acquire the same key twice if multiple workers holds replicas tasks.pop() continue if total_bytes + ts.get_nbytes() > self.target_message_size: @@ -3140,7 +3211,7 @@ async def gather_dep( total_nbytes: int, *, stimulus_id: str, - ) -> None: + ) -> StateMachineEvent | None: """Gather dependencies for a task from a worker who has them Parameters @@ -3155,12 +3226,16 @@ async def gather_dep( Total number of bytes for all the dependencies in to_gather combined """ if self.status not in WORKER_ANY_RUNNING: # type: ignore - return + return None recommendations: Recs = {} response = {} to_gather_keys: set[str] = set() cancelled_keys: set[str] = set() + + def done_event(): + return GatherDepDoneEvent(stimulus_id=f"gather-dep-done-{time()}") + try: to_gather_keys, cancelled_keys, cause = self._filter_deps_for_fetch( to_gather @@ -3170,7 +3245,7 @@ async def gather_dep( self.log.append( ("nothing-to-gather", worker, to_gather, stimulus_id, time()) ) - return + return done_event() assert cause # Keep namespace clean since this func is long and has many @@ -3193,7 +3268,7 @@ async def gather_dep( ) stop = time() if response["status"] == "busy": - return + return done_event() self._update_metrics_received_data( start=start, @@ -3205,6 +3280,7 @@ async def gather_dep( self.log.append( ("receive-dep", worker, set(response["data"]), stimulus_id, time()) ) + return done_event() except OSError: logger.exception("Worker stream died during communication: %s", worker) @@ -3226,6 +3302,8 @@ async def gather_dep( self.log.append( ("missing-who-has", worker, ts.key, stimulus_id, time()) ) + return done_event() + except Exception as e: logger.exception(e) if self.batched_stream and LOG_PDB: @@ -3236,7 +3314,8 @@ async def gather_dep( for k in self.in_flight_workers[worker]: ts = self.tasks[k] recommendations[ts] = tuple(msg.values()) - raise + return done_event() + finally: self.comm_nbytes -= total_nbytes busy = response.get("status", "") == "busy" @@ -3291,12 +3370,12 @@ async def gather_dep( ) self.update_who_has(who_has) - self.ensure_communicating() - @log_errors def _readd_busy_worker(self, worker: str) -> None: self.busy_workers.remove(worker) - self.ensure_communicating() + self.handle_stimulus( + GatherDepDoneEvent(stimulus_id=f"readd-busy-worker-{time()}") + ) @log_errors async def find_missing(self) -> None: @@ -3325,7 +3404,6 @@ async def find_missing(self) -> None: self.periodic_callbacks[ "find-missing" ].callback_time = self.periodic_callbacks["heartbeat"].callback_time - self.ensure_communicating() def update_who_has(self, who_has: dict[str, Collection[str]]) -> None: try: @@ -3798,8 +3876,15 @@ def _(self, ev: UnpauseEvent) -> RecsInstrs: Worker.status back to running. """ assert self.status == Status.running - self.ensure_communicating() - return self._ensure_computing() + return merge_recs_instructions( + self._ensure_computing(), + self._ensure_communicating(stimulus_id=ev.stimulus_id), + ) + + @handle_event.register + def _(self, ev: GatherDepDoneEvent) -> RecsInstrs: + """Temporary hack - to be removed""" + return self._ensure_communicating(stimulus_id=ev.stimulus_id) @handle_event.register def _(self, ev: CancelComputeEvent) -> RecsInstrs: diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index abdbc12108..8619db8bf2 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -259,18 +259,13 @@ class Instruction: __slots__ = () -# TODO https://github.com/dask/distributed/issues/5736 - -# @dataclass -# class GatherDep(Instruction): -# __slots__ = ("worker", "to_gather") -# worker: str -# to_gather: set[str] - - -# @dataclass -# class FindMissing(Instruction): -# __slots__ = () +@dataclass +class GatherDep(Instruction): + worker: str + to_gather: set[str] + total_nbytes: int + stimulus_id: str + __slots__ = tuple(__annotations__) # type: ignore @dataclass @@ -292,6 +287,12 @@ def to_dict(self) -> dict[str, Any]: return d +@dataclass +class EnsureCommunicatingAfterTransitions(Instruction): + __slots__ = ("stimulus_id",) + stimulus_id: str + + @dataclass class TaskFinishedMsg(SendMessageToScheduler): op = "task-finished" @@ -434,6 +435,13 @@ class UnpauseEvent(StateMachineEvent): __slots__ = () +@dataclass +class GatherDepDoneEvent(StateMachineEvent): + """Temporary hack - to be removed""" + + __slots__ = () + + @dataclass class ExecuteSuccessEvent(StateMachineEvent): key: str