From 98e902bcc5d38363b5109c38cfae3e6073a3a23c Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Mon, 28 Nov 2022 14:34:05 +0100 Subject: [PATCH] `PsqlDosBackend`: Use transaction whenever mutating session state Storing a node while iterating over the result of `QueryBuilder.iterall` would raise `sqlalchemy.exc.InvalidRequestError` with the message: Can't operate on closed transaction inside context manager. The problem was that the `Node` implementation for the `PsqlDosBackend`, the `SqlaNode` class, would not consistently open a transaction, using the `PsqlDosBackend.transaction` method, whenever it mutated the state of the session, and would then straight up commit to the current session. For example, when storing a new node, the `store` method would simply call save. Through the `ModelWrapper`, this would call commit on the session, but that was the session being used for the iteration. The same problem was present for the `SqlaGroup` implementation that had a number of places where sessions state was mutated without opening a transaction first. The problem is fixed therefore by consistently opening a transaction before making changes to the session. The `transaction` implementation is slightly changed as any `SqlaIntegrityError` raised during the context is now converted into an `aiida.common.exceptions.IntegrityError` to make it backend independent. --- aiida/storage/psql_dos/backend.py | 43 +++++++++---------- aiida/storage/psql_dos/orm/groups.py | 40 +++++++++--------- aiida/storage/psql_dos/orm/nodes.py | 53 +++++++++--------------- aiida/tools/graph/deletions.py | 3 +- tests/orm/implementation/test_backend.py | 7 +--- tests/orm/nodes/test_node.py | 3 +- tests/orm/test_querybuilder.py | 31 +++++++++++++- tests/storage/psql_dos/test_nodes.py | 2 +- tests/storage/psql_dos/test_session.py | 6 ++- tests/test_nodes.py | 4 +- 10 files changed, 101 insertions(+), 91 deletions(-) diff --git a/aiida/storage/psql_dos/backend.py b/aiida/storage/psql_dos/backend.py index 42dd076308..1500894eb7 100644 --- a/aiida/storage/psql_dos/backend.py +++ b/aiida/storage/psql_dos/backend.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Iterator, List, Optional, Sequence, Set, Union from disk_objectstore import Container +from sqlalchemy.exc import IntegrityError as SqlaIntegrityError from sqlalchemy.orm import Session, scoped_session, sessionmaker from aiida.common.exceptions import ClosedStorage, ConfigurationError, IntegrityError @@ -232,20 +233,23 @@ def users(self): @contextmanager def transaction(self) -> Iterator[Session]: - """Open a transaction to be used as a context manager. + """Open a transaction and yield the current session. If there is an exception within the context then the changes will be rolled back and the state will be as before - entering. Transactions can be nested. + entering, otherwise the changes will be commited and the transaction closed. Transactions can be nested. """ session = self.get_session() - if session.in_transaction(): - with session.begin_nested(): - yield session - session.commit() - else: - with session.begin(): + + try: + if session.in_transaction(): with session.begin_nested(): yield session + else: + with session.begin(): + with session.begin_nested(): + yield session + except SqlaIntegrityError as exception: + raise IntegrityError(str(exception)) from exception @property def in_transaction(self) -> bool: @@ -323,19 +327,16 @@ def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None: # from aiida.storage.psql_dos.models.group import DbGroupNode from aiida.storage.psql_dos.models.node import DbLink, DbNode - if not self.in_transaction: - raise AssertionError('Cannot delete nodes and links outside a transaction') - - session = self.get_session() - # Delete the membership of these nodes to groups. - session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete)) - ).delete(synchronize_session='fetch') - # Delete the links coming out of the nodes marked for deletion. - session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') - # Delete the links pointing to the nodes marked for deletion. - session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') - # Delete the actual nodes - session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + with self.transaction() as session: + # Delete the membership of these nodes to groups. + session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete)) + ).delete(synchronize_session='fetch') + # Delete the links coming out of the nodes marked for deletion. + session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + # Delete the links pointing to the nodes marked for deletion. + session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + # Delete the actual nodes + session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') def get_backend_entity(self, model: base.Base) -> BackendEntity: """ diff --git a/aiida/storage/psql_dos/orm/groups.py b/aiida/storage/psql_dos/orm/groups.py index 8c3e2b1940..319fc22a82 100644 --- a/aiida/storage/psql_dos/orm/groups.py +++ b/aiida/storage/psql_dos/orm/groups.py @@ -15,9 +15,10 @@ from aiida.orm.implementation.groups import BackendGroup, BackendGroupCollection from aiida.storage.psql_dos.models.group import DbGroup, DbGroupNode -from . import entities, users, utils +from . import entities, users from .extras_mixin import ExtrasMixin from .nodes import SqlaNode +from .utils import ModelWrapper, disable_expire_on_commit _LOGGER = logging.getLogger(__name__) @@ -46,7 +47,7 @@ def __init__(self, backend, label, user, description='', type_string=''): super().__init__(backend) dbgroup = self.MODEL_CLASS(label=label, description=description, user=user.bare_model, type_string=type_string) - self._model = utils.ModelWrapper(dbgroup, backend) + self._model = ModelWrapper(dbgroup, backend) @property def label(self): @@ -115,8 +116,9 @@ def is_stored(self): return self.pk is not None def store(self): - self.model.save() - return self + with self.backend.transaction(): + self.model.save() + return self def count(self): """Return the number of entities in this group. @@ -128,10 +130,9 @@ def count(self): def clear(self): """Remove all the nodes from this group.""" - session = self.backend.get_session() - # Note we have to call `bare_model` to circumvent flushing data to the database - self.bare_model.dbnodes = [] - session.commit() + with self.backend.transaction(): + # Note we have to call `bare_model` to circumvent flushing data to the database + self.bare_model.dbnodes = [] @property def nodes(self): @@ -192,7 +193,10 @@ def check_node(given_node): if not given_node.is_stored: raise ValueError('At least one of the provided nodes is unstored, stopping...') - with utils.disable_expire_on_commit(self.backend.get_session()) as session: + session = self.backend.get_session() + + with disable_expire_on_commit(session), self.backend.transaction() as session: + assert session.expire_on_commit is False if not skip_orm: # Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database dbnodes = self.model.dbnodes @@ -219,9 +223,6 @@ def check_node(given_node): ins = insert(table).values(ins_dict) session.execute(ins.on_conflict_do_nothing(index_elements=['dbnode_id', 'dbgroup_id'])) - # Commit everything as up till now we've just flushed - session.commit() - def remove_nodes(self, nodes, **kwargs): """Remove a node or a set of nodes from the group. @@ -249,7 +250,10 @@ def check_node(node): list_nodes = [] - with utils.disable_expire_on_commit(self.backend.get_session()) as session: + session = self.backend.get_session() + + with disable_expire_on_commit(session), self.backend.transaction() as session: + assert session.expire_on_commit is False if not skip_orm: for node in nodes: check_node(node) @@ -268,8 +272,6 @@ def check_node(node): statement = table.delete().where(clause) session.execute(statement) - session.commit() - class SqlaGroupCollection(BackendGroupCollection): """The SLQA collection of groups""" @@ -277,8 +279,6 @@ class SqlaGroupCollection(BackendGroupCollection): ENTITY_CLASS = SqlaGroup def delete(self, id): # pylint: disable=redefined-builtin - session = self.backend.get_session() - - row = session.get(self.ENTITY_CLASS.MODEL_CLASS, id) - session.delete(row) - session.commit() + with self.backend.transaction() as session: + row = session.get(self.ENTITY_CLASS.MODEL_CLASS, id) + session.delete(row) diff --git a/aiida/storage/psql_dos/orm/nodes.py b/aiida/storage/psql_dos/orm/nodes.py index 6822e05d50..a7c3760000 100644 --- a/aiida/storage/psql_dos/orm/nodes.py +++ b/aiida/storage/psql_dos/orm/nodes.py @@ -183,8 +183,6 @@ def user(self, user): self.model.user = user.bare_model def add_incoming(self, source, link_type, link_label): - session = self.backend.get_session() - type_check(source, self.__class__) if not self.is_stored: @@ -194,43 +192,33 @@ def add_incoming(self, source, link_type, link_label): raise exceptions.ModificationNotAllowed('source node has to be stored when adding a link from it') self._add_link(source, link_type, link_label) - session.commit() def _add_link(self, source, link_type, link_label): """Add a single link""" - session = self.backend.get_session() - - try: - with session.begin_nested(): + with self.backend.transaction() as session: + try: link = self.LINK_CLASS(input_id=source.pk, output_id=self.pk, label=link_label, type=link_type.value) session.add(link) - except SQLAlchemyError as exception: - raise exceptions.UniquenessError(f'failed to create the link: {exception}') from exception + except SQLAlchemyError as exception: + raise exceptions.UniquenessError(f'failed to create the link: {exception}') from exception def clean_values(self): self.model.attributes = clean_value(self.model.attributes) self.model.extras = clean_value(self.model.extras) - def store(self, links=None, with_transaction=True, clean=True): # pylint: disable=arguments-differ - session = self.backend.get_session() - - if clean: - self.clean_values() + def store(self, links=None, with_transaction=True, clean=True): # pylint: disable=arguments-differ,unused-argument + with self.backend.transaction(): - session.add(self.model) + if clean: + self.clean_values() - if links: - for link_triple in links: - self._add_link(*link_triple) + self.model.save() - if with_transaction: - try: - session.commit() - except SQLAlchemyError: - session.rollback() - raise + if links: + for link_triple in links: + self._add_link(*link_triple) - return self + return self @property def attributes(self): @@ -313,7 +301,6 @@ class SqlaNodeCollection(BackendNodeCollection): def get(self, pk): session = self.backend.get_session() - try: return self.ENTITY_CLASS.from_dbmodel( session.query(self.ENTITY_CLASS.MODEL_CLASS).filter_by(id=pk).one(), self.backend @@ -322,11 +309,9 @@ def get(self, pk): raise exceptions.NotExistent(f"Node with pk '{pk}' not found") from NoResultFound def delete(self, pk): - session = self.backend.get_session() - - try: - row = session.query(self.ENTITY_CLASS.MODEL_CLASS).filter_by(id=pk).one() - session.delete(row) - session.commit() - except NoResultFound: - raise exceptions.NotExistent(f"Node with pk '{pk}' not found") from NoResultFound + with self.backend.transaction() as session: + try: + row = session.query(self.ENTITY_CLASS.MODEL_CLASS).filter_by(id=pk).one() + session.delete(row) + except NoResultFound: + raise exceptions.NotExistent(f"Node with pk '{pk}' not found") from NoResultFound diff --git a/aiida/tools/graph/deletions.py b/aiida/tools/graph/deletions.py index 48011a2550..bc8465d533 100644 --- a/aiida/tools/graph/deletions.py +++ b/aiida/tools/graph/deletions.py @@ -104,8 +104,7 @@ def _missing_callback(_pks: Iterable[int]): return (pks_set_to_delete, True) DELETE_LOGGER.report('Starting node deletion...') - with backend.transaction(): - backend.delete_nodes_and_connections(pks_set_to_delete) + backend.delete_nodes_and_connections(pks_set_to_delete) DELETE_LOGGER.report('Deletion of nodes completed.') return (pks_set_to_delete, True) diff --git a/tests/orm/implementation/test_backend.py b/tests/orm/implementation/test_backend.py index ede98f9126..17b74465c4 100644 --- a/tests/orm/implementation/test_backend.py +++ b/tests/orm/implementation/test_backend.py @@ -162,12 +162,7 @@ def test_delete_nodes_and_connections(self): assert len(calc_node.base.links.get_outgoing().all()) == 1 assert len(group.nodes) == 1 - # cannot call outside a transaction - with pytest.raises(AssertionError): - self.backend.delete_nodes_and_connections([node_pk]) - - with self.backend.transaction(): - self.backend.delete_nodes_and_connections([node_pk]) + self.backend.delete_nodes_and_connections([node_pk]) # checks after deletion with pytest.raises(exceptions.NotExistent): diff --git a/tests/orm/nodes/test_node.py b/tests/orm/nodes/test_node.py index f9297ac3ba..6db9a77514 100644 --- a/tests/orm/nodes/test_node.py +++ b/tests/orm/nodes/test_node.py @@ -860,8 +860,7 @@ def test_delete_through_backend(self): assert len(Log.collection.get_logs_for(data_two)) == 1 assert Log.collection.get_logs_for(data_two)[0].pk == log_two.pk - with backend.transaction(): - backend.delete_nodes_and_connections([data_two.pk]) + backend.delete_nodes_and_connections([data_two.pk]) assert len(Log.collection.get_logs_for(data_one)) == 1 assert Log.collection.get_logs_for(data_one)[0].pk == log_one.pk diff --git a/tests/orm/test_querybuilder.py b/tests/orm/test_querybuilder.py index ea741e0657..9893660eef 100644 --- a/tests/orm/test_querybuilder.py +++ b/tests/orm/test_querybuilder.py @@ -1476,7 +1476,10 @@ def test_len_results(self): assert len(qb.all()) == qb.count() def test_iterall_with_mutation(self): - """Test that nodes can be mutated while being iterated using ``QueryBuilder.iterall``.""" + """Test that nodes can be mutated while being iterated using ``QueryBuilder.iterall``. + + This is a regression test for https://github.com/aiidateam/aiida-core/issues/5672 . + """ count = 10 pks = [] @@ -1491,6 +1494,32 @@ def test_iterall_with_mutation(self): for pk in pks: assert orm.load_node(pk).get_extra('key') == 'value' + @pytest.mark.usefixtures('aiida_profile_clean') + def test_iterall_with_store(self): + """Test that nodes can be stored while being iterated using ``QueryBuilder.iterall``. + + This is a regression test for https://github.com/aiidateam/aiida-core/issues/5802 . + """ + count = 10 + pks = [] + pks_clone = [] + + for index in range(count): + node = orm.Int(index).store() + pks.append(node.pk) + + # Ensure that batch size is smaller than the total rows yielded + for [node] in orm.QueryBuilder().append(orm.Data).iterall(batch_size=2): + clone = copy.deepcopy(node) + clone.store() + pks_clone.append((clone.value, clone.pk)) + group = orm.Group(label=str(node.uuid)).store() + group.add_nodes([node]) + + # Need to sort the cloned pks based on the value, because the order of ``iterall`` is not guaranteed + for pk, pk_clone in zip(pks, [e[1] for e in sorted(pks_clone)]): + assert orm.load_node(pk) == orm.load_node(pk_clone) + class TestManager: diff --git a/tests/storage/psql_dos/test_nodes.py b/tests/storage/psql_dos/test_nodes.py index a3953f3fe7..223aa54f8e 100644 --- a/tests/storage/psql_dos/test_nodes.py +++ b/tests/storage/psql_dos/test_nodes.py @@ -81,7 +81,7 @@ def test_multiple_node_creation(self): from aiida.storage.psql_dos.models.node import DbNode # Get the automatic user - dbuser = self.backend.users.create('user@aiida.net').store().bare_model + dbuser = self.backend.users.create(get_new_uuid()).store().bare_model # Create a new node but don't add it to the session node_uuid = get_new_uuid() DbNode(user=dbuser, uuid=node_uuid, node_type=None) diff --git a/tests/storage/psql_dos/test_session.py b/tests/storage/psql_dos/test_session.py index 24d0c3fe9e..b61b0a5553 100644 --- a/tests/storage/psql_dos/test_session.py +++ b/tests/storage/psql_dos/test_session.py @@ -151,13 +151,15 @@ def test_node_access_with_sessions(self): custom_session = session() try: - user = self.backend.users.create(email=uuid.uuid4().hex).store() - node = self.backend.nodes.create(node_type='', user=user).store() + with self.backend.transaction(): + user = self.backend.users.create(email=uuid.uuid4().hex).store() + node = self.backend.nodes.create(node_type='', user=user).store() master_session = node.model.session # pylint: disable=protected-access assert master_session is not custom_session # Manually load the DbNode in a different session dbnode_reloaded = custom_session.get(sa.models.node.DbNode, node.id) + assert dbnode_reloaded is not None # Now, go through one by one changing the possible attributes (of the model) # and check that they're updated when the user reads them from the aiida node diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 757d163526..cfc0e7db23 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -272,14 +272,14 @@ def test_uuid_uniquess(self): """ A uniqueness constraint on the UUID column of the Node model should prevent multiple nodes with identical UUID """ - from sqlalchemy.exc import IntegrityError as SqlaIntegrityError + from aiida.common.exceptions import IntegrityError a = orm.Data() b = orm.Data() b.backend_entity.bare_model.uuid = a.uuid a.store() - with pytest.raises(SqlaIntegrityError): + with pytest.raises(IntegrityError): b.store() def test_attribute_mutability(self):