Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor archive progress bar #4504

Merged
merged 16 commits into from
Oct 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ repos:
(?x)^(
aiida/engine/processes/calcjobs/calcjob.py|
aiida/tools/groups/paths.py|
aiida/tools/importexport/dbexport/__init__.py
aiida/tools/importexport/dbexport/__init__.py|
aiida/common/progress_reporter.py|
)$

- repo: local
Expand Down
21 changes: 17 additions & 4 deletions aiida/cmdline/commands/cmd_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
###########################################################################
# pylint: disable=too-many-arguments,import-error,too-many-locals
"""`verdi export` command."""

import os
import tempfile

Expand Down Expand Up @@ -67,6 +66,13 @@ def inspect(archive, version, data, meta_data):
@options.NODES()
@options.ARCHIVE_FORMAT()
@options.FORCE(help='overwrite output file if it already exists')
@click.option(
'-v',
'--verbosity',
default='INFO',
type=click.Choice(['DEBUG', 'INFO', 'WARNING', 'CRITICAL']),
help='Control the verbosity of console logging'
)
@options.graph_traversal_rules(GraphTraversalRules.EXPORT.value)
@click.option(
'--include-logs/--exclude-logs',
Expand All @@ -83,7 +89,7 @@ def inspect(archive, version, data, meta_data):
@decorators.with_dbenv()
def create(
output_file, codes, computers, groups, nodes, archive_format, force, input_calc_forward, input_work_forward,
create_backward, return_backward, call_calc_backward, call_work_backward, include_comments, include_logs
create_backward, return_backward, call_calc_backward, call_work_backward, include_comments, include_logs, verbosity
):
"""
Export subsets of the provenance graph to file for sharing.
Expand All @@ -94,7 +100,9 @@ def create(
their provenance, according to the rules outlined in the documentation.
You can modify some of those rules using options of this command.
"""
from aiida.tools.importexport import export, ExportFileFormat
from aiida.common.log import override_log_formatter_context
from aiida.common.progress_reporter import set_progress_bar_tqdm
from aiida.tools.importexport import export, ExportFileFormat, EXPORT_LOGGER
from aiida.tools.importexport.common.exceptions import ArchiveExportError

entities = []
Expand Down Expand Up @@ -132,8 +140,13 @@ def create(
elif archive_format == 'tar.gz':
export_format = ExportFileFormat.TAR_GZIPPED

if verbosity in ['DEBUG', 'INFO']:
set_progress_bar_tqdm(leave=(verbosity == 'DEBUG'))
EXPORT_LOGGER.setLevel(verbosity)

try:
export(entities, filename=output_file, file_format=export_format, **kwargs)
with override_log_formatter_context('%(message)s'):
export(entities, filename=output_file, file_format=export_format, **kwargs)
except ArchiveExportError as exception:
echo.echo_critical(f'failed to write the archive file. Exception: {exception}')
else:
Expand Down
6 changes: 5 additions & 1 deletion aiida/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,9 @@
from .extendeddicts import *
from .links import *
from .log import *
from .progress_reporter import *

__all__ = (datastructures.__all__ + exceptions.__all__ + extendeddicts.__all__ + links.__all__ + log.__all__)
__all__ = (
datastructures.__all__ + exceptions.__all__ + extendeddicts.__all__ + links.__all__ + log.__all__ +
progress_reporter.__all__
)
32 changes: 20 additions & 12 deletions aiida/common/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,25 @@ def override_log_level(level=logging.CRITICAL):
logging.disable(level=logging.NOTSET)


@contextmanager
def override_log_formatter_context(fmt: str):
"""Temporarily use a different formatter for all handlers.

NOTE: One can _only_ set `fmt` (not `datefmt` or `style`).
Be aware! This may fail if the number of handlers is changed within the decorated function/method.
"""
temp_formatter = logging.Formatter(fmt=fmt)
cached_formatters = [handler.formatter for handler in AIIDA_LOGGER.handlers]

for handler in AIIDA_LOGGER.handlers:
handler.setFormatter(temp_formatter)

yield

for index, handler in enumerate(AIIDA_LOGGER.handlers):
handler.setFormatter(cached_formatters[index])


def override_log_formatter(fmt: str):
"""Temporarily use a different formatter for all handlers.

Expand All @@ -221,18 +240,7 @@ def override_log_formatter(fmt: str):

@decorator
def wrapper(wrapped, instance, args, kwargs): # pylint: disable=unused-argument
temp_formatter = logging.Formatter(fmt=fmt)

cached_formatters = []
for handler in AIIDA_LOGGER.handlers:
cached_formatters.append(handler.formatter)

try:
for handler in AIIDA_LOGGER.handlers:
handler.setFormatter(temp_formatter)
with override_log_formatter_context(fmt=fmt):
return wrapped(*args, **kwargs)
finally:
for index, handler in enumerate(AIIDA_LOGGER.handlers):
handler.setFormatter(cached_formatters[index])

return wrapper
150 changes: 150 additions & 0 deletions aiida/common/progress_reporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
# pylint: disable=global-statement,unused-argument
"""Provide a singleton progress reporter implementation.
ltalirz marked this conversation as resolved.
Show resolved Hide resolved

The interface is inspired by `tqdm <https://github.com/tqdm/tqdm>`,
and indeed a valid implementation is::

from tqdm import tqdm
set_progress_reporter(tqdm, bar_format='{l_bar}{bar}{r_bar}')

"""
from functools import partial
from types import TracebackType
from typing import Any, Optional, Type

__all__ = (
'get_progress_reporter', 'set_progress_reporter', 'set_progress_bar_tqdm', 'ProgressReporterAbstract',
'TQDM_BAR_FORMAT'
)

TQDM_BAR_FORMAT = '{desc:40.40}{percentage:6.1f}%|{bar}| {n_fmt}/{total_fmt}'


class ProgressReporterAbstract:
"""An abstract class for incrementing a progress reporter.

This class provides the base interface for any `ProgressReporter` class.

Example Usage::

with ProgressReporter(total=10, desc="A process:") as progress:
for i in range(10):
progress.set_description_str(f"A process: {i}")
progress.update()

"""

def __init__(self, *, total: int, desc: Optional[str] = None, **kwargs: Any):
"""Initialise the progress reporting contextmanager.

:param total: The number of expected iterations.
:param desc: A description of the process

"""
self.total = total
self.desc = desc
self.increment = 0

def __enter__(self) -> 'ProgressReporterAbstract':
"""Enter the contextmanager."""
return self

def __exit__(
self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType]
):
"""Exit the contextmanager."""
return False

def set_description_str(self, text: Optional[str] = None, refresh: bool = True):
"""Set the text shown by the progress reporter.

:param text: The text to show
:param refresh: Force refresh of the progress reporter

"""
self.desc = text

def update(self, n: int = 1): # pylint: disable=invalid-name
"""Update the progress counter.

:param n: Increment to add to the internal counter of iterations

"""
self.increment += n


class ProgressReporterNull(ProgressReporterAbstract):
"""A null implementation of the progress reporter.

This implementation does not output anything.
"""


PROGRESS_REPORTER: Type[ProgressReporterAbstract] = ProgressReporterNull # pylint: disable=invalid-name


def get_progress_reporter() -> Type[ProgressReporterAbstract]:
"""Return the progress reporter

Example Usage::

with get_progress_reporter()(total=10, desc="A process:") as progress:
for i in range(10):
progress.set_description_str(f"A process: {i}")
progress.update()

"""
global PROGRESS_REPORTER
return PROGRESS_REPORTER # type: ignore


def set_progress_reporter(reporter: Optional[Type[ProgressReporterAbstract]] = None, **kwargs: Any):
"""Set the progress reporter implementation

:param reporter: A progress reporter for a process. If None, reset to ``ProgressReporterNull``.

:param kwargs: If present, set a partial function with these kwargs

The reporter should be a context manager that implements the
:func:`~aiida.common.progress_reporter.ProgressReporterAbstract` interface.

Example Usage::

set_progress_reporter(ProgressReporterNull)
with get_progress_reporter()(total=10, desc="A process:") as progress:
for i in range(10):
progress.set_description_str(f"A process: {i}")
progress.update()

"""
global PROGRESS_REPORTER
if reporter is None:
PROGRESS_REPORTER = ProgressReporterNull # type: ignore
elif kwargs:
PROGRESS_REPORTER = partial(reporter, **kwargs) # type: ignore
else:
PROGRESS_REPORTER = reporter # type: ignore


def set_progress_bar_tqdm(bar_format: Optional[str] = TQDM_BAR_FORMAT, leave: Optional[bool] = False, **kwargs: Any):
"""Set a `tqdm <https://github.com/tqdm/tqdm>`__ implementation of the progress reporter interface.

See :func:`~aiida.common.progress_reporter.set_progress_reporter` for details.

:param bar_format: Specify a custom bar string format.
:param leave: If True, keeps all traces of the progressbar upon termination of iteration.
If `None`, will leave only if `position` is `0`.
:param kwargs: pass to the tqdm init

"""
from tqdm import tqdm
set_progress_reporter(tqdm, bar_format=bar_format, leave=leave, **kwargs)
3 changes: 0 additions & 3 deletions aiida/tools/importexport/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
# The name of the subfolder in which the node files are stored
NODES_EXPORT_SUBFOLDER = 'nodes'

# Progress bar
BAR_FORMAT = '{desc:40.40}{percentage:6.1f}%|{bar}| {n_fmt}/{total_fmt}'

# Giving names to the various entities. Attributes and links are not AiiDA
# entities but we will refer to them as entities in the file (to simplify
# references to them).
Expand Down
2 changes: 1 addition & 1 deletion aiida/tools/importexport/common/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from aiida.common.lang import type_check

from aiida.tools.importexport.common.config import BAR_FORMAT
from aiida.common.progress_reporter import TQDM_BAR_FORMAT as BAR_FORMAT
ltalirz marked this conversation as resolved.
Show resolved Hide resolved
from aiida.tools.importexport.common.exceptions import ProgressBarError

__all__ = ('get_progress_bar', 'close_progress_bar')
Expand Down
Loading