Skip to content

Commit

Permalink
PsqlDosBackend: Use transaction whenever mutating session state
Browse files Browse the repository at this point in the history
Storing a node while iterating over the result of `QueryBuilder.iterall`
would raise `sqlalchemy.exc.InvalidRequestError` with the message:

    Can't operate on closed transaction inside context manager.

The problem was that the `Node` implementation for the `PsqlDosBackend`,
the `SqlaNode` class, would not consistently open a transaction, using
the `PsqlDosBackend.transaction` method, whenever it mutated the state
of the session, and would then straight up commit to the current session.
For example, when storing a new node, the `store` method would simply
call save. Through the `ModelWrapper`, this would call commit on the
session, but that was the session being used for the iteration.

The same problem was present for the `SqlaGroup` implementation that had
a number of places where sessions state was mutated without opening a
transaction first.

The problem is fixed therefore by consistently opening a transaction
before making changes to the session. The `transaction` implementation
is slightly changed as any `SqlaIntegrityError` raised during the context
is now converted into an `aiida.common.exceptions.IntegrityError` to make
it backend independent.
  • Loading branch information
sphuber committed Dec 12, 2022
1 parent 688ace5 commit 98e902b
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 91 deletions.
43 changes: 22 additions & 21 deletions aiida/storage/psql_dos/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import TYPE_CHECKING, Iterator, List, Optional, Sequence, Set, Union

from disk_objectstore import Container
from sqlalchemy.exc import IntegrityError as SqlaIntegrityError
from sqlalchemy.orm import Session, scoped_session, sessionmaker

from aiida.common.exceptions import ClosedStorage, ConfigurationError, IntegrityError
Expand Down Expand Up @@ -232,20 +233,23 @@ def users(self):

@contextmanager
def transaction(self) -> Iterator[Session]:
"""Open a transaction to be used as a context manager.
"""Open a transaction and yield the current session.
If there is an exception within the context then the changes will be rolled back and the state will be as before
entering. Transactions can be nested.
entering, otherwise the changes will be commited and the transaction closed. Transactions can be nested.
"""
session = self.get_session()
if session.in_transaction():
with session.begin_nested():
yield session
session.commit()
else:
with session.begin():

try:
if session.in_transaction():
with session.begin_nested():
yield session
else:
with session.begin():
with session.begin_nested():
yield session
except SqlaIntegrityError as exception:
raise IntegrityError(str(exception)) from exception

@property
def in_transaction(self) -> bool:
Expand Down Expand Up @@ -323,19 +327,16 @@ def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None: #
from aiida.storage.psql_dos.models.group import DbGroupNode
from aiida.storage.psql_dos.models.node import DbLink, DbNode

if not self.in_transaction:
raise AssertionError('Cannot delete nodes and links outside a transaction')

session = self.get_session()
# Delete the membership of these nodes to groups.
session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete))
).delete(synchronize_session='fetch')
# Delete the links coming out of the nodes marked for deletion.
session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
# Delete the links pointing to the nodes marked for deletion.
session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
# Delete the actual nodes
session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
with self.transaction() as session:
# Delete the membership of these nodes to groups.
session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete))
).delete(synchronize_session='fetch')
# Delete the links coming out of the nodes marked for deletion.
session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
# Delete the links pointing to the nodes marked for deletion.
session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
# Delete the actual nodes
session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')

def get_backend_entity(self, model: base.Base) -> BackendEntity:
"""
Expand Down
40 changes: 20 additions & 20 deletions aiida/storage/psql_dos/orm/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
from aiida.orm.implementation.groups import BackendGroup, BackendGroupCollection
from aiida.storage.psql_dos.models.group import DbGroup, DbGroupNode

from . import entities, users, utils
from . import entities, users
from .extras_mixin import ExtrasMixin
from .nodes import SqlaNode
from .utils import ModelWrapper, disable_expire_on_commit

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -46,7 +47,7 @@ def __init__(self, backend, label, user, description='', type_string=''):
super().__init__(backend)

dbgroup = self.MODEL_CLASS(label=label, description=description, user=user.bare_model, type_string=type_string)
self._model = utils.ModelWrapper(dbgroup, backend)
self._model = ModelWrapper(dbgroup, backend)

@property
def label(self):
Expand Down Expand Up @@ -115,8 +116,9 @@ def is_stored(self):
return self.pk is not None

def store(self):
self.model.save()
return self
with self.backend.transaction():
self.model.save()
return self

