Skip to content

Commit

Permalink
Merge "Fix ExcludeConstraint with func." into main
Browse files Browse the repository at this point in the history
  • Loading branch information
zzzeek authored and Gerrit Code Review committed May 3, 2023
2 parents 3336e37 + ac0785d commit 9f70100
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 37 deletions.
51 changes: 22 additions & 29 deletions alembic/ddl/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.elements import UnaryExpression
from sqlalchemy.sql.functions import FunctionElement
from sqlalchemy.types import NULLTYPE

from .base import alter_column
Expand Down Expand Up @@ -662,22 +663,15 @@ def _exclude_constraint(
("name", render._render_gen_name(autogen_context, constraint.name))
)

if alter:
def do_expr_where_opts():
args = [
repr(render._render_gen_name(autogen_context, constraint.name))
"(%s, %r)"
% (
_render_potential_column(sqltext, autogen_context),
opstring,
)
for sqltext, name, opstring in constraint._render_exprs # type:ignore[attr-defined] # noqa
]
if not has_batch:
args += [repr(render._ident(constraint.table.name))]
args.extend(
[
"(%s, %r)"
% (
_render_potential_column(sqltext, autogen_context),
opstring,
)
for sqltext, name, opstring in constraint._render_exprs # type:ignore[attr-defined] # noqa
]
)
if constraint.where is not None:
args.append(
"where=%s"
Expand All @@ -686,32 +680,29 @@ def _exclude_constraint(
)
)
args.extend(["%s=%r" % (k, v) for k, v in opts])
return args

if alter:
args = [
repr(render._render_gen_name(autogen_context, constraint.name))
]
if not has_batch:
args += [repr(render._ident(constraint.table.name))]
args.extend(do_expr_where_opts())
return "%(prefix)screate_exclude_constraint(%(args)s)" % {
"prefix": render._alembic_autogenerate_prefix(autogen_context),
"args": ", ".join(args),
}
else:
args = [
"(%s, %r)"
% (_render_potential_column(sqltext, autogen_context), opstring)
for sqltext, name, opstring in constraint._render_exprs
]
if constraint.where is not None:
args.append(
"where=%s"
% render._render_potential_expr(
constraint.where, autogen_context
)
)
args.extend(["%s=%r" % (k, v) for k, v in opts])
args = do_expr_where_opts()
return "%(prefix)sExcludeConstraint(%(args)s)" % {
"prefix": _postgresql_autogenerate_prefix(autogen_context),
"args": ", ".join(args),
}


def _render_potential_column(
value: Union[ColumnClause, Column, TextClause],
value: Union[ColumnClause, Column, TextClause, FunctionElement],
autogen_context: AutogenContext,
) -> str:
if isinstance(value, ColumnClause):
Expand All @@ -727,5 +718,7 @@ def _render_potential_column(
}
else:
return render._render_potential_expr(
value, autogen_context, wrap_in_text=isinstance(value, TextClause)
value,
autogen_context,
wrap_in_text=isinstance(value, (TextClause, FunctionElement)),
)
7 changes: 7 additions & 0 deletions docs/build/unreleased/1230.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.. change::
:tags: bug, postgresql
:tickets: 1230

Fix autogenerate issue with PostgreSQL :class:`.ExcludeConstraint`
that included sqlalchemy functions. The function text was previously
rendered as a plain string without surrounding with ``text()``.
41 changes: 33 additions & 8 deletions tests/test_postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,8 +1053,6 @@ def repr_type(typestring, object_, autogen_context):
)

def test_add_exclude_constraint(self):
from sqlalchemy.dialects.postgresql import ExcludeConstraint

autogen_context = self.autogen_context

m = MetaData()
Expand All @@ -1074,8 +1072,6 @@ def test_add_exclude_constraint(self):
)

def test_add_exclude_constraint_case_sensitive(self):
from sqlalchemy.dialects.postgresql import ExcludeConstraint

autogen_context = self.autogen_context

m = MetaData()
Expand All @@ -1100,8 +1096,6 @@ def test_add_exclude_constraint_case_sensitive(self):
)

def test_inline_exclude_constraint(self):
from sqlalchemy.dialects.postgresql import ExcludeConstraint

autogen_context = self.autogen_context

m = MetaData()
Expand Down Expand Up @@ -1130,8 +1124,6 @@ def test_inline_exclude_constraint(self):
)

def test_inline_exclude_constraint_case_sensitive(self):
from sqlalchemy.dialects.postgresql import ExcludeConstraint

autogen_context = self.autogen_context

m = MetaData()
Expand Down Expand Up @@ -1183,6 +1175,39 @@ def test_inline_exclude_constraint_literal_column(self):
"name='TExclID'))",
)

@config.requirements.sqlalchemy_2
def test_inline_exclude_constraint_fn(self):
"""test for #1230"""

autogen_context = self.autogen_context

effective_time = Column("effective_time", DateTime(timezone=True))
expiry_time = Column("expiry_time", DateTime(timezone=True))

m = MetaData()
t = Table(
"TTable",
m,
effective_time,
expiry_time,
ExcludeConstraint(
(func.tstzrange(effective_time, expiry_time), "&&"),
using="gist",
),
)

op_obj = ops.CreateTableOp.from_table(t)

eq_ignore_whitespace(
autogenerate.render_op_text(autogen_context, op_obj),
"op.create_table('TTable',sa.Column('effective_time', "
"sa.DateTime(timezone=True), nullable=True),"
"sa.Column('expiry_time', sa.DateTime(timezone=True), "
"nullable=True),postgresql.ExcludeConstraint("
"(sa.text('tstzrange(effective_time, expiry_time)'), "
"'&&'), using='gist'))",
)

@config.requirements.sqlalchemy_2
def test_inline_exclude_constraint_text(self):
"""test for #1184.
Expand Down

0 comments on commit 9f70100

Please sign in to comment.