Skip to content

Commit

Permalink
Deprecate cubed.extensions and move to cubed.diagnostics (#533)
Browse files Browse the repository at this point in the history
* Rename `cubed.extensions` to `cubed.diagnostics`

* Deprecate `cubed.extensions`

* Update tests, examples and notebooks to use `cubed.diagnostics`
  • Loading branch information
tomwhite authored Aug 2, 2024
1 parent ceddb7f commit 19f844e
Show file tree
Hide file tree
Showing 20 changed files with 463 additions and 420 deletions.
Empty file added cubed/diagnostics/__init__.py
Empty file.
102 changes: 102 additions & 0 deletions cubed/diagnostics/history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from dataclasses import asdict
from pathlib import Path

import pandas as pd

from cubed.runtime.pipeline import visit_nodes
from cubed.runtime.types import Callback


class HistoryCallback(Callback):
def on_compute_start(self, event):
plan = []
for name, node in visit_nodes(event.dag, event.resume):
primitive_op = node["primitive_op"]
plan.append(
dict(
name=name,
op_name=node["op_name"],
projected_mem=primitive_op.projected_mem,
reserved_mem=primitive_op.reserved_mem,
num_tasks=primitive_op.num_tasks,
)
)

self.plan = plan
self.events = []

def on_task_end(self, event):
self.events.append(asdict(event))

def on_compute_end(self, event):
self.plan_df = pd.DataFrame(self.plan)
self.events_df = pd.DataFrame(self.events)
history_path = Path(f"history/{event.compute_id}")
history_path.mkdir(parents=True, exist_ok=True)
self.plan_df_path = history_path / "plan.csv"
self.events_df_path = history_path / "events.csv"
self.stats_df_path = history_path / "stats.csv"
self.plan_df.to_csv(self.plan_df_path, index=False)
self.events_df.to_csv(self.events_df_path, index=False)

self.stats_df = analyze(self.plan_df, self.events_df)
self.stats_df.to_csv(self.stats_df_path, index=False)


def analyze(plan_df, events_df):
# convert memory to MB
plan_df["projected_mem_mb"] = plan_df["projected_mem"] / 1_000_000
plan_df["reserved_mem_mb"] = plan_df["reserved_mem"] / 1_000_000
plan_df = plan_df[
[
"name",
"op_name",
"projected_mem_mb",
"reserved_mem_mb",
"num_tasks",
]
]
events_df["peak_measured_mem_start_mb"] = (
events_df["peak_measured_mem_start"] / 1_000_000
)
events_df["peak_measured_mem_end_mb"] = (
events_df["peak_measured_mem_end"] / 1_000_000
)
events_df["peak_measured_mem_delta_mb"] = (
events_df["peak_measured_mem_end_mb"] - events_df["peak_measured_mem_start_mb"]
)

# find per-array stats
df = events_df.groupby("name", as_index=False).agg(
{
"peak_measured_mem_start_mb": ["min", "mean", "max"],
"peak_measured_mem_end_mb": ["max"],
"peak_measured_mem_delta_mb": ["min", "mean", "max"],
}
)

# flatten multi-index
df.columns = ["_".join(a).rstrip("_") for a in df.columns.to_flat_index()]
df = df.merge(plan_df, on="name")

def projected_mem_utilization(row):
return row["peak_measured_mem_end_mb_max"] / row["projected_mem_mb"]

df["projected_mem_utilization"] = df.apply(
lambda row: projected_mem_utilization(row), axis=1
)
df = df[
[
"name",
"op_name",
"num_tasks",
"peak_measured_mem_start_mb_max",
"peak_measured_mem_end_mb_max",
"peak_measured_mem_delta_mb_max",
"projected_mem_mb",
"reserved_mem_mb",
"projected_mem_utilization",
]
]

return df
35 changes: 35 additions & 0 deletions cubed/diagnostics/mem_warn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import warnings
from collections import Counter

from cubed.runtime.pipeline import visit_nodes
from cubed.runtime.types import Callback


class MemoryWarningCallback(Callback):
def on_compute_start(self, event):
# store ops keyed by name
self.ops = {}
for name, node in visit_nodes(event.dag, event.resume):
primitive_op = node["primitive_op"]
self.ops[name] = primitive_op

# count number of times each op exceeds allowed mem
self.counter = Counter()

def on_task_end(self, event):
allowed_mem = self.ops[event.name].allowed_mem
if (
event.peak_measured_mem_end is not None
and event.peak_measured_mem_end > allowed_mem
):
self.counter.update({event.name: 1})

def on_compute_end(self, event):
if sum(self.counter.values()) > 0:
exceeded = [
f"{k} ({v}/{self.ops[k].num_tasks})" for k, v in self.counter.items()
]
warnings.warn(
f"Peak memory usage exceeded allowed_mem when running tasks: {', '.join(exceeded)}",
UserWarning,
)
117 changes: 117 additions & 0 deletions cubed/diagnostics/rich.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import logging
import sys
from contextlib import contextmanager

from rich.console import RenderableType
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
Task,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
)
from rich.text import Text

from cubed.runtime.pipeline import visit_nodes
from cubed.runtime.types import Callback


class RichProgressBar(Callback):
"""Rich progress bar for a computation."""

def on_compute_start(self, event):
# Set the pulse_style to the background colour to disable pulsing,
# since Rich will pulse all non-started bars.
logger_aware_progress = LoggerAwareProgress(
SpinnerWhenRunningColumn(),
TextColumn("[progress.description]{task.description}"),
LeftJustifiedMofNCompleteColumn(),
BarColumn(bar_width=None, pulse_style="bar.back"),
TaskProgressColumn(
text_format="[progress.percentage]{task.percentage:>3.1f}%"
),
TimeElapsedColumn(),
logger=logging.getLogger(),
)
progress = logger_aware_progress.__enter__()

