Skip to content

Commit

Permalink
Finalise DumpInspector class
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Mar 10, 2022
1 parent 6063d68 commit 62c8810
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 63 deletions.
111 changes: 75 additions & 36 deletions distributed/cluster_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import msgpack

from distributed.compatibility import to_thread
from distributed.stories import scheduler_story, worker_story


def _tuple_to_list(node):
Expand Down Expand Up @@ -75,44 +76,82 @@ def load_cluster_dump(url: str):
return reader(f)


class ClusterInspector:
def __init__(
self, url_or_state: str | dict, context: Literal["scheduler" | "workers"]
):
class DumpInspector:
"""
Utility class for inspecting the state of a cluster dump
.. code-block:: python
inspector = DumpInspect("dump.msgpack.gz")
memory_tasks = inspector.tasks_in_state("memory")
released_tasks = inspector.tasks_in_state("released)
"""

def __init__(self, url_or_state: str | dict):
if isinstance(url_or_state, str):
self.dump = load_cluster_dump(url_or_state)
elif isinstance(url_or_state, dict):
self.dump = url_or_state
else:
raise TypeError(f"'url_or_state' must be a str or dict")

self.context = context


def get_tasks_in_state(
url: str,
state: str,
worker: bool = False,
) -> dict:
dump = load_cluster_dump(url)
context_str = "workers" if worker else "scheduler"

try:
context = dump[context_str]
except KeyError:
raise ValueError(
f"The '{context_str}' context was not present in the dumped state"
)

try:
tasks = context["tasks"]
except KeyError:
raise ValueError(
f"'tasks' was not present within the '{context_str}' "
f"context of the dumped state"
)

if state:
return {k: v for k, v in tasks.items() if v["state"] == state}

return tasks
raise TypeError("'url_or_state' must be a str or dict")

def tasks_in_state(self, state: str = "", workers: bool = False) -> dict:
"""
Returns
-------
tasks : dict
A dictionary of scheduler tasks with state `state`.
worker tasks are included if `workers=True`
"""
stasks = self.dump["scheduler"]["tasks"]

if state:
tasks = {k: v for k, v in stasks.items() if v["state"] == state}
else:
tasks = stasks.copy()

if not workers:
return tasks

for worker_dump in self.dump["workers"].values():
if state:
tasks.update(
(k, v)
for k, v in worker_dump["tasks"].items()
if v["state"] == state
)
else:
tasks.update(worker_dump["tasks"])

return tasks

def story(self, *key_or_stimulus_id: str, workers: bool = False) -> list:
"""
Returns
-------
stories : list
A list of stories for the keys/stimulus ID's in `*key_or_stimulus_id`.
worker stories are included if `workers=True`
"""
keys = set(key_or_stimulus_id)
story = scheduler_story(keys, self.dump["scheduler"]["transition_log"])

if not workers:
return story

for wdump in self.dump["workers"].values():
story.extend(worker_story(keys, wdump["log"]))

return story

def missing_workers(self) -> list:
"""
Returns
-------
missing : list
A list of workers connected to the scheduler, but which
did not respond to requests for a state dump.
"""
scheduler_workers = self.dump["scheduler"]["workers"]
responsive_workers = set(self.dump["workers"].keys())
return [w for w in scheduler_workers.keys() if w not in responsive_workers]
4 changes: 1 addition & 3 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
from .security import Security
from .semaphore import SemaphoreExtension
from .stealing import WorkStealing
from .stories import scheduler_story
from .utils import (
All,
TimeoutError,
Expand Down Expand Up @@ -7533,9 +7534,6 @@ def transitions(self, recommendations: dict):
def story(self, *keys):
"""Get all transitions that touch one of the input keys"""
keys = {key.key if isinstance(key, TaskState) else key for key in keys}

from .stories import scheduler_story

return scheduler_story(keys, self.transition_log)

transition_story = story
Expand Down
7 changes: 5 additions & 2 deletions distributed/stories.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
def scheduler_story(keys: set, transition_log: list):
from typing import Iterable


def scheduler_story(keys: set, transition_log: Iterable):
return [t for t in transition_log if t[0] in keys or keys.intersection(t[3])]


def worker_story(keys: set, log: list):
def worker_story(keys: set, log: Iterable):
return [
msg
for msg in log
Expand Down
32 changes: 12 additions & 20 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
tokenize,
wait,
)
from distributed.cluster_dump import get_tasks_in_state
from distributed.cluster_dump import DumpInspector, load_cluster_dump
from distributed.comm import CommClosedError
from distributed.compatibility import LINUX, WINDOWS
from distributed.core import Status
Expand Down Expand Up @@ -7263,22 +7263,9 @@ def test_print_simple(capsys):


def _verify_cluster_dump(url, format: str, addresses: set[str]) -> dict:
fsspec = pytest.importorskip("fsspec")

url = str(url)
if format == "msgpack":
import msgpack

url += ".msgpack.gz"
loader = msgpack.unpack
else:
import yaml

url += ".yaml"
loader = yaml.safe_load

with fsspec.open(url, mode="rb", compression="infer") as f:
state = loader(f)
fsspec = pytest.importorskip("fsspec") # for load_cluster_dump
url = str(url) + (".msgpack.gz" if format == "msgpack" else ".yaml")
state = load_cluster_dump(url)

assert isinstance(state, dict)
assert "scheduler" in state
Expand Down Expand Up @@ -7349,8 +7336,9 @@ async def test_dump_cluster_state_json(c, s, a, b, tmp_path, local):

@pytest.mark.parametrize("local", [True, False])
@pytest.mark.parametrize("_format", ["msgpack", "yaml"])
@pytest.mark.parametrize("workers", [True, False])
@gen_cluster(client=True)
async def test_get_cluster_state(c, s, a, b, tmp_path, _format, local):
async def test_inspect_cluster_dump(c, s, a, b, tmp_path, _format, local, workers):
filename = tmp_path / "foo"
if not local:
pytest.importorskip("fsspec")
Expand All @@ -7362,9 +7350,13 @@ async def test_get_cluster_state(c, s, a, b, tmp_path, _format, local):
await c.dump_cluster_state(filename, format=_format)

suffix = ".gz" if _format == "msgpack" else ""
outfile = f"{filename}.{_format}{suffix}"
tasks = get_tasks_in_state(outfile, "memory")
inspector = DumpInspector(f"{filename}.{_format}{suffix}")
tasks = inspector.tasks_in_state("memory", workers=workers)
assert set(tasks.keys()) == set(map(str, A.__dask_keys__()))
it = iter(tasks.keys())
stories = inspector.story(next(it), workers=workers)
stories = inspector.story(next(it), next(it), workers=workers)
missing = inspector.missing_workers()


@gen_cluster(client=True)
Expand Down
3 changes: 1 addition & 2 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from .security import Security
from .shuffle import ShuffleWorkerExtension
from .sizeof import safe_sizeof as sizeof
from .stories import worker_story
from .threadpoolexecutor import ThreadPoolExecutor
from .threadpoolexecutor import secede as tpe_secede
from .utils import (
Expand Down Expand Up @@ -2895,8 +2896,6 @@ def stateof(self, key: str) -> dict[str, Any]:

def story(self, *keys_or_tasks: str | TaskState) -> list[tuple]:
keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks}
from .stories import worker_story

return worker_story(keys, self.log)

def ensure_communicating(self) -> None:
Expand Down

0 comments on commit 62c8810

Please sign in to comment.