From 8abba499cb0d88810f4c08e766c9042c3f52f34a Mon Sep 17 00:00:00 2001 From: Martijn Pieters Date: Thu, 12 Dec 2024 18:02:42 +0000 Subject: [PATCH] [typing] prefect.server.utilities.database This is a complete refactor of this module. - Move functions into the `sqlalchemy.func` namespace so they don't need to be imported everywhere - Re-use SQLAlchemy's Postgresql JSONB operators by providing SQLite equivalents - Provide a new function that calculates the difference between timestamps as seconds. This removes the need for many separate PostgreSQL vs SQLite queries. --- src/prefect/server/api/run_history.py | 6 +- src/prefect/server/api/ui/task_runs.py | 43 +- src/prefect/server/database/configurations.py | 18 +- src/prefect/server/database/interface.py | 3 - src/prefect/server/database/orm_models.py | 41 +- .../server/database/query_components.py | 22 +- src/prefect/server/events/counting.py | 13 +- src/prefect/server/events/filters.py | 13 +- .../server/models/concurrency_limits_v2.py | 42 +- src/prefect/server/models/deployments.py | 9 +- src/prefect/server/schemas/filters.py | 31 +- src/prefect/server/utilities/database.py | 784 +++++++++--------- tests/server/utilities/test_database.py | 379 +++++---- 13 files changed, 705 insertions(+), 699 deletions(-) diff --git a/src/prefect/server/api/run_history.py b/src/prefect/server/api/run_history.py index f70976f4a1e9..932a2076e4ea 100644 --- a/src/prefect/server/api/run_history.py +++ b/src/prefect/server/api/run_history.py @@ -102,12 +102,14 @@ async def run_history( # estimated run times only includes positive run times (to avoid any unexpected corner cases) "sum_estimated_run_time", sa.func.sum( - db.greatest(0, sa.extract("epoch", runs.c.estimated_run_time)) + sa.func.greatest( + 0, sa.extract("epoch", runs.c.estimated_run_time) + ) ), # estimated lateness is the sum of any positive start time deltas "sum_estimated_lateness", sa.func.sum( - db.greatest( + sa.func.greatest( 0, sa.extract("epoch", runs.c.estimated_start_time_delta) ) ), diff --git a/src/prefect/server/api/ui/task_runs.py b/src/prefect/server/api/ui/task_runs.py index b8f4bb778240..f39565499ba8 100644 --- a/src/prefect/server/api/ui/task_runs.py +++ b/src/prefect/server/api/ui/task_runs.py @@ -1,5 +1,4 @@ -import sys -from datetime import datetime, timezone +from datetime import datetime from typing import List, Optional, cast import pendulum @@ -37,37 +36,6 @@ def ser_model(self) -> dict: } -def _postgres_bucket_expression( - db: PrefectDBInterface, delta: pendulum.Duration, start_datetime: datetime -): - # asyncpg under Python 3.7 doesn't support timezone-aware datetimes for the EXTRACT - # function, so we will send it as a naive datetime in UTC - if sys.version_info < (3, 8): - start_datetime = start_datetime.astimezone(timezone.utc).replace(tzinfo=None) - - return sa.func.floor( - ( - sa.func.extract("epoch", db.TaskRun.start_time) - - sa.func.extract("epoch", start_datetime) - ) - / delta.total_seconds() - ).label("bucket") - - -def _sqlite_bucket_expression( - db: PrefectDBInterface, delta: pendulum.Duration, start_datetime: datetime -): - return sa.func.floor( - ( - ( - sa.func.strftime("%s", db.TaskRun.start_time) - - sa.func.strftime("%s", start_datetime) - ) - / delta.total_seconds() - ) - ).label("bucket") - - @router.post("/dashboard/counts") async def read_dashboard_task_run_counts( task_runs: schemas.filters.TaskRunFilter, @@ -121,11 +89,10 @@ async def read_dashboard_task_run_counts( start_time.microsecond, start_time.timezone, ) - bucket_expression = ( - _sqlite_bucket_expression(db, delta, start_datetime) - if db.dialect.name == "sqlite" - else _postgres_bucket_expression(db, delta, start_datetime) - ) + bucket_expression = sa.func.floor( + sa.func.date_diff_seconds(db.TaskRun.start_time, start_datetime) + / delta.total_seconds() + ).label("bucket") raw_counts = ( ( diff --git a/src/prefect/server/database/configurations.py b/src/prefect/server/database/configurations.py index 9721f16d7dfc..9eadbaab65bd 100644 --- a/src/prefect/server/database/configurations.py +++ b/src/prefect/server/database/configurations.py @@ -8,16 +8,14 @@ from typing import Dict, Hashable, Optional, Tuple import sqlalchemy as sa - -try: - from sqlalchemy import AdaptedConnection - from sqlalchemy.pool import ConnectionPoolEntry -except ImportError: - # SQLAlchemy 1.4 equivalents - from sqlalchemy.pool import _ConnectionFairy as AdaptedConnection - from sqlalchemy.pool.base import _ConnectionRecord as ConnectionPoolEntry - -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine +from sqlalchemy import AdaptedConnection +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + AsyncSessionTransaction, + create_async_engine, +) +from sqlalchemy.pool import ConnectionPoolEntry from typing_extensions import Literal from prefect.settings import ( diff --git a/src/prefect/server/database/interface.py b/src/prefect/server/database/interface.py index c1ab752fe30b..d4a427abdb50 100644 --- a/src/prefect/server/database/interface.py +++ b/src/prefect/server/database/interface.py @@ -359,9 +359,6 @@ def insert(self, model): """INSERTs a model into the database""" return self.queries.insert(model) - def greatest(self, *values): - return self.queries.greatest(*values) - def make_timestamp_intervals( self, start_time: datetime.datetime, diff --git a/src/prefect/server/database/orm_models.py b/src/prefect/server/database/orm_models.py index 49d89e6b6559..5728cac94fe0 100644 --- a/src/prefect/server/database/orm_models.py +++ b/src/prefect/server/database/orm_models.py @@ -1,17 +1,15 @@ import datetime import uuid from abc import ABC, abstractmethod +from collections.abc import Hashable, Iterable from pathlib import Path from typing import ( TYPE_CHECKING, Any, ClassVar, Dict, - Hashable, - Iterable, Optional, Union, - cast, ) import pendulum @@ -46,15 +44,12 @@ WorkQueueStatus, ) from prefect.server.utilities.database import ( + CAMEL_TO_SNAKE, JSON, UUID, GenerateUUID, Pydantic, Timestamp, - camel_to_snake, - date_diff, - interval_add, - now, ) from prefect.server.utilities.encryption import decrypt_fernet, encrypt_fernet from prefect.utilities.names import generate_slug @@ -117,7 +112,7 @@ def __tablename__(cls) -> str: into a snake-case table name. Override by providing an explicit `__tablename__` class property. """ - return camel_to_snake.sub("_", cls.__name__).lower() + return CAMEL_TO_SNAKE.sub("_", cls.__name__).lower() id: Mapped[uuid.UUID] = mapped_column( primary_key=True, @@ -126,7 +121,7 @@ def __tablename__(cls) -> str: ) created: Mapped[pendulum.DateTime] = mapped_column( - server_default=now(), default=lambda: pendulum.now("UTC") + server_default=sa.func.now(), default=lambda: pendulum.now("UTC") ) # onupdate is only called when statements are actually issued @@ -134,9 +129,9 @@ def __tablename__(cls) -> str: # will not be updated updated: Mapped[pendulum.DateTime] = mapped_column( index=True, - server_default=now(), + server_default=sa.func.now(), default=lambda: pendulum.now("UTC"), - onupdate=now(), + onupdate=sa.func.now(), server_onupdate=FetchedValue(), ) @@ -175,7 +170,7 @@ class FlowRunState(Base): sa.Enum(schemas.states.StateType, name="state_type"), index=True ) timestamp: Mapped[pendulum.DateTime] = mapped_column( - server_default=now(), default=lambda: pendulum.now("UTC") + server_default=sa.func.now(), default=lambda: pendulum.now("UTC") ) name: Mapped[str] = mapped_column(index=True) message: Mapped[Optional[str]] @@ -240,7 +235,7 @@ class TaskRunState(Base): sa.Enum(schemas.states.StateType, name="state_type"), index=True ) timestamp: Mapped[pendulum.DateTime] = mapped_column( - server_default=now(), default=lambda: pendulum.now("UTC") + server_default=sa.func.now(), default=lambda: pendulum.now("UTC") ) name: Mapped[str] = mapped_column(index=True) message: Mapped[Optional[str]] @@ -419,9 +414,9 @@ def _estimated_run_time_expression(cls) -> sa.Label[datetime.timedelta]: sa.case( ( cls.state_type == schemas.states.StateType.RUNNING, - interval_add( + sa.func.interval_add( cls.total_run_time, - date_diff(now(), cls.state_timestamp), + sa.func.date_diff(sa.func.now(), cls.state_timestamp), ), ), else_=cls.total_run_time, @@ -464,15 +459,15 @@ def _estimated_start_time_delta_expression( return sa.case( ( cls.start_time > cls.expected_start_time, - date_diff(cls.start_time, cls.expected_start_time), + sa.func.date_diff(cls.start_time, cls.expected_start_time), ), ( sa.and_( cls.start_time.is_(None), cls.state_type.not_in(schemas.states.TERMINAL_STATES), - cls.expected_start_time < now(), + cls.expected_start_time < sa.func.now(), ), - date_diff(now(), cls.expected_start_time), + sa.func.date_diff(sa.func.now(), cls.expected_start_time), ), else_=datetime.timedelta(0), ) @@ -1165,7 +1160,7 @@ class Worker(Base): name: Mapped[str] last_heartbeat_time: Mapped[pendulum.DateTime] = mapped_column( - server_default=now(), default=lambda: pendulum.now("UTC") + server_default=sa.func.now(), default=lambda: pendulum.now("UTC") ) heartbeat_interval_seconds: Mapped[Optional[int]] @@ -1195,7 +1190,7 @@ class Agent(Base): ) last_activity_time: Mapped[pendulum.DateTime] = mapped_column( - server_default=now(), default=lambda: pendulum.now("UTC") + server_default=sa.func.now(), default=lambda: pendulum.now("UTC") ) __table_args__: Any = (sa.UniqueConstraint("name"),) @@ -1277,11 +1272,11 @@ class Automation(Base): @classmethod def sort_expression(cls, value: AutomationSort) -> sa.ColumnExpressionArgument[Any]: """Return an expression used to sort Automations""" - sort_mapping = { + sort_mapping: dict[AutomationSort, sa.ColumnExpressionArgument[Any]] = { AutomationSort.CREATED_DESC: cls.created.desc(), AutomationSort.UPDATED_DESC: cls.updated.desc(), - AutomationSort.NAME_ASC: cast(sa.Column, cls.name).asc(), - AutomationSort.NAME_DESC: cast(sa.Column, cls.name).desc(), + AutomationSort.NAME_ASC: cls.name.asc(), + AutomationSort.NAME_DESC: cls.name.desc(), } return sort_mapping[value] diff --git a/src/prefect/server/database/query_components.py b/src/prefect/server/database/query_components.py index 6aaad5c76251..0b1ad0756b65 100644 --- a/src/prefect/server/database/query_components.py +++ b/src/prefect/server/database/query_components.py @@ -61,14 +61,6 @@ def _unique_key(self) -> Tuple[Hashable, ...]: def insert(self, obj) -> Union[postgresql.Insert, sqlite.Insert]: """dialect-specific insert statement""" - @abstractmethod - def greatest(self, *values): - """dialect-specific SqlAlchemy binding""" - - @abstractmethod - def least(self, *values): - """dialect-specific SqlAlchemy binding""" - # --- dialect-specific JSON handling @abstractproperty @@ -179,7 +171,7 @@ def get_scheduled_flow_runs_from_work_queues( concurrency_queues = ( sa.select( orm_models.WorkQueue.id, - self.greatest( + sa.func.greatest( 0, orm_models.WorkQueue.concurrency_limit - sa.func.count(orm_models.FlowRun.id), @@ -628,12 +620,6 @@ class AsyncPostgresQueryComponents(BaseQueryComponents): def insert(self, obj) -> postgresql.Insert: return postgresql.insert(obj) - def greatest(self, *values): - return sa.func.greatest(*values) - - def least(self, *values): - return sa.func.least(*values) - # --- Postgres-specific JSON handling @property @@ -984,12 +970,6 @@ class AioSqliteQueryComponents(BaseQueryComponents): def insert(self, obj) -> sqlite.Insert: return sqlite.insert(obj) - def greatest(self, *values): - return sa.func.max(*values) - - def least(self, *values): - return sa.func.min(*values) - # --- Sqlite-specific JSON handling @property diff --git a/src/prefect/server/events/counting.py b/src/prefect/server/events/counting.py index 72d6051c0d19..ec14ad7f70a0 100644 --- a/src/prefect/server/events/counting.py +++ b/src/prefect/server/events/counting.py @@ -8,7 +8,6 @@ from prefect.server.database.dependencies import provide_database_interface from prefect.server.database.interface import PrefectDBInterface -from prefect.server.utilities.database import json_extract from prefect.types import DateTime from prefect.utilities.collections import AutoEnum @@ -290,16 +289,8 @@ def _database_label_expression( return db.Event.event elif self == self.resource: return sa.func.coalesce( - json_extract( - db.Event.resource, - "prefect.resource.name", - wrap_quotes=db.dialect.name == "sqlite", - ), - json_extract( - db.Event.resource, - "prefect.name", - wrap_quotes=db.dialect.name == "sqlite", - ), + db.Event.resource["prefect.resource.name"].astext, + db.Event.resource["prefect.name"].astext, db.Event.resource_id, ) else: diff --git a/src/prefect/server/events/filters.py b/src/prefect/server/events/filters.py index d4f2453b09e4..8b87829f9614 100644 --- a/src/prefect/server/events/filters.py +++ b/src/prefect/server/events/filters.py @@ -15,7 +15,6 @@ PrefectFilterBaseModel, PrefectOperatorFilterBaseModel, ) -from prefect.server.utilities.database import json_extract from prefect.types import DateTime from prefect.utilities.collections import AutoEnum @@ -309,9 +308,7 @@ def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: for _, (label, values) in enumerate(labels.items()): label_ops = LabelOperations(values) - label_column = json_extract( - orm_models.EventResource.resource, label - ) + label_column = orm_models.EventResource.resource[label].astext # With negative labels, the resource _must_ have the label if label_ops.negative.simple or label_ops.negative.prefixes: @@ -404,9 +401,7 @@ def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: for _, (label, values) in enumerate(labels.items()): label_ops = LabelOperations(values) - label_column = json_extract( - orm_models.EventResource.resource, label - ) + label_column = orm_models.EventResource.resource[label].astext if label_ops.negative.simple or label_ops.negative.prefixes: label_filters.append(label_column.is_not(None)) @@ -518,9 +513,7 @@ def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]: for _, (label, values) in enumerate(labels.items()): label_ops = LabelOperations(values) - label_column = json_extract( - orm_models.EventResource.resource, label - ) + label_column = orm_models.EventResource.resource[label].astext if label_ops.negative.simple or label_ops.negative.prefixes: label_filters.append(label_column.is_not(None)) diff --git a/src/prefect/server/models/concurrency_limits_v2.py b/src/prefect/server/models/concurrency_limits_v2.py index 4bbf8ccbdbaa..1effd5d2327e 100644 --- a/src/prefect/server/models/concurrency_limits_v2.py +++ b/src/prefect/server/models/concurrency_limits_v2.py @@ -12,51 +12,25 @@ from prefect.server.database.interface import PrefectDBInterface -def greatest( - db: PrefectDBInterface, clamped_value: int, sql_value: ColumnElement -) -> ColumnElement: - # Determine the greatest value based on the database type - if db.dialect.name == "sqlite": - # `sa.func.greatest` isn't available in SQLite, fallback to using a - # `case` statement. - return sa.case((clamped_value > sql_value, clamped_value), else_=sql_value) - else: - return sa.func.greatest(clamped_value, sql_value) - - -def seconds_ago(db: PrefectDBInterface, field: ColumnElement) -> ColumnElement: - if db.dialect.name == "sqlite": - # `sa.func.timezone` isn't available in SQLite, fallback to using - # `julianday` . - return (sa.func.julianday("now") - sa.func.julianday(field)) * 86400.0 - else: - return sa.func.extract( - "epoch", - sa.func.timezone("UTC", sa.func.now()) - sa.func.timezone("UTC", field), - ).cast(sa.Float) - - -def active_slots_after_decay(db: PrefectDBInterface) -> ColumnElement[float]: +def active_slots_after_decay() -> ColumnElement[float]: # Active slots will decay at a rate of `slot_decay_per_second` per second. - return greatest( - db, + return sa.func.greatest( 0, orm_models.ConcurrencyLimitV2.active_slots - sa.func.floor( orm_models.ConcurrencyLimitV2.slot_decay_per_second - * seconds_ago(db, orm_models.ConcurrencyLimitV2.updated) + * sa.func.date_diff_seconds(orm_models.ConcurrencyLimitV2.updated) ), ) -def denied_slots_after_decay(db: PrefectDBInterface) -> ColumnElement[float]: +def denied_slots_after_decay() -> ColumnElement[float]: # Denied slots decay at a rate of `slot_decay_per_second` per second if it's # greater than 0, otherwise it decays at a rate of `avg_slot_occupancy_seconds`. # The combination of `denied_slots` and `slot_decay_per_second` / # `avg_slot_occupancy_seconds` is used to by the API to give a best guess at # when slots will be available again. - return greatest( - db, + return sa.func.greatest( 0, orm_models.ConcurrencyLimitV2.denied_slots - sa.func.floor( @@ -73,7 +47,7 @@ def denied_slots_after_decay(db: PrefectDBInterface) -> ColumnElement[float]: ) ), ) - * seconds_ago(db, orm_models.ConcurrencyLimitV2.updated) + * sa.func.date_diff_seconds(orm_models.ConcurrencyLimitV2.updated) ), ) @@ -232,8 +206,8 @@ async def bulk_increment_active_slots( concurrency_limit_ids: List[UUID], slots: int, ) -> bool: - active_slots = active_slots_after_decay(db) - denied_slots = denied_slots_after_decay(db) + active_slots = active_slots_after_decay() + denied_slots = denied_slots_after_decay() query = ( sa.update(orm_models.ConcurrencyLimitV2) diff --git a/src/prefect/server/models/deployments.py b/src/prefect/server/models/deployments.py index baa5e5899622..bfc0cc00442b 100644 --- a/src/prefect/server/models/deployments.py +++ b/src/prefect/server/models/deployments.py @@ -4,7 +4,7 @@ """ import datetime -from typing import Dict, Iterable, List, Optional, Sequence, TypeVar, cast +from typing import Any, Dict, Iterable, List, Optional, Sequence, TypeVar, cast from uuid import UUID, uuid4 import pendulum @@ -21,7 +21,7 @@ from prefect.server.exceptions import ObjectNotFoundError from prefect.server.models.events import deployment_status_event from prefect.server.schemas.statuses import DeploymentStatus -from prefect.server.utilities.database import json_contains +from prefect.server.utilities.database import JSON from prefect.settings import ( PREFECT_API_SERVICES_SCHEDULER_MAX_RUNS, PREFECT_API_SERVICES_SCHEDULER_MAX_SCHEDULED_TIME, @@ -795,7 +795,7 @@ async def check_work_queues_for_deployment( - Our database currently allows either "null" and empty lists as null values in filters, so we need to catch both cases with "or". - - `json_contains(A, B)` should be interpreted as "True if A + - `A.contains(B)` should be interpreted as "True if A contains B". Returns: @@ -805,6 +805,9 @@ async def check_work_queues_for_deployment( if not deployment: raise ObjectNotFoundError(f"Deployment with id {deployment_id} not found") + def json_contains(a: Any, b: Any) -> sa.ColumnElement[bool]: + return sa.type_coerce(a, type_=JSON).contains(sa.type_coerce(b, type_=JSON)) + query = ( select(orm_models.WorkQueue) # work queue tags are a subset of deployment tags diff --git a/src/prefect/server/schemas/filters.py b/src/prefect/server/schemas/filters.py index e5e1112c40a9..40ba0fdcf5b8 100644 --- a/src/prefect/server/schemas/filters.py +++ b/src/prefect/server/schemas/filters.py @@ -151,11 +151,9 @@ class FlowFilterTags(PrefectOperatorFilterBaseModel): ) def _get_filter_list(self) -> List: - from prefect.server.utilities.database import json_has_all_keys - filters = [] if self.all_ is not None: - filters.append(json_has_all_keys(orm_models.Flow.tags, self.all_)) + filters.append(orm_models.Flow.tags.has_all(self.all_)) if self.is_null_ is not None: filters.append( orm_models.Flow.tags == [] @@ -266,16 +264,11 @@ class FlowRunFilterTags(PrefectOperatorFilterBaseModel): ) def _get_filter_list(self) -> List: - from prefect.server.utilities.database import ( - json_has_all_keys, - json_has_any_key, - ) - filters = [] if self.all_ is not None: - filters.append(json_has_all_keys(orm_models.FlowRun.tags, self.all_)) + filters.append(orm_models.FlowRun.tags.has_all(self.all_)) if self.any_ is not None: - filters.append(json_has_any_key(orm_models.FlowRun.tags, self.any_)) + filters.append(orm_models.FlowRun.tags.has_any(self.any_)) if self.is_null_ is not None: filters.append( orm_models.FlowRun.tags == [] @@ -765,11 +758,9 @@ class TaskRunFilterTags(PrefectOperatorFilterBaseModel): ) def _get_filter_list(self) -> List: - from prefect.server.utilities.database import json_has_all_keys - filters = [] if self.all_ is not None: - filters.append(json_has_all_keys(orm_models.TaskRun.tags, self.all_)) + filters.append(orm_models.TaskRun.tags.has_all(self.all_)) if self.is_null_ is not None: filters.append( orm_models.TaskRun.tags == [] @@ -1083,11 +1074,9 @@ class DeploymentFilterTags(PrefectOperatorFilterBaseModel): ) def _get_filter_list(self) -> List: - from prefect.server.utilities.database import json_has_all_keys - filters = [] if self.all_ is not None: - filters.append(json_has_all_keys(orm_models.Deployment.tags, self.all_)) + filters.append(orm_models.Deployment.tags.has_all(self.all_)) if self.is_null_ is not None: filters.append( orm_models.Deployment.tags == [] @@ -1420,13 +1409,9 @@ class BlockSchemaFilterCapabilities(PrefectFilterBaseModel): ) def _get_filter_list(self) -> List: - from prefect.server.utilities.database import json_has_all_keys - filters = [] if self.all_ is not None: - filters.append( - json_has_all_keys(orm_models.BlockSchema.capabilities, self.all_) - ) + filters.append(orm_models.BlockSchema.capabilities.has_all(self.all_)) return filters @@ -2169,11 +2154,9 @@ class VariableFilterTags(PrefectOperatorFilterBaseModel): ) def _get_filter_list(self) -> List: - from prefect.server.utilities.database import json_has_all_keys - filters = [] if self.all_ is not None: - filters.append(json_has_all_keys(orm_models.Variable.tags, self.all_)) + filters.append(orm_models.Variable.tags.has_all(self.all_)) if self.is_null_ is not None: filters.append( orm_models.Variable.tags == [] diff --git a/src/prefect/server/utilities/database.py b/src/prefect/server/utilities/database.py index c9f38019860d..12875e4b1440 100644 --- a/src/prefect/server/utilities/database.py +++ b/src/prefect/server/utilities/database.py @@ -7,23 +7,46 @@ import datetime import json +import operator import re import uuid -from typing import List, Optional, Union +from functools import partial +from typing import ( + Any, + Callable, + Optional, + TypeVar, + Union, + overload, +) import pendulum import pydantic import sqlalchemy as sa from sqlalchemy.dialects import postgresql, sqlite +from sqlalchemy.dialects.postgresql.operators import ( + # these are all incompletely annotated + ASTEXT, # type: ignore + CONTAINS, # type: ignore + HAS_ALL, # type: ignore + HAS_ANY, # type: ignore +) from sqlalchemy.ext.compiler import compiles -from sqlalchemy.sql.functions import FunctionElement -from sqlalchemy.sql.sqltypes import BOOLEAN +from sqlalchemy.orm import Session +from sqlalchemy.sql import functions, schema +from sqlalchemy.sql.compiler import SQLCompiler +from sqlalchemy.sql.operators import OperatorType +from sqlalchemy.sql.visitors import replacement_traverse from sqlalchemy.types import CHAR, TypeDecorator, TypeEngine +from typing_extensions import TypeAlias -camel_to_snake = re.compile(r"(? str: """ Generates a random UUID in Postgres; requires the pgcrypto extension. """ @@ -44,7 +68,9 @@ def _generate_uuid_postgresql(element, compiler, **kwargs): @compiles(GenerateUUID, "sqlite") -def _generate_uuid_sqlite(element, compiler, **kwargs): +def generate_uuid_sqlite( + element: GenerateUUID, compiler: SQLCompiler, **kwargs: Any +) -> str: """ Generates a random UUID in other databases (SQLite) by concatenating bytes in a way that approximates a UUID hex representation. This is @@ -68,7 +94,7 @@ def _generate_uuid_sqlite(element, compiler, **kwargs): """ -class Timestamp(TypeDecorator): +class Timestamp(TypeDecorator[pendulum.DateTime]): """TypeDecorator that ensures that timestamps have a timezone. For SQLite, all timestamps are converted to UTC (since they are stored @@ -78,35 +104,23 @@ class Timestamp(TypeDecorator): impl = sa.TIMESTAMP(timezone=True) cache_ok = True - def load_dialect_impl(self, dialect): + def load_dialect_impl(self, dialect: sa.Dialect) -> TypeEngine[Any]: if dialect.name == "postgresql": return dialect.type_descriptor(postgresql.TIMESTAMP(timezone=True)) elif dialect.name == "sqlite": - return dialect.type_descriptor( - sqlite.DATETIME( - # SQLite is very particular about datetimes, and performs all comparisons - # as alphanumeric comparisons without regard for actual timestamp - # semantics or timezones. Therefore, it's important to have uniform - # and sortable datetime representations. The default is an ISO8601-compatible - # string with NO time zone and a space (" ") delimiter between the date - # and the time. The below settings can be used to add a "T" delimiter but - # will require all other sqlite datetimes to be set similarly, including - # the custom default value for datetime columns and any handwritten SQL - # formed with `strftime()`. - # - # store with "T" separator for time - # storage_format=( - # "%(year)04d-%(month)02d-%(day)02d" - # "T%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" - # ), - # handle ISO 8601 with "T" or " " as the time separator - # regexp=r"(\d+)-(\d+)-(\d+)[T ](\d+):(\d+):(\d+).(\d+)", - ) - ) + # see the sqlite.DATETIME docstring on the particulars of the storage + # format. Note that the sqlite implementations for timestamp and interval + # arithmetic below would require updating if a different format was to + # be configured here. + return dialect.type_descriptor(sqlite.DATETIME()) else: return dialect.type_descriptor(sa.TIMESTAMP(timezone=True)) - def process_bind_param(self, value, dialect): + def process_bind_param( + self, + value: Optional[pendulum.DateTime], + dialect: sa.Dialect, + ) -> Optional[pendulum.DateTime]: if value is None: return None else: @@ -117,13 +131,17 @@ def process_bind_param(self, value, dialect): else: return value - def process_result_value(self, value, dialect): + def process_result_value( + self, + value: Optional[Union[datetime.datetime, pendulum.DateTime]], + dialect: sa.Dialect, + ) -> Optional[pendulum.DateTime]: # retrieve timestamps in their native timezone (or UTC) if value is not None: - return pendulum.instance(value).in_timezone("utc") + return pendulum.instance(value).in_timezone("UTC") -class UUID(TypeDecorator): +class UUID(TypeDecorator[uuid.UUID]): """ Platform-independent UUID type. @@ -135,13 +153,15 @@ class UUID(TypeDecorator): impl = TypeEngine cache_ok = True - def load_dialect_impl(self, dialect): + def load_dialect_impl(self, dialect: sa.Dialect) -> TypeEngine[Any]: if dialect.name == "postgresql": return dialect.type_descriptor(postgresql.UUID()) else: return dialect.type_descriptor(CHAR(36)) - def process_bind_param(self, value, dialect): + def process_bind_param( + self, value: Optional[Union[str, uuid.UUID]], dialect: sa.Dialect + ) -> Optional[str]: if value is None: return None elif dialect.name == "postgresql": @@ -151,7 +171,9 @@ def process_bind_param(self, value, dialect): else: return str(uuid.UUID(value)) - def process_result_value(self, value, dialect): + def process_result_value( + self, value: Optional[Union[str, uuid.UUID]], dialect: sa.Dialect + ) -> Optional[uuid.UUID]: if value is None: return value else: @@ -160,7 +182,7 @@ def process_result_value(self, value, dialect): return value -class JSON(TypeDecorator): +class JSON(TypeDecorator[Any]): """ JSON type that returns SQLAlchemy's dialect-specific JSON types, where possible. Uses generic JSON otherwise. @@ -172,7 +194,7 @@ class JSON(TypeDecorator): impl = postgresql.JSONB cache_ok = True - def load_dialect_impl(self, dialect): + def load_dialect_impl(self, dialect: sa.Dialect) -> TypeEngine[Any]: if dialect.name == "postgresql": return dialect.type_descriptor(postgresql.JSONB(none_as_null=True)) elif dialect.name == "sqlite": @@ -180,7 +202,9 @@ def load_dialect_impl(self, dialect): else: return dialect.type_descriptor(sa.JSON(none_as_null=True)) - def process_bind_param(self, value, dialect): + def process_bind_param( + self, value: Optional[Any], dialect: sa.Dialect + ) -> Optional[Any]: """Prepares the given value to be used as a JSON field in a parameter binding""" if not value: return value @@ -199,7 +223,7 @@ def process_bind_param(self, value, dialect): return json.loads(json.dumps(value), parse_constant=lambda c: None) -class Pydantic(TypeDecorator): +class Pydantic(TypeDecorator[T]): """ A pydantic type that converts inserted parameters to json and converts read values to the pydantic type. @@ -208,13 +232,38 @@ class Pydantic(TypeDecorator): impl = JSON cache_ok = True - def __init__(self, pydantic_type, sa_column_type=None): + @overload + def __init__( + self, + pydantic_type: type[T], + sa_column_type: Optional[Union[type[TypeEngine[Any]], TypeEngine[Any]]] = None, + ) -> None: + ... + + # This overload is needed to allow for typing special forms (e.g. + # Union[...], etc.) as these can't be married with `type[...]`. Also see + # https://github.com/pydantic/pydantic/pull/8923 + @overload + def __init__( + self: "Pydantic[Any]", + pydantic_type: Any, + sa_column_type: Optional[Union[type[TypeEngine[Any]], TypeEngine[Any]]] = None, + ) -> None: + ... + + def __init__( + self, + pydantic_type: type[T], + sa_column_type: Optional[Union[type[TypeEngine[Any]], TypeEngine[Any]]] = None, + ) -> None: super().__init__() self._pydantic_type = pydantic_type if sa_column_type is not None: self.impl = sa_column_type - def process_bind_param(self, value, dialect) -> Optional[str]: + def process_bind_param( + self, value: Optional[T], dialect: sa.Dialect + ) -> Optional[str]: if value is None: return None @@ -229,25 +278,18 @@ def process_bind_param(self, value, dialect) -> Optional[str]: # it into a python-native form. return adapter.dump_python(value, mode="json") - def process_result_value(self, value, dialect): + def process_result_value( + self, value: Optional[Any], dialect: sa.Dialect + ) -> Optional[T]: if value is not None: # load the json object into a fully hydrated typed object return pydantic.TypeAdapter(self._pydantic_type).validate_python(value) -class now(FunctionElement): - """ - Platform-independent "now" generator. - """ - - type = Timestamp() - name = "now" - # see https://docs.sqlalchemy.org/en/14/core/compiler.html#enabling-caching-support-for-custom-constructs - inherit_cache = True - - -@compiles(now, "sqlite") -def _current_timestamp_sqlite(element, compiler, **kwargs): +@compiles(functions.now, "sqlite") +def current_timestamp_sqlite( + element: functions.now, compiler: SQLCompiler, **kwargs: Any +) -> str: """ Generates the current timestamp for SQLite @@ -264,386 +306,380 @@ def _current_timestamp_sqlite(element, compiler, **kwargs): return "strftime('%Y-%m-%d %H:%M:%f000', 'now')" -@compiles(now) -def _current_timestamp(element, compiler, **kwargs): - """ - Generates the current timestamp in standard SQL - """ - return "CURRENT_TIMESTAMP" +# Platform-independent datetime and timedelta arithmetic functions -class date_add(FunctionElement): - """ - Platform-independent way to add a date and an interval. - """ +class date_add(functions.GenericFunction[pendulum.DateTime]): + """Platform-independent way to add a timestamp and an interval""" type = Timestamp() - name = "date_add" - # see https://docs.sqlalchemy.org/en/14/core/compiler.html#enabling-caching-support-for-custom-constructs - inherit_cache = False - - def __init__(self, dt, interval): - self.dt = dt - self.interval = interval - super().__init__() - - -@compiles(date_add, "postgresql") -@compiles(date_add) -def _date_add_postgresql(element, compiler, **kwargs): - return compiler.process( - sa.cast(element.dt, Timestamp()) + sa.cast(element.interval, sa.Interval()) - ) - - -@compiles(date_add, "sqlite") -def _date_add_sqlite(element, compiler, **kwargs): - """ - In sqlite, we represent intervals as datetimes after the epoch, following - SQLAlchemy convention for the Interval() type. - """ + inherit_cache = True - dt = element.dt - if isinstance(dt, datetime.datetime): - dt = str(dt) - - interval = element.interval - if isinstance(interval, datetime.timedelta): - interval = str(pendulum.datetime(1970, 1, 1) + interval) - - return compiler.process( - # convert to date - sa.func.strftime( - "%Y-%m-%d %H:%M:%f000", - sa.func.julianday(dt) - + ( - # convert interval to fractional days after the epoch - sa.func.julianday(interval) - 2440587.5 - ), + def __init__( + self, + dt: _SQLExpressionOrLiteral[datetime.datetime], + interval: _SQLExpressionOrLiteral[datetime.timedelta], + **kwargs: Any, + ): + super().__init__( + sa.type_coerce(dt, Timestamp()), + sa.type_coerce(interval, sa.Interval()), + **kwargs, ) - ) -class interval_add(FunctionElement): - """ - Platform-independent way to add two intervals. - """ +class interval_add(functions.GenericFunction[datetime.timedelta]): + """Platform-independent way to add two intervals.""" type = sa.Interval() - name = "interval_add" - # see https://docs.sqlalchemy.org/en/14/core/compiler.html#enabling-caching-support-for-custom-constructs - inherit_cache = False - - def __init__(self, i1, i2): - self.i1 = i1 - self.i2 = i2 - super().__init__() - - -@compiles(interval_add, "postgresql") -@compiles(interval_add) -def _interval_add_postgresql(element, compiler, **kwargs): - return compiler.process( - sa.cast(element.i1, sa.Interval()) + sa.cast(element.i2, sa.Interval()) - ) - - -@compiles(interval_add, "sqlite") -def _interval_add_sqlite(element, compiler, **kwargs): - """ - In sqlite, we represent intervals as datetimes after the epoch, following - SQLAlchemy convention for the Interval() type. - - Therefore the sum of two intervals is - - (i1 - epoch) + (i2 - epoch) = i1 + i2 - epoch - """ - - i1 = element.i1 - if isinstance(i1, datetime.timedelta): - i1 = str(pendulum.datetime(1970, 1, 1) + i1) - - i2 = element.i2 - if isinstance(i2, datetime.timedelta): - i2 = str(pendulum.datetime(1970, 1, 1) + i2) + inherit_cache = True - return compiler.process( - # convert to date - sa.func.strftime( - "%Y-%m-%d %H:%M:%f000", - sa.func.julianday(i1) + sa.func.julianday(i2) - 2440587.5, + def __init__( + self, + i1: _SQLExpressionOrLiteral[datetime.timedelta], + i2: _SQLExpressionOrLiteral[datetime.timedelta], + **kwargs: Any, + ): + super().__init__( + sa.type_coerce(i1, sa.Interval()), + sa.type_coerce(i2, sa.Interval()), + **kwargs, ) - ) -class date_diff(FunctionElement): - """ - Platform-independent difference of dates. Computes d1 - d2. - """ +class date_diff(functions.GenericFunction[datetime.timedelta]): + """Platform-independent difference of two timestamps. Computes d1 - d2.""" type = sa.Interval() - name = "date_diff" - # see https://docs.sqlalchemy.org/en/14/core/compiler.html#enabling-caching-support-for-custom-constructs - inherit_cache = False - - def __init__(self, d1, d2): - self.d1 = d1 - self.d2 = d2 - super().__init__() - + inherit_cache = True -@compiles(date_diff, "postgresql") -@compiles(date_diff) -def _date_diff_postgresql(element, compiler, **kwargs): - return compiler.process( - sa.cast(element.d1, Timestamp()) - sa.cast(element.d2, Timestamp()) - ) + def __init__( + self, + d1: _SQLExpressionOrLiteral[datetime.datetime], + d2: _SQLExpressionOrLiteral[datetime.datetime], + **kwargs: Any, + ) -> None: + super().__init__( + sa.type_coerce(d1, Timestamp()), sa.type_coerce(d2, Timestamp()), **kwargs + ) -@compiles(date_diff, "sqlite") -def _date_diff_sqlite(element, compiler, **kwargs): - """ - In sqlite, we represent intervals as datetimes after the epoch, following - SQLAlchemy convention for the Interval() type. - """ - d1 = element.d1 - if isinstance(d1, datetime.datetime): - d1 = str(d1) - - d2 = element.d2 - if isinstance(d2, datetime.datetime): - d2 = str(d2) - - return compiler.process( - # convert to date - sa.func.strftime( - "%Y-%m-%d %H:%M:%f000", - # the epoch in julian days - 2440587.5 - # plus the date difference in julian days - + sa.func.julianday(d1) - - sa.func.julianday(d2), - ) - ) +class date_diff_seconds(functions.GenericFunction[float]): + """Platform-independent calculation of the number of seconds between two timestamps or from 'now'""" + type = sa.REAL + inherit_cache = True -class json_contains(FunctionElement): - """ - Platform independent json_contains operator, tests if the - `left` expression contains the `right` expression. + def __init__( + self, + dt1: _SQLExpressionOrLiteral[datetime.datetime], + dt2: Optional[_SQLExpressionOrLiteral[datetime.datetime]] = None, + **kwargs: Any, + ) -> None: + args = (sa.type_coerce(dt1, Timestamp()),) + if dt2 is not None: + args = (*args, sa.type_coerce(dt2, Timestamp())) + super().__init__(*args, **kwargs) - On postgres this is equivalent to the @> containment operator. - https://www.postgresql.org/docs/current/functions-json.html - """ - type = BOOLEAN - name = "json_contains" - # see https://docs.sqlalchemy.org/en/14/core/compiler.html#enabling-caching-support-for-custom-constructs - inherit_cache = False +# timestamp and interval arithmetic implementations for PostgreSQL - def __init__(self, left, right): - self.left = left - self.right = right - super().__init__() +@compiles(date_add, "postgresql") +@compiles(interval_add, "postgresql") +@compiles(date_diff, "postgresql") +def datetime_or_interval_add_postgresql( + element: Union[date_add, interval_add, date_diff], + compiler: SQLCompiler, + **kwargs: Any, +) -> str: + match element: + case date_add() | interval_add(): + operation = operator.add + case date_diff(): + operation = operator.sub + return compiler.process(operation(*element.clauses)) + + +@compiles(date_diff_seconds, "postgresql") +def date_diff_seconds_postgresql( + element: date_diff_seconds, compiler: SQLCompiler, **kwargs: Any +) -> str: + # either 1 or 2 timestamps; if 1, subtract from 'now' + dts: list[sa.ColumnElement[datetime.datetime]] = list(element.clauses) + if len(dts) == 1: + dts = [sa.func.now(), *dts] + as_utc = (sa.func.timezone("UTC", dt) for dt in dts) + return compiler.process(sa.func.extract("epoch", operator.sub(*as_utc))) + + +# SQLite implementations for the Timestamp and Interval arithmetic functions. +# +# The following concepts are at play here: +# +# - By default, SQLAlchemy stores Timestamp values formatted as ISO8601 strings +# (with a space between the date and the time parts), with microsecond precision. +# - SQLAlchemy stores Interval values as a Timestamp, offset from the UNIX epoch. +# - SQLite processes timestamp values with _at most_ millisecond precision, and +# only if you use the `juliandate()` function or the 'subsec' modifier for +# the `unixepoch()` function (the latter requires SQLite 3.42.0, released +# 2023-05-16) +# +# In order for arthmetic to work well, you need to convert timestamps to +# fractional [Julian day numbers][JDN], and intervals to a real number +# by subtracting the UNIX epoch from their Julian day number representation. +# +# Once the result has been computed, the result needs to be converted back +# to an ISO8601 formatted string including any milliseconds. For an +# interval result, that means adding the UNIX epoch offset to it first. +# +# [JDN]: https://en.wikipedia.org/wiki/Julian_day + +# The UNIX epoch, 1970-01-01T00:00:00Z, expressed as a fractional Julian day +# number. +SQLITE_EPOCH_JULIANDAYNUMBER = sa.literal(2440587.5, literal_execute=True) +SECONDS_PER_DAY = sa.literal(24 * 60 * 60.0, literal_execute=True) + +# SQLite strftime() format to output ISO8601 date and time with milliseconds +SQLITE_DATETIME_FORMAT = sa.literal("%F %R:%f000", literal_execute=True) + + +_sqlite_strftime = partial(sa.func.strftime, SQLITE_DATETIME_FORMAT) +"""Format SQLite timestamp to a SQLAlchemy-compatible string""" + + +def _sqlite_strfinterval( + offset: sa.ColumnElement[float], +) -> sa.ColumnElement[datetime.datetime]: + """Format interval offset to a SQLAlchemy-compatible string""" + return _sqlite_strftime(SQLITE_EPOCH_JULIANDAYNUMBER + offset) + + +def _sqlite_interval_offset( + interval: _SQLExpressionOrLiteral[datetime.timedelta], +) -> sa.ColumnElement[float]: + """Convert interval value to a fraction Julian day number REAL offset from UNIX epoch""" + return sa.func.julianday(interval) - SQLITE_EPOCH_JULIANDAYNUMBER -@compiles(json_contains, "postgresql") -@compiles(json_contains) -def _json_contains_postgresql(element, compiler, **kwargs): - return compiler.process( - sa.type_coerce(element.left, postgresql.JSONB).contains( - sa.type_coerce(element.right, postgresql.JSONB) - ), - **kwargs, - ) +@compiles(date_add, "sqlite") +def date_add_sqlite(element: date_add, compiler: SQLCompiler, **kwargs: Any) -> str: + dt, interval = element.clauses + jdn, offset = sa.func.julianday(dt), _sqlite_interval_offset(interval) + # dt + interval, as fractional Julian day number values + return compiler.process(_sqlite_strftime(jdn + offset)) -def _json_contains_sqlite_fn(left, right, compiler, **kwargs): - # if the value is literal, convert to a JSON string - if isinstance(left, (list, dict, tuple, str)): - left = json.dumps(left) - # if the value is literal, convert to a JSON string - if isinstance(right, (list, dict, tuple, str)): - right = json.dumps(right) +@compiles(interval_add, "sqlite") +def interval_add_sqlite( + element: interval_add, compiler: SQLCompiler, **kwargs: Any +) -> str: + offsets = map(_sqlite_interval_offset, element.clauses) + # interval + interval, as fractional Julian day number values + return compiler.process(_sqlite_strfinterval(operator.add(*offsets))) - json_each_left = sa.func.json_each(left).alias("left") - json_each_right = sa.func.json_each(right).alias("right") - # compute equality by counting the number of distinct matches between - # the left items and the right items (e.g. the number of rows resulting from a join) - # and seeing if it exceeds the number of distinct keys in the right operand - # - # note that using distinct emulates postgres behavior to disregard duplicates - distinct_matches = ( - sa.select(sa.func.count(sa.distinct(sa.literal_column("left.value")))) - .select_from(json_each_left) - .join( - json_each_right, - sa.literal_column("left.value") == sa.literal_column("right.value"), - ) - .scalar_subquery() +@compiles(date_diff, "sqlite") +def date_diff_sqlite(element: date_diff, compiler: SQLCompiler, **kwargs: Any) -> str: + jdns = map(sa.func.julianday, element.clauses) + # timestamp - timestamp, as fractional Julian day number values + return compiler.process(_sqlite_strfinterval(operator.sub(*jdns))) + + +@compiles(date_diff_seconds, "sqlite") +def date_diff_seconds_sqlite( + element: date_diff_seconds, compiler: SQLCompiler, **kwargs: Any +) -> str: + # either 1 or 2 timestamps; if 1, subtract from 'now' + dts: list[sa.ColumnElement[Any]] = list(element.clauses) + if len(dts) == 1: + dts = [sa.literal("now", literal_execute=True), *dts] + as_jdn = (sa.func.julianday(dt) for dt in dts) + # timestamp - timestamp, as a fractional Julian day number, times the number of seconds in a day + return compiler.process(operator.sub(*as_jdn) * SECONDS_PER_DAY) + + +# PostgreSQL JSON(B) Comparator operators ported to SQLite + + +def _is_literal(elem: Any) -> bool: + """Element is not a SQLAlchemy SQL construct""" + # Copied from sqlalchemy.sql.coercions._is_literal + return not ( + isinstance(elem, (sa.Visitable, schema.SchemaEventTarget)) + or hasattr(elem, "__clause_element__") ) - distinct_keys = ( - sa.select(sa.func.count(sa.distinct(sa.literal_column("right.value")))) - .select_from(json_each_right) - .scalar_subquery() - ) - return compiler.process(distinct_matches >= distinct_keys) +def _postgresql_array_to_json_array( + elem: sa.ColumnElement[Any], +) -> sa.ColumnElement[Any]: + """Replace any postgresql arary() literals with a json_array() function call + Because an _empty_ array leads to a PostgreSQL error, array() is often + coupled with a cast(); this function replaces arrays with or without + such a cast. -@compiles(json_contains, "sqlite") -def _json_contains_sqlite(element, compiler, **kwargs): - return _json_contains_sqlite_fn(element.left, element.right, compiler, **kwargs) + This allows us to map the postgres JSONB.has_any / JSONB.has_all operand to + SQLite. + Returns the updated expression. -class json_extract(FunctionElement): """ - Platform independent json_extract operator, extracts a value from a JSON - field via key. - On postgres this is equivalent to the ->> operator. - https://www.postgresql.org/docs/current/functions-json.html - """ + def _replacer(element: Any, **kwargs: Any) -> Optional[Any]: + match element: + # either array(...), or cast(array(...), ...) + case ( + sa.Cast(clause=postgresql.array(clauses=clauses)) + | postgresql.array(clauses=clauses) + ): + return sa.func.json_array(*clauses) + case _: + return None - type = sa.Text() - name = "json_extract" - # see https://docs.sqlalchemy.org/en/14/core/compiler.html#enabling-caching-support-for-custom-constructs - inherit_cache = False + opts: dict[str, Any] = {} + return replacement_traverse(elem, opts, _replacer) - def __init__(self, column: sa.Column, path: str, wrap_quotes: bool = False): - self.column = column - self.path = path - self.wrap_quotes = wrap_quotes - super().__init__() +def _json_each(elem: sa.ColumnElement[Any]) -> sa.TableValuedAlias: + """SQLite json_each() table-valued consruct -@compiles(json_extract, "postgresql") -@compiles(json_extract) -def _json_extract_postgresql(element, compiler, **kwargs): - return "%s ->> '%s'" % (compiler.process(element.column, **kwargs), element.path) + Configures a SQLAlchemy table-valued object with the minimum + column definitions and correct configuration. + """ + return sa.func.json_each(elem).table_valued("key", "value", joins_implicitly=True) -@compiles(json_extract, "sqlite") -def _json_extract_sqlite(element, compiler, **kwargs): - path = element.path.replace("'", "''") # escape single quotes for JSON path - if element.wrap_quotes: - path = f'"{path}"' - return "JSON_EXTRACT(%s, '$.%s')" % ( - compiler.process(element.column, **kwargs), - path, - ) +# sqlite JSON operator implementations. -class json_has_any_key(FunctionElement): - """ - Platform independent json_has_any_key operator. - On postgres this is equivalent to the ?| existence operator. - https://www.postgresql.org/docs/current/functions-json.html - """ +def _sqlite_json_astext( + element: sa.BinaryExpression[Any], +) -> sa.BinaryExpression[Any]: + """Map postgres JSON.astext / JSONB.astext (`->>`) to sqlite json_extract() - type = BOOLEAN - name = "json_has_any_key" - # see https://docs.sqlalchemy.org/en/14/core/compiler.html#enabling-caching-support-for-custom-constructs - inherit_cache = False + Without the `as_string()` call, SQLAlchemy outputs json_quote(json_extract(...)) - def __init__(self, json_expr, values: List): - self.json_expr = json_expr - if not all(isinstance(v, str) for v in values): - raise ValueError("json_has_any_key values must be strings") - self.values = values - super().__init__() + """ + return element.left[element.right].as_string() -@compiles(json_has_any_key, "postgresql") -@compiles(json_has_any_key) -def _json_has_any_key_postgresql(element, compiler, **kwargs): - values_array = postgresql.array(element.values) - # if the array is empty, postgres requires a type annotation - if not element.values: - values_array = sa.cast(values_array, postgresql.ARRAY(sa.String)) +def _sqlite_json_contains( + element: sa.BinaryExpression[bool], +) -> sa.ColumnElement[bool]: + """Map JSONB.contains() and JSONB.has_all() to a SQLite expression""" + # left can be a JSON value as a (Python) literal, or a SQL expression for a JSON value + # right can be a SQLA postgresql.array() literal or a SQL expression for a + # JSON array (for .has_all()) or it can be a JSON value as a (Python) + # literal or a SQL expression for a JSON object (for .contains()) + left, right = element.left, element.right - return compiler.process( - sa.type_coerce(element.json_expr, postgresql.JSONB).has_any(values_array), - **kwargs, - ) + # if either top-level operand is literal, convert to a JSON bindparam + if _is_literal(left): + left = sa.bindparam("haystack", left, expanding=True, type_=JSON) + if _is_literal(right): + right = sa.bindparam("needles", right, expanding=True, type_=JSON) + else: + # convert the array() literal used in JSONB.has_all() to a JSON array. + right = _postgresql_array_to_json_array(right) + jleft, jright = _json_each(left), _json_each(right) -@compiles(json_has_any_key, "sqlite") -def _json_has_any_key_sqlite(element, compiler, **kwargs): - # attempt to match any of the provided values at least once - json_each = sa.func.json_each(element.json_expr).alias("json_each") - return compiler.process( - sa.select(1) - .select_from(json_each) - .where( - sa.literal_column("json_each.value").in_( - # manually set the bindparam key because the default will - # include the `.` from the literal column name and sqlite params - # must be alphanumeric. `unique=True` automatically suffixes the bindparam - # if there are overlaps. - sa.bindparam(key="json_each_values", value=element.values, unique=True) - ) - ) - .exists(), - **kwargs, + # compute equality by counting the number of distinct matches between the + # left items and the right items (e.g. the number of rows resulting from a + # join) and seeing if it exceeds the number of distinct keys in the right + # operand. + # + # note that using distinct emulates postgres behavior to disregard duplicates + distinct_matches = ( + sa.select(sa.func.count(sa.distinct(jleft.c.value))) + .join(jright, onclause=jleft.c.value == jright.c.value) + .scalar_subquery() ) + distinct_keys = sa.select( + sa.func.count(sa.distinct(jright.c.value)) + ).scalar_subquery() + return distinct_matches >= distinct_keys -class json_has_all_keys(FunctionElement): - """Platform independent json_has_all_keys operator. - - On postgres this is equivalent to the ?& existence operator. - https://www.postgresql.org/docs/current/functions-json.html - """ - - type = BOOLEAN - name = "json_has_all_keys" - # see https://docs.sqlalchemy.org/en/14/core/compiler.html#enabling-caching-support-for-custom-constructs - inherit_cache = False - - def __init__(self, json_expr, values: List): - self.json_expr = json_expr - if isinstance(values, list) and not all(isinstance(v, str) for v in values): - raise ValueError( - "json_has_all_key values must be strings if provided as a literal list" - ) - self.values = values - super().__init__() +def _sqlite_json_has_any(element: sa.BinaryExpression[bool]) -> sa.ColumnElement[bool]: + """Map JSONB.has_any() to a SQLite expression""" + # left can be a JSON value as a (Python) literal, or a SQL expression for a JSON value + # right can be a SQLA postgresql.array() literal or a SQL expression for a JSON array + left, right = element.left, element.right -@compiles(json_has_all_keys, "postgresql") -@compiles(json_has_all_keys) -def _json_has_all_keys_postgresql(element, compiler, **kwargs): - values_array = postgresql.array(element.values) + # convert the array() literal used in JSONB.has_all() to a JSON array. + right = _postgresql_array_to_json_array(right) - # if the array is empty, postgres requires a type annotation - if not element.values: - values_array = sa.cast(values_array, postgresql.ARRAY(sa.String)) + jleft, jright = _json_each(left), _json_each(right) - return compiler.process( - sa.type_coerce(element.json_expr, postgresql.JSONB).has_all(values_array), - **kwargs, + # deal with "json array ?| [value, ...]"" vs "json object ?| [key, ...]" tests + # if left is a JSON object, match keys, else match values; the latter works + # for arrays and all JSON scalar types + json_object = sa.literal("object", literal_execute=True) + left_elem = sa.case( + (sa.func.json_type(element.left) == json_object, jleft.c.key), + else_=jleft.c.value, ) + return sa.exists().where(left_elem == jright.c.value) + + +# Map of SQLA postgresql JSON/JSONB operators and a function to rewrite +# a BinaryExpression with such an operator to their SQLite equivalent. +_sqlite_json_operator_map: dict[ + OperatorType, Callable[[sa.BinaryExpression[Any]], sa.ColumnElement[Any]] +] = { + ASTEXT: _sqlite_json_astext, + CONTAINS: _sqlite_json_contains, + HAS_ALL: _sqlite_json_contains, # "has all" is equivalent to "contains" + HAS_ANY: _sqlite_json_has_any, +} + + +@compiles(sa.BinaryExpression, "sqlite") +def sqlite_json_operators( + element: sa.BinaryExpression[Any], + compiler: SQLCompiler, + override_operator: Optional[OperatorType] = None, + **kwargs: Any, +) -> str: + """Intercept the PostgreSQL-only JSON / JSONB operators and translate them to SQLite""" + operator = override_operator or element.operator + if (handler := _sqlite_json_operator_map.get(operator)) is not None: + return compiler.process(handler(element), **kwargs) + # ignore reason: SQLA compilation hooks are not as well covered with type annotations + return compiler.visit_binary(element, override_operator=operator, **kwargs) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType] + + +class greatest(functions.ReturnTypeFromArgs[T]): + name = "greatest" + inherit_cache = True -@compiles(json_has_all_keys, "sqlite") -def _json_has_all_keys_sqlite(element, compiler, **kwargs): - # "has all keys" is equivalent to "json contains" - return _json_contains_sqlite_fn( - left=element.json_expr, - right=element.values, - compiler=compiler, - **kwargs, - ) + +@compiles(greatest, "sqlite") +def sqlite_greatest_as_max( + element: greatest[Any], compiler: SQLCompiler, **kw: Any +) -> str: + # TODO: SQLite MAX() is very close to PostgreSQL GREATEST(), *except* when + # it comes to nulls: SQLite MAX() returns NULL if _any_ clause is NULL, + # whereas PostgreSQL GREATEST() only returns NULL if _all_ clauses are NULL. + # + # A work-around is to use MAX() as an aggregate function instead, in a + # subquery. This, however, would probably require a VALUES-like construct + # that SQLA doesn't currently support for SQLite. You can [provide + # compilation hooks for + # this](https://github.com/sqlalchemy/sqlalchemy/issues/7228#issuecomment-1746837960) + # but this would only be worth it if sa.func.greatest() starts being used on + # values that include NULLs. Up until the time of this comment this hasn't + # been an issue. + return compiler.process(sa.func.max(*element.clauses), **kw) -def get_dialect( - obj: Union[str, sa.orm.Session, sa.engine.Engine], -) -> sa.engine.Dialect: +def get_dialect(obj: Union[str, Session, sa.Engine]) -> type[sa.Dialect]: """ Get the dialect of a session, engine, or connection url. @@ -662,9 +698,11 @@ def get_dialect( print("Using Postgres!") ``` """ - if isinstance(obj, sa.orm.Session): - url = obj.bind.url - elif isinstance(obj, sa.engine.Engine): + if isinstance(obj, Session): + assert obj.bind is not None + obj = obj.bind.engine if isinstance(obj.bind, sa.Connection) else obj.bind + + if isinstance(obj, sa.engine.Engine): url = obj.url else: url = sa.engine.url.make_url(obj) diff --git a/tests/server/utilities/test_database.py b/tests/server/utilities/test_database.py index ab0f2731d300..e3e54295f905 100644 --- a/tests/server/utilities/test_database.py +++ b/tests/server/utilities/test_database.py @@ -2,34 +2,25 @@ import enum import math import sqlite3 -from typing import List +from typing import Any, Optional, Union from unittest import mock import pendulum import pydantic import pytest import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import ARRAY, array from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import declarative_base +from sqlalchemy.ext.asyncio.engine import AsyncEngine +from sqlalchemy.orm import Mapped, declarative_base, mapped_column from prefect.server.database.configurations import AioSqliteConfiguration from prefect.server.database.interface import PrefectDBInterface from prefect.server.database.orm_models import AioSqliteORMConfiguration from prefect.server.database.query_components import AioSqliteQueryComponents -from prefect.server.utilities.database import ( - JSON, - Pydantic, - Timestamp, - date_add, - date_diff, - interval_add, - json_contains, - json_extract, - json_has_all_keys, - json_has_any_key, -) - -DBBase = declarative_base() +from prefect.server.utilities.database import JSON, Pydantic, Timestamp + +DBBase = declarative_base(type_annotation_map={pendulum.DateTime: Timestamp}) class PydanticModel(pydantic.BaseModel): @@ -45,31 +36,35 @@ class Color(enum.Enum): class SQLPydanticModel(DBBase): __tablename__ = "_test_pydantic_model" - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - data = sa.Column(Pydantic(PydanticModel)) - data_list = sa.Column(Pydantic(List[PydanticModel])) - color = sa.Column(Pydantic(Color, sa_column_type=sa.Text())) + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + data: Mapped[Optional[PydanticModel]] = mapped_column(Pydantic(PydanticModel)) + data_list: Mapped[Optional[list[PydanticModel]]] = mapped_column( + Pydantic(list[PydanticModel]) + ) + color: Mapped[Optional[Color]] = mapped_column( + Pydantic(Color, sa_column_type=sa.Text()) + ) class SQLTimestampModel(DBBase): __tablename__ = "_test_timestamp_model" - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - ts_1 = sa.Column(Timestamp) - ts_2 = sa.Column(Timestamp) - i_1 = sa.Column(sa.Interval) - i_2 = sa.Column(sa.Interval) + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + ts_1: Mapped[Optional[pendulum.DateTime]] + ts_2: Mapped[Optional[pendulum.DateTime]] + i_1: Mapped[Optional[datetime.timedelta]] + i_2: Mapped[Optional[datetime.timedelta]] class SQLJSONModel(DBBase): __tablename__ = "_test_json_model" - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - data = sa.Column(JSON) + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + data: Mapped[Any] = mapped_column(JSON) @pytest.fixture(scope="module", autouse=True) -async def create_database_models(database_engine): +async def create_database_models(database_engine: AsyncEngine): """ Add the models defined in this file to the database """ @@ -84,21 +79,18 @@ async def create_database_models(database_engine): @pytest.fixture(scope="function", autouse=True) -async def clear_database_models(db): +async def clear_database_models(db: PrefectDBInterface): """ Clears the models defined in this file """ yield async with db.session_context(begin_transaction=True) as session: - # work pool has a circular dependency on pool queue; delete it first - await session.execute(db.WorkPool.__table__.delete()) - for table in reversed(DBBase.metadata.sorted_tables): await session.execute(table.delete()) class TestPydantic: - async def test_write_to_Pydantic(self, session): + async def test_write_to_Pydantic(self, session: AsyncSession): p_model = PydanticModel(x=100) s_model = SQLPydanticModel(data=p_model) session.add(s_model) @@ -107,13 +99,13 @@ async def test_write_to_Pydantic(self, session): # clear cache session.expire_all() - query = await session.execute(sa.select(SQLPydanticModel)) - results = query.scalars().all() + query = await session.scalars(sa.select(SQLPydanticModel)) + results = query.all() assert len(results) == 1 assert isinstance(results[0].data, PydanticModel) assert results[0].data.y < pendulum.now("UTC") - async def test_write_dict_to_Pydantic(self, session): + async def test_write_dict_to_Pydantic(self, session: AsyncSession): p_model = PydanticModel(x=100) s_model = SQLPydanticModel(data=p_model.model_dump()) session.add(s_model) @@ -122,12 +114,12 @@ async def test_write_dict_to_Pydantic(self, session): # clear cache session.expire_all() - query = await session.execute(sa.select(SQLPydanticModel)) - results = query.scalars().all() + query = await session.scalars(sa.select(SQLPydanticModel)) + results = query.all() assert len(results) == 1 assert isinstance(results[0].data, PydanticModel) - async def test_nullable_Pydantic(self, session): + async def test_nullable_Pydantic(self, session: AsyncSession): s_model = SQLPydanticModel(data=None) session.add(s_model) await session.flush() @@ -135,12 +127,12 @@ async def test_nullable_Pydantic(self, session): # clear cache session.expire_all() - query = await session.execute(sa.select(SQLPydanticModel)) - results = query.scalars().all() + query = await session.scalars(sa.select(SQLPydanticModel)) + results = query.all() assert len(results) == 1 assert results[0].data is None - async def test_generic_model(self, session): + async def test_generic_model(self, session: AsyncSession): p_model = PydanticModel(x=100) s_model = SQLPydanticModel(data_list=[p_model]) session.add(s_model) @@ -149,31 +141,32 @@ async def test_generic_model(self, session): # clear cache session.expire_all() - query = await session.execute(sa.select(SQLPydanticModel)) - results = query.scalars().all() + query = await session.scalars(sa.select(SQLPydanticModel)) + results = query.all() assert len(results) == 1 + assert results[0].data_list is not None assert isinstance(results[0].data_list[0], PydanticModel) assert results[0].data_list == [p_model] - async def test_generic_model_validates(self, session): + async def test_generic_model_validates(self, session: AsyncSession): p_model = PydanticModel(x=100) s_model = SQLPydanticModel(data_list=p_model) session.add(s_model) with pytest.raises(sa.exc.StatementError, match="(validation error)"): await session.flush() - async def test_write_to_enum_field(self, session): + async def test_write_to_enum_field(self, session: AsyncSession): s_model = SQLPydanticModel(color="RED") session.add(s_model) await session.flush() - async def test_write_to_enum_field_is_validated(self, session): + async def test_write_to_enum_field_is_validated(self, session: AsyncSession): s_model = SQLPydanticModel(color="GREEN") session.add(s_model) with pytest.raises(sa.exc.StatementError, match="(validation error)"): await session.flush() - async def test_enum_field_is_a_string_in_database(self, session): + async def test_enum_field_is_a_string_in_database(self, session: AsyncSession): s_model = SQLPydanticModel(color="RED") session.add(s_model) await session.flush() @@ -195,13 +188,13 @@ async def test_enum_field_is_a_string_in_database(self, session): class TestTimestamp: - async def test_error_if_naive_timestamp_passed(self, session): + async def test_error_if_naive_timestamp_passed(self, session: AsyncSession): model = SQLTimestampModel(ts_1=datetime.datetime(2000, 1, 1)) session.add(model) with pytest.raises(sa.exc.StatementError, match="(must have a timezone)"): await session.flush() - async def test_timestamp_converted_to_utc(self, session): + async def test_timestamp_converted_to_utc(self, session: AsyncSession): model = SQLTimestampModel( ts_1=datetime.datetime(2000, 1, 1, tzinfo=pendulum.timezone("EST")) ) @@ -211,15 +204,16 @@ async def test_timestamp_converted_to_utc(self, session): # clear cache session.expire_all() - query = await session.execute(sa.select(SQLTimestampModel)) - results = query.scalars().all() + query = await session.scalars(sa.select(SQLTimestampModel)) + results = query.all() assert results[0].ts_1 == model.ts_1 + assert results[0].ts_1 is not None assert results[0].ts_1.tzinfo == pendulum.timezone("UTC") class TestJSON: @pytest.fixture(autouse=True) - async def data(self, session): + async def data(self, session: AsyncSession): session.add_all( [ SQLJSONModel(id=1, data=["a"]), @@ -232,9 +226,11 @@ async def data(self, session): ) await session.commit() - async def get_ids(self, session, query): - result = await session.execute(query) - return [r.id for r in result.scalars().all()] + async def get_ids( + self, session: AsyncSession, query: sa.Select[tuple[SQLJSONModel]] + ) -> list[int]: + result = await session.scalars(query) + return [r.id for r in result] @pytest.mark.parametrize( "keys,ids", @@ -255,10 +251,19 @@ async def get_ids(self, session, query): (["a", "a", "a"], [1, 3, 4]), ], ) - async def test_json_contains_right_side_literal(self, session, keys, ids): + async def test_json_contains_right_side_literal( + self, + session: AsyncSession, + keys: list[str] + | list[dict[str, str]] + | list[int] + | list[list[int]] + | list[list[int] | int], + ids: list[int], + ): query = ( sa.select(SQLJSONModel) - .where(json_contains(SQLJSONModel.data, keys)) + .where(SQLJSONModel.data.contains(keys)) .order_by(SQLJSONModel.id) ) assert await self.get_ids(session, query) == ids @@ -276,10 +281,15 @@ async def test_json_contains_right_side_literal(self, session, keys, ids): (["a", "a", "a"], [1]), ], ) - async def test_json_contains_left_side_literal(self, session, keys, ids): + async def test_json_contains_left_side_literal( + self, + session: AsyncSession, + keys: list[str] | list[str | list[int] | int], + ids: list[int], + ): query = ( sa.select(SQLJSONModel) - .where(json_contains(keys, SQLJSONModel.data)) + .where(sa.bindparam("keys", keys, type_=JSON).contains(SQLJSONModel.data)) .order_by(SQLJSONModel.id) ) assert await self.get_ids(session, query) == ids @@ -293,10 +303,18 @@ async def test_json_contains_left_side_literal(self, session, keys, ids): (["a"], ["a", "b"], False), ], ) - async def test_json_contains_both_sides_literal(self, session, left, right, match): - query = sa.select(sa.literal("match")).where(json_contains(left, right)) - result = await session.execute(query) - assert (result.scalar() == "match") == match + async def test_json_contains_both_sides_literal( + self, + session: AsyncSession, + left: list[str], + right: list[str], + match: bool, + ): + query = sa.select(sa.literal("match")).where( + sa.bindparam("left", left, type_=JSON).contains(right) + ) + result = await session.scalar(query) + assert (result == "match") is match @pytest.mark.parametrize( "id_for_keys,ids_for_results", @@ -306,17 +324,19 @@ async def test_json_contains_both_sides_literal(self, session, left, right, matc ], ) async def test_json_contains_both_sides_columns( - self, session, id_for_keys, ids_for_results + self, + session: AsyncSession, + id_for_keys: list[list[int] | int], + ids_for_results: list[list[int] | int], ): query = ( sa.select(SQLJSONModel) .where( - json_contains( - SQLJSONModel.data, + SQLJSONModel.data.contains( # select the data corresponding to the `id_for_keys` id sa.select(SQLJSONModel.data) .where(SQLJSONModel.id == id_for_keys) - .scalar_subquery(), + .scalar_subquery() ) ) .order_by(SQLJSONModel.id) @@ -335,15 +355,20 @@ async def test_json_contains_both_sides_columns( ([], []), ], ) - async def test_json_has_any_key(self, session, keys, ids): + async def test_json_has_any_key( + self, + session: AsyncSession, + keys: list[str], + ids: list[int], + ): query = ( sa.select(SQLJSONModel) - .where(json_has_any_key(SQLJSONModel.data, keys)) + .where(SQLJSONModel.data.has_any(sa.cast(array(keys), ARRAY(sa.String)))) .order_by(SQLJSONModel.id) ) assert await self.get_ids(session, query) == ids - async def test_multiple_json_has_any(self, session): + async def test_multiple_json_has_any(self, session: AsyncSession): """ SQLAlchemy's default bindparam has a `.` in it, which SQLite rejects. We create a custom bindparam name with `unique=True` to avoid confusion; @@ -355,11 +380,11 @@ async def test_multiple_json_has_any(self, session): .where( sa.or_( sa.and_( - json_has_any_key(SQLJSONModel.data, ["a"]), - json_has_any_key(SQLJSONModel.data, ["b"]), + SQLJSONModel.data.has_any(array(["a"])), + SQLJSONModel.data.has_any(array(["b"])), ), - json_has_any_key(SQLJSONModel.data, ["c"]), - json_has_any_key(SQLJSONModel.data, ["d"]), + SQLJSONModel.data.has_any(array(["c"])), + SQLJSONModel.data.has_any(array(["d"])), ), ) .order_by(SQLJSONModel.id) @@ -376,80 +401,59 @@ async def test_multiple_json_has_any(self, session): ([], [1, 2, 3, 4, 5, 6]), ], ) - async def test_json_has_all_keys(self, session, keys, ids): + async def test_json_has_all_keys( + self, + session: AsyncSession, + keys: list[str], + ids: list[int], + ): query = ( sa.select(SQLJSONModel) - .where(json_has_all_keys(SQLJSONModel.data, keys)) + .where(SQLJSONModel.data.has_all(sa.cast(array(keys), ARRAY(sa.String())))) .order_by(SQLJSONModel.id) ) assert await self.get_ids(session, query) == ids - async def test_json_has_all_keys_requires_scalar_inputs(self): - with pytest.raises(ValueError, match="(values must be strings)"): - json_has_all_keys(SQLJSONModel.data, ["a", 3]) - - async def test_json_has_any_key_requires_scalar_inputs(self): - with pytest.raises(ValueError, match="(values must be strings)"): - json_has_any_key(SQLJSONModel.data, ["a", 3]) - async def test_json_functions_use_postgres_operators_with_postgres(self): dialect = sa.dialects.postgresql.dialect() extract_statement = SQLJSONModel.data["x"].compile(dialect=dialect) - alt_extract_statement = json_extract(SQLJSONModel.data, "x").compile( - dialect=dialect - ) - contains_stmt = json_contains(SQLJSONModel.data, ["x"]).compile(dialect=dialect) - any_stmt = json_has_any_key(SQLJSONModel.data, ["x"]).compile(dialect=dialect) - all_stmt = json_has_all_keys(SQLJSONModel.data, ["x"]).compile(dialect=dialect) - - assert "->" in str(extract_statement) - assert "JSON_EXTRACT" not in str(extract_statement) - assert "->" in str(alt_extract_statement) - assert "JSON_EXTRACT" not in str(alt_extract_statement) - assert "@>" in str(contains_stmt) - assert "json_each" not in str(contains_stmt) - assert "?|" in str(any_stmt) - assert "json_each" not in str(any_stmt) - assert "?&" in str(all_stmt) - assert "json_each" not in str(all_stmt) + alt_extract_statement = SQLJSONModel.data["x"].astext.compile(dialect=dialect) + contains_stmt = SQLJSONModel.data.contains(["x"]).compile(dialect=dialect) + any_stmt = SQLJSONModel.data.has_any(array(["x"])).compile(dialect=dialect) + all_stmt = SQLJSONModel.data.has_all(array(["x"])).compile(dialect=dialect) + + assert ".data -> " in str(extract_statement) + assert ".data ->> " in str(alt_extract_statement) + assert ".data @> " in str(contains_stmt) + assert ".data ?| " in str(any_stmt) + assert ".data ?& " in str(all_stmt) async def test_json_functions_dont_use_postgres_operators_with_sqlite(self): dialect = sa.dialects.sqlite.dialect() extract_statement = SQLJSONModel.data["x"].compile(dialect=dialect) - alt_extract_statement = json_extract(SQLJSONModel.data, "x").compile( - dialect=dialect - ) - contains_stmt = json_contains(SQLJSONModel.data, ["x"]).compile(dialect=dialect) - any_stmt = json_has_any_key(SQLJSONModel.data, ["x"]).compile(dialect=dialect) - all_stmt = json_has_all_keys(SQLJSONModel.data, ["x"]).compile(dialect=dialect) - - assert "->" not in str(extract_statement) - assert "JSON_EXTRACT" in str(extract_statement) - assert "->" not in str(alt_extract_statement) - assert "JSON_EXTRACT" in str(alt_extract_statement) - assert "@>" not in str(contains_stmt) - assert "json_each" in str(contains_stmt) - assert "?|" not in str(any_stmt) - assert "json_each" in str(any_stmt) - assert "?&" not in str(all_stmt) - assert "json_each" in str(all_stmt) - - async def test_sqlite_json_extract_wrap_quotes(self): + alt_extract_statement = SQLJSONModel.data["x"].astext.compile(dialect=dialect) + contains_stmt = SQLJSONModel.data.contains(["x"]).compile(dialect=dialect) + any_stmt = SQLJSONModel.data.has_any(array(["x"])).compile(dialect=dialect) + all_stmt = SQLJSONModel.data.has_all(array(["x"])).compile(dialect=dialect) + + assert ".data ->" not in str(extract_statement) + assert ".data ->>" not in str(alt_extract_statement) + assert ".data @>" not in str(contains_stmt) + assert ".data ?|" not in str(any_stmt) + assert ".data ?&" not in str(all_stmt) + + async def test_sqlite_json_extract_as_json_extract(self): dialect = sa.dialects.sqlite.dialect() - extract_statement = json_extract( - SQLJSONModel.data, "x.y.z", wrap_quotes=True - ).compile(dialect=dialect) - assert '$."x.y.z"' in str(extract_statement) + extract_statement = SQLJSONModel.data["x.y.z"].astext.compile(dialect=dialect) + assert "->>" not in str(extract_statement) + assert "JSON_EXTRACT" in str(extract_statement) - async def test_postgres_json_extract_no_wrap_quotes(self): + async def test_postgres_json_extract_as_native_operator(self): dialect = sa.dialects.postgresql.dialect() - extract_statement = json_extract( - SQLJSONModel.data, "x.y.z", wrap_quotes=True - ).compile(dialect=dialect) - assert "x.y.z" in str(extract_statement) - assert '"x.y.z"' not in str(extract_statement) + extract_statement = SQLJSONModel.data["x.y.z"].astext.compile(dialect=dialect) + assert "->>" in str(extract_statement) @pytest.mark.parametrize("extrema", [-math.inf, math.nan, +math.inf]) async def test_json_floating_point_extrema( @@ -460,18 +464,40 @@ async def test_json_floating_point_extrema( await session.flush() session.expire(example) - result = await session.execute( + result = await session.scalar( sa.select(SQLJSONModel).where(SQLJSONModel.id == 100) ) - from_db: SQLJSONModel = result.scalars().first() - assert from_db.data == [-1.0, None, 1.0] + assert result is not None + assert result.data == [-1.0, None, 1.0] + + +class TestCustomFunctions: + def test_sqlite_now_compilation(self) -> None: + dialect = sa.dialects.sqlite.dialect() + expression = sa.func.now() + compiled = str(expression.compile(dialect=dialect)) + assert compiled == "strftime('%Y-%m-%d %H:%M:%f000', 'now')" + + @pytest.mark.parametrize( + "dialect,expected_function", + ( + (sa.dialects.sqlite.dialect(), "max"), + (sa.dialects.postgresql.dialect(), "greatest"), + ), + ) + def test_greatest_compilation( + self, dialect: sa.Dialect, expected_function: str + ) -> None: + expression = sa.func.greatest(17, 42, 11) + compiled = str(expression.compile(dialect=dialect)) + assert compiled.partition("(")[0] == expected_function class TestDateFunctions: """Test combinations of Python literals and DB columns""" @pytest.fixture(autouse=True) - async def create_data(self, session): + async def create_data(self, session: AsyncSession): model = SQLTimestampModel( ts_1=pendulum.datetime(2021, 1, 1), ts_2=pendulum.datetime(2021, 1, 4, 0, 5), @@ -490,11 +516,16 @@ async def create_data(self, session): (SQLTimestampModel.ts_1, SQLTimestampModel.i_1), ], ) - async def test_date_add(self, session, ts_1, i_1): - result = await session.execute( - sa.select(date_add(ts_1, i_1)).select_from(SQLTimestampModel) + async def test_date_add( + self, + session: AsyncSession, + ts_1: Union[pendulum.DateTime, sa.Column[pendulum.DateTime]], + i_1: Union[datetime.timedelta, sa.Column[datetime.timedelta]], + ): + result = await session.scalar( + sa.select(sa.func.date_add(ts_1, i_1)).select_from(SQLTimestampModel) ) - assert result.scalar() == pendulum.datetime(2021, 1, 4, 0, 5) + assert result == pendulum.datetime(2021, 1, 4, 0, 5) @pytest.mark.parametrize( "ts_1, ts_2", @@ -505,11 +536,16 @@ async def test_date_add(self, session, ts_1, i_1): (SQLTimestampModel.ts_1, SQLTimestampModel.ts_2), ], ) - async def test_date_diff(self, session, ts_1, ts_2): - result = await session.execute( - sa.select(date_diff(ts_2, ts_1)).select_from(SQLTimestampModel) + async def test_date_diff( + self, + session: AsyncSession, + ts_1: Union[pendulum.DateTime, sa.Column[pendulum.DateTime]], + ts_2: Union[pendulum.DateTime, sa.Column[pendulum.DateTime]], + ): + result = await session.scalar( + sa.select(sa.func.date_diff(ts_2, ts_1)).select_from(SQLTimestampModel) ) - assert result.scalar() == datetime.timedelta(days=3, minutes=5) + assert result == datetime.timedelta(days=3, minutes=5) @pytest.mark.parametrize( "i_1, i_2", @@ -523,11 +559,60 @@ async def test_date_diff(self, session, ts_1, ts_2): (SQLTimestampModel.i_1, SQLTimestampModel.i_2), ], ) - async def test_interval_add(self, session, i_1, i_2): - result = await session.execute( - sa.select(interval_add(i_1, i_2)).select_from(SQLTimestampModel) + async def test_interval_add( + self, + session: AsyncSession, + i_1: Union[datetime.timedelta, sa.Column[datetime.timedelta]], + i_2: Union[datetime.timedelta, sa.Column[datetime.timedelta]], + ): + result = await session.scalar( + sa.select(sa.func.interval_add(i_1, i_2)).select_from(SQLTimestampModel) + ) + assert result == datetime.timedelta(days=3, minutes=48) + + @pytest.mark.parametrize( + "ts_1, ts_2", + [ + (pendulum.datetime(2021, 1, 1), pendulum.datetime(2021, 1, 4, 0, 5)), + (pendulum.datetime(2021, 1, 1), SQLTimestampModel.ts_2), + (SQLTimestampModel.ts_1, pendulum.datetime(2021, 1, 4, 0, 5)), + (SQLTimestampModel.ts_1, SQLTimestampModel.ts_2), + ], + ) + async def test_date_diff_seconds( + self, + session: AsyncSession, + ts_1: Union[pendulum.DateTime, sa.Column[pendulum.DateTime]], + ts_2: Union[pendulum.DateTime, sa.Column[pendulum.DateTime]], + ): + result = await session.scalar( + sa.select(sa.func.date_diff_seconds(ts_2, ts_1)).select_from( + SQLTimestampModel + ) + ) + assert pytest.approx(result) == 259500.0 + + async def test_date_diff_seconds_from_now_literal(self, session: AsyncSession): + value = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=17) + result: Optional[float] = await session.scalar( + sa.select(sa.func.date_diff_seconds(value)) + ) + assert result is not None + assert 16 <= result <= 18 + + async def test_date_diff_seconds_from_now_column(self, session: AsyncSession): + value = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=17) + model = SQLTimestampModel(ts_1=value) + session.add(model) + await session.commit() + + result: Optional[float] = await session.scalar( + sa.select(sa.func.date_diff_seconds(SQLTimestampModel.ts_1)).where( + SQLTimestampModel.id == model.id + ) ) - assert result.scalar() == datetime.timedelta(days=3, minutes=48) + assert result is not None + assert 16 <= result <= 18 async def test_error_thrown_if_sqlite_version_is_below_minimum():