progress_tasks = {}
for name, node in visit_nodes(event.dag, event.resume):
num_tasks = node["primitive_op"].num_tasks
op_display_name = node["op_display_name"].replace("\n", " ")
progress_task = progress.add_task(
f"{op_display_name}", start=False, total=num_tasks
)
progress_tasks[name] = progress_task

self.logger_aware_progress = logger_aware_progress
self.progress = progress
self.progress_tasks = progress_tasks

def on_compute_end(self, event):
self.logger_aware_progress.__exit__(None, None, None)

def on_operation_start(self, event):
self.progress.start_task(self.progress_tasks[event.name])

def on_task_end(self, event):
self.progress.update(
self.progress_tasks[event.name], advance=event.num_tasks, refresh=True
)


class SpinnerWhenRunningColumn(SpinnerColumn):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Override so spinner is not shown when bar has not yet started
def render(self, task: "Task") -> RenderableType:
text = (
self.finished_text
if not task.started or task.finished
else self.spinner.render(task.get_time())
)
return text


class LeftJustifiedMofNCompleteColumn(MofNCompleteColumn):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def render(self, task: "Task") -> Text:
"""Show completed/total."""
completed = int(task.completed)
total = int(task.total) if task.total is not None else "?"
total_width = len(str(total))
return Text(
f"{completed}{self.separator}{total}".ljust(total_width + 1 + total_width),
style="progress.download",
)


# Based on CustomProgress from https://github.com/Textualize/rich/discussions/1578
@contextmanager
def LoggerAwareProgress(*args, **kwargs):
"""Wrapper around rich.progress.Progress to manage logging output to stderr."""
try:
__logger = kwargs.pop("logger", None)
streamhandlers = [
x for x in __logger.root.handlers if type(x) is logging.StreamHandler
]

with Progress(*args, **kwargs) as progress:
for handler in streamhandlers:
__prior_stderr = handler.stream
handler.setStream(sys.stderr)

yield progress

finally:
streamhandlers = [
x for x in __logger.root.handlers if type(x) is logging.StreamHandler
]
for handler in streamhandlers:
handler.setStream(__prior_stderr)
101 changes: 101 additions & 0 deletions cubed/diagnostics/timeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os
import time
from dataclasses import asdict
from typing import Optional

import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
import pylab
import seaborn as sns

from cubed.runtime.types import Callback

sns.set_style("whitegrid")
pylab.switch_backend("Agg")


class TimelineVisualizationCallback(Callback):
def __init__(self, format: Optional[str] = None) -> None:
self.format = format

def on_compute_start(self, event):
self.start_tstamp = time.time()
self.stats = []

def on_task_end(self, event):
self.stats.append(asdict(event))

def on_compute_end(self, event):
end_tstamp = time.time()
dst = f"history/{event.compute_id}"
format = self.format
create_timeline(self.stats, self.start_tstamp, end_tstamp, dst, format)


# copy of lithops function of the same name, and modified for different field names
def create_timeline(stats, start_tstamp, end_tstamp, dst=None, format=None):
stats_df = pd.DataFrame(stats)

stats_df = stats_df.sort_values(by=["task_create_tstamp", "name"], ascending=True)

total_calls = len(stats_df)

palette = sns.color_palette("deep", 6)

fig = pylab.figure(figsize=(10, 6))
ax = fig.add_subplot(1, 1, 1)

y = np.arange(total_calls)
point_size = 10

fields = [
("task create", stats_df.task_create_tstamp - start_tstamp),
("function start", stats_df.function_start_tstamp - start_tstamp),
("function end", stats_df.function_end_tstamp - start_tstamp),
("task result", stats_df.task_result_tstamp - start_tstamp),
]

patches = []
for f_i, (field_name, val) in enumerate(fields):
ax.scatter(val, y, c=[palette[f_i]], edgecolor="none", s=point_size, alpha=0.8)
patches.append(mpatches.Patch(color=palette[f_i], label=field_name))

ax.set_xlabel("Execution Time (sec)")
ax.set_ylabel("Function Call")

legend = pylab.legend(handles=patches, loc="upper right", frameon=True)
legend.get_frame().set_facecolor("#FFFFFF")

yplot_step = int(np.max([1, total_calls / 20]))
y_ticks = np.arange(total_calls // yplot_step + 2) * yplot_step
ax.set_yticks(y_ticks)
ax.set_ylim(-0.02 * total_calls, total_calls * 1.02)
for y in y_ticks:
ax.axhline(y, c="k", alpha=0.1, linewidth=1)

max_seconds = np.max(end_tstamp - start_tstamp) * 1.25
xplot_step = max(int(max_seconds / 8), 1)
x_ticks = np.arange(max_seconds // xplot_step + 2) * xplot_step
ax.set_xlim(0, max_seconds)

ax.set_xticks(x_ticks)
for x in x_ticks:
ax.axvline(x, c="k", alpha=0.2, linewidth=0.8)

ax.grid(False)
fig.tight_layout()

if format is None:
format = "svg"

if dst is None:
os.makedirs("plots", exist_ok=True)
dst = os.path.join(
os.getcwd(), "plots", "{}_{}".format(int(time.time()), f"timeline.{format}")
)
else:
dst = os.path.expanduser(dst) if "~" in dst else dst
dst = "{}/{}".format(os.path.realpath(dst), f"timeline.{format}")

fig.savefig(dst)
Loading

0 comments on commit 19f844e

Please sign in to comment.