Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a Rich progress bar #383

Merged
merged 2 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions cubed/extensions/rich.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import logging
import sys
from contextlib import contextmanager

from rich.console import RenderableType
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
Task,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
)
from rich.text import Text

from cubed.runtime.pipeline import visit_nodes
from cubed.runtime.types import Callback


class RichProgressBar(Callback):
"""Rich progress bar for a computation."""

def on_compute_start(self, event):
# Set the pulse_style to the background colour to disable pulsing,
# since Rich will pulse all non-started bars.
logger_aware_progress = LoggerAwareProgress(
SpinnerWhenRunningColumn(),
TextColumn("[progress.description]{task.description}"),
LeftJustifiedMofNCompleteColumn(),
BarColumn(bar_width=None, pulse_style="bar.back"),
TaskProgressColumn(
text_format="[progress.percentage]{task.percentage:>3.1f}%"
),
TimeElapsedColumn(),
logger=logging.getLogger(),
)
progress = logger_aware_progress.__enter__()

progress_tasks = {}
for name, node in visit_nodes(event.dag, event.resume):
num_tasks = node["primitive_op"].num_tasks
progress_task = progress.add_task(f"{name}", start=False, total=num_tasks)
progress_tasks[name] = progress_task

self.logger_aware_progress = logger_aware_progress
self.progress = progress
self.progress_tasks = progress_tasks

def on_compute_end(self, event):
self.logger_aware_progress.__exit__(None, None, None)

def on_operation_start(self, event):
self.progress.start_task(self.progress_tasks[event.name])

def on_task_end(self, event):
self.progress.update(self.progress_tasks[event.name], advance=event.num_tasks)


class SpinnerWhenRunningColumn(SpinnerColumn):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Override so spinner is not shown when bar has not yet started
def render(self, task: "Task") -> RenderableType:
text = (
self.finished_text
if not task.started or task.finished
else self.spinner.render(task.get_time())
)
return text


class LeftJustifiedMofNCompleteColumn(MofNCompleteColumn):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def render(self, task: "Task") -> Text:
"""Show completed/total."""
completed = int(task.completed)
total = int(task.total) if task.total is not None else "?"
total_width = len(str(total))
return Text(
f"{completed}{self.separator}{total}".ljust(total_width + 1 + total_width),
style="progress.download",
)


# Based on CustomProgress from https://github.com/Textualize/rich/discussions/1578
@contextmanager
def LoggerAwareProgress(*args, **kwargs):
"""Wrapper around rich.progress.Progress to manage logging output to stderr."""
try:
__logger = kwargs.pop("logger", None)
streamhandlers = [
x for x in __logger.root.handlers if type(x) is logging.StreamHandler
]

with Progress(*args, **kwargs) as progress:
for handler in streamhandlers:
__prior_stderr = handler.stream
handler.setStream(sys.stderr)

yield progress

finally:
streamhandlers = [
x for x in __logger.root.handlers if type(x) is logging.StreamHandler
]
for handler in streamhandlers:
handler.setStream(__prior_stderr)
7 changes: 6 additions & 1 deletion cubed/runtime/executors/coiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

from cubed.runtime.pipeline import visit_nodes
from cubed.runtime.types import Callback, DagExecutor
from cubed.runtime.utils import execution_stats, handle_callbacks
from cubed.runtime.utils import (
execution_stats,
handle_callbacks,
handle_operation_start_callbacks,
)
from cubed.spec import Spec