def count(self):
"""Return the number of entities in this group.
Expand All @@ -128,10 +130,9 @@ def count(self):

def clear(self):
"""Remove all the nodes from this group."""
session = self.backend.get_session()
# Note we have to call `bare_model` to circumvent flushing data to the database
self.bare_model.dbnodes = []
session.commit()
with self.backend.transaction():
# Note we have to call `bare_model` to circumvent flushing data to the database
self.bare_model.dbnodes = []

@property
def nodes(self):
Expand Down Expand Up @@ -192,7 +193,10 @@ def check_node(given_node):
if not given_node.is_stored:
raise ValueError('At least one of the provided nodes is unstored, stopping...')

with utils.disable_expire_on_commit(self.backend.get_session()) as session:
session = self.backend.get_session()

with disable_expire_on_commit(session), self.backend.transaction() as session:
assert session.expire_on_commit is False
if not skip_orm:
# Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database
dbnodes = self.model.dbnodes
Expand All @@ -219,9 +223,6 @@ def check_node(given_node):
ins = insert(table).values(ins_dict)
session.execute(ins.on_conflict_do_nothing(index_elements=['dbnode_id', 'dbgroup_id']))

# Commit everything as up till now we've just flushed
session.commit()

def remove_nodes(self, nodes, **kwargs):
"""Remove a node or a set of nodes from the group.
Expand Down Expand Up @@ -249,7 +250,10 @@ def check_node(node):

list_nodes = []

with utils.disable_expire_on_commit(self.backend.get_session()) as session:
session = self.backend.get_session()

with disable_expire_on_commit(session), self.backend.transaction() as session:
assert session.expire_on_commit is False
if not skip_orm:
for node in nodes:
check_node(node)
Expand All @@ -268,17 +272,13 @@ def check_node(node):
statement = table.delete().where(clause)
session.execute(statement)

session.commit()


class SqlaGroupCollection(BackendGroupCollection):
"""The SLQA collection of groups"""

ENTITY_CLASS = SqlaGroup

def delete(self, id): # pylint: disable=redefined-builtin
session = self.backend.get_session()

row = session.get(self.ENTITY_CLASS.MODEL_CLASS, id)
session.delete(row)
session.commit()
with self.backend.transaction() as session:
row = session.get(self.ENTITY_CLASS.MODEL_CLASS, id)
session.delete(row)
53 changes: 19 additions & 34 deletions aiida/storage/psql_dos/orm/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,6 @@ def user(self, user):
self.model.user = user.bare_model

def add_incoming(self, source, link_type, link_label):
session = self.backend.get_session()

type_check(source, self.__class__)

if not self.is_stored:
Expand All @@ -194,43 +192,33 @@ def add_incoming(self, source, link_type, link_label):
raise exceptions.ModificationNotAllowed('source node has to be stored when adding a link from it')

self._add_link(source, link_type, link_label)
session.commit()

def _add_link(self, source, link_type, link_label):
"""Add a single link"""
session = self.backend.get_session()

try:
with session.begin_nested():
with self.backend.transaction() as session:
try:
link = self.LINK_CLASS(input_id=source.pk, output_id=self.pk, label=link_label, type=link_type.value)
session.add(link)
except SQLAlchemyError as exception:
raise exceptions.UniquenessError(f'failed to create the link: {exception}') from exception
except SQLAlchemyError as exception:
raise exceptions.UniquenessError(f'failed to create the link: {exception}') from exception

def clean_values(self):
self.model.attributes = clean_value(self.model.attributes)
self.model.extras = clean_value(self.model.extras)

def store(self, links=None, with_transaction=True, clean=True): # pylint: disable=arguments-differ
session = self.backend.get_session()

if clean:
self.clean_values()
def store(self, links=None, with_transaction=True, clean=True): # pylint: disable=arguments-differ,unused-argument
with self.backend.transaction():

session.add(self.model)
if clean:
self.clean_values()

if links:
for link_triple in links:
self._add_link(*link_triple)
self.model.save()

if with_transaction:
try:
session.commit()
except SQLAlchemyError:
session.rollback()
raise
if links:
for link_triple in links:
self._add_link(*link_triple)

return self
return self

@property
def attributes(self):
Expand Down Expand Up @@ -313,7 +301,6 @@ class SqlaNodeCollection(BackendNodeCollection):

def get(self, pk):
session = self.backend.get_session()

