Skip to content

Commit

Permalink
feat: use rich progress bars
Browse files Browse the repository at this point in the history
  • Loading branch information
andrzejnovak committed Mar 14, 2022
1 parent cd9a22f commit 359016f
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 66 deletions.
131 changes: 66 additions & 65 deletions coffea/processor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .accumulator import accumulate, set_accumulator, Accumulatable
from .dataframe import LazyDataFrame
from ..nanoevents import NanoEventsFactory, schemas
from ..util import _hash
from ..util import _hash, rich_bar

from collections.abc import Mapping, MutableMapping
from dataclasses import dataclass, field, asdict
Expand Down Expand Up @@ -610,19 +610,6 @@ def __call__(
if self.compression is not None:
function = _compression_wrapper(self.compression, function)

def merge_tqdm(chunks, accumulator=None, desc="Adding", **kwargs):
gen = (c for c in chunks)
return _compress(
accumulate(
tqdm(map(_decompress, gen),
disable=not self.status,
unit=self.unit,
total=len(chunks),
desc=desc,
**kwargs),
_decompress(accumulator),
), self.compression)

def processwith(pool, mergepool):
reducer = _reduce(self.compression)
FH = FuturesHolder(
Expand All @@ -635,42 +622,51 @@ def processwith(pool, mergepool):
nparts, minred, maxred = 5, 4, 100

try:
pbar = tqdm(disable=not self.status, unit=self.unit, total=len(items),
desc=self.desc, position=0, ascii=True)
if self.merging:
mbar = tqdm(disable=not self.status, total=1, desc="Merging",
position=1, ascii=True)
else:
merged = None

while len(FH.futures) + len(FH.merges) > 0:
FH.update()
reduce = min(maxred, max(len(FH.completed) // nparts + 1, minred))

pbar.update(FH.done["futures"] - pbar.n)
pbar.refresh()

if self.merging:
mbar.update(FH.done["merges"] - mbar.n)
mbar.refresh()
while len(FH.completed) > 1:
batch = [b for b in FH.fetch(reduce)]
if mergepool is None:
FH.merges.add(pool.submit(reducer, batch))
else:
FH.merges.add(mergepool.submit(reducer, batch))
mbar.total += 1
mbar.refresh()
else: # Merge within process
merged = merge_tqdm(FH.fetch(len(FH.completed)), merged, desc="Merging", leave=False)
progress = rich_bar()
prog_id = progress.add_task(self.desc, total=len(items), unit=self.unit)
_mdesc = "Merging" if self.merging else "Merging (main process, by batch)"
prog_id_merge = progress.add_task(_mdesc, total=0, unit='merges')
merged = None

with progress:
while len(FH.futures) + len(FH.merges) > 0:
FH.update()
reduce = min(maxred, max(len(FH.completed) // nparts + 1, minred))
progress.update(prog_id,
advance=FH.done["futures"] -
progress._tasks[prog_id].completed)

if self.merging:
progress.update(prog_id_merge,
advance=FH.done["merges"] -
progress._tasks[prog_id_merge].completed)
while len(FH.completed) > 1:
batch = [b for b in FH.fetch(reduce)]
if mergepool is None:
FH.merges.add(pool.submit(reducer, batch))
else:
FH.merges.add(mergepool.submit(reducer, batch))
progress.update(
prog_id_merge,
total=progress._tasks[prog_id_merge].total + 1)
else: # Merge within process
batch = FH.fetch(len(FH.completed))
merged = _compress(
accumulate(
progress.track(
map(_decompress, (c for c in batch)),
task_id=prog_id_merge,
total=progress._tasks[prog_id_merge].total +
len(batch)), _decompress(accumulator)),
self.compression)

progress.refresh()

# Add checkpointing

pbar.close()
if self.merging:
mbar.update(1) # last one
mbar.refresh()
mbar.close()
progress.update(prog_id_merge, advance=1)
progress.refresh()
merged = FH.completed.pop().result()
if len(FH.completed) > 0:
raise RuntimeError("Not all futures are added.")
Expand All @@ -682,25 +678,32 @@ def processwith(pool, mergepool):
for job in FH.futures:
job.cancel()

if self.merging:
with tqdm(disable=not self.status,
total=len(FH.merges),
desc="Recovering finished jobs",
position=1,
ascii=True) as mbar:
progress = rich_bar()
with progress:
if self.merging:
prog_id_wait = progress.add_task("Waiting for merge jobs", total=len(FH.merges), unit=self.unit)
while len(FH.merges) > 0:
FH.update()
mbar.update(FH.done["merges"] - mbar.n)
mbar.refresh()

def is_good(future):
return future.done(
) and not future.cancelled() and future.exception() is None

FH.update()
recovered = [future.result() for future in FH.completed if is_good(future)]

return merge_tqdm(recovered, accumulator, desc="Merging finished jobs"), e
progress.update(prog_id_wait,
completed=(progress._tasks[prog_id_wait].total - len(FH.merges)),
refresh=True)

def is_good(future):
return future.done(
) and not future.cancelled() and future.exception() is None

FH.update()
recovered = [future.result() for future in FH.completed if is_good(future)]
prog_id_merge = progress.add_task("Merging finished jobs", unit='merges')
merged = _compress(
accumulate(
progress.track(
map(_decompress, (c for c in recovered)),
task_id=prog_id_merge,
total=len(recovered))),
self.compression)

return accumulate([_decompress(merged), accumulator]), e
else:
raise type(e)(str(e)).with_traceback(sys.exc_info()[2]) from None

Expand Down Expand Up @@ -1520,10 +1523,8 @@ def __call__(

wrapped_out, e = executor(chunks, closure, None)
if e != 0:
# print(type(e[0])(str(e[0])).with_traceback(e[1]))
print(type(e)(str(e)))

print("X", wrapped_out)
if wrapped_out is None:
raise ValueError("No chunks were processed, veryify ``processor`` instance structure.")
else:
Expand Down
40 changes: 39 additions & 1 deletion coffea/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,44 @@ def _ensure_flat(array, allow_missing=False):
return array



from rich.progress import (Progress, BarColumn, TextColumn, TimeElapsedColumn,
TimeRemainingColumn, Column, ProgressColumn, Text)
from typing import Optional

class SpeedColumn(ProgressColumn):
"""Renders human readable transfer speed."""

def __init__(self, fmt: str = ".1f", table_column: Optional[Column] = None):
self.fmt = fmt
super().__init__(table_column=table_column)

def render(self, task: "Task") -> Text:
"""Show data transfer speed."""
speed = task.finished_speed or task.speed
if speed is None:
return Text("?", style="progress.data.speed")
return Text(f"{speed:{self.fmt}}", style="progress.data.speed")


def rich_bar():
return Progress(TextColumn("[bold blue]{task.description}", justify="right"),
"[progress.percentage]{task.percentage:>3.0f}%",
BarColumn(bar_width=None),
TextColumn(
"[bold blue][progress.completed]{task.completed}/{task.total}",
justify="right"),
"[",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
"|",
SpeedColumn(".1f"),
TextColumn("[progress.data.speed]{task.fields[unit]}/s",
justify='right'),
"]",
auto_refresh=False)

# lifted from awkward - https://github.com/scikit-hep/awkward-1.0/blob/5fe31a916bf30df6c2ea10d4094f6f1aefcf3d0c/src/awkward/_util.py#L47-L61 # noqa
# drive our deprecations-as-errors as with awkward
def deprecate(exception, version, date=None):
Expand All @@ -94,4 +132,4 @@ def deprecate(exception, version, date=None):
{2}: {3}""".format(
version, date, type(exception).__name__, str(exception)
)
warnings.warn(message, FutureWarning)
warnings.warn(message, FutureWarning)

0 comments on commit 359016f

Please sign in to comment.