Skip to content

Commit

Permalink
Added unique=True on Version (#157)
Browse files Browse the repository at this point in the history
* added unique=True on Version; re synthesized-io/sdk#582

* updated get_version_id

* added existing_nullable=False to a migration

* added tests for coverage
  • Loading branch information
marqueewinq authored Dec 18, 2023
1 parent 68dc221 commit f2e861e
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Set version.name.unique = True
Revision ID: a2198ae60b44
Revises: d2198fd60b0e
Create Date: 2023-12-13 13:25:17.878689
"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "a2198ae60b44"
down_revision = "d2198fd60b0e"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"version",
"name",
existing_type=sa.VARCHAR(length=50),
unique=True,
existing_nullable=False,
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column(
"version",
"name",
existing_type=sa.VARCHAR(length=50),
unique=False,
existing_nullable=False,
)
# ### end Alembic commands ###
2 changes: 1 addition & 1 deletion src/insight/database/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Version(Base):
__tablename__ = "version"

id = mapped_column(INTEGER, primary_key=True)
name = mapped_column(VARCHAR(50), nullable=False, default="unversioned")
name = mapped_column(VARCHAR(50), nullable=False, default="unversioned", unique=True)
created_at = mapped_column(TIMESTAMP, default=func.now())


Expand Down
37 changes: 28 additions & 9 deletions src/insight/database/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pandas as pd
from sqlalchemy import create_engine
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.sql import select

Expand Down Expand Up @@ -76,16 +77,34 @@ def get_version_id(version: str, session: Session) -> int:
Args:
version (str): The name of the version.
session (Session): The database session.
Returns:
int: The id of the version.
"""
db_version = get_object_from_db_by_name(version, session, model.Version)
if db_version is None:
with session:
db_version = model.Version(name=version)
session.add(db_version)
session.commit()
if not db_version.id:
raise ConnectionError(_database_fail_note)
return int(db_version.id)
try:
db_version = get_object_from_db_by_name(version, session, model.Version)
if db_version is None:
with session.begin_nested():
db_version = model.Version(name=version)
session.add(db_version)
session.commit()

if not db_version.id:
raise ConnectionError(_database_fail_note)

return int(db_version.id)

except IntegrityError as e:
session.rollback()
# Handle the integrity error by looking up the existing version
db_version = get_object_from_db_by_name(version, session, model.Version)
if db_version and db_version.id:
return int(db_version.id)
raise e

except SQLAlchemyError as e:
session.rollback()
raise e


def get_object_from_db_by_name(
Expand Down
65 changes: 65 additions & 0 deletions tests/test_database/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from unittest.mock import MagicMock

import pytest
from sqlalchemy.exc import IntegrityError, SQLAlchemyError

from insight.database.utils import get_version_id


@pytest.fixture
def mock_session():
session = MagicMock()
session.begin_nested.return_value.__enter__.return_value = session

def add_side_effect(model_instance):
model_instance.id = 123

session.add.side_effect = add_side_effect
session.commit = MagicMock()

session.rollback = MagicMock()

mock_scalar_one_or_none = MagicMock()
mock_scalar_one_or_none.id = 789
executed = MagicMock()
executed.scalar_one_or_none.return_value = mock_scalar_one_or_none
session.execute.return_value = executed

assert session.execute().scalar_one_or_none().id == 789
return session


def test_get_version_id_existing_version(mock_session):
mock_session.execute.return_value.scalar_one_or_none.return_value = MagicMock(id=123)
version_id = get_version_id("existing_version", mock_session)
assert version_id == 123


def test_get_version_id_new_version(mock_session):
mock_session.execute.return_value.scalar_one_or_none.return_value = None
mock_session.begin_nested.return_value.__enter__.return_value.add.return_value = MagicMock(
id=123
)
version_id = get_version_id("new_version", mock_session)
assert version_id == 123


def test_get_version_id_integrity_error(mock_session):
# First call to execute raises IntegrityError
# Second call to execute returns a MagicMock with the correct id
second_execute = MagicMock()
second_execute.scalar_one_or_none.return_value = MagicMock(id=789)
mock_session.execute.side_effect = [
IntegrityError("Mocked Integrity Error", "params", "orig"),
second_execute,
]

version_id = get_version_id("version_with_error", mock_session)

assert version_id == 789


def test_get_version_id_sqlalchemy_error(mock_session):
mock_session.execute.side_effect = SQLAlchemyError("Mocked SQLAlchemy Error")
with pytest.raises(SQLAlchemyError):
get_version_id("version_with_sqlalchemy_error", mock_session)

0 comments on commit f2e861e

Please sign in to comment.