diff --git a/distributed/_stories.py b/distributed/_stories.py index d17e54df53f..bc3624b900b 100644 --- a/distributed/_stories.py +++ b/distributed/_stories.py @@ -1,44 +1,55 @@ +from __future__ import annotations + from typing import Iterable -def scheduler_story(keys: set, transition_log: Iterable) -> list: +def scheduler_story( + keys_or_stimuli: set[str], transition_log: Iterable[tuple] +) -> list[tuple]: """Creates a story from the scheduler transition log given a set of keys describing tasks or stimuli. Parameters ---------- - keys : set - A set of task `keys` or `stimulus_id`'s + keys_or_stimuli : set[str] + Task keys or stimulus_id's log : iterable The scheduler transition log Returns ------- - story : list + story : list[tuple] """ - return [t for t in transition_log if t[0] in keys or keys.intersection(t[3])] + return [ + t + for t in transition_log + if t[0] in keys_or_stimuli or keys_or_stimuli.intersection(t[3]) + ] -def worker_story(keys: set, log: Iterable) -> list: +def worker_story(keys_or_stimuli: set[str], log: Iterable[tuple]) -> list: """Creates a story from the worker log given a set of keys describing tasks or stimuli. Parameters ---------- - keys : set - A set of task `keys` or `stimulus_id`'s + keys_or_stimuli : set[str] + Task keys or stimulus_id's log : iterable The worker log Returns ------- - story : list + story : list[str] """ return [ msg for msg in log - if any(key in msg for key in keys) + if any(key in msg for key in keys_or_stimuli) or any( - key in c for key in keys for c in msg if isinstance(c, (tuple, list, set)) + key in c + for key in keys_or_stimuli + for c in msg + if isinstance(c, (tuple, list, set)) ) ] diff --git a/distributed/client.py b/distributed/client.py index 8361f3e5374..e62d3449df9 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4272,11 +4272,13 @@ def collections_to_dsk(collections, *args, **kwargs): """Convert many collections into a single dask graph, after optimization""" return collections_to_dsk(collections, *args, **kwargs) - async def _story(self, keys=(), on_error="raise"): + async def _story(self, *keys_or_stimuli: str, on_error="raise"): assert on_error in ("raise", "ignore") try: - flat_stories = await self.scheduler.get_story(keys=keys) + flat_stories = await self.scheduler.get_story( + keys_or_stimuli=keys_or_stimuli + ) flat_stories = [("scheduler", *msg) for msg in flat_stories] except Exception: if on_error == "raise": @@ -4287,15 +4289,16 @@ async def _story(self, keys=(), on_error="raise"): raise ValueError(f"on_error not in {'raise', 'ignore'}") responses = await self.scheduler.broadcast( - msg={"op": "get_story", "keys": keys}, on_error=on_error + msg={"op": "get_story", "keys_or_stimuli": keys_or_stimuli}, + on_error=on_error, ) for worker, stories in responses.items(): flat_stories.extend((worker, *msg) for msg in stories) return flat_stories - def story(self, *keys_or_stimulus_ids, on_error="raise"): - """Returns a cluster-wide story for the given keys or simtulus_id's""" - return self.sync(self._story, keys=keys_or_stimulus_ids, on_error=on_error) + def story(self, *keys_or_stimuli, on_error="raise"): + """Returns a cluster-wide story for the given keys or stimulus_id's""" + return self.sync(self._story, *keys_or_stimuli, on_error=on_error) def get_task_stream( self, diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 26c5146d388..d5dccbc473e 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -6539,13 +6539,16 @@ def transitions(self, recommendations: dict, stimulus_id: str): self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) - def story(self, *keys): - """Get all transitions that touch one of the input keys""" - keys = {key.key if isinstance(key, TaskState) else key for key in keys} - return scheduler_story(keys, self.transition_log) + def story(self, *keys_or_tasks_or_stimuli: str | TaskState) -> list[tuple]: + """Get all transitions that touch one of the input keys or stimulus_id's""" + keys_or_stimuli = { + key.key if isinstance(key, TaskState) else key + for key in keys_or_tasks_or_stimuli + } + return scheduler_story(keys_or_stimuli, self.transition_log) - async def get_story(self, keys=()): - return self.story(*keys) + async def get_story(self, keys_or_stimuli: Iterable[str]) -> list[tuple]: + return self.story(*keys_or_stimuli) transition_story = story diff --git a/distributed/worker.py b/distributed/worker.py index d39decd416d..7216af7aecf 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1843,8 +1843,8 @@ def stateof(self, key: str) -> dict[str, Any]: "data": key in self.data, } - async def get_story(self, keys=None): - return self.story(*keys) + async def get_story(self, keys_or_stimuli: Iterable[str]) -> list[tuple]: + return self.state.story(*keys_or_stimuli) ########################## # Dependencies gathering # diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 3cca984cc7b..7c7bf686de1 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -2774,10 +2774,14 @@ def _handle_refresh_who_has(self, ev: RefreshWhoHasEvent) -> RecsInstrs: # Diagnostics # ############### - def story(self, *keys_or_tasks: str | TaskState) -> list[tuple]: - """Return all transitions involving one or more tasks""" - keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks} - return worker_story(keys, self.log) + def story(self, *keys_or_tasks_or_stimuli: str | TaskState) -> list[tuple]: + """Return all records from the transitions log involving one or more tasks or + stimulus_id's + """ + keys_or_stimuli = { + e.key if isinstance(e, TaskState) else e for e in keys_or_tasks_or_stimuli + } + return worker_story(keys_or_stimuli, self.log) def stimulus_story( self, *keys_or_tasks: str | TaskState