From 359016f883b88fda1ac381bfdc8535310a4f9d94 Mon Sep 17 00:00:00 2001 From: Andrzej Date: Mon, 14 Mar 2022 16:27:45 +0100 Subject: [PATCH] feat: use rich progress bars --- coffea/processor/executor.py | 131 ++++++++++++++++++----------------- coffea/util.py | 40 ++++++++++- 2 files changed, 105 insertions(+), 66 deletions(-) diff --git a/coffea/processor/executor.py b/coffea/processor/executor.py index 75d75bc111..138c96e72b 100644 --- a/coffea/processor/executor.py +++ b/coffea/processor/executor.py @@ -21,7 +21,7 @@ from .accumulator import accumulate, set_accumulator, Accumulatable from .dataframe import LazyDataFrame from ..nanoevents import NanoEventsFactory, schemas -from ..util import _hash +from ..util import _hash, rich_bar from collections.abc import Mapping, MutableMapping from dataclasses import dataclass, field, asdict @@ -610,19 +610,6 @@ def __call__( if self.compression is not None: function = _compression_wrapper(self.compression, function) - def merge_tqdm(chunks, accumulator=None, desc="Adding", **kwargs): - gen = (c for c in chunks) - return _compress( - accumulate( - tqdm(map(_decompress, gen), - disable=not self.status, - unit=self.unit, - total=len(chunks), - desc=desc, - **kwargs), - _decompress(accumulator), - ), self.compression) - def processwith(pool, mergepool): reducer = _reduce(self.compression) FH = FuturesHolder( @@ -635,42 +622,51 @@ def processwith(pool, mergepool): nparts, minred, maxred = 5, 4, 100 try: - pbar = tqdm(disable=not self.status, unit=self.unit, total=len(items), - desc=self.desc, position=0, ascii=True) - if self.merging: - mbar = tqdm(disable=not self.status, total=1, desc="Merging", - position=1, ascii=True) - else: - merged = None - - while len(FH.futures) + len(FH.merges) > 0: - FH.update() - reduce = min(maxred, max(len(FH.completed) // nparts + 1, minred)) - - pbar.update(FH.done["futures"] - pbar.n) - pbar.refresh() - - if self.merging: - mbar.update(FH.done["merges"] - mbar.n) - mbar.refresh() - while len(FH.completed) > 1: - batch = [b for b in FH.fetch(reduce)] - if mergepool is None: - FH.merges.add(pool.submit(reducer, batch)) - else: - FH.merges.add(mergepool.submit(reducer, batch)) - mbar.total += 1 - mbar.refresh() - else: # Merge within process - merged = merge_tqdm(FH.fetch(len(FH.completed)), merged, desc="Merging", leave=False) + progress = rich_bar() + prog_id = progress.add_task(self.desc, total=len(items), unit=self.unit) + _mdesc = "Merging" if self.merging else "Merging (main process, by batch)" + prog_id_merge = progress.add_task(_mdesc, total=0, unit='merges') + merged = None + + with progress: + while len(FH.futures) + len(FH.merges) > 0: + FH.update() + reduce = min(maxred, max(len(FH.completed) // nparts + 1, minred)) + progress.update(prog_id, + advance=FH.done["futures"] - + progress._tasks[prog_id].completed) + + if self.merging: + progress.update(prog_id_merge, + advance=FH.done["merges"] - + progress._tasks[prog_id_merge].completed) + while len(FH.completed) > 1: + batch = [b for b in FH.fetch(reduce)] + if mergepool is None: + FH.merges.add(pool.submit(reducer, batch)) + else: + FH.merges.add(mergepool.submit(reducer, batch)) + progress.update( + prog_id_merge, + total=progress._tasks[prog_id_merge].total + 1) + else: # Merge within process + batch = FH.fetch(len(FH.completed)) + merged = _compress( + accumulate( + progress.track( + map(_decompress, (c for c in batch)), + task_id=prog_id_merge, + total=progress._tasks[prog_id_merge].total + + len(batch)), _decompress(accumulator)), + self.compression) + + progress.refresh() # Add checkpointing - pbar.close() if self.merging: - mbar.update(1) # last one - mbar.refresh() - mbar.close() + progress.update(prog_id_merge, advance=1) + progress.refresh() merged = FH.completed.pop().result() if len(FH.completed) > 0: raise RuntimeError("Not all futures are added.") @@ -682,25 +678,32 @@ def processwith(pool, mergepool): for job in FH.futures: job.cancel() - if self.merging: - with tqdm(disable=not self.status, - total=len(FH.merges), - desc="Recovering finished jobs", - position=1, - ascii=True) as mbar: + progress = rich_bar() + with progress: + if self.merging: + prog_id_wait = progress.add_task("Waiting for merge jobs", total=len(FH.merges), unit=self.unit) while len(FH.merges) > 0: FH.update() - mbar.update(FH.done["merges"] - mbar.n) - mbar.refresh() - - def is_good(future): - return future.done( - ) and not future.cancelled() and future.exception() is None - - FH.update() - recovered = [future.result() for future in FH.completed if is_good(future)] - - return merge_tqdm(recovered, accumulator, desc="Merging finished jobs"), e + progress.update(prog_id_wait, + completed=(progress._tasks[prog_id_wait].total - len(FH.merges)), + refresh=True) + + def is_good(future): + return future.done( + ) and not future.cancelled() and future.exception() is None + + FH.update() + recovered = [future.result() for future in FH.completed if is_good(future)] + prog_id_merge = progress.add_task("Merging finished jobs", unit='merges') + merged = _compress( + accumulate( + progress.track( + map(_decompress, (c for c in recovered)), + task_id=prog_id_merge, + total=len(recovered))), + self.compression) + + return accumulate([_decompress(merged), accumulator]), e else: raise type(e)(str(e)).with_traceback(sys.exc_info()[2]) from None @@ -1520,10 +1523,8 @@ def __call__( wrapped_out, e = executor(chunks, closure, None) if e != 0: - # print(type(e[0])(str(e[0])).with_traceback(e[1])) print(type(e)(str(e))) - print("X", wrapped_out) if wrapped_out is None: raise ValueError("No chunks were processed, veryify ``processor`` instance structure.") else: diff --git a/coffea/util.py b/coffea/util.py index d991f87af6..9a6f5a7335 100644 --- a/coffea/util.py +++ b/coffea/util.py @@ -79,6 +79,44 @@ def _ensure_flat(array, allow_missing=False): return array + +from rich.progress import (Progress, BarColumn, TextColumn, TimeElapsedColumn, + TimeRemainingColumn, Column, ProgressColumn, Text) +from typing import Optional + +class SpeedColumn(ProgressColumn): + """Renders human readable transfer speed.""" + + def __init__(self, fmt: str = ".1f", table_column: Optional[Column] = None): + self.fmt = fmt + super().__init__(table_column=table_column) + + def render(self, task: "Task") -> Text: + """Show data transfer speed.""" + speed = task.finished_speed or task.speed + if speed is None: + return Text("?", style="progress.data.speed") + return Text(f"{speed:{self.fmt}}", style="progress.data.speed") + + +def rich_bar(): + return Progress(TextColumn("[bold blue]{task.description}", justify="right"), + "[progress.percentage]{task.percentage:>3.0f}%", + BarColumn(bar_width=None), + TextColumn( + "[bold blue][progress.completed]{task.completed}/{task.total}", + justify="right"), + "[", + TimeElapsedColumn(), + "<", + TimeRemainingColumn(), + "|", + SpeedColumn(".1f"), + TextColumn("[progress.data.speed]{task.fields[unit]}/s", + justify='right'), + "]", + auto_refresh=False) + # lifted from awkward - https://github.com/scikit-hep/awkward-1.0/blob/5fe31a916bf30df6c2ea10d4094f6f1aefcf3d0c/src/awkward/_util.py#L47-L61 # noqa # drive our deprecations-as-errors as with awkward def deprecate(exception, version, date=None): @@ -94,4 +132,4 @@ def deprecate(exception, version, date=None): {2}: {3}""".format( version, date, type(exception).__name__, str(exception) ) - warnings.warn(message, FutureWarning) + warnings.warn(message, FutureWarning) \ No newline at end of file