Expand All @@ -27,6 +31,7 @@ def execute_dag(
) -> None:
# Note this currently only builds the task graph for each stage once it gets to that stage in computation
for name, node in visit_nodes(dag, resume=resume):
handle_operation_start_callbacks(callbacks, name)
pipeline = node["pipeline"]
coiled_function = make_coiled_function(pipeline.function, coiled_kwargs)
input = list(
Expand Down
8 changes: 7 additions & 1 deletion cubed/runtime/executors/dask_distributed_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from cubed.runtime.executors.asyncio import async_map_unordered
from cubed.runtime.pipeline import visit_node_generations, visit_nodes
from cubed.runtime.types import Callback, CubedPipeline, DagExecutor
from cubed.runtime.utils import execution_stats, gensym, handle_callbacks
from cubed.runtime.utils import (
execution_stats,
gensym,
handle_callbacks,
handle_operation_start_callbacks,
)
from cubed.spec import Spec


Expand Down Expand Up @@ -123,6 +128,7 @@ async def async_execute_dag(
if not compute_arrays_in_parallel:
# run one pipeline at a time
for name, node in visit_nodes(dag, resume=resume):
handle_operation_start_callbacks(callbacks, name)
st = pipeline_to_stream(client, name, node["pipeline"], **kwargs)
async with st.stream() as streamer:
async for _, stats in streamer:
Expand Down
5 changes: 4 additions & 1 deletion cubed/runtime/executors/lithops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from cubed.runtime.pipeline import visit_node_generations, visit_nodes
from cubed.runtime.types import Callback, DagExecutor
from cubed.runtime.utils import handle_callbacks
from cubed.runtime.utils import handle_callbacks, handle_operation_start_callbacks
from cubed.spec import Spec

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -180,6 +180,7 @@ def execute_dag(
with RetryingFunctionExecutor(function_executor) as executor:
if not compute_arrays_in_parallel:
for name, node in visit_nodes(dag, resume=resume):
handle_operation_start_callbacks(callbacks, name)
pipeline = node["pipeline"]
for _, stats in map_unordered(
executor,
Expand Down Expand Up @@ -207,6 +208,8 @@ def execute_dag(
group_map_functions.append(f)
group_map_iterdata.append(pipeline.mappable)
group_names.append(name)
for name in group_names:
handle_operation_start_callbacks(callbacks, name)
for _, stats in map_unordered(
executor,
group_map_functions,
Expand Down
7 changes: 6 additions & 1 deletion cubed/runtime/executors/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

from cubed.runtime.pipeline import visit_nodes
from cubed.runtime.types import Callback, DagExecutor
from cubed.runtime.utils import execute_with_stats, handle_callbacks
from cubed.runtime.utils import (
execute_with_stats,
handle_callbacks,
handle_operation_start_callbacks,
)
from cubed.spec import Spec

RUNTIME_MEMORY_MIB = 2000
Expand Down Expand Up @@ -128,6 +132,7 @@ def execute_dag(
else:
raise ValueError(f"Unrecognized cloud: {cloud}")
for name, node in visit_nodes(dag, resume=resume):
handle_operation_start_callbacks(callbacks, name)
pipeline = node["pipeline"]
task_create_tstamp = time.time()
for _, stats in app_function.map(
Expand Down
3 changes: 2 additions & 1 deletion cubed/runtime/executors/modal_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from cubed.runtime.pipeline import visit_node_generations, visit_nodes
from cubed.runtime.types import Callback, DagExecutor
from cubed.runtime.utils import handle_callbacks
from cubed.runtime.utils import handle_callbacks, handle_operation_start_callbacks
from cubed.spec import Spec


Expand Down Expand Up @@ -127,6 +127,7 @@ async def async_execute_dag(
if not compute_arrays_in_parallel:
# run one pipeline at a time
for name, node in visit_nodes(dag, resume=resume):
handle_operation_start_callbacks(callbacks, name)
st = pipeline_to_stream(app_function, name, node["pipeline"], **kwargs)
async with st.stream() as streamer:
async for _, stats in streamer:
Expand Down
2 changes: 2 additions & 0 deletions cubed/runtime/executors/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from cubed.runtime.pipeline import visit_nodes
from cubed.runtime.types import Callback, CubedPipeline, DagExecutor, TaskEndEvent
from cubed.runtime.utils import handle_operation_start_callbacks
from cubed.spec import Spec


Expand All @@ -24,6 +25,7 @@ def execute_dag(
**kwargs,
) -> None:
for name, node in visit_nodes(dag, resume=resume):
handle_operation_start_callbacks(callbacks, name)
pipeline: CubedPipeline = node["pipeline"]
for m in pipeline.mappable:
exec_stage_func(
Expand Down
7 changes: 6 additions & 1 deletion cubed/runtime/executors/python_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from cubed.runtime.executors.asyncio import async_map_unordered
from cubed.runtime.pipeline import visit_node_generations, visit_nodes
from cubed.runtime.types import Callback, CubedPipeline, DagExecutor
from cubed.runtime.utils import execution_stats, handle_callbacks
from cubed.runtime.utils import (
execution_stats,
handle_callbacks,
handle_operation_start_callbacks,
)
from cubed.spec import Spec


Expand Down Expand Up @@ -92,6 +96,7 @@ async def async_execute_dag(
if not compute_arrays_in_parallel:
# run one pipeline at a time
for name, node in visit_nodes(dag, resume=resume):
handle_operation_start_callbacks(callbacks, name)
st = pipeline_to_stream(
concurrent_executor, name, node["pipeline"], **kwargs
)
Expand Down
11 changes: 11 additions & 0 deletions cubed/runtime/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@ class ComputeEndEvent:
"""The computation DAG."""


@dataclass
class OperationStartEvent:
"""Callback information about an operation that is about to start."""

name: str
"""Name of the operation."""


@dataclass
class TaskEndEvent:
"""Callback information about a completed task (or tasks)."""
Expand Down Expand Up @@ -101,6 +109,9 @@ def on_compute_end(self, ComputeEndEvent):
"""
pass # pragma: no cover

def on_operation_start(self, event):
pass

def on_task_end(self, event):
"""Called when the a task ends.

Expand Down
8 changes: 7 additions & 1 deletion cubed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import partial
from itertools import islice

from cubed.runtime.types import TaskEndEvent
from cubed.runtime.types import OperationStartEvent, TaskEndEvent
from cubed.utils import peak_measured_mem

sym_counter = 0
Expand Down Expand Up @@ -39,6 +39,12 @@ def execution_stats(func):
return partial(execute_with_stats, func)


def handle_operation_start_callbacks(callbacks, name):
if callbacks is not None:
event = OperationStartEvent(name)
[callback.on_operation_start(event) for callback in callbacks]


def handle_callbacks(callbacks, stats):
"""Construct a TaskEndEvent from stats and send to all callbacks."""

Expand Down
14 changes: 14 additions & 0 deletions cubed/tests/test_executor_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import cubed.array_api as xp
import cubed.random
from cubed.extensions.history import HistoryCallback
from cubed.extensions.rich import RichProgressBar
from cubed.extensions.timeline import TimelineVisualizationCallback
from cubed.extensions.tqdm import TqdmProgressBar
from cubed.primitive.blockwise import apply_blockwise
Expand Down Expand Up @@ -97,6 +98,19 @@ def test_callbacks(spec, executor):
assert task_counter.value == num_created_arrays + 4


def test_rich_progress_bar(spec, executor):
# test indirectly by checking it doesn't cause a failure
progress = RichProgressBar()

a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)
c = xp.add(a, b)
assert_array_equal(
c.compute(executor=executor, callbacks=[progress]),
np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10]]),
)


@pytest.mark.cloud
def test_callbacks_modal(spec, modal_executor):
task_counter = TaskCounter(check_timestamps=False)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ diagnostics = [
"pydot",
"pandas",
"matplotlib",
"rich",
"seaborn",
]
beam = ["apache-beam", "gcsfs"]
Expand Down
Loading