From f2e861e9b65da83ffd4a1ae423a13b8a29a67fa5 Mon Sep 17 00:00:00 2001 From: Mark Date: Mon, 18 Dec 2023 15:18:36 +0300 Subject: [PATCH] Added unique=True on Version (#157) * added unique=True on Version; re https://github.com/synthesized-io/sdk/issues/582 * updated get_version_id * added existing_nullable=False to a migration * added tests for coverage --- ...b44_added_version_name_unique_constaint.py | 39 +++++++++++ src/insight/database/schema.py | 2 +- src/insight/database/utils.py | 37 ++++++++--- tests/test_database/test_utils.py | 65 +++++++++++++++++++ 4 files changed, 133 insertions(+), 10 deletions(-) create mode 100644 src/insight/alembic/versions/a2198ae60b44_added_version_name_unique_constaint.py create mode 100644 tests/test_database/test_utils.py diff --git a/src/insight/alembic/versions/a2198ae60b44_added_version_name_unique_constaint.py b/src/insight/alembic/versions/a2198ae60b44_added_version_name_unique_constaint.py new file mode 100644 index 00000000..d8155847 --- /dev/null +++ b/src/insight/alembic/versions/a2198ae60b44_added_version_name_unique_constaint.py @@ -0,0 +1,39 @@ +"""Set version.name.unique = True + +Revision ID: a2198ae60b44 +Revises: d2198fd60b0e +Create Date: 2023-12-13 13:25:17.878689 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "a2198ae60b44" +down_revision = "d2198fd60b0e" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "version", + "name", + existing_type=sa.VARCHAR(length=50), + unique=True, + existing_nullable=False, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "version", + "name", + existing_type=sa.VARCHAR(length=50), + unique=False, + existing_nullable=False, + ) + # ### end Alembic commands ### diff --git a/src/insight/database/schema.py b/src/insight/database/schema.py index cbec19d4..f61e1c03 100644 --- a/src/insight/database/schema.py +++ b/src/insight/database/schema.py @@ -29,7 +29,7 @@ class Version(Base): __tablename__ = "version" id = mapped_column(INTEGER, primary_key=True) - name = mapped_column(VARCHAR(50), nullable=False, default="unversioned") + name = mapped_column(VARCHAR(50), nullable=False, default="unversioned", unique=True) created_at = mapped_column(TIMESTAMP, default=func.now()) diff --git a/src/insight/database/utils.py b/src/insight/database/utils.py index 1ff70dbf..eb42868b 100644 --- a/src/insight/database/utils.py +++ b/src/insight/database/utils.py @@ -5,6 +5,7 @@ import pandas as pd from sqlalchemy import create_engine +from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.sql import select @@ -76,16 +77,34 @@ def get_version_id(version: str, session: Session) -> int: Args: version (str): The name of the version. session (Session): The database session. + + Returns: + int: The id of the version. """ - db_version = get_object_from_db_by_name(version, session, model.Version) - if db_version is None: - with session: - db_version = model.Version(name=version) - session.add(db_version) - session.commit() - if not db_version.id: - raise ConnectionError(_database_fail_note) - return int(db_version.id) + try: + db_version = get_object_from_db_by_name(version, session, model.Version) + if db_version is None: + with session.begin_nested(): + db_version = model.Version(name=version) + session.add(db_version) + session.commit() + + if not db_version.id: + raise ConnectionError(_database_fail_note) + + return int(db_version.id) + + except IntegrityError as e: + session.rollback() + # Handle the integrity error by looking up the existing version + db_version = get_object_from_db_by_name(version, session, model.Version) + if db_version and db_version.id: + return int(db_version.id) + raise e + + except SQLAlchemyError as e: + session.rollback() + raise e def get_object_from_db_by_name( diff --git a/tests/test_database/test_utils.py b/tests/test_database/test_utils.py new file mode 100644 index 00000000..d1543298 --- /dev/null +++ b/tests/test_database/test_utils.py @@ -0,0 +1,65 @@ +from unittest.mock import MagicMock + +import pytest +from sqlalchemy.exc import IntegrityError, SQLAlchemyError + +from insight.database.utils import get_version_id + + +@pytest.fixture +def mock_session(): + session = MagicMock() + session.begin_nested.return_value.__enter__.return_value = session + + def add_side_effect(model_instance): + model_instance.id = 123 + + session.add.side_effect = add_side_effect + session.commit = MagicMock() + + session.rollback = MagicMock() + + mock_scalar_one_or_none = MagicMock() + mock_scalar_one_or_none.id = 789 + executed = MagicMock() + executed.scalar_one_or_none.return_value = mock_scalar_one_or_none + session.execute.return_value = executed + + assert session.execute().scalar_one_or_none().id == 789 + return session + + +def test_get_version_id_existing_version(mock_session): + mock_session.execute.return_value.scalar_one_or_none.return_value = MagicMock(id=123) + version_id = get_version_id("existing_version", mock_session) + assert version_id == 123 + + +def test_get_version_id_new_version(mock_session): + mock_session.execute.return_value.scalar_one_or_none.return_value = None + mock_session.begin_nested.return_value.__enter__.return_value.add.return_value = MagicMock( + id=123 + ) + version_id = get_version_id("new_version", mock_session) + assert version_id == 123 + + +def test_get_version_id_integrity_error(mock_session): + # First call to execute raises IntegrityError + # Second call to execute returns a MagicMock with the correct id + second_execute = MagicMock() + second_execute.scalar_one_or_none.return_value = MagicMock(id=789) + mock_session.execute.side_effect = [ + IntegrityError("Mocked Integrity Error", "params", "orig"), + second_execute, + ] + + version_id = get_version_id("version_with_error", mock_session) + + assert version_id == 789 + + +def test_get_version_id_sqlalchemy_error(mock_session): + mock_session.execute.side_effect = SQLAlchemyError("Mocked SQLAlchemy Error") + with pytest.raises(SQLAlchemyError): + get_version_id("version_with_sqlalchemy_error", mock_session)