Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[typing] prefect.server.utilities.database #16362

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/prefect/server/api/run_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,14 @@ async def run_history(
# estimated run times only includes positive run times (to avoid any unexpected corner cases)
"sum_estimated_run_time",
sa.func.sum(
db.greatest(0, sa.extract("epoch", runs.c.estimated_run_time))
sa.func.greatest(
0, sa.extract("epoch", runs.c.estimated_run_time)
)
),
# estimated lateness is the sum of any positive start time deltas
"sum_estimated_lateness",
sa.func.sum(
db.greatest(
sa.func.greatest(
0, sa.extract("epoch", runs.c.estimated_start_time_delta)
)
),
Expand Down
43 changes: 5 additions & 38 deletions src/prefect/server/api/ui/task_runs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import sys
from datetime import datetime, timezone
from datetime import datetime
from typing import List, Optional, cast

import pendulum
Expand Down Expand Up @@ -37,37 +36,6 @@ def ser_model(self) -> dict:
}


def _postgres_bucket_expression(
db: PrefectDBInterface, delta: pendulum.Duration, start_datetime: datetime
):
# asyncpg under Python 3.7 doesn't support timezone-aware datetimes for the EXTRACT
# function, so we will send it as a naive datetime in UTC
if sys.version_info < (3, 8):
start_datetime = start_datetime.astimezone(timezone.utc).replace(tzinfo=None)

return sa.func.floor(
(
sa.func.extract("epoch", db.TaskRun.start_time)
- sa.func.extract("epoch", start_datetime)
)
/ delta.total_seconds()
).label("bucket")


def _sqlite_bucket_expression(
db: PrefectDBInterface, delta: pendulum.Duration, start_datetime: datetime
):
return sa.func.floor(
(
(
sa.func.strftime("%s", db.TaskRun.start_time)
- sa.func.strftime("%s", start_datetime)
)
/ delta.total_seconds()
)
).label("bucket")


@router.post("/dashboard/counts")
async def read_dashboard_task_run_counts(
task_runs: schemas.filters.TaskRunFilter,
Expand Down Expand Up @@ -121,11 +89,10 @@ async def read_dashboard_task_run_counts(
start_time.microsecond,
start_time.timezone,
)
bucket_expression = (
_sqlite_bucket_expression(db, delta, start_datetime)
if db.dialect.name == "sqlite"
else _postgres_bucket_expression(db, delta, start_datetime)
)
bucket_expression = sa.func.floor(
sa.func.date_diff_seconds(db.TaskRun.start_time, start_datetime)
/ delta.total_seconds()
).label("bucket")

raw_counts = (
(
Expand Down
18 changes: 8 additions & 10 deletions src/prefect/server/database/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@
from typing import Dict, Hashable, Optional, Tuple

import sqlalchemy as sa

try:
from sqlalchemy import AdaptedConnection
from sqlalchemy.pool import ConnectionPoolEntry
except ImportError:
# SQLAlchemy 1.4 equivalents
from sqlalchemy.pool import _ConnectionFairy as AdaptedConnection
from sqlalchemy.pool.base import _ConnectionRecord as ConnectionPoolEntry

from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy import AdaptedConnection
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
AsyncSessionTransaction,
create_async_engine,
)
from sqlalchemy.pool import ConnectionPoolEntry
from typing_extensions import Literal

from prefect.settings import (
Expand Down
3 changes: 0 additions & 3 deletions src/prefect/server/database/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,6 @@ def insert(self, model):
"""INSERTs a model into the database"""
return self.queries.insert(model)

def greatest(self, *values):
return self.queries.greatest(*values)

def make_timestamp_intervals(
self,
start_time: datetime.datetime,
Expand Down
41 changes: 18 additions & 23 deletions src/prefect/server/database/orm_models.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import datetime
import uuid
from abc import ABC, abstractmethod
from collections.abc import Hashable, Iterable
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Hashable,
Iterable,
Optional,
Union,
cast,
)

import pendulum
Expand Down Expand Up @@ -46,15 +44,12 @@
WorkQueueStatus,
)
from prefect.server.utilities.database import (
CAMEL_TO_SNAKE,
JSON,
UUID,
GenerateUUID,
Pydantic,
Timestamp,
camel_to_snake,
date_diff,
interval_add,
now,
)
from prefect.server.utilities.encryption import decrypt_fernet, encrypt_fernet
from prefect.utilities.names import generate_slug
Expand Down Expand Up @@ -117,7 +112,7 @@ def __tablename__(cls) -> str:
into a snake-case table name. Override by providing
an explicit `__tablename__` class property.
"""
return camel_to_snake.sub("_", cls.__name__).lower()
return CAMEL_TO_SNAKE.sub("_", cls.__name__).lower()

id: Mapped[uuid.UUID] = mapped_column(
primary_key=True,
Expand All @@ -126,17 +121,17 @@ def __tablename__(cls) -> str:
)

created: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)

# onupdate is only called when statements are actually issued
# against the database. until COMMIT is issued, this column
# will not be updated
updated: Mapped[pendulum.DateTime] = mapped_column(
index=True,
server_default=now(),
server_default=sa.func.now(),
default=lambda: pendulum.now("UTC"),
onupdate=now(),
onupdate=sa.func.now(),
server_onupdate=FetchedValue(),
)

Expand Down Expand Up @@ -175,7 +170,7 @@ class FlowRunState(Base):
sa.Enum(schemas.states.StateType, name="state_type"), index=True
)
timestamp: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)
name: Mapped[str] = mapped_column(index=True)
message: Mapped[Optional[str]]
Expand Down Expand Up @@ -240,7 +235,7 @@ class TaskRunState(Base):
sa.Enum(schemas.states.StateType, name="state_type"), index=True
)
timestamp: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)
name: Mapped[str] = mapped_column(index=True)
message: Mapped[Optional[str]]
Expand Down Expand Up @@ -419,9 +414,9 @@ def _estimated_run_time_expression(cls) -> sa.Label[datetime.timedelta]:
sa.case(
(
cls.state_type == schemas.states.StateType.RUNNING,
interval_add(
sa.func.interval_add(
cls.total_run_time,
date_diff(now(), cls.state_timestamp),
sa.func.date_diff(sa.func.now(), cls.state_timestamp),
),
),
else_=cls.total_run_time,
Expand Down Expand Up @@ -464,15 +459,15 @@ def _estimated_start_time_delta_expression(
return sa.case(
(
cls.start_time > cls.expected_start_time,
date_diff(cls.start_time, cls.expected_start_time),
sa.func.date_diff(cls.start_time, cls.expected_start_time),
),
(
sa.and_(
cls.start_time.is_(None),
cls.state_type.not_in(schemas.states.TERMINAL_STATES),
cls.expected_start_time < now(),
cls.expected_start_time < sa.func.now(),
),
date_diff(now(), cls.expected_start_time),
sa.func.date_diff(sa.func.now(), cls.expected_start_time),
),
else_=datetime.timedelta(0),
)
Expand Down Expand Up @@ -1165,7 +1160,7 @@ class Worker(Base):

name: Mapped[str]
last_heartbeat_time: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)
heartbeat_interval_seconds: Mapped[Optional[int]]

Expand Down Expand Up @@ -1195,7 +1190,7 @@ class Agent(Base):
)

last_activity_time: Mapped[pendulum.DateTime] = mapped_column(
server_default=now(), default=lambda: pendulum.now("UTC")
server_default=sa.func.now(), default=lambda: pendulum.now("UTC")
)

__table_args__: Any = (sa.UniqueConstraint("name"),)
Expand Down Expand Up @@ -1277,11 +1272,11 @@ class Automation(Base):
@classmethod
def sort_expression(cls, value: AutomationSort) -> sa.ColumnExpressionArgument[Any]:
"""Return an expression used to sort Automations"""
sort_mapping = {
sort_mapping: dict[AutomationSort, sa.ColumnExpressionArgument[Any]] = {
AutomationSort.CREATED_DESC: cls.created.desc(),
AutomationSort.UPDATED_DESC: cls.updated.desc(),
AutomationSort.NAME_ASC: cast(sa.Column, cls.name).asc(),
AutomationSort.NAME_DESC: cast(sa.Column, cls.name).desc(),
AutomationSort.NAME_ASC: cls.name.asc(),
AutomationSort.NAME_DESC: cls.name.desc(),
}
return sort_mapping[value]

Expand Down
22 changes: 1 addition & 21 deletions src/prefect/server/database/query_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,6 @@ def _unique_key(self) -> Tuple[Hashable, ...]:
def insert(self, obj) -> Union[postgresql.Insert, sqlite.Insert]:
"""dialect-specific insert statement"""

@abstractmethod
def greatest(self, *values):
"""dialect-specific SqlAlchemy binding"""

@abstractmethod
def least(self, *values):
"""dialect-specific SqlAlchemy binding"""

# --- dialect-specific JSON handling

@abstractproperty
Expand Down Expand Up @@ -179,7 +171,7 @@ def get_scheduled_flow_runs_from_work_queues(
concurrency_queues = (
sa.select(
orm_models.WorkQueue.id,
self.greatest(
sa.func.greatest(
0,
orm_models.WorkQueue.concurrency_limit
- sa.func.count(orm_models.FlowRun.id),
Expand Down Expand Up @@ -628,12 +620,6 @@ class AsyncPostgresQueryComponents(BaseQueryComponents):
def insert(self, obj) -> postgresql.Insert:
return postgresql.insert(obj)

def greatest(self, *values):
return sa.func.greatest(*values)

def least(self, *values):
return sa.func.least(*values)

# --- Postgres-specific JSON handling

@property
Expand Down Expand Up @@ -984,12 +970,6 @@ class AioSqliteQueryComponents(BaseQueryComponents):
def insert(self, obj) -> sqlite.Insert:
return sqlite.insert(obj)

def greatest(self, *values):
return sa.func.max(*values)

def least(self, *values):
return sa.func.min(*values)

# --- Sqlite-specific JSON handling

@property
Expand Down
13 changes: 2 additions & 11 deletions src/prefect/server/events/counting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from prefect.server.database.dependencies import provide_database_interface
from prefect.server.database.interface import PrefectDBInterface
from prefect.server.utilities.database import json_extract
from prefect.types import DateTime
from prefect.utilities.collections import AutoEnum

Expand Down Expand Up @@ -290,16 +289,8 @@ def _database_label_expression(
return db.Event.event
elif self == self.resource:
return sa.func.coalesce(
json_extract(
db.Event.resource,
"prefect.resource.name",
wrap_quotes=db.dialect.name == "sqlite",
),
json_extract(
db.Event.resource,
"prefect.name",
wrap_quotes=db.dialect.name == "sqlite",
),
db.Event.resource["prefect.resource.name"].astext,
db.Event.resource["prefect.name"].astext,
db.Event.resource_id,
)
else:
Expand Down
13 changes: 3 additions & 10 deletions src/prefect/server/events/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
PrefectFilterBaseModel,
PrefectOperatorFilterBaseModel,
)
from prefect.server.utilities.database import json_extract
from prefect.types import DateTime
from prefect.utilities.collections import AutoEnum

Expand Down Expand Up @@ -309,9 +308,7 @@ def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]:
for _, (label, values) in enumerate(labels.items()):
label_ops = LabelOperations(values)

label_column = json_extract(
orm_models.EventResource.resource, label
)
label_column = orm_models.EventResource.resource[label].astext

# With negative labels, the resource _must_ have the label
if label_ops.negative.simple or label_ops.negative.prefixes:
Expand Down Expand Up @@ -404,9 +401,7 @@ def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]:
for _, (label, values) in enumerate(labels.items()):
label_ops = LabelOperations(values)

label_column = json_extract(
orm_models.EventResource.resource, label
)
label_column = orm_models.EventResource.resource[label].astext

if label_ops.negative.simple or label_ops.negative.prefixes:
label_filters.append(label_column.is_not(None))
Expand Down Expand Up @@ -518,9 +513,7 @@ def build_where_clauses(self) -> Sequence["ColumnExpressionArgument[bool]"]:
for _, (label, values) in enumerate(labels.items()):
label_ops = LabelOperations(values)

label_column = json_extract(
orm_models.EventResource.resource, label
)
label_column = orm_models.EventResource.resource[label].astext

if label_ops.negative.simple or label_ops.negative.prefixes:
label_filters.append(label_column.is_not(None))
Expand Down
Loading
Loading