Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement]: Improve pending migration(s) check #2013

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions agenta-backend/agenta_backend/migrations/postgres/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ def is_initial_setup(engine) -> bool:
return not all_tables_exist


async def get_applied_migrations(engine: AsyncEngine):
async def get_current_migration_head_from_db(engine: AsyncEngine):
"""
Checks the alembic_version table to get all the migrations that has been applied.
Checks the alembic_version table to get the current migration head that has been applied.

Args:
engine (Engine): The engine that connects to an sqlalchemy pool

Returns:
a list of strings
the current migration head (where 'head' is the revision stored in the migration script)
"""

async with engine.connect() as connection:
Expand All @@ -75,32 +75,37 @@ async def get_applied_migrations(engine: AsyncEngine):
# to make Alembic start tracking the migration changes.
# --------------------------------------------------------------------------------------
# This effect (the exception raising) happens for both users (first-time and returning)
return ["alembic_version"]
return "alembic_version"

applied_migrations = [row[0] for row in result.fetchall()]
return applied_migrations
migration_heads = [row[0] for row in result.fetchall()]
assert (
len(migration_heads) == 1
), "There can only be one migration head stored in the database."
return migration_heads[0]


async def get_pending_migrations():
async def get_pending_migration_head():
"""
Gets the migrations that have not been applied.
Gets the migration head that have not been applied.

Returns:
the number of pending migrations
the pending migration head
"""

engine = create_async_engine(url=os.environ["POSTGRES_URI"])
try:
applied_migrations = await get_applied_migrations(engine=engine)
migration_files = [script.revision for script in script.walk_revisions()]
pending_migrations = [m for m in migration_files if m not in applied_migrations]

if "alembic_version" in applied_migrations:
pending_migrations.append("alembic_version")
current_migration_script_head = script.get_current_head()
migration_head_from_db = await get_current_migration_head_from_db(engine=engine)

pending_migration_head = []
if current_migration_script_head != migration_head_from_db:
pending_migration_head.append(current_migration_script_head)
if "alembic_version" == migration_head_from_db:
pending_migration_head.append("alembic_version")
finally:
await engine.dispose()

return pending_migrations
return pending_migration_head


def run_alembic_migration():
Expand All @@ -109,9 +114,9 @@ def run_alembic_migration():
"""

try:
pending_migrations = asyncio.run(get_pending_migrations())
pending_migration_head = asyncio.run(get_pending_migration_head())
APPLY_AUTO_MIGRATIONS = os.environ.get("AGENTA_AUTO_MIGRATIONS")
FIRST_TIME_USER = True if "alembic_version" in pending_migrations else False
FIRST_TIME_USER = True if "alembic_version" in pending_migration_head else False

if FIRST_TIME_USER or APPLY_AUTO_MIGRATIONS == "true":
command.upgrade(alembic_cfg, "head")
Expand All @@ -133,7 +138,7 @@ def run_alembic_migration():
except Exception as e:
click.echo(
click.style(
f"\nAn ERROR occured while applying migration: {traceback.format_exc()}\nThe container will now exit.",
f"\nAn ERROR occurred while applying migration: {traceback.format_exc()}\nThe container will now exit.",
fg="red",
),
color=True,
Expand All @@ -146,11 +151,11 @@ async def check_for_new_migrations():
Checks for new migrations and notify the user.
"""

pending_migrations = await get_pending_migrations()
if len(pending_migrations) >= 1:
pending_migration_head = await get_pending_migration_head()
if len(pending_migration_head) >= 1 and isinstance(pending_migration_head[0], str):
click.echo(
click.style(
f"\nWe have detected that there are pending database migrations {pending_migrations} that need to be applied to keep the application up to date. To ensure the application functions correctly with the latest updates, please follow the guide here => https://docs.agenta.ai/self-host/migration/applying-schema-migration\n",
f"\nWe have detected that there are pending database migrations {pending_migration_head} that need to be applied to keep the application up to date. To ensure the application functions correctly with the latest updates, please follow the guide here => https://docs.agenta.ai/self-host/migration/applying-schema-migration\n",
fg="yellow",
),
color=True,
Expand Down
Loading