diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 25746889..cc17ee75 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -21,7 +21,12 @@ from typing import Union from sqlalchemy import cast +from sqlalchemy import Column +from sqlalchemy import MetaData +from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import schema +from sqlalchemy import String +from sqlalchemy import Table from sqlalchemy import text from . import _autogen @@ -43,11 +48,9 @@ from sqlalchemy.sql import Executable from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import quoted_name - from sqlalchemy.sql.schema import Column from sqlalchemy.sql.schema import Constraint from sqlalchemy.sql.schema import ForeignKeyConstraint from sqlalchemy.sql.schema import Index - from sqlalchemy.sql.schema import Table from sqlalchemy.sql.schema import UniqueConstraint from sqlalchemy.sql.selectable import TableClause from sqlalchemy.sql.type_api import TypeEngine @@ -136,6 +139,32 @@ def static_output(self, text: str) -> None: self.output_buffer.write(text + "\n\n") self.output_buffer.flush() + def version_table_impl( + self, *, version_table, version_table_schema, version_table_pk, **kw + ) -> Table: + """create the Table object for the version_table. + + Provided as part of impl so that third party dialects can override + this. + + .. versionadded:: 1.13.4 + + """ + vt = Table( + version_table, + MetaData(), + Column("version_num", String(32), nullable=False), + schema=version_table_schema, + ) + if version_table_pk: + vt.append_constraint( + PrimaryKeyConstraint( + "version_num", name=f"{version_table}_pkc" + ) + ) + + return vt + def requires_recreate_in_batch( self, batch_op: BatchOperationsImpl ) -> bool: diff --git a/alembic/runtime/migration.py b/alembic/runtime/migration.py index 6cfe5e23..28f01c3b 100644 --- a/alembic/runtime/migration.py +++ b/alembic/runtime/migration.py @@ -24,10 +24,6 @@ from sqlalchemy import Column from sqlalchemy import literal_column -from sqlalchemy import MetaData -from sqlalchemy import PrimaryKeyConstraint -from sqlalchemy import String -from sqlalchemy import Table from sqlalchemy.engine import Engine from sqlalchemy.engine import url as sqla_url from sqlalchemy.engine.strategies import MockEngineStrategy @@ -36,6 +32,7 @@ from .. import util from ..util import sqla_compat from ..util.compat import EncodedIO +from ..util.sqla_compat import _select if TYPE_CHECKING: from sqlalchemy.engine import Dialect @@ -190,18 +187,6 @@ def __init__( self.version_table_schema = version_table_schema = opts.get( "version_table_schema", None ) - self._version = Table( - version_table, - MetaData(), - Column("version_num", String(32), nullable=False), - schema=version_table_schema, - ) - if opts.get("version_table_pk", True): - self._version.append_constraint( - PrimaryKeyConstraint( - "version_num", name="%s_pkc" % version_table - ) - ) self._start_from_rev: Optional[str] = opts.get("starting_rev") self.impl = ddl.DefaultImpl.get_by_dialect(dialect)( @@ -212,6 +197,13 @@ def __init__( self.output_buffer, opts, ) + + self._version = self.impl.version_table_impl( + version_table=version_table, + version_table_schema=version_table_schema, + version_table_pk=opts.get("version_table_pk", True), + ) + log.info("Context impl %s.", self.impl.__class__.__name__) if self.as_sql: log.info("Generating static SQL") @@ -540,7 +532,10 @@ def get_current_heads(self) -> Tuple[str, ...]: return () assert self.connection is not None return tuple( - row[0] for row in self.connection.execute(self._version.select()) + row[0] + for row in self.connection.execute( + _select(self._version.c.version_num) + ) ) def _ensure_version_table(self, purge: bool = False) -> None: diff --git a/tests/test_version_table.py b/tests/test_version_table.py index 5ad3c21d..4fa36014 100644 --- a/tests/test_version_table.py +++ b/tests/test_version_table.py @@ -1,10 +1,13 @@ from sqlalchemy import Column from sqlalchemy import inspect +from sqlalchemy import Integer from sqlalchemy import MetaData +from sqlalchemy import PrimaryKeyConstraint from sqlalchemy import String from sqlalchemy import Table from alembic import migration +from alembic.ddl import impl from alembic.testing import assert_raises from alembic.testing import assert_raises_message from alembic.testing import config @@ -373,3 +376,46 @@ def test_delete_multi_match_no_sane_rowcount(self): self.connection.dialect, "supports_sane_rowcount", False ): self.updater.update_to_step(_down("a", None, True)) + + +class CustomVersionTableTest(TestMigrationContext): + + class MyDialectImpl(impl.DefaultImpl): + + def version_table_impl( + self, + *, + version_table, + version_table_schema, + version_table_pk, + **kw, + ): + vt = Table( + version_table, + MetaData(), + Column("id", Integer, autoincrement=True), + Column("version_num", String(32), nullable=False), + schema=version_table_schema, + ) + if version_table_pk: + vt.append_constraint( + PrimaryKeyConstraint("id", name=f"{version_table}_pkc") + ) + return vt + + def setUp(self): + # nasty hack to get the sqlite dialect + # to use our custom dialect implementation + impl._impls["sqlite_bak"] = impl._impls["sqlite"] + impl._impls["sqlite"] = self.MyDialectImpl + super().setUp() + + def tearDown(self): + super().tearDown() + impl._impls["sqlite"] = impl._impls["sqlite_bak"] + + def test_custom_version_table(self): + context = migration.MigrationContext.configure( + dialect_name="sqlite", + ) + eq_(len(context._version.columns), 2)