try:
return self.ENTITY_CLASS.from_dbmodel(
session.query(self.ENTITY_CLASS.MODEL_CLASS).filter_by(id=pk).one(), self.backend
Expand All @@ -322,11 +309,9 @@ def get(self, pk):
raise exceptions.NotExistent(f"Node with pk '{pk}' not found") from NoResultFound

def delete(self, pk):
session = self.backend.get_session()

try:
row = session.query(self.ENTITY_CLASS.MODEL_CLASS).filter_by(id=pk).one()
session.delete(row)
session.commit()
except NoResultFound:
raise exceptions.NotExistent(f"Node with pk '{pk}' not found") from NoResultFound
with self.backend.transaction() as session:
try:
row = session.query(self.ENTITY_CLASS.MODEL_CLASS).filter_by(id=pk).one()
session.delete(row)
except NoResultFound:
raise exceptions.NotExistent(f"Node with pk '{pk}' not found") from NoResultFound
3 changes: 1 addition & 2 deletions aiida/tools/graph/deletions.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ def _missing_callback(_pks: Iterable[int]):
return (pks_set_to_delete, True)

DELETE_LOGGER.report('Starting node deletion...')
with backend.transaction():
backend.delete_nodes_and_connections(pks_set_to_delete)
backend.delete_nodes_and_connections(pks_set_to_delete)
DELETE_LOGGER.report('Deletion of nodes completed.')

return (pks_set_to_delete, True)
Expand Down
7 changes: 1 addition & 6 deletions tests/orm/implementation/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,7 @@ def test_delete_nodes_and_connections(self):
assert len(calc_node.base.links.get_outgoing().all()) == 1
assert len(group.nodes) == 1

# cannot call outside a transaction
with pytest.raises(AssertionError):
self.backend.delete_nodes_and_connections([node_pk])

with self.backend.transaction():
self.backend.delete_nodes_and_connections([node_pk])
self.backend.delete_nodes_and_connections([node_pk])

# checks after deletion
with pytest.raises(exceptions.NotExistent):
Expand Down
3 changes: 1 addition & 2 deletions tests/orm/nodes/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,8 +860,7 @@ def test_delete_through_backend(self):
assert len(Log.collection.get_logs_for(data_two)) == 1
assert Log.collection.get_logs_for(data_two)[0].pk == log_two.pk

with backend.transaction():
backend.delete_nodes_and_connections([data_two.pk])
backend.delete_nodes_and_connections([data_two.pk])

assert len(Log.collection.get_logs_for(data_one)) == 1
assert Log.collection.get_logs_for(data_one)[0].pk == log_one.pk
Expand Down
31 changes: 30 additions & 1 deletion tests/orm/test_querybuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1476,7 +1476,10 @@ def test_len_results(self):
assert len(qb.all()) == qb.count()

def test_iterall_with_mutation(self):
"""Test that nodes can be mutated while being iterated using ``QueryBuilder.iterall``."""
"""Test that nodes can be mutated while being iterated using ``QueryBuilder.iterall``.
This is a regression test for https://github.com/aiidateam/aiida-core/issues/5672 .
"""
count = 10
pks = []

Expand All @@ -1491,6 +1494,32 @@ def test_iterall_with_mutation(self):
for pk in pks:
assert orm.load_node(pk).get_extra('key') == 'value'

@pytest.mark.usefixtures('aiida_profile_clean')
def test_iterall_with_store(self):
"""Test that nodes can be stored while being iterated using ``QueryBuilder.iterall``.
This is a regression test for https://github.com/aiidateam/aiida-core/issues/5802 .
"""
count = 10
pks = []
pks_clone = []

for index in range(count):
node = orm.Int(index).store()
pks.append(node.pk)

# Ensure that batch size is smaller than the total rows yielded
for [node] in orm.QueryBuilder().append(orm.Data).iterall(batch_size=2):
clone = copy.deepcopy(node)
clone.store()
pks_clone.append((clone.value, clone.pk))
group = orm.Group(label=str(node.uuid)).store()
group.add_nodes([node])

# Need to sort the cloned pks based on the value, because the order of ``iterall`` is not guaranteed
for pk, pk_clone in zip(pks, [e[1] for e in sorted(pks_clone)]):
assert orm.load_node(pk) == orm.load_node(pk_clone)


class TestManager:

Expand Down
2 changes: 1 addition & 1 deletion tests/storage/psql_dos/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_multiple_node_creation(self):
from aiida.storage.psql_dos.models.node import DbNode

# Get the automatic user
dbuser = self.backend.users.create('user@aiida.net').store().bare_model
dbuser = self.backend.users.create(get_new_uuid()).store().bare_model
# Create a new node but don't add it to the session
node_uuid = get_new_uuid()
DbNode(user=dbuser, uuid=node_uuid, node_type=None)
Expand Down
6 changes: 4 additions & 2 deletions tests/storage/psql_dos/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,15 @@ def test_node_access_with_sessions(self):
custom_session = session()

try:
user = self.backend.users.create(email=uuid.uuid4().hex).store()
node = self.backend.nodes.create(node_type='', user=user).store()
with self.backend.transaction():
user = self.backend.users.create(email=uuid.uuid4().hex).store()
node = self.backend.nodes.create(node_type='', user=user).store()
master_session = node.model.session # pylint: disable=protected-access
assert master_session is not custom_session

# Manually load the DbNode in a different session
dbnode_reloaded = custom_session.get(sa.models.node.DbNode, node.id)
assert dbnode_reloaded is not None

# Now, go through one by one changing the possible attributes (of the model)
# and check that they're updated when the user reads them from the aiida node
Expand Down
4 changes: 2 additions & 2 deletions tests/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,14 @@ def test_uuid_uniquess(self):
"""
A uniqueness constraint on the UUID column of the Node model should prevent multiple nodes with identical UUID
"""
from sqlalchemy.exc import IntegrityError as SqlaIntegrityError
from aiida.common.exceptions import IntegrityError

a = orm.Data()
b = orm.Data()
b.backend_entity.bare_model.uuid = a.uuid
a.store()

with pytest.raises(SqlaIntegrityError):
with pytest.raises(IntegrityError):
b.store()

def test_attribute_mutability(self):
Expand Down

0 comments on commit 98e902b

Please sign in to comment.