diff --git a/distributed/_stories.py b/distributed/_stories.py index d17e54df53f..40bb26fa804 100644 --- a/distributed/_stories.py +++ b/distributed/_stories.py @@ -1,44 +1,55 @@ +from __future__ import annotations + from typing import Iterable -def scheduler_story(keys: set, transition_log: Iterable) -> list: +def scheduler_story( + keys_or_stimuli: set[str], transition_log: Iterable[tuple] +) -> list[tuple]: """Creates a story from the scheduler transition log given a set of keys describing tasks or stimuli. Parameters ---------- - keys : set - A set of task `keys` or `stimulus_id`'s + keys_or_stimuli : set[str] + Task keys or stimulus_id's log : iterable The scheduler transition log Returns ------- - story : list + story : list[tuple] """ - return [t for t in transition_log if t[0] in keys or keys.intersection(t[3])] + return [ + t + for t in transition_log + if t[0] in keys_or_stimuli or keys_or_stimuli.intersection(t[3]) + ] -def worker_story(keys: set, log: Iterable) -> list: +def worker_story(keys_or_tags: set[str], log: Iterable[tuple]) -> list: """Creates a story from the worker log given a set of keys describing tasks or stimuli. Parameters ---------- - keys : set - A set of task `keys` or `stimulus_id`'s + keys_or_tags : set[str] + Task keys or arbitrary tags from the transition log, e.g. stimulus_id's log : iterable The worker log Returns ------- - story : list + story : list[str] """ return [ msg for msg in log - if any(key in msg for key in keys) + if any(key in msg for key in keys_or_tags) or any( - key in c for key in keys for c in msg if isinstance(c, (tuple, list, set)) + key in c + for key in keys_or_tags + for c in msg + if isinstance(c, (tuple, list, set)) ) ] diff --git a/distributed/client.py b/distributed/client.py index af52607c9df..4d06a92e11c 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4286,11 +4286,11 @@ def collections_to_dsk(collections, *args, **kwargs): """Convert many collections into a single dask graph, after optimization""" return collections_to_dsk(collections, *args, **kwargs) - async def _story(self, keys=(), on_error="raise"): + async def _story(self, *keys_or_tags: str, on_error="raise"): assert on_error in ("raise", "ignore") try: - flat_stories = await self.scheduler.get_story(keys=keys) + flat_stories = await self.scheduler.get_story(keys_or_stimuli=keys_or_tags) flat_stories = [("scheduler", *msg) for msg in flat_stories] except Exception: if on_error == "raise": @@ -4301,15 +4301,17 @@ async def _story(self, keys=(), on_error="raise"): raise ValueError(f"on_error not in {'raise', 'ignore'}") responses = await self.scheduler.broadcast( - msg={"op": "get_story", "keys": keys}, on_error=on_error + msg={"op": "get_story", "keys_or_tags": keys_or_tags}, on_error=on_error ) for worker, stories in responses.items(): flat_stories.extend((worker, *msg) for msg in stories) return flat_stories - def story(self, *keys_or_stimulus_ids, on_error="raise"): - """Returns a cluster-wide story for the given keys or simtulus_id's""" - return self.sync(self._story, keys=keys_or_stimulus_ids, on_error=on_error) + def story(self, *keys_or_tags, on_error="raise"): + """Returns a cluster-wide story for the given keys or transition log tags, such + as stimulus_id's + """ + return self.sync(self._story, *keys_or_tags, on_error=on_error) def get_task_stream( self, diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 048661ac2f2..1d88b7db6b8 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -6540,13 +6540,16 @@ def transitions(self, recommendations: dict, stimulus_id: str): self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id) self.send_all(client_msgs, worker_msgs) - def story(self, *keys): - """Get all transitions that touch one of the input keys""" - keys = {key.key if isinstance(key, TaskState) else key for key in keys} - return scheduler_story(keys, self.transition_log) + def story(self, *keys_or_tasks_or_stimuli: str | TaskState) -> list[tuple]: + """Get all transitions that touch one of the input keys or stimulus_id's""" + keys_or_stimuli = { + key.key if isinstance(key, TaskState) else key + for key in keys_or_tasks_or_stimuli + } + return scheduler_story(keys_or_stimuli, self.transition_log) - async def get_story(self, keys=()): - return self.story(*keys) + async def get_story(self, keys_or_stimuli: Iterable[str]) -> list[tuple]: + return self.story(*keys_or_stimuli) transition_story = story diff --git a/distributed/worker.py b/distributed/worker.py index 0ad15cd3fb0..403dc19256b 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2949,13 +2949,17 @@ def stateof(self, key: str) -> dict[str, Any]: "data": key in self.data, } - def story(self, *keys_or_tasks: str | TaskState) -> list[tuple]: - """Return all transitions involving one or more tasks""" - keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks} - return worker_story(keys, self.log) + def story(self, *keys_or_tasks_or_tags: str | TaskState) -> list[tuple]: + """Return all records from the transitions log involving one or more tasks; + it can also be used for arbitrary non-transition tags. + """ + keys_or_tags = { + e.key if isinstance(e, TaskState) else e for e in keys_or_tasks_or_tags + } + return worker_story(keys_or_tags, self.log) - async def get_story(self, keys=None): - return self.story(*keys) + async def get_story(self, keys_or_tags: Iterable[str]) -> list[tuple]: + return self.story(*keys_or_tags) def stimulus_story( self, *keys_or_tasks: str | TaskState