From a54c319a913fbe80808255873b29e25f28fcde35 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Sat, 24 Oct 2020 02:19:36 +0200 Subject: [PATCH] re-re-factor progress bar! --- .pre-commit-config.yaml | 3 +- aiida/cmdline/commands/cmd_export.py | 9 +- aiida/common/__init__.py | 6 +- aiida/common/progress_reporter.py | 103 +++++++++++++ aiida/tools/importexport/dbexport/__init__.py | 137 ++++++------------ 5 files changed, 164 insertions(+), 94 deletions(-) create mode 100644 aiida/common/progress_reporter.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 04ef5e1843..acc268f778 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,7 +38,8 @@ repos: (?x)^( aiida/engine/processes/calcjobs/calcjob.py| aiida/tools/groups/paths.py| - aiida/tools/importexport/dbexport/__init__.py + aiida/tools/importexport/dbexport/__init__.py| + aiida/tools/importexport/common/progress_reporter.py| )$ - repo: local diff --git a/aiida/cmdline/commands/cmd_export.py b/aiida/cmdline/commands/cmd_export.py index f7f1d266a3..42fb0567d1 100644 --- a/aiida/cmdline/commands/cmd_export.py +++ b/aiida/cmdline/commands/cmd_export.py @@ -9,7 +9,7 @@ ########################################################################### # pylint: disable=too-many-arguments,import-error,too-many-locals """`verdi export` command.""" - +from functools import partial import os import tempfile @@ -95,8 +95,11 @@ def create( their provenance, according to the rules outlined in the documentation. You can modify some of those rules using options of this command. """ + from tqdm import tqdm + from aiida.common.progress_reporter import set_progress_reporter from aiida.tools.importexport import export, ExportFileFormat from aiida.tools.importexport.common.exceptions import ArchiveExportError + from aiida.tools.importexport.common.config import BAR_FORMAT entities = [] @@ -133,8 +136,10 @@ def create( elif archive_format == 'tar.gz': export_format = ExportFileFormat.TAR_GZIPPED + set_progress_reporter(partial(tqdm, bar_format=BAR_FORMAT, leave=verbose)) + try: - export(entities, filename=output_file, file_format=export_format, verbose=verbose, **kwargs) + export(entities, filename=output_file, file_format=export_format, **kwargs) except ArchiveExportError as exception: echo.echo_critical(f'failed to write the archive file. Exception: {exception}') else: diff --git a/aiida/common/__init__.py b/aiida/common/__init__.py index 938113d34b..ea59db2024 100644 --- a/aiida/common/__init__.py +++ b/aiida/common/__init__.py @@ -20,5 +20,9 @@ from .extendeddicts import * from .links import * from .log import * +from .progress_reporter import * -__all__ = (datastructures.__all__ + exceptions.__all__ + extendeddicts.__all__ + links.__all__ + log.__all__) +__all__ = ( + datastructures.__all__ + exceptions.__all__ + extendeddicts.__all__ + links.__all__ + log.__all__ + + progress_reporter.__all__ +) diff --git a/aiida/common/progress_reporter.py b/aiida/common/progress_reporter.py new file mode 100644 index 0000000000..657c3b5957 --- /dev/null +++ b/aiida/common/progress_reporter.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=global-statement,unused-argument +"""Provide a singleton progress reporter implementation. + +The interface is inspired by `tqdm `, +and indeed a valid implementation is:: + + from tqdm import tqdm + set_progress_reporter(tqdm) + +""" +from contextlib import contextmanager +from typing import Any, Callable, ContextManager, Iterator, Optional + +__all__ = ('get_progress_reporter', 'set_progress_reporter', 'progress_reporter_base', 'ProgressIncrementerBase') + + +class ProgressIncrementerBase: + """A base class for incrementing a progress reporter.""" + + def set_description_str(self, text: Optional[str] = None, refresh: bool = True): + """Set the text shown by the progress reporter. + + :param text: The text to show + :param refresh: Force refresh of the progress reporter + + """ + + def update(self, n: int = 1): # pylint: disable=invalid-name + """Update the progress counter. + + :param n: Increment to add to the internal counter of iterations + + """ + + +@contextmanager +def progress_reporter_base(*, + total: int, + desc: Optional[str] = None, + **kwargs: Any) -> Iterator[ProgressIncrementerBase]: + """A context manager for providing a progress reporter for a process. + + Example Usage:: + + with progress_reporter(total=10, desc="A process:") as progress: + for i in range(10): + progress.set_description_str(f"A process: {i}") + progress.update() + + :param total: The number of expected iterations. + :param desc: A description of the process + :yield: A class for incrementing the progress reporter + + """ + yield ProgressIncrementerBase() + + +PROGRESS_REPORTER = progress_reporter_base + + +def get_progress_reporter() -> Callable[..., ContextManager[Any]]: + """Return the progress reporter + + Example Usage:: + + with get_progress_reporter()(total=10, desc="A process:") as progress: + for i in range(10): + progress.set_description_str(f"A process: {i}") + progress.update() + + """ + global PROGRESS_REPORTER + return PROGRESS_REPORTER # type: ignore + + +def set_progress_reporter(reporter: Optional[Callable[..., ContextManager[Any]]] = None): + """Set the progress reporter implementation + + :param reporter: A context manager for providing a progress reporter for a process. + If None, reset to default null reporter + + The reporter should be a context manager that implements the + :func:`~aiida.common.progress_reporter.progress_reporter_base` interface. + + Example Usage:: + + with get_progress_reporter()(total=10, desc="A process:") as progress: + for i in range(10): + progress.set_description_str(f"A process: {i}") + progress.update() + + """ + global PROGRESS_REPORTER + PROGRESS_REPORTER = reporter or progress_reporter_base # type: ignore diff --git a/aiida/tools/importexport/dbexport/__init__.py b/aiida/tools/importexport/dbexport/__init__.py index d862b621a5..044507ccf7 100644 --- a/aiida/tools/importexport/dbexport/__init__.py +++ b/aiida/tools/importexport/dbexport/__init__.py @@ -12,7 +12,6 @@ from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass, field, asdict -from functools import partial import logging import os import tarfile @@ -34,8 +33,6 @@ ) import warnings -from tqdm import tqdm - from aiida import get_version, orm from aiida.common import json from aiida.common.exceptions import LicensingException @@ -43,13 +40,13 @@ from aiida.common.links import GraphTraversalRules from aiida.common.lang import type_check from aiida.common.log import LOG_LEVEL_REPORT, override_log_formatter +from aiida.common.progress_reporter import get_progress_reporter from aiida.common.warnings import AiidaDeprecationWarning from aiida.orm.utils._repository import Repository from aiida.tools.importexport.common import ( exceptions, ) from aiida.tools.importexport.common.config import ( - BAR_FORMAT, COMMENT_ENTITY_NAME, COMPUTER_ENTITY_NAME, EXPORT_VERSION, @@ -134,35 +131,24 @@ def total_time(self) -> float: return (self.time_write_stop or self.time_collect_stop) - self.time_collect_start -# should make more generic -ProgressContext = Callable[..., ContextManager[tqdm]] - - class ArchiveWriterAbstract(ABC): """An abstract interface for AiiDA archive writers.""" - def __init__(self, filename: str, progress_context: Optional[ProgressContext] = None, **kwargs: Any): + def __init__(self, filename: str, **kwargs: Any): """An archive writer :param filename: the filename (possibly including the absolute path) of the file on which to export. - :param progress_context: A context manager for creating and updating a progress bar """ # pylint: disable=unused-argument self._filename = filename - self._progress_context = progress_context or partial(tqdm, bar_format=BAR_FORMAT) @property def filename(self) -> str: """Return the filename to write to.""" return self._filename - @property - def progress_context(self) -> ProgressContext: - """Return the progress bar context manager.""" - return self._progress_context - @property @abstractmethod def file_format_verbose(self) -> str: @@ -186,9 +172,7 @@ def write( """ -def _write_to_json_archive( - folder: Union[Folder, ZipFolder], export_data: ArchiveData, export_version: str, progress_context: ProgressContext -) -> None: +def _write_to_json_archive(folder: Union[Folder, ZipFolder], export_data: ArchiveData, export_version: str) -> None: """Write data to the archive.""" # subfolder inside the export package nodesubfolder = folder.get_subfolder(NODES_EXPORT_SUBFOLDER, create=True, reset_limit=True) @@ -235,13 +219,13 @@ def _write_to_json_archive( return pbar_base_str = 'Exporting repository - ' - with progress_context(total=len(export_data.node_uuids)) as progress_bar: + with get_progress_reporter()(total=len(export_data.node_uuids)) as progress: for uuid in export_data.node_uuids: sharded_uuid = export_shard_uuid(uuid) - progress_bar.set_description_str(f"{pbar_base_str}UUID={uuid.split('-')[0]}", refresh=False) - progress_bar.update() + progress.set_description_str(f"{pbar_base_str}UUID={uuid.split('-')[0]}", refresh=False) + progress.update() # Important to set create=False, otherwise creates twice a subfolder. # Maybe this is a bug of insert_path? @@ -264,17 +248,16 @@ class WriterJsonZip(ArchiveWriterAbstract): which writes database data as a single JSON and repository data in a zipped folder system. """ - def __init__(self, filename: str, progress_context: Optional[ProgressContext] = None, **kwargs: Any): + def __init__(self, filename: str, use_compression: bool = True, **kwargs: Any): """A writer for zipped archives. :param filename: the filename (possibly including the absolute path) of the file on which to export. - :param progress_context: A context manager for creating and updating a progress bar :param use_compression: Whether or not to compress the zip file. """ - super().__init__(filename, progress_context, **kwargs) - self._use_compression = kwargs.get('use_compression', True) + super().__init__(filename, **kwargs) + self._use_compression = use_compression @property def file_format_verbose(self) -> str: @@ -296,7 +279,6 @@ def write(self, export_data: ArchiveData) -> dict: folder=folder, export_data=export_data, export_version=self.export_version, - progress_context=self.progress_context ) return {} @@ -309,16 +291,15 @@ class WriterJsonTar(ArchiveWriterAbstract): The entire containing folder is then compressed as a tar file. """ - def __init__(self, filename: str, progress_context: Optional[ProgressContext] = None, **kwargs: Any): + def __init__(self, filename: str, **kwargs: Any): """A writer for zipped archives. :param filename: the filename (possibly including the absolute path) of the file on which to export. - :param progress_context: A context manager for creating and updating a progress bar :param sandbox_in_repo: Create the temporary uncompressed folder within the aiida repository """ - super().__init__(filename, progress_context, **kwargs) + super().__init__(filename, **kwargs) self.sandbox_in_repo = kwargs.get('sandbox_in_repo', True) @property @@ -341,7 +322,6 @@ def write(self, export_data: ArchiveData) -> dict: folder=folder, export_data=export_data, export_version=self.export_version, - progress_context=self.progress_context ) with tarfile.open(self.filename, 'w:gz', format=tarfile.PAX_FORMAT, dereference=True) as tar: @@ -359,21 +339,17 @@ class WriterJsonFolder(ArchiveWriterAbstract): This writer is mainly intended for backward compatibility with `export_tree`. """ - def __init__(self, filename: str, progress_context: Optional[ProgressContext] = None, **kwargs: Any): + def __init__(self, filename: str, folder: Union[Folder, ZipFolder] = None, **kwargs: Any): """A writer for zipped archives. :param filename: the filename (possibly including the absolute path) of the file on which to export. - :param progress_context: A context manager for creating and updating a progress bar :param folder: a folder to write the archive to. """ - super().__init__(filename, progress_context, **kwargs) - type_check( - kwargs.get('folder', None), (Folder, ZipFolder), - msg='`folder` must be specified and given as an AiiDA Folder entity' - ) - self._folder = kwargs.get('folder') + super().__init__(filename, **kwargs) + type_check(folder, (Folder, ZipFolder), msg='`folder` must be specified and given as an AiiDA Folder entity') + self._folder = cast(Union[Folder, ZipFolder], folder) @property def file_format_verbose(self) -> str: @@ -394,7 +370,6 @@ def write(self, export_data: ArchiveData) -> dict: folder=self._folder, export_data=export_data, export_version=self.export_version, - progress_context=self.progress_context ) return {} @@ -414,7 +389,7 @@ def get_writer(file_format: str) -> Type[ArchiveWriterAbstract]: f'Can only export in the formats: {tuple(writers.keys())}, please specify one for "file_format".' ) - return writers[file_format] + return cast(Type[ArchiveWriterAbstract], writers[file_format]) def export( @@ -429,7 +404,6 @@ def export( allowed_licenses: Optional[Union[list, Callable]] = None, forbidden_licenses: Optional[Union[list, Callable]] = None, writer_init: Optional[Dict[str, Any]] = None, - verbose: bool = False, **traversal_rules: bool, ) -> ExportReport: """Export AiiDA data to an archive file. @@ -530,10 +504,7 @@ def export( name: traversal_rules.get(name, rule.default) for name, rule in GraphTraversalRules.EXPORT.value.items() } - progress_context = partial(tqdm, bar_format=BAR_FORMAT, leave=verbose) - writer = get_writer(file_format)( - filename=filename, progress_context=progress_context, use_compression=use_compression, **(writer_init or {}) - ) + writer = get_writer(file_format)(filename=filename, use_compression=use_compression, **(writer_init or {})) if silent: logging.disable(logging.CRITICAL) @@ -553,7 +524,6 @@ def export( report_data['time_collect_start'] = time.time() export_data = _collect_archive_data( entities=entities, - progress_context=progress_context, allowed_licenses=allowed_licenses, forbidden_licenses=forbidden_licenses, include_comments=include_comments, @@ -591,7 +561,6 @@ def export( @override_log_formatter('%(message)s') def _collect_archive_data( - progress_context: ProgressContext, entities: Optional[Iterable[Any]] = None, allowed_licenses: Optional[Union[list, Callable]] = None, forbidden_licenses: Optional[Union[list, Callable]] = None, @@ -658,14 +627,14 @@ def _collect_archive_data( all_fields_info, unique_identifiers = get_all_fields_info() - entities_starting_set, given_node_entry_ids = _get_starting_node_ids(entities, progress_context) + entities_starting_set, given_node_entry_ids = _get_starting_node_ids(entities) ( node_ids_to_be_exported, node_pk_2_uuid_mapping, links_uuid, traversal_rules, - ) = _collect_node_ids(given_node_entry_ids, progress_context, **traversal_rules) + ) = _collect_node_ids(given_node_entry_ids, **traversal_rules) _check_node_licenses(node_ids_to_be_exported, allowed_licenses, forbidden_licenses) @@ -673,12 +642,11 @@ def _collect_archive_data( node_ids_to_be_exported, entities_starting_set, node_pk_2_uuid_mapping, - progress_context, include_comments, include_logs, ) - export_data = _perform_export_queries(entries_queries, progress_context) + export_data = _perform_export_queries(entries_queries) # note this was originally below the attributes and group_uuid gather check_process_nodes_sealed({ @@ -698,7 +666,7 @@ def _collect_archive_data( level=LOG_LEVEL_REPORT, ) - groups_uuid = _get_groups_uuid(export_data, progress_context) + groups_uuid = _get_groups_uuid(export_data) # Turn sets into lists to be able to export them as JSON metadata. for entity, entity_set in entities_starting_set.items(): @@ -717,7 +685,7 @@ def _collect_archive_data( # we get the node data last, because it is generally the largest data source # and later we may look to stream this data in chunks - node_data = _collect_node_data(node_ids_to_be_exported, progress_context) + node_data = _collect_node_data(node_ids_to_be_exported) return ArchiveData( metadata, @@ -729,12 +697,10 @@ def _collect_archive_data( ) -def _get_starting_node_ids(entities: List[Any], - progress_context: ProgressContext) -> Tuple[DefaultDict[str, Set[str]], Set[int]]: +def _get_starting_node_ids(entities: List[Any]) -> Tuple[DefaultDict[str, Set[str]], Set[int]]: """Get the starting node UUIDs and PKs :param entities: a list of entity instances - :param silent: suppress console prints and progress bar. :raises exceptions.ArchiveExportError: :return: entities_starting_set, given_node_entry_ids @@ -747,14 +713,9 @@ def _get_starting_node_ids(entities: List[Any], if not total: return entities_starting_set, given_node_entry_ids - with progress_context(desc='Collecting chosen entities', total=total) as progress_bar: + with get_progress_reporter()(desc='Collecting chosen entities', total=total) as progress: for entry in entities: - # This returns the class name (as in imports). E.g. for a model node: - # aiida.backends.djsite.db.models.DbNode - # entry_class_string = get_class_string(entry) - # Now a load the backend-independent name into entry_entity_name, e.g. Node! - # entry_entity_name = schema_to_entity_names(entry_class_string) if issubclass(entry.__class__, orm.Group): entities_starting_set[GROUP_ENTITY_NAME].add(entry.uuid) elif issubclass(entry.__class__, orm.Node): @@ -768,12 +729,12 @@ def _get_starting_node_ids(entities: List[Any], ' which is not a Node, Computer, or Group instance' ) - progress_bar.update() + progress.update() # Add all the nodes contained within the specified groups if GROUP_ENTITY_NAME in entities_starting_set: - progress_bar.set_description_str('Retrieving Nodes from Groups ...', refresh=False) + progress.set_description_str('Retrieving Nodes from Groups ...', refresh=False) # Use single query instead of given_group.nodes iterator for performance. qh_groups = ( @@ -803,16 +764,16 @@ def _get_starting_node_ids(entities: List[Any], return entities_starting_set, given_node_entry_ids -def _collect_node_ids(given_node_entry_ids: Set[int], progress_context: ProgressContext, +def _collect_node_ids(given_node_entry_ids: Set[int], **traversal_rules: bool) -> Tuple[Set[int], Dict[int, str], List[dict], Dict[str, bool]]: """Iteratively explore the AiiDA graph to find further nodes that should also be exported At the same time, we will create the links_uuid list of dicts to be exported """ - with progress_context(desc='Traversing provenance via links ...', total=1) as progress_bar: + with get_progress_reporter()(desc='Traversing provenance via links ...', total=1) as progress: traverse_output = get_nodes_export(starting_pks=given_node_entry_ids, get_links=True, **traversal_rules) - progress_bar.update() + progress.update() node_ids_to_be_exported = traverse_output['nodes'] graph_traversal_rules = traverse_output['rules'] @@ -872,7 +833,6 @@ def _collect_entity_queries( node_ids_to_be_exported: Set[int], entities_starting_set: DefaultDict[str, Set[str]], node_pk_2_uuid_mapping: Dict[int, str], - progress_context: ProgressContext, include_comments: bool = True, include_logs: bool = True, ) -> Dict[str, orm.QueryBuilder]: @@ -883,7 +843,7 @@ def _collect_entity_queries( all_fields_info, _ = get_all_fields_info() total = 1 + ((1 if include_logs else 0) + (1 if include_logs else 0) if node_ids_to_be_exported else 0) - with progress_context(desc='Initializing export of all entities', total=total) as progress_bar: + with get_progress_reporter()(desc='Initializing export of all entities', total=total) as progress: # Universal "entities" attributed to all types of nodes # Logs @@ -900,7 +860,7 @@ def _collect_entity_queries( res = set(builder.all(flat=True)) given_log_entry_ids.update(res) - progress_bar.update() + progress.update() # Comments if include_comments and node_ids_to_be_exported: @@ -916,7 +876,7 @@ def _collect_entity_queries( res = set(builder.all(flat=True)) given_comment_entry_ids.update(res) - progress_bar.update() + progress.update() # Here we get all the columns that we plan to project per entity that we would like to extract given_entities = set(entities_starting_set.keys()) @@ -927,19 +887,19 @@ def _collect_entity_queries( if given_comment_entry_ids: given_entities.add(COMMENT_ENTITY_NAME) - progress_bar.update() + progress.update() entities_to_add: Dict[str, orm.QueryBuilder] = {} if not given_entities: return entities_to_add - with progress_context(total=len(given_entities)) as progress_bar: + with get_progress_reporter()(total=len(given_entities)) as progress: pbar_base_str = 'Preparing entities' for given_entity in given_entities: - progress_bar.set_description_str(f'{pbar_base_str} - {given_entity}s', refresh=False) - progress_bar.update() + progress.set_description_str(f'{pbar_base_str} - {given_entity}s', refresh=False) + progress.update() project_cols = ['id'] # The following gets a list of fields that we need, @@ -980,12 +940,10 @@ def _collect_entity_queries( return entities_to_add -def _perform_export_queries(entries_queries: Dict[str, orm.QueryBuilder], - progress_context: ProgressContext) -> Dict[str, Dict[int, dict]]: +def _perform_export_queries(entries_queries: Dict[str, orm.QueryBuilder]) -> Dict[str, Dict[int, dict]]: """Start automatic recursive export data generation :param entries_queries: partial queries for all entities to export - :param silent: suppress console prints and progress bar. :return: export data mappings by entity type -> pk -> db_columns, e.g. {'ENTITY_NAME': {: {'uuid': 'abc', ...}, ...}, ...} @@ -1002,11 +960,11 @@ def _perform_export_queries(entries_queries: Dict[str, orm.QueryBuilder], entity_separator = '_' counts = [p_query.count() for p_query in entries_queries.values()] - with progress_context(total=sum(counts)) as progress_bar: + with get_progress_reporter()(total=sum(counts)) as progress: for entity_name, partial_query in entries_queries.items(): - progress_bar.set_description_str(f'Exporting {entity_name} fields', refresh=False) + progress.set_description_str(f'Exporting {entity_name} fields', refresh=False) foreign_fields = {k: v for k, v in all_fields_info[entity_name].items() if 'requires' in v} @@ -1022,7 +980,7 @@ def _perform_export_queries(entries_queries: Dict[str, orm.QueryBuilder], for temp_d in partial_query.iterdict(): - progress_bar.update() + progress.update() for key in temp_d: # Get current entity @@ -1045,12 +1003,11 @@ def _perform_export_queries(entries_queries: Dict[str, orm.QueryBuilder], return export_data -def _collect_node_data(all_node_pks: Set[int], progress_context: ProgressContext) -> Iterable[Tuple[str, dict, dict]]: +def _collect_node_data(all_node_pks: Set[int]) -> Iterable[Tuple[str, dict, dict]]: """Gather attributes and extras for nodes :param export_data: mappings by entity type -> pk -> db_columns :param all_node_pks: set of pks - :param silent: for progress printing :return: iterable of (uuid, attributes, extras) @@ -1069,18 +1026,18 @@ def _collect_node_data(all_node_pks: Set[int], progress_context: ProgressContext project=['id', 'attributes', 'extras'], ) - with progress_context(total=all_nodes_query.count()) as progress_bar: - progress_bar.set_description_str('Exporting Attributes and Extras', refresh=False) + with get_progress_reporter()(total=all_nodes_query.count()) as progress: + progress.set_description_str('Exporting Attributes and Extras', refresh=False) for node_pk, attributes, extras in all_nodes_query.iterall(): - progress_bar.update() + progress.update() node_data.append((str(node_pk), attributes, extras)) return node_data -def _get_groups_uuid(export_data: Dict[str, Dict[int, dict]], progress_context: ProgressContext) -> Dict[str, Set[str]]: +def _get_groups_uuid(export_data: Dict[str, Dict[int, dict]]) -> Dict[str, Set[str]]: """Get node UUIDs per group.""" EXPORT_LOGGER.debug('GATHERING GROUP ELEMENTS...') groups_uuid: Dict[str, Set[str]] = defaultdict(set) @@ -1106,10 +1063,10 @@ def _get_groups_uuid(export_data: Dict[str, Dict[int, dict]], progress_context: if not total_node_uuids_for_groups: return groups_uuid - with progress_context(desc='Exporting Groups ...', total=total_node_uuids_for_groups) as progress_bar: + with get_progress_reporter()(desc='Exporting Groups ...', total=total_node_uuids_for_groups) as progress: for group_uuid, node_uuid in group_uuids_with_node_uuids.iterall(): - progress_bar.update() + progress.update() groups_uuid[group_uuid].add(node_uuid)