From 803d16a48941f557a7d558bb6740dc07b26c0715 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 2 Nov 2022 19:55:39 +0100 Subject: [PATCH 1/6] =?UTF-8?q?=F0=9F=90=9B=20FIX:=20Import=20archive=20in?= =?UTF-8?q?to=20large=20DB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit As detailed in https://www.sqlite.org/limits.html, SQLITE_MAX_VARIABLE_NUMBER puts a limit on how many variables can be used in a single SQL query. This size can be easily reached, if filtering by nodes in a large database. Therefore, this commit changes the filtering of UUIDs to be on the client side, then batches queries for the full nodes by a fixed number. --- aiida/tools/archive/imports.py | 194 ++++++++++++++++++++------------- 1 file changed, 119 insertions(+), 75 deletions(-) diff --git a/aiida/tools/archive/imports.py b/aiida/tools/archive/imports.py index b6e860aade..ed05589e6c 100644 --- a/aiida/tools/archive/imports.py +++ b/aiida/tools/archive/imports.py @@ -9,6 +9,7 @@ ########################################################################### # pylint: disable=too-many-branches,too-many-lines,too-many-locals,too-many-statements """Import an archive.""" +from dataclasses import dataclass from pathlib import Path from typing import Callable, Dict, Literal, Optional, Set, Tuple, Union @@ -54,10 +55,20 @@ DUPLICATE_LABEL_TEMPLATE = '{0} (Imported #{1})' +@dataclass +class QueryParams: + """Parameters for executing backend queries.""" + batch_size: int + """Batch size for streaming database rows.""" + filter_size: int + """Maximum size of parameters allowed in a single query filter.""" + + def import_archive( path: Union[str, Path], *, archive_format: Optional[ArchiveFormatAbstract] = None, + filter_size: int = 999, batch_size: int = 1000, import_new_extras: bool = True, merge_extras: MergeExtrasType = ('k', 'n', 'l'), @@ -72,6 +83,7 @@ def import_archive( :param path: the path to the archive :param archive_format: The class for interacting with the archive + :param filter_size: Maximum size of parameters allowed in a single query filter :param batch_size: Batch size for streaming database rows :param import_new_extras: Keep extras on new nodes (except private aiida keys), else strip :param merge_extras: Rules for merging extras into existing nodes. @@ -117,6 +129,7 @@ def import_archive( type_check(test_run, bool) backend = backend or get_manager().get_profile_storage() type_check(backend, StorageBackend) + qparams = QueryParams(batch_size=batch_size, filter_size=filter_size) if group and not group.is_stored: group.store() @@ -157,28 +170,28 @@ def import_archive( # Every addition/update is made in a single transaction, which is commited on exit with backend.transaction(): - user_ids_archive_backend = _import_users(backend_from, backend, batch_size) - computer_ids_archive_backend = _import_computers(backend_from, backend, batch_size) + user_ids_archive_backend = _import_users(backend_from, backend, qparams) + computer_ids_archive_backend = _import_computers(backend_from, backend, qparams) if include_authinfos: _import_authinfos( - backend_from, backend, batch_size, user_ids_archive_backend, computer_ids_archive_backend + backend_from, backend, qparams, user_ids_archive_backend, computer_ids_archive_backend ) node_ids_archive_backend = _import_nodes( - backend_from, backend, batch_size, user_ids_archive_backend, computer_ids_archive_backend, + backend_from, backend, qparams, user_ids_archive_backend, computer_ids_archive_backend, import_new_extras, merge_extras ) - _import_logs(backend_from, backend, batch_size, node_ids_archive_backend) + _import_logs(backend_from, backend, qparams, node_ids_archive_backend) _import_comments( - backend_from, backend, batch_size, user_ids_archive_backend, node_ids_archive_backend, merge_comments + backend_from, backend, qparams, user_ids_archive_backend, node_ids_archive_backend, merge_comments ) - _import_links(backend_from, backend, batch_size, node_ids_archive_backend) + _import_links(backend_from, backend, qparams, node_ids_archive_backend) group_labels = _import_groups( - backend_from, backend, batch_size, user_ids_archive_backend, node_ids_archive_backend + backend_from, backend, qparams, user_ids_archive_backend, node_ids_archive_backend ) import_group_id = None if create_group: - import_group_id = _make_import_group(group, group_labels, node_ids_archive_backend, backend, batch_size) - new_repo_keys = _get_new_object_keys(archive_format.key_format, backend_from, backend, batch_size) + import_group_id = _make_import_group(group, group_labels, node_ids_archive_backend, backend, qparams) + new_repo_keys = _get_new_object_keys(archive_format.key_format, backend_from, backend, qparams) if test_run: # exit before we write anything to the database or repository @@ -195,35 +208,45 @@ def import_archive( def _add_new_entities( etype: EntityTypes, total: int, unique_field: str, backend_unique_id: dict, backend_from: StorageBackend, - backend_to: StorageBackend, batch_size: int, transform: Callable[[dict], dict] + backend_to: StorageBackend, qparams: QueryParams, transform: Callable[[dict], dict] ) -> None: """Add new entities to the output backend and update the mapping of unique field -> id.""" IMPORT_LOGGER.report(f'Adding {total} new {etype.value}(s)') - iterator = QueryBuilder(backend=backend_from).append( - entity_type_to_orm[etype], - filters={ - unique_field: { - '!in': list(backend_unique_id) - } - } if backend_unique_id else {}, - project=['**'], - tag='entity' - ).iterdict(batch_size=batch_size) + + # collect the unique entities from the input backend to be added to the output backend + ufields = [] + query = QueryBuilder(backend=backend_from).append(entity_type_to_orm[etype], project=[unique_field]) + for (ufield,) in query.distinct().iterall(batch_size=qparams.batch_size): + if ufield not in backend_unique_id: + ufields.append(ufield) + with get_progress_reporter()(desc=f'Adding new {etype.value}(s)', total=total) as progress: - for nrows, rows in batch_iter(iterator, batch_size, transform): + for nrows, ufields_batch in batch_iter(ufields, qparams.filter_size): + rows = [ + transform(row) for row in QueryBuilder(backend=backend_from).append( + entity_type_to_orm[etype], + filters={ + unique_field: { + 'in': ufields_batch + } + }, + project=['**'], + tag='entity' + ).dict(batch_size=qparams.batch_size) + ] new_ids = backend_to.bulk_insert(etype, rows) backend_unique_id.update({row[unique_field]: pk for pk, row in zip(new_ids, rows)}) progress.update(nrows) -def _import_users(backend_from: StorageBackend, backend_to: StorageBackend, batch_size: int) -> Dict[int, int]: +def _import_users(backend_from: StorageBackend, backend_to: StorageBackend, qparams: QueryParams) -> Dict[int, int]: """Import users from one backend to another. :returns: mapping of input backend id to output backend id """ # get the records from the input backend qbuilder = QueryBuilder(backend=backend_from) - input_id_email = dict(qbuilder.append(orm.User, project=['id', 'email']).all(batch_size=batch_size)) + input_id_email = dict(qbuilder.append(orm.User, project=['id', 'email']).all(batch_size=qparams.batch_size)) # get matching emails from the backend output_email_id = {} @@ -235,7 +258,7 @@ def _import_users(backend_from: StorageBackend, backend_to: StorageBackend, batc 'email': { 'in': list(input_id_email.values()) } - }, project=['email', 'id']).all(batch_size=batch_size) + }, project=['email', 'id']).all(batch_size=qparams.batch_size) ) new_users = len(input_id_email) - len(output_email_id) @@ -247,21 +270,21 @@ def _import_users(backend_from: StorageBackend, backend_to: StorageBackend, batc # add new users and update output_email_id with their email -> id mapping transform = lambda row: {k: v for k, v in row['entity'].items() if k != 'id'} _add_new_entities( - EntityTypes.USER, new_users, 'email', output_email_id, backend_from, backend_to, batch_size, transform + EntityTypes.USER, new_users, 'email', output_email_id, backend_from, backend_to, qparams, transform ) # generate mapping of input backend id to output backend id return {int(i): output_email_id[email] for i, email in input_id_email.items()} -def _import_computers(backend_from: StorageBackend, backend_to: StorageBackend, batch_size: int) -> Dict[int, int]: +def _import_computers(backend_from: StorageBackend, backend_to: StorageBackend, qparams: QueryParams) -> Dict[int, int]: """Import computers from one backend to another. :returns: mapping of input backend id to output backend id """ # get the records from the input backend qbuilder = QueryBuilder(backend=backend_from) - input_id_uuid = dict(qbuilder.append(orm.Computer, project=['id', 'uuid']).all(batch_size=batch_size)) + input_id_uuid = dict(qbuilder.append(orm.Computer, project=['id', 'uuid']).all(batch_size=qparams.batch_size)) # get matching uuids from the backend backend_uuid_id = {} @@ -273,7 +296,7 @@ def _import_computers(backend_from: StorageBackend, backend_to: StorageBackend, 'uuid': { 'in': list(input_id_uuid.values()) } - }, project=['uuid', 'id']).all(batch_size=batch_size) + }, project=['uuid', 'id']).all(batch_size=qparams.batch_size) ) new_computers = len(input_id_uuid) - len(backend_uuid_id) @@ -287,7 +310,7 @@ def _import_computers(backend_from: StorageBackend, backend_to: StorageBackend, # Labels should be unique, so we create new labels on clashes labels = { label for label, in orm.QueryBuilder(backend=backend_to).append(orm.Computer, project='label' - ).iterall(batch_size=batch_size) + ).iterall(batch_size=qparams.batch_size) } relabelled = 0 @@ -311,8 +334,7 @@ def transform(row: dict) -> dict: return data _add_new_entities( - EntityTypes.COMPUTER, new_computers, 'uuid', backend_uuid_id, backend_from, backend_to, batch_size, - transform + EntityTypes.COMPUTER, new_computers, 'uuid', backend_uuid_id, backend_from, backend_to, qparams, transform ) if relabelled: @@ -323,8 +345,8 @@ def transform(row: dict) -> dict: def _import_authinfos( - backend_from: StorageBackend, backend_to: StorageBackend, batch_size: int, user_ids_archive_backend: Dict[int, int], - computer_ids_archive_backend: Dict[int, int] + backend_from: StorageBackend, backend_to: StorageBackend, qparams: QueryParams, + user_ids_archive_backend: Dict[int, int], computer_ids_archive_backend: Dict[int, int] ) -> None: """Import logs from one backend to another. @@ -336,7 +358,7 @@ def _import_authinfos( qbuilder.append( orm.AuthInfo, project=['id', 'aiidauser_id', 'dbcomputer_id'], - ).all(batch_size=batch_size) + ).all(batch_size=qparams.batch_size) ) # translate user_id / computer_id, from -> to @@ -363,7 +385,7 @@ def _import_authinfos( project=['id', 'aiidauser_id', 'dbcomputer_id'] ) backend_id_user_comp = [(user_id, comp_id) - for _, user_id, comp_id in qbuilder.all(batch_size=batch_size) + for _, user_id, comp_id in qbuilder.all(batch_size=qparams.batch_size) if (user_id, comp_id) in to_user_id_comp_id] new_authinfos = len(input_id_user_comp) - len(backend_id_user_comp) @@ -396,23 +418,28 @@ def transform(row: dict) -> dict: with get_progress_reporter()( desc=f'Adding new {EntityTypes.AUTHINFO.value}(s)', total=qbuilder.count() ) as progress: - for nrows, rows in batch_iter(iterator, batch_size, transform): + for nrows, rows in batch_iter(iterator, qparams.batch_size, transform): backend_to.bulk_insert(EntityTypes.AUTHINFO, rows) progress.update(nrows) def _import_nodes( - backend_from: StorageBackend, backend_to: StorageBackend, batch_size: int, user_ids_archive_backend: Dict[int, int], - computer_ids_archive_backend: Dict[int, int], import_new_extras: bool, merge_extras: MergeExtrasType + backend_from: StorageBackend, + backend_to: StorageBackend, + qparams: QueryParams, + user_ids_archive_backend: Dict[int, int], + computer_ids_archive_backend: Dict[int, int], + import_new_extras: bool, + merge_extras: MergeExtrasType, ) -> Dict[int, int]: - """Import users from one backend to another. + """Import nodes from one backend to another. :returns: mapping of input backend id to output backend id """ IMPORT_LOGGER.report('Collecting Node(s) ...') # get the records from the input backend qbuilder = QueryBuilder(backend=backend_from) - input_id_uuid = dict(qbuilder.append(orm.Node, project=['id', 'uuid']).all(batch_size=batch_size)) + input_id_uuid = dict(qbuilder.append(orm.Node, project=['id', 'uuid']).all(batch_size=qparams.batch_size)) # get matching uuids from the backend backend_uuid_id = {} @@ -424,19 +451,19 @@ def _import_nodes( 'uuid': { 'in': list(input_id_uuid.values()) } - }, project=['uuid', 'id']).all(batch_size=batch_size) + }, project=['uuid', 'id']).all(batch_size=qparams.batch_size) ) new_nodes = len(input_id_uuid) - len(backend_uuid_id) if backend_uuid_id: - _merge_node_extras(backend_from, backend_to, batch_size, backend_uuid_id, merge_extras) + _merge_node_extras(backend_from, backend_to, qparams, backend_uuid_id, merge_extras) if new_nodes: # add new nodes and update backend_uuid_id with their uuid -> id mapping transform = NodeTransform(user_ids_archive_backend, computer_ids_archive_backend, import_new_extras) _add_new_entities( - EntityTypes.NODE, new_nodes, 'uuid', backend_uuid_id, backend_from, backend_to, batch_size, transform + EntityTypes.NODE, new_nodes, 'uuid', backend_uuid_id, backend_from, backend_to, qparams, transform ) # generate mapping of input backend id to output backend id @@ -481,7 +508,10 @@ def __call__(self, row: dict) -> dict: def _import_logs( - backend_from: StorageBackend, backend_to: StorageBackend, batch_size: int, node_ids_archive_backend: Dict[int, int] + backend_from: StorageBackend, + backend_to: StorageBackend, + qparams: QueryParams, + node_ids_archive_backend: Dict[int, int], ) -> Dict[int, int]: """Import logs from one backend to another. @@ -489,7 +519,7 @@ def _import_logs( """ # get the records from the input backend qbuilder = QueryBuilder(backend=backend_from) - input_id_uuid = dict(qbuilder.append(orm.Log, project=['id', 'uuid']).all(batch_size=batch_size)) + input_id_uuid = dict(qbuilder.append(orm.Log, project=['id', 'uuid']).all(batch_size=qparams.batch_size)) # get matching uuids from the backend backend_uuid_id = {} @@ -501,7 +531,7 @@ def _import_logs( 'uuid': { 'in': list(input_id_uuid.values()) } - }, project=['uuid', 'id']).all(batch_size=batch_size) + }, project=['uuid', 'id']).all(batch_size=qparams.batch_size) ) new_logs = len(input_id_uuid) - len(backend_uuid_id) @@ -521,7 +551,7 @@ def transform(row: dict) -> dict: return data _add_new_entities( - EntityTypes.LOG, new_logs, 'uuid', backend_uuid_id, backend_from, backend_to, batch_size, transform + EntityTypes.LOG, new_logs, 'uuid', backend_uuid_id, backend_from, backend_to, qparams, transform ) # generate mapping of input backend id to output backend id @@ -529,7 +559,7 @@ def transform(row: dict) -> dict: def _merge_node_extras( - backend_from: StorageBackend, backend_to: StorageBackend, batch_size: int, backend_uuid_id: Dict[str, int], + backend_from: StorageBackend, backend_to: StorageBackend, qparams: QueryParams, backend_uuid_id: Dict[str, int], mode: MergeExtrasType ) -> None: """Merge extras from the input backend with the ones in the output backend. @@ -559,7 +589,9 @@ def _merge_node_extras( IMPORT_LOGGER.report(f'Replacing {num_existing} existing Node extras') transform = lambda row: {'id': backend_uuid_id[row[0]], 'extras': row[1]} with get_progress_reporter()(desc='Replacing extras', total=input_extras.count()) as progress: - for nrows, rows in batch_iter(input_extras.iterall(batch_size=batch_size), batch_size, transform): + for nrows, rows in batch_iter( + input_extras.iterall(batch_size=qparams.batch_size), qparams.batch_size, transform + ): backend_to.bulk_update(EntityTypes.NODE, rows) progress.update(nrows) return @@ -640,8 +672,10 @@ def _transform(data: Tuple[Tuple[str, dict], Tuple[str, dict]]) -> dict: with get_progress_reporter()(desc='Merging extras', total=input_extras.count()) as progress: for nrows, rows in batch_iter( - zip(input_extras.iterall(batch_size=batch_size), backend_extras.iterall(batch_size=batch_size)), batch_size, - _transform + zip( + input_extras.iterall(batch_size=qparams.batch_size), + backend_extras.iterall(batch_size=qparams.batch_size) + ), qparams.batch_size, _transform ): backend_to.bulk_update(EntityTypes.NODE, rows) progress.update(nrows) @@ -676,7 +710,7 @@ def __call__(self, row: dict) -> dict: def _import_comments( backend_from: StorageBackend, backend: StorageBackend, - batch_size: int, + qparams: QueryParams, user_ids_archive_backend: Dict[int, int], node_ids_archive_backend: Dict[int, int], merge_comments: MergeCommentsType, @@ -687,7 +721,7 @@ def _import_comments( """ # get the records from the input backend qbuilder = QueryBuilder(backend=backend_from) - input_id_uuid = dict(qbuilder.append(orm.Comment, project=['id', 'uuid']).all(batch_size=batch_size)) + input_id_uuid = dict(qbuilder.append(orm.Comment, project=['id', 'uuid']).all(batch_size=qparams.batch_size)) # get matching uuids from the backend backend_uuid_id = {} @@ -699,7 +733,7 @@ def _import_comments( 'uuid': { 'in': list(input_id_uuid.values()) } - }, project=['uuid', 'id']).all(batch_size=batch_size) + }, project=['uuid', 'id']).all(batch_size=qparams.batch_size) ) new_comments = len(input_id_uuid) - len(backend_uuid_id) @@ -722,7 +756,9 @@ def _transform(row): return data with get_progress_reporter()(desc='Overwriting comments', total=archive_comments.count()) as progress: - for nrows, rows in batch_iter(archive_comments.iterall(batch_size=batch_size), batch_size, _transform): + for nrows, rows in batch_iter( + archive_comments.iterall(batch_size=qparams.batch_size), qparams.batch_size, _transform + ): backend.bulk_update(EntityTypes.COMMENT, rows) progress.update(nrows) @@ -738,7 +774,9 @@ def _transform(row): cmt.set_content(new_comment) with get_progress_reporter()(desc='Updating comments', total=archive_comments.count()) as progress: - for nrows, rows in batch_iter(archive_comments.iterall(batch_size=batch_size), batch_size, _transform): + for nrows, rows in batch_iter( + archive_comments.iterall(batch_size=qparams.batch_size), qparams.batch_size, _transform + ): progress.update(nrows) else: @@ -746,7 +784,7 @@ def _transform(row): if new_comments: # add new comments and update backend_uuid_id with their uuid -> id mapping _add_new_entities( - EntityTypes.COMMENT, new_comments, 'uuid', backend_uuid_id, backend_from, backend, batch_size, + EntityTypes.COMMENT, new_comments, 'uuid', backend_uuid_id, backend_from, backend, qparams, CommentTransform(user_ids_archive_backend, node_ids_archive_backend) ) @@ -755,7 +793,10 @@ def _transform(row): def _import_links( - backend_from: StorageBackend, backend_to: StorageBackend, batch_size: int, node_ids_archive_backend: Dict[int, int] + backend_from: StorageBackend, + backend_to: StorageBackend, + qparams: QueryParams, + node_ids_archive_backend: Dict[int, int], ) -> None: """Import links from one backend to another.""" @@ -811,7 +852,7 @@ def _import_links( tuple(link) for link in orm.QueryBuilder(backend=backend_to). append(entity_type='link', filters={ 'type': link_type.value - }, project=['input_id', 'output_id', 'label']).iterall(batch_size=batch_size) + }, project=['input_id', 'output_id', 'label']).iterall(batch_size=qparams.batch_size) } # create additional validators # note, we only populate them when required, to reduce memory usage @@ -823,7 +864,9 @@ def _import_links( new_count = existing_count = 0 insert_rows = [] with get_progress_reporter()(desc=f'Processing {link_type.value!r} Link(s)', total=total) as progress: - for in_id, in_type, out_id, out_type, link_id, link_label in archive_query.iterall(batch_size=batch_size): + for in_id, in_type, out_id, out_type, link_id, link_label in archive_query.iterall( + batch_size=qparams.batch_size + ): progress.update() @@ -876,7 +919,7 @@ def _import_links( existing_out_id_label.add((out_id, link_label)) # flush new rows, once batch size is reached - if (new_count % batch_size) == 0: + if (new_count % qparams.batch_size) == 0: backend_to.bulk_insert(EntityTypes.LINK, insert_rows) insert_rows = [] @@ -924,8 +967,8 @@ def __call__(self, row: dict) -> dict: def _import_groups( - backend_from: StorageBackend, backend_to: StorageBackend, batch_size: int, user_ids_archive_backend: Dict[int, int], - node_ids_archive_backend: Dict[int, int] + backend_from: StorageBackend, backend_to: StorageBackend, qparams: QueryParams, + user_ids_archive_backend: Dict[int, int], node_ids_archive_backend: Dict[int, int] ) -> Set[str]: """Import groups from the input backend, and add group -> node records. @@ -933,7 +976,7 @@ def _import_groups( """ # get the records from the input backend qbuilder = QueryBuilder(backend=backend_from) - input_id_uuid = dict(qbuilder.append(orm.Group, project=['id', 'uuid']).all(batch_size=batch_size)) + input_id_uuid = dict(qbuilder.append(orm.Group, project=['id', 'uuid']).all(batch_size=qparams.batch_size)) # get matching uuids from the backend backend_uuid_id = {} @@ -945,13 +988,13 @@ def _import_groups( 'uuid': { 'in': list(input_id_uuid.values()) } - }, project=['uuid', 'id']).all(batch_size=batch_size) + }, project=['uuid', 'id']).all(batch_size=qparams.batch_size) ) # get all labels labels = { label for label, in orm.QueryBuilder(backend=backend_to).append(orm.Group, project='label' - ).iterall(batch_size=batch_size) + ).iterall(batch_size=qparams.batch_size) } new_groups = len(input_id_uuid) - len(backend_uuid_id) @@ -966,7 +1009,7 @@ def _import_groups( transform = GroupTransform(user_ids_archive_backend, labels) _add_new_entities( - EntityTypes.GROUP, new_groups, 'uuid', backend_uuid_id, backend_from, backend_to, batch_size, transform + EntityTypes.GROUP, new_groups, 'uuid', backend_uuid_id, backend_from, backend_to, qparams, transform ) if transform.relabelled: @@ -995,7 +1038,7 @@ def group_node_transform(row): with get_progress_reporter()(desc=f'Adding new {EntityTypes.GROUP_NODE.value}(s)', total=total) as progress: for nrows, rows in batch_iter( - iterator.iterall(batch_size=batch_size), batch_size, group_node_transform + iterator.iterall(batch_size=qparams.batch_size), qparams.batch_size, group_node_transform ): backend_to.bulk_insert(EntityTypes.GROUP_NODE, rows) progress.update(nrows) @@ -1005,7 +1048,7 @@ def group_node_transform(row): def _make_import_group( group: Optional[orm.Group], labels: Set[str], node_ids_archive_backend: Dict[int, int], backend_to: StorageBackend, - batch_size: int + qparams: QueryParams ) -> Optional[int]: """Make an import group containing all imported nodes. @@ -1049,7 +1092,7 @@ def _make_import_group( group_node_ids = { pk for pk, in orm.QueryBuilder(backend=backend_to).append(orm.Group, filters={ 'id': group_id - }, tag='group').append(orm.Node, with_group='group', project='id').iterall(batch_size=batch_size) + }, tag='group').append(orm.Node, with_group='group', project='id').iterall(batch_size=qparams.batch_size) } # Add all the nodes to the Group @@ -1060,20 +1103,21 @@ def _make_import_group( 'dbgroup_id': group_id, 'dbnode_id': node_id } for node_id in node_ids_archive_backend.values() if node_id not in group_node_ids) - for nrows, rows in batch_iter(iterator, batch_size): + for nrows, rows in batch_iter(iterator, qparams.batch_size): backend_to.bulk_insert(EntityTypes.GROUP_NODE, rows) progress.update(nrows) return group_id -def _get_new_object_keys(key_format: str, backend_from: StorageBackend, backend_to: StorageBackend, - batch_size: int) -> Set[str]: +def _get_new_object_keys( + key_format: str, backend_from: StorageBackend, backend_to: StorageBackend, qparams: QueryParams +) -> Set[str]: """Return the object keys that need to be added to the backend.""" archive_hashkeys: Set[str] = set() query = QueryBuilder(backend=backend_from).append(orm.Node, project='repository_metadata') with get_progress_reporter()(desc='Collecting archive Node file keys', total=query.count()) as progress: - for repository_metadata, in query.iterall(batch_size=batch_size): + for repository_metadata, in query.iterall(batch_size=qparams.batch_size): archive_hashkeys.update(key for key in Repository.flatten(repository_metadata).values() if key is not None) progress.update() From 1440a60d2003b5df085a9d2b8f3db3f95a05823e Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 4 Nov 2022 15:53:17 +0100 Subject: [PATCH 2/6] Update aiida/tools/archive/imports.py Co-authored-by: Sebastiaan Huber --- aiida/tools/archive/imports.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiida/tools/archive/imports.py b/aiida/tools/archive/imports.py index ed05589e6c..3f57143f49 100644 --- a/aiida/tools/archive/imports.py +++ b/aiida/tools/archive/imports.py @@ -215,7 +215,7 @@ def _add_new_entities( # collect the unique entities from the input backend to be added to the output backend ufields = [] - query = QueryBuilder(backend=backend_from).append(entity_type_to_orm[etype], project=[unique_field]) + query = QueryBuilder(backend=backend_from).append(entity_type_to_orm[etype], project=unique_field) for (ufield,) in query.distinct().iterall(batch_size=qparams.batch_size): if ufield not in backend_unique_id: ufields.append(ufield) From 19bd3b3576c4b09b4a3d3f3e1132115395fba567 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 4 Nov 2022 15:56:00 +0100 Subject: [PATCH 3/6] qparams to query_params --- aiida/tools/archive/imports.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/aiida/tools/archive/imports.py b/aiida/tools/archive/imports.py index 3f57143f49..a26fef2cd9 100644 --- a/aiida/tools/archive/imports.py +++ b/aiida/tools/archive/imports.py @@ -129,7 +129,7 @@ def import_archive( type_check(test_run, bool) backend = backend or get_manager().get_profile_storage() type_check(backend, StorageBackend) - qparams = QueryParams(batch_size=batch_size, filter_size=filter_size) + query_params = QueryParams(batch_size=batch_size, filter_size=filter_size) if group and not group.is_stored: group.store() @@ -170,28 +170,28 @@ def import_archive( # Every addition/update is made in a single transaction, which is commited on exit with backend.transaction(): - user_ids_archive_backend = _import_users(backend_from, backend, qparams) - computer_ids_archive_backend = _import_computers(backend_from, backend, qparams) + user_ids_archive_backend = _import_users(backend_from, backend, query_params) + computer_ids_archive_backend = _import_computers(backend_from, backend, query_params) if include_authinfos: _import_authinfos( - backend_from, backend, qparams, user_ids_archive_backend, computer_ids_archive_backend + backend_from, backend, query_params, user_ids_archive_backend, computer_ids_archive_backend ) node_ids_archive_backend = _import_nodes( - backend_from, backend, qparams, user_ids_archive_backend, computer_ids_archive_backend, + backend_from, backend, query_params, user_ids_archive_backend, computer_ids_archive_backend, import_new_extras, merge_extras ) - _import_logs(backend_from, backend, qparams, node_ids_archive_backend) + _import_logs(backend_from, backend, query_params, node_ids_archive_backend) _import_comments( - backend_from, backend, qparams, user_ids_archive_backend, node_ids_archive_backend, merge_comments + backend_from, backend, query_params, user_ids_archive_backend, node_ids_archive_backend, merge_comments ) - _import_links(backend_from, backend, qparams, node_ids_archive_backend) + _import_links(backend_from, backend, query_params, node_ids_archive_backend) group_labels = _import_groups( - backend_from, backend, qparams, user_ids_archive_backend, node_ids_archive_backend + backend_from, backend, query_params, user_ids_archive_backend, node_ids_archive_backend ) import_group_id = None if create_group: - import_group_id = _make_import_group(group, group_labels, node_ids_archive_backend, backend, qparams) - new_repo_keys = _get_new_object_keys(archive_format.key_format, backend_from, backend, qparams) + import_group_id = _make_import_group(group, group_labels, node_ids_archive_backend, backend, query_params) + new_repo_keys = _get_new_object_keys(archive_format.key_format, backend_from, backend, query_params) if test_run: # exit before we write anything to the database or repository From 86b97ac1271930d5d542011642df96f531021137 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 4 Nov 2022 14:58:37 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- aiida/tools/archive/imports.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aiida/tools/archive/imports.py b/aiida/tools/archive/imports.py index a26fef2cd9..7b384cce2d 100644 --- a/aiida/tools/archive/imports.py +++ b/aiida/tools/archive/imports.py @@ -190,7 +190,9 @@ def import_archive( ) import_group_id = None if create_group: - import_group_id = _make_import_group(group, group_labels, node_ids_archive_backend, backend, query_params) + import_group_id = _make_import_group( + group, group_labels, node_ids_archive_backend, backend, query_params + ) new_repo_keys = _get_new_object_keys(archive_format.key_format, backend_from, backend, query_params) if test_run: From 43c2ba1ba49f2729af6bdea5a336d77812df392a Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 4 Nov 2022 18:22:45 +0100 Subject: [PATCH 5/6] Update imports.py --- aiida/tools/archive/imports.py | 113 +++++++++++++++++---------------- 1 file changed, 59 insertions(+), 54 deletions(-) diff --git a/aiida/tools/archive/imports.py b/aiida/tools/archive/imports.py index 7b384cce2d..b2754b603d 100644 --- a/aiida/tools/archive/imports.py +++ b/aiida/tools/archive/imports.py @@ -210,7 +210,7 @@ def import_archive( def _add_new_entities( etype: EntityTypes, total: int, unique_field: str, backend_unique_id: dict, backend_from: StorageBackend, - backend_to: StorageBackend, qparams: QueryParams, transform: Callable[[dict], dict] + backend_to: StorageBackend, query_params: QueryParams, transform: Callable[[dict], dict] ) -> None: """Add new entities to the output backend and update the mapping of unique field -> id.""" IMPORT_LOGGER.report(f'Adding {total} new {etype.value}(s)') @@ -218,12 +218,12 @@ def _add_new_entities( # collect the unique entities from the input backend to be added to the output backend ufields = [] query = QueryBuilder(backend=backend_from).append(entity_type_to_orm[etype], project=unique_field) - for (ufield,) in query.distinct().iterall(batch_size=qparams.batch_size): + for (ufield,) in query.distinct().iterall(batch_size=query_params.batch_size): if ufield not in backend_unique_id: ufields.append(ufield) with get_progress_reporter()(desc=f'Adding new {etype.value}(s)', total=total) as progress: - for nrows, ufields_batch in batch_iter(ufields, qparams.filter_size): + for nrows, ufields_batch in batch_iter(ufields, query_params.filter_size): rows = [ transform(row) for row in QueryBuilder(backend=backend_from).append( entity_type_to_orm[etype], @@ -234,21 +234,22 @@ def _add_new_entities( }, project=['**'], tag='entity' - ).dict(batch_size=qparams.batch_size) + ).dict(batch_size=query_params.batch_size) ] new_ids = backend_to.bulk_insert(etype, rows) backend_unique_id.update({row[unique_field]: pk for pk, row in zip(new_ids, rows)}) progress.update(nrows) -def _import_users(backend_from: StorageBackend, backend_to: StorageBackend, qparams: QueryParams) -> Dict[int, int]: +def _import_users(backend_from: StorageBackend, backend_to: StorageBackend, + query_params: QueryParams) -> Dict[int, int]: """Import users from one backend to another. :returns: mapping of input backend id to output backend id """ # get the records from the input backend qbuilder = QueryBuilder(backend=backend_from) - input_id_email = dict(qbuilder.append(orm.User, project=['id', 'email']).all(batch_size=qparams.batch_size)) + input_id_email = dict(qbuilder.append(orm.User, project=['id', 'email']).all(batch_size=query_params.batch_size)) # get matching emails from the backend output_email_id = {} @@ -260,7 +261,7 @@ def _import_users(backend_from: StorageBackend, backend_to: StorageBackend, qpar 'email': { 'in': list(input_id_email.values()) } - }, project=['email', 'id']).all(batch_size=qparams.batch_size) + }, project=['email', 'id']).all(batch_size=query_params.batch_size) ) new_users = len(input_id_email) - len(output_email_id) @@ -272,21 +273,22 @@ def _import_users(backend_from: StorageBackend, backend_to: StorageBackend, qpar # add new users and update output_email_id with their email -> id mapping transform = lambda row: {k: v for k, v in row['entity'].items() if k != 'id'} _add_new_entities( - EntityTypes.USER, new_users, 'email', output_email_id, backend_from, backend_to, qparams, transform + EntityTypes.USER, new_users, 'email', output_email_id, backend_from, backend_to, query_params, transform ) # generate mapping of input backend id to output backend id return {int(i): output_email_id[email] for i, email in input_id_email.items()} -def _import_computers(backend_from: StorageBackend, backend_to: StorageBackend, qparams: QueryParams) -> Dict[int, int]: +def _import_computers(backend_from: StorageBackend, backend_to: StorageBackend, + query_params: QueryParams) -> Dict[int, int]: """Import computers from one backend to another. :returns: mapping of input backend id to output backend id """ # get the records from the input backend qbuilder = QueryBuilder(backend=backend_from) - input_id_uuid = dict(qbuilder.append(orm.Computer, project=['id', 'uuid']).all(batch_size=qparams.batch_size)) + input_id_uuid = dict(qbuilder.append(orm.Computer, project=['id', 'uuid']).all(batch_size=query_params.batch_size)) # get matching uuids from the backend backend_uuid_id = {} @@ -298,7 +300,7 @@ def _import_computers(backend_from: StorageBackend, backend_to: StorageBackend, 'uuid': { 'in': list(input_id_uuid.values()) } - }, project=['uuid', 'id']).all(batch_size=qparams.batch_size) + }, project=['uuid', 'id']).all(batch_size=query_params.batch_size) ) new_computers = len(input_id_uuid) - len(backend_uuid_id) @@ -311,8 +313,9 @@ def _import_computers(backend_from: StorageBackend, backend_to: StorageBackend, # Labels should be unique, so we create new labels on clashes labels = { - label for label, in orm.QueryBuilder(backend=backend_to).append(orm.Computer, project='label' - ).iterall(batch_size=qparams.batch_size) + label for label, in orm.QueryBuilder(backend=backend_to).append(orm.Computer, project='label').iterall( + batch_size=query_params.batch_size + ) } relabelled = 0 @@ -336,7 +339,8 @@ def transform(row: dict) -> dict: return data _add_new_entities( - EntityTypes.COMPUTER, new_computers, 'uuid', backend_uuid_id, backend_from, backend_to, qparams, transform + EntityTypes.COMPUTER, new_computers, 'uuid', backend_uuid_id, backend_from, backend_to, query_params, + transform ) if relabelled: @@ -347,7 +351,7 @@ def transform(row: dict) -> dict: def _import_authinfos( - backend_from: StorageBackend, backend_to: StorageBackend, qparams: QueryParams, + backend_from: StorageBackend, backend_to: StorageBackend, query_params: QueryParams, user_ids_archive_backend: Dict[int, int], computer_ids_archive_backend: Dict[int, int] ) -> None: """Import logs from one backend to another. @@ -360,7 +364,7 @@ def _import_authinfos( qbuilder.append( orm.AuthInfo, project=['id', 'aiidauser_id', 'dbcomputer_id'], - ).all(batch_size=qparams.batch_size) + ).all(batch_size=query_params.batch_size) ) # translate user_id / computer_id, from -> to @@ -387,7 +391,7 @@ def _import_authinfos( project=['id', 'aiidauser_id', 'dbcomputer_id'] ) backend_id_user_comp = [(user_id, comp_id) - for _, user_id, comp_id in qbuilder.all(batch_size=qparams.batch_size) + for _, user_id, comp_id in qbuilder.all(batch_size=query_params.batch_size) if (user_id, comp_id) in to_user_id_comp_id] new_authinfos = len(input_id_user_comp) - len(backend_id_user_comp) @@ -420,7 +424,7 @@ def transform(row: dict) -> dict: with get_progress_reporter()( desc=f'Adding new {EntityTypes.AUTHINFO.value}(s)', total=qbuilder.count() ) as progress: - for nrows, rows in batch_iter(iterator, qparams.batch_size, transform): + for nrows, rows in batch_iter(iterator, query_params.batch_size, transform): backend_to.bulk_insert(EntityTypes.AUTHINFO, rows) progress.update(nrows) @@ -428,7 +432,7 @@ def transform(row: dict) -> dict: def _import_nodes( backend_from: StorageBackend, backend_to: StorageBackend, - qparams: QueryParams, + query_params: QueryParams, user_ids_archive_backend: Dict[int, int], computer_ids_archive_backend: Dict[int, int], import_new_extras: bool, @@ -441,7 +445,7 @@ def _import_nodes( IMPORT_LOGGER.report('Collecting Node(s) ...') # get the records from the input backend qbuilder = QueryBuilder(backend=backend_from) - input_id_uuid = dict(qbuilder.append(orm.Node, project=['id', 'uuid']).all(batch_size=qparams.batch_size)) + input_id_uuid = dict(qbuilder.append(orm.Node, project=['id', 'uuid']).all(batch_size=query_params.batch_size)) # get matching uuids from the backend backend_uuid_id = {} @@ -453,19 +457,19 @@ def _import_nodes( 'uuid': { 'in': list(input_id_uuid.values()) } - }, project=['uuid', 'id']).all(batch_size=qparams.batch_size) + }, project=['uuid', 'id']).all(batch_size=query_params.batch_size) ) new_nodes = len(input_id_uuid) - len(backend_uuid_id) if backend_uuid_id: - _merge_node_extras(backend_from, backend_to, qparams, backend_uuid_id, merge_extras) + _merge_node_extras(backend_from, backend_to, query_params, backend_uuid_id, merge_extras) if new_nodes: # add new nodes and update backend_uuid_id with their uuid -> id mapping transform = NodeTransform(user_ids_archive_backend, computer_ids_archive_backend, import_new_extras) _add_new_entities( - EntityTypes.NODE, new_nodes, 'uuid', backend_uuid_id, backend_from, backend_to, qparams, transform + EntityTypes.NODE, new_nodes, 'uuid', backend_uuid_id, backend_from, backend_to, query_params, transform ) # generate mapping of input backend id to output backend id @@ -512,7 +516,7 @@ def __call__(self, row: dict) -> dict: def _import_logs( backend_from: StorageBackend, backend_to: StorageBackend, - qparams: QueryParams, + query_params: QueryParams, node_ids_archive_backend: Dict[int, int], ) -> Dict[int, int]: """Import logs from one backend to another. @@ -521,7 +525,7 @@ def _import_logs( """ # get the records from the input backend qbuilder = QueryBuilder(backend=backend_from) - input_id_uuid = dict(qbuilder.append(orm.Log, project=['id', 'uuid']).all(batch_size=qparams.batch_size)) + input_id_uuid = dict(qbuilder.append(orm.Log, project=['id', 'uuid']).all(batch_size=query_params.batch_size)) # get matching uuids from the backend backend_uuid_id = {} @@ -533,7 +537,7 @@ def _import_logs( 'uuid': { 'in': list(input_id_uuid.values()) } - }, project=['uuid', 'id']).all(batch_size=qparams.batch_size) + }, project=['uuid', 'id']).all(batch_size=query_params.batch_size) ) new_logs = len(input_id_uuid) - len(backend_uuid_id) @@ -553,7 +557,7 @@ def transform(row: dict) -> dict: return data _add_new_entities( - EntityTypes.LOG, new_logs, 'uuid', backend_uuid_id, backend_from, backend_to, qparams, transform + EntityTypes.LOG, new_logs, 'uuid', backend_uuid_id, backend_from, backend_to, query_params, transform ) # generate mapping of input backend id to output backend id @@ -561,8 +565,8 @@ def transform(row: dict) -> dict: def _merge_node_extras( - backend_from: StorageBackend, backend_to: StorageBackend, qparams: QueryParams, backend_uuid_id: Dict[str, int], - mode: MergeExtrasType + backend_from: StorageBackend, backend_to: StorageBackend, query_params: QueryParams, + backend_uuid_id: Dict[str, int], mode: MergeExtrasType ) -> None: """Merge extras from the input backend with the ones in the output backend. @@ -592,7 +596,7 @@ def _merge_node_extras( transform = lambda row: {'id': backend_uuid_id[row[0]], 'extras': row[1]} with get_progress_reporter()(desc='Replacing extras', total=input_extras.count()) as progress: for nrows, rows in batch_iter( - input_extras.iterall(batch_size=qparams.batch_size), qparams.batch_size, transform + input_extras.iterall(batch_size=query_params.batch_size), query_params.batch_size, transform ): backend_to.bulk_update(EntityTypes.NODE, rows) progress.update(nrows) @@ -675,9 +679,9 @@ def _transform(data: Tuple[Tuple[str, dict], Tuple[str, dict]]) -> dict: with get_progress_reporter()(desc='Merging extras', total=input_extras.count()) as progress: for nrows, rows in batch_iter( zip( - input_extras.iterall(batch_size=qparams.batch_size), - backend_extras.iterall(batch_size=qparams.batch_size) - ), qparams.batch_size, _transform + input_extras.iterall(batch_size=query_params.batch_size), + backend_extras.iterall(batch_size=query_params.batch_size) + ), query_params.batch_size, _transform ): backend_to.bulk_update(EntityTypes.NODE, rows) progress.update(nrows) @@ -712,7 +716,7 @@ def __call__(self, row: dict) -> dict: def _import_comments( backend_from: StorageBackend, backend: StorageBackend, - qparams: QueryParams, + query_params: QueryParams, user_ids_archive_backend: Dict[int, int], node_ids_archive_backend: Dict[int, int], merge_comments: MergeCommentsType, @@ -723,7 +727,7 @@ def _import_comments( """ # get the records from the input backend qbuilder = QueryBuilder(backend=backend_from) - input_id_uuid = dict(qbuilder.append(orm.Comment, project=['id', 'uuid']).all(batch_size=qparams.batch_size)) + input_id_uuid = dict(qbuilder.append(orm.Comment, project=['id', 'uuid']).all(batch_size=query_params.batch_size)) # get matching uuids from the backend backend_uuid_id = {} @@ -735,7 +739,7 @@ def _import_comments( 'uuid': { 'in': list(input_id_uuid.values()) } - }, project=['uuid', 'id']).all(batch_size=qparams.batch_size) + }, project=['uuid', 'id']).all(batch_size=query_params.batch_size) ) new_comments = len(input_id_uuid) - len(backend_uuid_id) @@ -759,7 +763,7 @@ def _transform(row): with get_progress_reporter()(desc='Overwriting comments', total=archive_comments.count()) as progress: for nrows, rows in batch_iter( - archive_comments.iterall(batch_size=qparams.batch_size), qparams.batch_size, _transform + archive_comments.iterall(batch_size=query_params.batch_size), query_params.batch_size, _transform ): backend.bulk_update(EntityTypes.COMMENT, rows) progress.update(nrows) @@ -777,7 +781,7 @@ def _transform(row): with get_progress_reporter()(desc='Updating comments', total=archive_comments.count()) as progress: for nrows, rows in batch_iter( - archive_comments.iterall(batch_size=qparams.batch_size), qparams.batch_size, _transform + archive_comments.iterall(batch_size=query_params.batch_size), query_params.batch_size, _transform ): progress.update(nrows) @@ -786,7 +790,7 @@ def _transform(row): if new_comments: # add new comments and update backend_uuid_id with their uuid -> id mapping _add_new_entities( - EntityTypes.COMMENT, new_comments, 'uuid', backend_uuid_id, backend_from, backend, qparams, + EntityTypes.COMMENT, new_comments, 'uuid', backend_uuid_id, backend_from, backend, query_params, CommentTransform(user_ids_archive_backend, node_ids_archive_backend) ) @@ -797,7 +801,7 @@ def _transform(row): def _import_links( backend_from: StorageBackend, backend_to: StorageBackend, - qparams: QueryParams, + query_params: QueryParams, node_ids_archive_backend: Dict[int, int], ) -> None: """Import links from one backend to another.""" @@ -854,7 +858,7 @@ def _import_links( tuple(link) for link in orm.QueryBuilder(backend=backend_to). append(entity_type='link', filters={ 'type': link_type.value - }, project=['input_id', 'output_id', 'label']).iterall(batch_size=qparams.batch_size) + }, project=['input_id', 'output_id', 'label']).iterall(batch_size=query_params.batch_size) } # create additional validators # note, we only populate them when required, to reduce memory usage @@ -867,7 +871,7 @@ def _import_links( insert_rows = [] with get_progress_reporter()(desc=f'Processing {link_type.value!r} Link(s)', total=total) as progress: for in_id, in_type, out_id, out_type, link_id, link_label in archive_query.iterall( - batch_size=qparams.batch_size + batch_size=query_params.batch_size ): progress.update() @@ -921,7 +925,7 @@ def _import_links( existing_out_id_label.add((out_id, link_label)) # flush new rows, once batch size is reached - if (new_count % qparams.batch_size) == 0: + if (new_count % query_params.batch_size) == 0: backend_to.bulk_insert(EntityTypes.LINK, insert_rows) insert_rows = [] @@ -969,7 +973,7 @@ def __call__(self, row: dict) -> dict: def _import_groups( - backend_from: StorageBackend, backend_to: StorageBackend, qparams: QueryParams, + backend_from: StorageBackend, backend_to: StorageBackend, query_params: QueryParams, user_ids_archive_backend: Dict[int, int], node_ids_archive_backend: Dict[int, int] ) -> Set[str]: """Import groups from the input backend, and add group -> node records. @@ -978,7 +982,7 @@ def _import_groups( """ # get the records from the input backend qbuilder = QueryBuilder(backend=backend_from) - input_id_uuid = dict(qbuilder.append(orm.Group, project=['id', 'uuid']).all(batch_size=qparams.batch_size)) + input_id_uuid = dict(qbuilder.append(orm.Group, project=['id', 'uuid']).all(batch_size=query_params.batch_size)) # get matching uuids from the backend backend_uuid_id = {} @@ -990,13 +994,13 @@ def _import_groups( 'uuid': { 'in': list(input_id_uuid.values()) } - }, project=['uuid', 'id']).all(batch_size=qparams.batch_size) + }, project=['uuid', 'id']).all(batch_size=query_params.batch_size) ) # get all labels labels = { label for label, in orm.QueryBuilder(backend=backend_to).append(orm.Group, project='label' - ).iterall(batch_size=qparams.batch_size) + ).iterall(batch_size=query_params.batch_size) } new_groups = len(input_id_uuid) - len(backend_uuid_id) @@ -1011,7 +1015,7 @@ def _import_groups( transform = GroupTransform(user_ids_archive_backend, labels) _add_new_entities( - EntityTypes.GROUP, new_groups, 'uuid', backend_uuid_id, backend_from, backend_to, qparams, transform + EntityTypes.GROUP, new_groups, 'uuid', backend_uuid_id, backend_from, backend_to, query_params, transform ) if transform.relabelled: @@ -1040,7 +1044,7 @@ def group_node_transform(row): with get_progress_reporter()(desc=f'Adding new {EntityTypes.GROUP_NODE.value}(s)', total=total) as progress: for nrows, rows in batch_iter( - iterator.iterall(batch_size=qparams.batch_size), qparams.batch_size, group_node_transform + iterator.iterall(batch_size=query_params.batch_size), query_params.batch_size, group_node_transform ): backend_to.bulk_insert(EntityTypes.GROUP_NODE, rows) progress.update(nrows) @@ -1050,7 +1054,7 @@ def group_node_transform(row): def _make_import_group( group: Optional[orm.Group], labels: Set[str], node_ids_archive_backend: Dict[int, int], backend_to: StorageBackend, - qparams: QueryParams + query_params: QueryParams ) -> Optional[int]: """Make an import group containing all imported nodes. @@ -1094,7 +1098,8 @@ def _make_import_group( group_node_ids = { pk for pk, in orm.QueryBuilder(backend=backend_to).append(orm.Group, filters={ 'id': group_id - }, tag='group').append(orm.Node, with_group='group', project='id').iterall(batch_size=qparams.batch_size) + }, tag='group').append(orm.Node, with_group='group', project='id' + ).iterall(batch_size=query_params.batch_size) } # Add all the nodes to the Group @@ -1105,7 +1110,7 @@ def _make_import_group( 'dbgroup_id': group_id, 'dbnode_id': node_id } for node_id in node_ids_archive_backend.values() if node_id not in group_node_ids) - for nrows, rows in batch_iter(iterator, qparams.batch_size): + for nrows, rows in batch_iter(iterator, query_params.batch_size): backend_to.bulk_insert(EntityTypes.GROUP_NODE, rows) progress.update(nrows) @@ -1113,13 +1118,13 @@ def _make_import_group( def _get_new_object_keys( - key_format: str, backend_from: StorageBackend, backend_to: StorageBackend, qparams: QueryParams + key_format: str, backend_from: StorageBackend, backend_to: StorageBackend, query_params: QueryParams ) -> Set[str]: """Return the object keys that need to be added to the backend.""" archive_hashkeys: Set[str] = set() query = QueryBuilder(backend=backend_from).append(orm.Node, project='repository_metadata') with get_progress_reporter()(desc='Collecting archive Node file keys', total=query.count()) as progress: - for repository_metadata, in query.iterall(batch_size=qparams.batch_size): + for repository_metadata, in query.iterall(batch_size=query_params.batch_size): archive_hashkeys.update(key for key in Repository.flatten(repository_metadata).values() if key is not None) progress.update() From c0cd3c6d0f4e87af34d2518122564ea7c7ac1c19 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 4 Nov 2022 18:29:03 +0100 Subject: [PATCH 6/6] Update imports.py --- aiida/tools/archive/imports.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aiida/tools/archive/imports.py b/aiida/tools/archive/imports.py index b2754b603d..8d4b628d99 100644 --- a/aiida/tools/archive/imports.py +++ b/aiida/tools/archive/imports.py @@ -61,7 +61,7 @@ class QueryParams: batch_size: int """Batch size for streaming database rows.""" filter_size: int - """Maximum size of parameters allowed in a single query filter.""" + """Maximum number of parameters allowed in a single query filter.""" def import_archive( @@ -223,6 +223,8 @@ def _add_new_entities( ufields.append(ufield) with get_progress_reporter()(desc=f'Adding new {etype.value}(s)', total=total) as progress: + # batch the filtering of rows by filter size, to limit the number of query variables used in any one query, + # since certain backends have a limit on the number of variables in a query (such as SQLITE_MAX_VARIABLE_NUMBER) for nrows, ufields_batch in batch_iter(ufields, query_params.filter_size): rows = [ transform(row) for row in QueryBuilder(backend=backend_from).append(