Skip to content

Commit

Permalink
Refactor archive progress bar (#4504)
Browse files Browse the repository at this point in the history
This commit introduces a new generic progress reporter interface (in `aiida/common/progress_reporter.py`),
that can be used for adding progress reporting to any process.
It is intended to deprecate the existing `aiida/tools/importexport/common/progress_bar.py` module.

The reporter is designed to work similar to logging,
such that its "handler" is set external to the actual function, e.g. by the CLI.
Its default implementation is to do nothing (a null reporter),
and there is convenience function to set a [tqdm](https://tqdm.github.io/) progress bar implementation (`set_progress_bar_tqdm`).

The reporter is intended to always be used as context manager,
e.g. to allow the progress bar to be removed once the process is complete.

The reporter has been implemented in the archive export module,
and it is intended that it will also be implemented in the archive import module.
At this point the existing `aiida/tools/importexport/common/progress_bar.py` module can be removed.
  • Loading branch information
chrisjsewell authored Oct 25, 2020
1 parent 1be12e1 commit 33c9f41
Show file tree
Hide file tree
Showing 11 changed files with 525 additions and 360 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ repos:
(?x)^(
aiida/engine/processes/calcjobs/calcjob.py|
aiida/tools/groups/paths.py|
aiida/tools/importexport/dbexport/__init__.py
aiida/tools/importexport/dbexport/__init__.py|
aiida/common/progress_reporter.py|
)$
- repo: local
Expand Down
21 changes: 17 additions & 4 deletions aiida/cmdline/commands/cmd_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
###########################################################################
# pylint: disable=too-many-arguments,import-error,too-many-locals
"""`verdi export` command."""

import os
import tempfile

Expand Down Expand Up @@ -67,6 +66,13 @@ def inspect(archive, version, data, meta_data):
@options.NODES()
@options.ARCHIVE_FORMAT()
@options.FORCE(help='overwrite output file if it already exists')
@click.option(
'-v',
'--verbosity',
default='INFO',
type=click.Choice(['DEBUG', 'INFO', 'WARNING', 'CRITICAL']),
help='Control the verbosity of console logging'
)
@options.graph_traversal_rules(GraphTraversalRules.EXPORT.value)
@click.option(
'--include-logs/--exclude-logs',
Expand All @@ -83,7 +89,7 @@ def inspect(archive, version, data, meta_data):
@decorators.with_dbenv()
def create(
output_file, codes, computers, groups, nodes, archive_format, force, input_calc_forward, input_work_forward,
create_backward, return_backward, call_calc_backward, call_work_backward, include_comments, include_logs
create_backward, return_backward, call_calc_backward, call_work_backward, include_comments, include_logs, verbosity
):
"""
Export subsets of the provenance graph to file for sharing.
Expand All @@ -94,7 +100,9 @@ def create(
their provenance, according to the rules outlined in the documentation.
You can modify some of those rules using options of this command.
"""
from aiida.tools.importexport import export, ExportFileFormat
from aiida.common.log import override_log_formatter_context
from aiida.common.progress_reporter import set_progress_bar_tqdm
from aiida.tools.importexport import export, ExportFileFormat, EXPORT_LOGGER
from aiida.tools.importexport.common.exceptions import ArchiveExportError

entities = []
Expand Down Expand Up @@ -132,8 +140,13 @@ def create(
elif archive_format == 'tar.gz':
export_format = ExportFileFormat.TAR_GZIPPED

if verbosity in ['DEBUG', 'INFO']:
set_progress_bar_tqdm(leave=(verbosity == 'DEBUG'))
EXPORT_LOGGER.setLevel(verbosity)

try:
export(entities, filename=output_file, file_format=export_format, **kwargs)
with override_log_formatter_context('%(message)s'):
export(entities, filename=output_file, file_format=export_format, **kwargs)
except ArchiveExportError as exception:
echo.echo_critical(f'failed to write the archive file. Exception: {exception}')
else:
Expand Down
6 changes: 5 additions & 1 deletion aiida/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,9 @@
from .extendeddicts import *
from .links import *
from .log import *
from .progress_reporter import *

__all__ = (datastructures.__all__ + exceptions.__all__ + extendeddicts.__all__ + links.__all__ + log.__all__)
__all__ = (
datastructures.__all__ + exceptions.__all__ + extendeddicts.__all__ + links.__all__ + log.__all__ +
progress_reporter.__all__
)
32 changes: 20 additions & 12 deletions aiida/common/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,25 @@ def override_log_level(level=logging.CRITICAL):
logging.disable(level=logging.NOTSET)


@contextmanager
def override_log_formatter_context(fmt: str):
"""Temporarily use a different formatter for all handlers.
NOTE: One can _only_ set `fmt` (not `datefmt` or `style`).
Be aware! This may fail if the number of handlers is changed within the decorated function/method.
"""
temp_formatter = logging.Formatter(fmt=fmt)
cached_formatters = [handler.formatter for handler in AIIDA_LOGGER.handlers]

for handler in AIIDA_LOGGER.handlers:
handler.setFormatter(temp_formatter)

yield

for index, handler in enumerate(AIIDA_LOGGER.handlers):
handler.setFormatter(cached_formatters[index])


def override_log_formatter(fmt: str):
"""Temporarily use a different formatter for all handlers.
Expand All @@ -221,18 +240,7 @@ def override_log_formatter(fmt: str):

@decorator
def wrapper(wrapped, instance, args, kwargs): # pylint: disable=unused-argument
temp_formatter = logging.Formatter(fmt=fmt)

cached_formatters = []
for handler in AIIDA_LOGGER.handlers:
cached_formatters.append(handler.formatter)

try:
for handler in AIIDA_LOGGER.handlers:
handler.setFormatter(temp_formatter)
with override_log_formatter_context(fmt=fmt):
return wrapped(*args, **kwargs)
finally:
for index, handler in enumerate(AIIDA_LOGGER.handlers):
handler.setFormatter(cached_formatters[index])

return wrapper
150 changes: 150 additions & 0 deletions aiida/common/progress_reporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
# pylint: disable=global-statement,unused-argument
"""Provide a singleton progress reporter implementation.
The interface is inspired by `tqdm <https://github.com/tqdm/tqdm>`,
and indeed a valid implementation is::
from tqdm import tqdm
set_progress_reporter(tqdm, bar_format='{l_bar}{bar}{r_bar}')
"""
from functools import partial
from types import TracebackType
from typing import Any, Optional, Type

__all__ = (
'get_progress_reporter', 'set_progress_reporter', 'set_progress_bar_tqdm', 'ProgressReporterAbstract',
'TQDM_BAR_FORMAT'
)

TQDM_BAR_FORMAT = '{desc:40.40}{percentage:6.1f}%|{bar}| {n_fmt}/{total_fmt}'


class ProgressReporterAbstract:
"""An abstract class for incrementing a progress reporter.
This class provides the base interface for any `ProgressReporter` class.
Example Usage::
with ProgressReporter(total=10, desc="A process:") as progress:
for i in range(10):
progress.set_description_str(f"A process: {i}")
progress.update()
"""

def __init__(self, *, total: int, desc: Optional[str] = None, **kwargs: Any):
"""Initialise the progress reporting contextmanager.
:param total: The number of expected iterations.
:param desc: A description of the process
"""
self.total = total
self.desc = desc
self.increment = 0

def __enter__(self) -> 'ProgressReporterAbstract':
"""Enter the contextmanager."""
return self

def __exit__(
self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType]
):
"""Exit the contextmanager."""
return False

def set_description_str(self, text: Optional[str] = None, refresh: bool = True):
"""Set the text shown by the progress reporter.
:param text: The text to show
:param refresh: Force refresh of the progress reporter
"""
self.desc = text

def update(self, n: int = 1): # pylint: disable=invalid-name
"""Update the progress counter.
:param n: Increment to add to the internal counter of iterations
"""
self.increment += n


class ProgressReporterNull(ProgressReporterAbstract):
"""A null implementation of the progress reporter.
This implementation does not output anything.
"""


PROGRESS_REPORTER: Type[ProgressReporterAbstract] = ProgressReporterNull # pylint: disable=invalid-name


def get_progress_reporter() -> Type[ProgressReporterAbstract]:
"""Return the progress reporter
Example Usage::
with get_progress_reporter()(total=10, desc="A process:") as progress:
for i in range(10):
progress.set_description_str(f"A process: {i}")
progress.update()
"""
global PROGRESS_REPORTER
return PROGRESS_REPORTER # type: ignore


def set_progress_reporter(reporter: Optional[Type[ProgressReporterAbstract]] = None, **kwargs: Any):
"""Set the progress reporter implementation
:param reporter: A progress reporter for a process. If None, reset to ``ProgressReporterNull``.
:param kwargs: If present, set a partial function with these kwargs
The reporter should be a context manager that implements the
:func:`~aiida.common.progress_reporter.ProgressReporterAbstract` interface.
Example Usage::
set_progress_reporter(ProgressReporterNull)
with get_progress_reporter()(total=10, desc="A process:") as progress:
for i in range(10):
progress.set_description_str(f"A process: {i}")
progress.update()
"""
global PROGRESS_REPORTER
if reporter is None:
PROGRESS_REPORTER = ProgressReporterNull # type: ignore
elif kwargs:
PROGRESS_REPORTER = partial(reporter, **kwargs) # type: ignore
else:
PROGRESS_REPORTER = reporter # type: ignore


def set_progress_bar_tqdm(bar_format: Optional[str] = TQDM_BAR_FORMAT, leave: Optional[bool] = False, **kwargs: Any):
"""Set a `tqdm <https://github.com/tqdm/tqdm>`__ implementation of the progress reporter interface.
See :func:`~aiida.common.progress_reporter.set_progress_reporter` for details.
:param bar_format: Specify a custom bar string format.
:param leave: If True, keeps all traces of the progressbar upon termination of iteration.
If `None`, will leave only if `position` is `0`.
:param kwargs: pass to the tqdm init
"""
from tqdm import tqdm
set_progress_reporter(tqdm, bar_format=bar_format, leave=leave, **kwargs)
3 changes: 0 additions & 3 deletions aiida/tools/importexport/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
# The name of the subfolder in which the node files are stored
NODES_EXPORT_SUBFOLDER = 'nodes'

# Progress bar
BAR_FORMAT = '{desc:40.40}{percentage:6.1f}%|{bar}| {n_fmt}/{total_fmt}'

# Giving names to the various entities. Attributes and links are not AiiDA
# entities but we will refer to them as entities in the file (to simplify
# references to them).
Expand Down
2 changes: 1 addition & 1 deletion aiida/tools/importexport/common/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from aiida.common.lang import type_check

from aiida.tools.importexport.common.config import BAR_FORMAT
from aiida.common.progress_reporter import TQDM_BAR_FORMAT as BAR_FORMAT
from aiida.tools.importexport.common.exceptions import ProgressBarError

__all__ = ('get_progress_bar', 'close_progress_bar')
Expand Down
Loading

0 comments on commit 33c9f41

Please sign in to comment.